arize-phoenix 3.16.1__py3-none-any.whl → 7.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.0.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.0.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -241
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +4 -112
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.1.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.1.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -619
  295. phoenix/core/traces.py +0 -96
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -0,0 +1,597 @@
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import AsyncIterator, Iterator
4
+ from datetime import datetime, timedelta, timezone
5
+ from typing import (
6
+ Any,
7
+ AsyncGenerator,
8
+ Coroutine,
9
+ Iterable,
10
+ Mapping,
11
+ Optional,
12
+ Sequence,
13
+ TypeVar,
14
+ cast,
15
+ )
16
+
17
+ import strawberry
18
+ from openinference.instrumentation import safe_json_dumps
19
+ from openinference.semconv.trace import SpanAttributes
20
+ from sqlalchemy import and_, func, insert, select
21
+ from sqlalchemy.orm import load_only
22
+ from strawberry.relay.types import GlobalID
23
+ from strawberry.types import Info
24
+ from typing_extensions import TypeAlias, assert_never
25
+
26
+ from phoenix.datetime_utils import local_now, normalize_datetime
27
+ from phoenix.db import models
28
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly
29
+ from phoenix.server.api.context import Context
30
+ from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
31
+ from phoenix.server.api.helpers.playground_clients import (
32
+ PlaygroundStreamingClient,
33
+ initialize_playground_clients,
34
+ )
35
+ from phoenix.server.api.helpers.playground_registry import (
36
+ PLAYGROUND_CLIENT_REGISTRY,
37
+ )
38
+ from phoenix.server.api.helpers.playground_spans import (
39
+ get_db_experiment_run,
40
+ get_db_span,
41
+ get_db_trace,
42
+ streaming_llm_span,
43
+ )
44
+ from phoenix.server.api.input_types.ChatCompletionInput import (
45
+ ChatCompletionInput,
46
+ ChatCompletionOverDatasetInput,
47
+ )
48
+ from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
49
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
50
+ ChatCompletionSubscriptionError,
51
+ ChatCompletionSubscriptionExperiment,
52
+ ChatCompletionSubscriptionPayload,
53
+ ChatCompletionSubscriptionResult,
54
+ )
55
+ from phoenix.server.api.types.Dataset import Dataset
56
+ from phoenix.server.api.types.DatasetExample import DatasetExample
57
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
58
+ from phoenix.server.api.types.Experiment import to_gql_experiment
59
+ from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
60
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
61
+ from phoenix.server.api.types.Span import to_gql_span
62
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
63
+ from phoenix.server.dml_event import SpanInsertEvent
64
+ from phoenix.server.types import DbSessionFactory
65
+ from phoenix.utilities.template_formatters import (
66
+ FStringTemplateFormatter,
67
+ MustacheTemplateFormatter,
68
+ NoOpFormatter,
69
+ TemplateFormatter,
70
+ TemplateFormatterError,
71
+ )
72
+
73
+ GenericType = TypeVar("GenericType")
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+ initialize_playground_clients()
78
+
79
+ ChatCompletionMessage: TypeAlias = tuple[
80
+ ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
81
+ ]
82
+ DatasetExampleID: TypeAlias = GlobalID
83
+ ChatCompletionResult: TypeAlias = tuple[
84
+ DatasetExampleID, Optional[models.Span], models.ExperimentRun
85
+ ]
86
+ ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
87
+ PLAYGROUND_PROJECT_NAME = "playground"
88
+
89
+
90
+ @strawberry.type
91
+ class Subscription:
92
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
93
+ async def chat_completion(
94
+ self, info: Info[Context, None], input: ChatCompletionInput
95
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
96
+ provider_key = input.model.provider_key
97
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
98
+ if llm_client_class is None:
99
+ raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
100
+ try:
101
+ llm_client = llm_client_class(
102
+ model=input.model,
103
+ api_key=input.api_key,
104
+ )
105
+ except CustomGraphQLError:
106
+ raise
107
+ except Exception as error:
108
+ raise BadRequest(
109
+ f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
110
+ f"{str(error)}"
111
+ )
112
+
113
+ messages = [
114
+ (
115
+ message.role,
116
+ message.content,
117
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
118
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
119
+ )
120
+ for message in input.messages
121
+ ]
122
+ attributes = None
123
+ if template_options := input.template:
124
+ messages = list(
125
+ _formatted_messages(
126
+ messages=messages,
127
+ template_language=template_options.language,
128
+ template_variables=template_options.variables,
129
+ )
130
+ )
131
+ attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
132
+ invocation_parameters = llm_client.construct_invocation_parameters(
133
+ input.invocation_parameters
134
+ )
135
+ async with streaming_llm_span(
136
+ input=input,
137
+ messages=messages,
138
+ invocation_parameters=invocation_parameters,
139
+ attributes=attributes,
140
+ ) as span:
141
+ async for chunk in llm_client.chat_completion_create(
142
+ messages=messages, tools=input.tools or [], **invocation_parameters
143
+ ):
144
+ span.add_response_chunk(chunk)
145
+ yield chunk
146
+ span.set_attributes(llm_client.attributes)
147
+ if span.status_message is not None:
148
+ yield ChatCompletionSubscriptionError(message=span.status_message)
149
+ async with info.context.db() as session:
150
+ if (
151
+ playground_project_id := await session.scalar(
152
+ select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
153
+ )
154
+ ) is None:
155
+ playground_project_id = await session.scalar(
156
+ insert(models.Project)
157
+ .returning(models.Project.id)
158
+ .values(
159
+ name=PLAYGROUND_PROJECT_NAME,
160
+ description="Traces from prompt playground",
161
+ )
162
+ )
163
+ db_trace = get_db_trace(span, playground_project_id)
164
+ db_span = get_db_span(span, db_trace)
165
+ session.add(db_span)
166
+ await session.flush()
167
+ info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
168
+ yield ChatCompletionSubscriptionResult(span=to_gql_span(db_span))
169
+
170
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
171
+ async def chat_completion_over_dataset(
172
+ self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
173
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
174
+ provider_key = input.model.provider_key
175
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
176
+ if llm_client_class is None:
177
+ raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
178
+ try:
179
+ llm_client = llm_client_class(
180
+ model=input.model,
181
+ api_key=input.api_key,
182
+ )
183
+ except CustomGraphQLError:
184
+ raise
185
+ except Exception as error:
186
+ raise BadRequest(
187
+ f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
188
+ f"{str(error)}"
189
+ )
190
+
191
+ dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
192
+ version_id = (
193
+ from_global_id_with_expected_type(
194
+ global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
195
+ )
196
+ if input.dataset_version_id
197
+ else None
198
+ )
199
+ async with info.context.db() as session:
200
+ if (
201
+ dataset := await session.scalar(
202
+ select(models.Dataset).where(models.Dataset.id == dataset_id)
203
+ )
204
+ ) is None:
205
+ raise NotFound(f"Could not find dataset with ID {dataset_id}")
206
+ if version_id is None:
207
+ if (
208
+ resolved_version_id := await session.scalar(
209
+ select(models.DatasetVersion.id)
210
+ .where(models.DatasetVersion.dataset_id == dataset_id)
211
+ .order_by(models.DatasetVersion.id.desc())
212
+ .limit(1)
213
+ )
214
+ ) is None:
215
+ raise NotFound(f"No versions found for dataset with ID {dataset_id}")
216
+ else:
217
+ if (
218
+ resolved_version_id := await session.scalar(
219
+ select(models.DatasetVersion.id).where(
220
+ and_(
221
+ models.DatasetVersion.dataset_id == dataset_id,
222
+ models.DatasetVersion.id == version_id,
223
+ )
224
+ )
225
+ )
226
+ ) is None:
227
+ raise NotFound(f"Could not find dataset version with ID {version_id}")
228
+ revision_ids = (
229
+ select(func.max(models.DatasetExampleRevision.id))
230
+ .join(models.DatasetExample)
231
+ .where(
232
+ and_(
233
+ models.DatasetExample.dataset_id == dataset_id,
234
+ models.DatasetExampleRevision.dataset_version_id <= resolved_version_id,
235
+ )
236
+ )
237
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
238
+ )
239
+ if not (
240
+ revisions := [
241
+ rev
242
+ async for rev in await session.stream_scalars(
243
+ select(models.DatasetExampleRevision)
244
+ .where(
245
+ and_(
246
+ models.DatasetExampleRevision.id.in_(revision_ids),
247
+ models.DatasetExampleRevision.revision_kind != "DELETE",
248
+ )
249
+ )
250
+ .order_by(models.DatasetExampleRevision.dataset_example_id.asc())
251
+ .options(
252
+ load_only(
253
+ models.DatasetExampleRevision.dataset_example_id,
254
+ models.DatasetExampleRevision.input,
255
+ )
256
+ )
257
+ )
258
+ ]
259
+ ):
260
+ raise NotFound("No examples found for the given dataset and version")
261
+ if (
262
+ playground_project_id := await session.scalar(
263
+ select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
264
+ )
265
+ ) is None:
266
+ playground_project_id = await session.scalar(
267
+ insert(models.Project)
268
+ .returning(models.Project.id)
269
+ .values(
270
+ name=PLAYGROUND_PROJECT_NAME,
271
+ description="Traces from prompt playground",
272
+ )
273
+ )
274
+ experiment = models.Experiment(
275
+ dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
276
+ dataset_version_id=resolved_version_id,
277
+ name=input.experiment_name or _default_playground_experiment_name(),
278
+ description=input.experiment_description
279
+ or _default_playground_experiment_description(dataset_name=dataset.name),
280
+ repetitions=1,
281
+ metadata_=input.experiment_metadata
282
+ or _default_playground_experiment_metadata(
283
+ dataset_name=dataset.name,
284
+ dataset_id=input.dataset_id,
285
+ version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
286
+ ),
287
+ project_name=PLAYGROUND_PROJECT_NAME,
288
+ )
289
+ session.add(experiment)
290
+ await session.flush()
291
+ yield ChatCompletionSubscriptionExperiment(
292
+ experiment=to_gql_experiment(experiment)
293
+ ) # eagerly yields experiment so it can be linked by consumers of the subscription
294
+
295
+ results: asyncio.Queue[ChatCompletionResult] = asyncio.Queue()
296
+ not_started: list[tuple[DatasetExampleID, ChatStream]] = [
297
+ (
298
+ GlobalID(DatasetExample.__name__, str(revision.dataset_example_id)),
299
+ _stream_chat_completion_over_dataset_example(
300
+ input=input,
301
+ llm_client=llm_client,
302
+ revision=revision,
303
+ results=results,
304
+ experiment_id=experiment.id,
305
+ project_id=playground_project_id,
306
+ ),
307
+ )
308
+ for revision in revisions
309
+ ]
310
+ in_progress: list[
311
+ tuple[
312
+ Optional[DatasetExampleID],
313
+ ChatStream,
314
+ asyncio.Task[ChatCompletionSubscriptionPayload],
315
+ ]
316
+ ] = []
317
+ max_in_progress = 3
318
+ write_batch_size = 10
319
+ write_interval = timedelta(seconds=10)
320
+ last_write_time = datetime.now()
321
+ while not_started or in_progress:
322
+ while not_started and len(in_progress) < max_in_progress:
323
+ ex_id, stream = not_started.pop()
324
+ task = _create_task_with_timeout(stream)
325
+ in_progress.append((ex_id, stream, task))
326
+ async_tasks_to_run = [task for _, _, task in in_progress]
327
+ completed_tasks, _ = await asyncio.wait(
328
+ async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED
329
+ )
330
+ for completed_task in completed_tasks:
331
+ idx = [task for _, _, task in in_progress].index(completed_task)
332
+ example_id, stream, _ = in_progress[idx]
333
+ try:
334
+ yield completed_task.result()
335
+ except StopAsyncIteration:
336
+ del in_progress[idx] # removes exhausted stream
337
+ except asyncio.TimeoutError:
338
+ del in_progress[idx] # removes timed-out stream
339
+ if example_id is not None:
340
+ yield ChatCompletionSubscriptionError(
341
+ message="Playground task timed out", dataset_example_id=example_id
342
+ )
343
+ except Exception as error:
344
+ del in_progress[idx] # removes failed stream
345
+ if example_id is not None:
346
+ yield ChatCompletionSubscriptionError(
347
+ message="An unexpected error occurred", dataset_example_id=example_id
348
+ )
349
+ logger.exception(error)
350
+ else:
351
+ task = _create_task_with_timeout(stream)
352
+ in_progress[idx] = (example_id, stream, task)
353
+
354
+ exceeded_write_batch_size = results.qsize() >= write_batch_size
355
+ exceeded_write_interval = datetime.now() - last_write_time > write_interval
356
+ write_already_in_progress = any(
357
+ _is_result_payloads_stream(stream) for _, stream, _ in in_progress
358
+ )
359
+ if (
360
+ not results.empty()
361
+ and (exceeded_write_batch_size or exceeded_write_interval)
362
+ and not write_already_in_progress
363
+ ):
364
+ result_payloads_stream = _chat_completion_result_payloads(
365
+ db=info.context.db, results=_drain_no_wait(results)
366
+ )
367
+ task = _create_task_with_timeout(result_payloads_stream)
368
+ in_progress.append((None, result_payloads_stream, task))
369
+ last_write_time = datetime.now()
370
+ if remaining_results := await _drain(results):
371
+ async for result_payload in _chat_completion_result_payloads(
372
+ db=info.context.db, results=remaining_results
373
+ ):
374
+ yield result_payload
375
+
376
+
377
+ async def _stream_chat_completion_over_dataset_example(
378
+ *,
379
+ input: ChatCompletionOverDatasetInput,
380
+ llm_client: PlaygroundStreamingClient,
381
+ revision: models.DatasetExampleRevision,
382
+ results: asyncio.Queue[ChatCompletionResult],
383
+ experiment_id: int,
384
+ project_id: int,
385
+ ) -> ChatStream:
386
+ example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
387
+ invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
388
+ messages = [
389
+ (
390
+ message.role,
391
+ message.content,
392
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
393
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
394
+ )
395
+ for message in input.messages
396
+ ]
397
+ try:
398
+ format_start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
399
+ messages = list(
400
+ _formatted_messages(
401
+ messages=messages,
402
+ template_language=input.template_language,
403
+ template_variables=revision.input,
404
+ )
405
+ )
406
+ except TemplateFormatterError as error:
407
+ format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
408
+ yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
409
+ await results.put(
410
+ (
411
+ example_id,
412
+ None,
413
+ models.ExperimentRun(
414
+ experiment_id=experiment_id,
415
+ dataset_example_id=revision.dataset_example_id,
416
+ trace_id=None,
417
+ output={},
418
+ repetition_number=1,
419
+ start_time=format_start_time,
420
+ end_time=format_end_time,
421
+ error=str(error),
422
+ trace=None,
423
+ ),
424
+ )
425
+ )
426
+ return
427
+ async with streaming_llm_span(
428
+ input=input,
429
+ messages=messages,
430
+ invocation_parameters=invocation_parameters,
431
+ attributes={PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(revision.input)},
432
+ ) as span:
433
+ async for chunk in llm_client.chat_completion_create(
434
+ messages=messages, tools=input.tools or [], **invocation_parameters
435
+ ):
436
+ span.add_response_chunk(chunk)
437
+ chunk.dataset_example_id = example_id
438
+ yield chunk
439
+ span.set_attributes(llm_client.attributes)
440
+ db_trace = get_db_trace(span, project_id)
441
+ db_span = get_db_span(span, db_trace)
442
+ db_run = get_db_experiment_run(
443
+ db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
444
+ )
445
+ await results.put((example_id, db_span, db_run))
446
+ if span.status_message is not None:
447
+ yield ChatCompletionSubscriptionError(
448
+ message=span.status_message, dataset_example_id=example_id
449
+ )
450
+
451
+
452
+ async def _chat_completion_result_payloads(
453
+ *,
454
+ db: DbSessionFactory,
455
+ results: Sequence[ChatCompletionResult],
456
+ ) -> ChatStream:
457
+ if not results:
458
+ return
459
+ async with db() as session:
460
+ for _, span, run in results:
461
+ if span:
462
+ session.add(span)
463
+ session.add(run)
464
+ await session.flush()
465
+ for example_id, span, run in results:
466
+ yield ChatCompletionSubscriptionResult(
467
+ span=to_gql_span(span) if span else None,
468
+ experiment_run=to_gql_experiment_run(run),
469
+ dataset_example_id=example_id,
470
+ )
471
+
472
+
473
+ def _is_result_payloads_stream(
474
+ stream: ChatStream,
475
+ ) -> bool:
476
+ """
477
+ Checks if the given generator was instantiated from
478
+ `_chat_completion_result_payloads`
479
+ """
480
+ return stream.ag_code == _chat_completion_result_payloads.__code__
481
+
482
+
483
+ def _create_task_with_timeout(
484
+ iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 90
485
+ ) -> asyncio.Task[GenericType]:
486
+ return asyncio.create_task(
487
+ _wait_for(
488
+ _as_coroutine(iterable),
489
+ timeout=timeout_in_seconds,
490
+ timeout_message="Playground task timed out",
491
+ )
492
+ )
493
+
494
+
495
+ async def _wait_for(
496
+ coro: Coroutine[None, None, GenericType],
497
+ timeout: float,
498
+ timeout_message: Optional[str] = None,
499
+ ) -> GenericType:
500
+ """
501
+ A function that imitates asyncio.wait_for, but allows the task to be
502
+ cancelled with a custom message.
503
+ """
504
+ task = asyncio.create_task(coro)
505
+ done, pending = await asyncio.wait([task], timeout=timeout)
506
+ assert len(done) + len(pending) == 1
507
+ if done:
508
+ task = done.pop()
509
+ return task.result()
510
+ task = pending.pop()
511
+ task.cancel(msg=timeout_message)
512
+ try:
513
+ return await task
514
+ except asyncio.CancelledError:
515
+ raise asyncio.TimeoutError()
516
+
517
+
518
+ async def _drain(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
519
+ values: list[GenericType] = []
520
+ while not queue.empty():
521
+ values.append(await queue.get())
522
+ return values
523
+
524
+
525
+ def _drain_no_wait(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
526
+ values: list[GenericType] = []
527
+ while True:
528
+ try:
529
+ values.append(queue.get_nowait())
530
+ except asyncio.QueueEmpty:
531
+ break
532
+ return values
533
+
534
+
535
+ async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
536
+ return await iterable.__anext__()
537
+
538
+
539
+ def _formatted_messages(
540
+ *,
541
+ messages: Iterable[ChatCompletionMessage],
542
+ template_language: TemplateLanguage,
543
+ template_variables: Mapping[str, Any],
544
+ ) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
545
+ """
546
+ Formats the messages using the given template options.
547
+ """
548
+ template_formatter = _template_formatter(template_language=template_language)
549
+ (
550
+ roles,
551
+ templates,
552
+ tool_call_id,
553
+ tool_calls,
554
+ ) = zip(*messages)
555
+ formatted_templates = map(
556
+ lambda template: template_formatter.format(template, **template_variables),
557
+ templates,
558
+ )
559
+ formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
560
+ return formatted_messages
561
+
562
+
563
+ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
564
+ """
565
+ Instantiates the appropriate template formatter for the template language.
566
+ """
567
+ if template_language is TemplateLanguage.MUSTACHE:
568
+ return MustacheTemplateFormatter()
569
+ if template_language is TemplateLanguage.F_STRING:
570
+ return FStringTemplateFormatter()
571
+ if template_language is TemplateLanguage.NONE:
572
+ return NoOpFormatter()
573
+ assert_never(template_language)
574
+
575
+
576
+ def _default_playground_experiment_name() -> str:
577
+ return "playground-experiment"
578
+
579
+
580
+ def _default_playground_experiment_description(dataset_name: str) -> str:
581
+ return f'Playground experiment for dataset "{dataset_name}"'
582
+
583
+
584
+ def _default_playground_experiment_metadata(
585
+ dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
586
+ ) -> dict[str, Any]:
587
+ return {
588
+ "dataset_name": dataset_name,
589
+ "dataset_id": str(dataset_id),
590
+ "dataset_version_id": str(version_id),
591
+ }
592
+
593
+
594
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
595
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
596
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
597
+ PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
@@ -0,0 +1,21 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.interface
7
+ class Annotation:
8
+ name: str = strawberry.field(
9
+ description="Name of the annotation, e.g. 'helpfulness' or 'relevance'."
10
+ )
11
+ score: Optional[float] = strawberry.field(
12
+ description="Value of the annotation in the form of a numeric score."
13
+ )
14
+ label: Optional[str] = strawberry.field(
15
+ description="Value of the annotation in the form of a string, e.g. "
16
+ "'helpful' or 'not helpful'. Note that the label is not necessarily binary."
17
+ )
18
+ explanation: Optional[str] = strawberry.field(
19
+ description="The annotator's explanation for the annotation result (i.e. "
20
+ "score or label, or both) given to the subject."
21
+ )
@@ -0,0 +1,55 @@
1
+ from typing import Optional, Union, cast
2
+
3
+ import pandas as pd
4
+ import strawberry
5
+ from strawberry import Private
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.api.types.LabelFraction import LabelFraction
9
+
10
+ AnnotationType = Union[models.SpanAnnotation, models.TraceAnnotation]
11
+
12
+
13
+ @strawberry.type
14
+ class AnnotationSummary:
15
+ df: Private[pd.DataFrame]
16
+
17
+ def __init__(self, dataframe: pd.DataFrame) -> None:
18
+ self.df = dataframe
19
+
20
+ @strawberry.field
21
+ def count(self) -> int:
22
+ return cast(int, self.df.record_count.sum())
23
+
24
+ @strawberry.field
25
+ def labels(self) -> list[str]:
26
+ return self.df.label.dropna().tolist()
27
+
28
+ @strawberry.field
29
+ def label_fractions(self) -> list[LabelFraction]:
30
+ if not (n := self.df.label_count.sum()):
31
+ return []
32
+ return [
33
+ LabelFraction(
34
+ label=cast(str, row.label),
35
+ fraction=row.label_count / n,
36
+ )
37
+ for row in self.df.loc[
38
+ self.df.label.notna(),
39
+ ["label", "label_count"],
40
+ ].itertuples()
41
+ ]
42
+
43
+ @strawberry.field
44
+ def mean_score(self) -> Optional[float]:
45
+ if not (n := self.df.score_count.sum()):
46
+ return None
47
+ return cast(float, self.df.score_sum.sum() / n)
48
+
49
+ @strawberry.field
50
+ def score_count(self) -> int:
51
+ return cast(int, self.df.score_count.sum())
52
+
53
+ @strawberry.field
54
+ def label_count(self) -> int:
55
+ return cast(int, self.df.label_count.sum())
@@ -0,0 +1,16 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.enum
7
+ class ExperimentRunAnnotatorKind(Enum):
8
+ LLM = "LLM"
9
+ HUMAN = "HUMAN"
10
+ CODE = "CODE"
11
+
12
+
13
+ @strawberry.enum
14
+ class AnnotatorKind(Enum):
15
+ LLM = "LLM"
16
+ HUMAN = "HUMAN"