arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,14 @@
1
+ import re
1
2
  from collections import defaultdict
2
3
  from datetime import datetime
3
- from typing import Iterable, Iterator, Optional, Union, cast
4
+ from typing import Any, Iterable, Iterator, Literal, Optional, Union
5
+ from typing import cast as type_cast
4
6
 
5
7
  import numpy as np
6
8
  import numpy.typing as npt
7
9
  import strawberry
8
- from sqlalchemy import and_, distinct, func, select, text
9
- from sqlalchemy.orm import joinedload
10
+ from sqlalchemy import ColumnElement, String, and_, case, cast, func, select, text
11
+ from sqlalchemy.orm import joinedload, load_only
10
12
  from starlette.authentication import UnauthenticatedUser
11
13
  from strawberry import ID, UNSET
12
14
  from strawberry.relay import Connection, GlobalID, Node
@@ -18,19 +20,17 @@ from phoenix.config import (
18
20
  get_env_database_allocated_storage_capacity_gibibytes,
19
21
  getenv,
20
22
  )
21
- from phoenix.db import enums, models
23
+ from phoenix.db import models
22
24
  from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
23
- from phoenix.db.helpers import SupportedSQLDialect, exclude_experiment_projects
24
- from phoenix.db.models import DatasetExample as OrmExample
25
- from phoenix.db.models import DatasetExampleRevision as OrmRevision
26
- from phoenix.db.models import DatasetVersion as OrmVersion
27
- from phoenix.db.models import Experiment as OrmExperiment
28
- from phoenix.db.models import ExperimentRun as OrmExperimentRun
29
- from phoenix.db.models import Trace as OrmTrace
25
+ from phoenix.db.helpers import (
26
+ SupportedSQLDialect,
27
+ exclude_experiment_projects,
28
+ )
29
+ from phoenix.db.models import LatencyMs
30
30
  from phoenix.pointcloud.clustering import Hdbscan
31
31
  from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
32
32
  from phoenix.server.api.context import Context
33
- from phoenix.server.api.exceptions import NotFound, Unauthorized
33
+ from phoenix.server.api.exceptions import BadRequest, NotFound, Unauthorized
34
34
  from phoenix.server.api.helpers import ensure_list
35
35
  from phoenix.server.api.helpers.experiment_run_filters import (
36
36
  ExperimentRunFilterConditionSyntaxError,
@@ -41,14 +41,18 @@ from phoenix.server.api.helpers.playground_clients import initialize_playground_
41
41
  from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
42
42
  from phoenix.server.api.input_types.ClusterInput import ClusterInput
43
43
  from phoenix.server.api.input_types.Coordinates import InputCoordinate2D, InputCoordinate3D
44
+ from phoenix.server.api.input_types.DatasetFilter import DatasetFilter
44
45
  from phoenix.server.api.input_types.DatasetSort import DatasetSort
45
46
  from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
46
47
  from phoenix.server.api.input_types.ProjectFilter import ProjectFilter
47
48
  from phoenix.server.api.input_types.ProjectSort import ProjectColumn, ProjectSort
49
+ from phoenix.server.api.input_types.PromptFilter import PromptFilter
48
50
  from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_annotation_config
49
51
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
50
- from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
52
+ from phoenix.server.api.types.Dataset import Dataset
51
53
  from phoenix.server.api.types.DatasetExample import DatasetExample
54
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
55
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
52
56
  from phoenix.server.api.types.Dimension import to_gql_dimension
53
57
  from phoenix.server.api.types.EmbeddingDimension import (
54
58
  DEFAULT_CLUSTER_SELECTION_EPSILON,
@@ -58,30 +62,48 @@ from phoenix.server.api.types.EmbeddingDimension import (
58
62
  )
59
63
  from phoenix.server.api.types.Event import create_event_id, unpack_event_id
60
64
  from phoenix.server.api.types.Experiment import Experiment
61
- from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
62
- from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
65
+ from phoenix.server.api.types.ExperimentComparison import (
66
+ ExperimentComparison,
67
+ )
68
+ from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
69
+ ExperimentRepeatedRunGroup,
70
+ parse_experiment_repeated_run_group_node_id,
71
+ )
72
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
63
73
  from phoenix.server.api.types.Functionality import Functionality
64
74
  from phoenix.server.api.types.GenerativeModel import GenerativeModel
65
75
  from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
76
+ from phoenix.server.api.types.InferenceModel import InferenceModel
66
77
  from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
67
- from phoenix.server.api.types.Model import Model
68
- from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
69
- from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
78
+ from phoenix.server.api.types.node import (
79
+ from_global_id,
80
+ from_global_id_with_expected_type,
81
+ is_global_id,
82
+ )
83
+ from phoenix.server.api.types.pagination import (
84
+ ConnectionArgs,
85
+ Cursor,
86
+ CursorString,
87
+ connection_from_cursors_and_nodes,
88
+ connection_from_list,
89
+ )
90
+ from phoenix.server.api.types.PlaygroundModel import PlaygroundModel
70
91
  from phoenix.server.api.types.Project import Project
71
- from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
92
+ from phoenix.server.api.types.ProjectSession import ProjectSession
72
93
  from phoenix.server.api.types.ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
73
- from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
74
- from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
94
+ from phoenix.server.api.types.Prompt import Prompt
95
+ from phoenix.server.api.types.PromptLabel import PromptLabel
75
96
  from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
76
- from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
97
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
98
+ from phoenix.server.api.types.ServerStatus import ServerStatus
77
99
  from phoenix.server.api.types.SortDir import SortDir
78
100
  from phoenix.server.api.types.Span import Span
79
- from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
101
+ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
80
102
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
81
103
  from phoenix.server.api.types.Trace import Trace
82
- from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
83
- from phoenix.server.api.types.User import User, to_gql_user
84
- from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
104
+ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
105
+ from phoenix.server.api.types.User import User
106
+ from phoenix.server.api.types.UserApiKey import UserApiKey
85
107
  from phoenix.server.api.types.UserRole import UserRole
86
108
  from phoenix.server.api.types.ValidationResult import ValidationResult
87
109
 
@@ -100,6 +122,55 @@ class DbTableStats:
100
122
  num_bytes: float
101
123
 
102
124
 
125
+ @strawberry.type
126
+ class ExperimentRunMetricComparison:
127
+ num_runs_improved: int = strawberry.field(
128
+ description=(
129
+ "The number of runs in which the base experiment improved "
130
+ "on the best run in any compare experiment."
131
+ )
132
+ )
133
+ num_runs_regressed: int = strawberry.field(
134
+ description=(
135
+ "The number of runs in which the base experiment regressed "
136
+ "on the best run in any compare experiment."
137
+ )
138
+ )
139
+ num_runs_equal: int = strawberry.field(
140
+ description=(
141
+ "The number of runs in which the base experiment is equal to the best run "
142
+ "in any compare experiment."
143
+ )
144
+ )
145
+ num_total_runs: strawberry.Private[int]
146
+
147
+ @strawberry.field(
148
+ description=(
149
+ "The number of runs in the base experiment that could not be compared, either because "
150
+ "the base experiment run was missing a value or because all compare experiment runs "
151
+ "were missing values."
152
+ )
153
+ ) # type: ignore[misc]
154
+ def num_runs_without_comparison(self) -> int:
155
+ return (
156
+ self.num_total_runs
157
+ - self.num_runs_improved
158
+ - self.num_runs_regressed
159
+ - self.num_runs_equal
160
+ )
161
+
162
+
163
+ @strawberry.type
164
+ class ExperimentRunMetricComparisons:
165
+ latency: ExperimentRunMetricComparison
166
+ total_token_count: ExperimentRunMetricComparison
167
+ prompt_token_count: ExperimentRunMetricComparison
168
+ completion_token_count: ExperimentRunMetricComparison
169
+ total_cost: ExperimentRunMetricComparison
170
+ prompt_cost: ExperimentRunMetricComparison
171
+ completion_cost: ExperimentRunMetricComparison
172
+
173
+
103
174
  @strawberry.type
104
175
  class Query:
105
176
  @strawberry.field
@@ -114,20 +185,50 @@ class Query:
114
185
  ]
115
186
 
116
187
  @strawberry.field
117
- async def models(self, input: Optional[ModelsInput] = None) -> list[GenerativeModel]:
188
+ async def generative_models(
189
+ self,
190
+ info: Info[Context, None],
191
+ first: Optional[int] = 50,
192
+ last: Optional[int] = UNSET,
193
+ after: Optional[CursorString] = UNSET,
194
+ before: Optional[CursorString] = UNSET,
195
+ ) -> Connection[GenerativeModel]:
196
+ args = ConnectionArgs(
197
+ first=first,
198
+ after=after if isinstance(after, CursorString) else None,
199
+ last=last,
200
+ before=before if isinstance(before, CursorString) else None,
201
+ )
202
+ async with info.context.db() as session:
203
+ result = await session.scalars(
204
+ select(models.GenerativeModel)
205
+ .where(models.GenerativeModel.deleted_at.is_(None))
206
+ .order_by(
207
+ models.GenerativeModel.is_built_in.asc(), # display custom models first
208
+ models.GenerativeModel.provider.nullslast(),
209
+ models.GenerativeModel.name,
210
+ )
211
+ )
212
+ data = [GenerativeModel(id=model.id, db_record=model) for model in result.unique()]
213
+ return connection_from_list(data=data, args=args)
214
+
215
+ @strawberry.field
216
+ async def playground_models(self, input: Optional[ModelsInput] = None) -> list[PlaygroundModel]:
118
217
  if input is not None and input.provider_key is not None:
119
218
  supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
120
219
  supported_models = [
121
- GenerativeModel(name=model_name, provider_key=input.provider_key)
220
+ PlaygroundModel(name_value=model_name, provider_key_value=input.provider_key)
122
221
  for model_name in supported_model_names
123
222
  ]
124
223
  return supported_models
125
224
 
126
225
  registered_models = PLAYGROUND_CLIENT_REGISTRY.list_all_models()
127
- all_models: list[GenerativeModel] = []
226
+ all_models: list[PlaygroundModel] = []
128
227
  for provider_key, model_name in registered_models:
129
228
  if model_name is not None and provider_key is not None:
130
- all_models.append(GenerativeModel(name=model_name, provider_key=provider_key))
229
+ all_models.append(
230
+ PlaygroundModel(name_value=model_name, provider_key_value=provider_key)
231
+ )
131
232
  return all_models
132
233
 
133
234
  @strawberry.field
@@ -165,13 +266,13 @@ class Query:
165
266
  stmt = (
166
267
  select(models.User)
167
268
  .join(models.UserRole)
168
- .where(models.UserRole.name != enums.UserRole.SYSTEM.value)
269
+ .where(models.UserRole.name != "SYSTEM")
169
270
  .order_by(models.User.email)
170
271
  .options(joinedload(models.User.role))
171
272
  )
172
273
  async with info.context.db() as session:
173
274
  users = await session.stream_scalars(stmt)
174
- data = [to_gql_user(user) async for user in users]
275
+ data = [User(id=user.id, db_record=user) async for user in users]
175
276
  return connection_from_list(data=data, args=args)
176
277
 
177
278
  @strawberry.field
@@ -181,7 +282,7 @@ class Query:
181
282
  ) -> list[UserRole]:
182
283
  async with info.context.db() as session:
183
284
  roles = await session.scalars(
184
- select(models.UserRole).where(models.UserRole.name != enums.UserRole.SYSTEM.value)
285
+ select(models.UserRole).where(models.UserRole.name != "SYSTEM")
185
286
  )
186
287
  return [
187
288
  UserRole(
@@ -197,11 +298,11 @@ class Query:
197
298
  select(models.ApiKey)
198
299
  .join(models.User)
199
300
  .join(models.UserRole)
200
- .where(models.UserRole.name != enums.UserRole.SYSTEM.value)
301
+ .where(models.UserRole.name != "SYSTEM")
201
302
  )
202
303
  async with info.context.db() as session:
203
304
  api_keys = await session.scalars(stmt)
204
- return [to_gql_api_key(api_key) for api_key in api_keys]
305
+ return [UserApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
205
306
 
206
307
  @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
207
308
  async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
@@ -209,20 +310,11 @@ class Query:
209
310
  select(models.ApiKey)
210
311
  .join(models.User)
211
312
  .join(models.UserRole)
212
- .where(models.UserRole.name == enums.UserRole.SYSTEM.value)
313
+ .where(models.UserRole.name == "SYSTEM")
213
314
  )
214
315
  async with info.context.db() as session:
215
316
  api_keys = await session.scalars(stmt)
216
- return [
217
- SystemApiKey(
218
- id_attr=api_key.id,
219
- name=api_key.name,
220
- description=api_key.description,
221
- created_at=api_key.created_at,
222
- expires_at=api_key.expires_at,
223
- )
224
- for api_key in api_keys
225
- ]
317
+ return [SystemApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
226
318
 
227
319
  @strawberry.field
228
320
  async def projects(
@@ -263,13 +355,7 @@ class Query:
263
355
  stmt = exclude_experiment_projects(stmt)
264
356
  async with info.context.db() as session:
265
357
  projects = await session.stream_scalars(stmt)
266
- data = [
267
- Project(
268
- project_rowid=project.id,
269
- db_project=project,
270
- )
271
- async for project in projects
272
- ]
358
+ data = [Project(id=project.id, db_record=project) async for project in projects]
273
359
  return connection_from_list(data=data, args=args)
274
360
 
275
361
  @strawberry.field
@@ -285,6 +371,7 @@ class Query:
285
371
  after: Optional[CursorString] = UNSET,
286
372
  before: Optional[CursorString] = UNSET,
287
373
  sort: Optional[DatasetSort] = UNSET,
374
+ filter: Optional[DatasetFilter] = UNSET,
288
375
  ) -> Connection[Dataset]:
289
376
  args = ConnectionArgs(
290
377
  first=first,
@@ -296,10 +383,40 @@ class Query:
296
383
  if sort:
297
384
  sort_col = getattr(models.Dataset, sort.col.value)
298
385
  stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
386
+ if filter:
387
+ # Apply name filter
388
+ if filter.col and filter.value:
389
+ stmt = stmt.where(
390
+ getattr(models.Dataset, filter.col.value).ilike(f"%{filter.value}%")
391
+ )
392
+
393
+ # Apply label filter
394
+ if filter.filter_labels and filter.filter_labels is not UNSET:
395
+ label_rowids = []
396
+ for label_id in filter.filter_labels:
397
+ try:
398
+ label_rowid = from_global_id_with_expected_type(
399
+ global_id=GlobalID.from_id(label_id),
400
+ expected_type_name="DatasetLabel",
401
+ )
402
+ label_rowids.append(label_rowid)
403
+ except ValueError:
404
+ continue # Skip invalid label IDs
405
+
406
+ if label_rowids:
407
+ # Join with the junction table to filter by labels
408
+ stmt = (
409
+ stmt.join(
410
+ models.DatasetsDatasetLabel,
411
+ models.Dataset.id == models.DatasetsDatasetLabel.dataset_id,
412
+ )
413
+ .where(models.DatasetsDatasetLabel.dataset_label_id.in_(label_rowids))
414
+ .distinct()
415
+ )
299
416
  async with info.context.db() as session:
300
417
  datasets = await session.scalars(stmt)
301
418
  return connection_from_list(
302
- data=[to_gql_dataset(dataset) for dataset in datasets], args=args
419
+ data=[Dataset(id=dataset.id, db_record=dataset) for dataset in datasets], args=args
303
420
  )
304
421
 
305
422
  @strawberry.field
@@ -310,122 +427,429 @@ class Query:
310
427
  async def compare_experiments(
311
428
  self,
312
429
  info: Info[Context, None],
313
- experiment_ids: list[GlobalID],
430
+ base_experiment_id: GlobalID,
431
+ compare_experiment_ids: list[GlobalID],
432
+ first: Optional[int] = 50,
433
+ after: Optional[CursorString] = UNSET,
314
434
  filter_condition: Optional[str] = UNSET,
315
- ) -> list[ExperimentComparison]:
316
- experiment_ids_ = [
317
- from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
318
- for experiment_id in experiment_ids
319
- ]
320
- if len(set(experiment_ids_)) != len(experiment_ids_):
321
- raise ValueError("Experiment IDs must be unique.")
435
+ ) -> Connection[ExperimentComparison]:
436
+ if base_experiment_id in compare_experiment_ids:
437
+ raise BadRequest("Compare experiment IDs cannot contain the base experiment ID")
438
+ if len(set(compare_experiment_ids)) < len(compare_experiment_ids):
439
+ raise BadRequest("Compare experiment IDs must be unique")
440
+
441
+ try:
442
+ base_experiment_rowid = from_global_id_with_expected_type(
443
+ base_experiment_id, models.Experiment.__name__
444
+ )
445
+ except ValueError:
446
+ raise BadRequest(f"Invalid base experiment ID: {base_experiment_id}")
447
+
448
+ compare_experiment_rowids = []
449
+ for compare_experiment_id in compare_experiment_ids:
450
+ try:
451
+ compare_experiment_rowids.append(
452
+ from_global_id_with_expected_type(
453
+ compare_experiment_id, models.Experiment.__name__
454
+ )
455
+ )
456
+ except ValueError:
457
+ raise BadRequest(f"Invalid compare experiment ID: {compare_experiment_id}")
458
+
459
+ experiment_rowids = [base_experiment_rowid, *compare_experiment_rowids]
460
+
461
+ cursor = Cursor.from_string(after) if after else None
462
+ page_size = first or 50
322
463
 
323
464
  async with info.context.db() as session:
324
- validation_result = (
325
- await session.execute(
465
+ experiments = (
466
+ await session.scalars(
326
467
  select(
327
- func.count(distinct(OrmVersion.dataset_id)),
328
- func.max(OrmVersion.dataset_id),
329
- func.max(OrmVersion.id),
330
- func.count(OrmExperiment.id),
331
- )
332
- .select_from(OrmVersion)
333
- .join(
334
- OrmExperiment,
335
- OrmExperiment.dataset_version_id == OrmVersion.id,
336
- )
337
- .where(
338
- OrmExperiment.id.in_(experiment_ids_),
468
+ models.Experiment,
339
469
  )
340
- )
341
- ).first()
342
- if validation_result is None:
343
- raise ValueError("No experiments could be found for input IDs.")
344
-
345
- num_datasets, dataset_id, version_id, num_resolved_experiment_ids = validation_result
346
- if num_datasets != 1:
347
- raise ValueError("Experiments must belong to the same dataset.")
348
- if num_resolved_experiment_ids != len(experiment_ids_):
349
- raise ValueError("Unable to resolve one or more experiment IDs.")
350
-
351
- revision_ids = (
352
- select(func.max(OrmRevision.id))
353
- .join(OrmExample, OrmExample.id == OrmRevision.dataset_example_id)
354
- .where(
355
- and_(
356
- OrmRevision.dataset_version_id <= version_id,
357
- OrmExample.dataset_id == dataset_id,
470
+ .where(models.Experiment.id.in_(experiment_rowids))
471
+ .options(
472
+ load_only(
473
+ models.Experiment.dataset_id, models.Experiment.dataset_version_id
474
+ )
358
475
  )
359
476
  )
360
- .group_by(OrmRevision.dataset_example_id)
361
- .scalar_subquery()
477
+ ).all()
478
+
479
+ if not experiments or len(experiments) < len(experiment_rowids):
480
+ raise NotFound("Unable to resolve one or more experiment IDs.")
481
+ num_datasets = len(set(experiment.dataset_id for experiment in experiments))
482
+ if num_datasets > 1:
483
+ raise BadRequest("Experiments must belong to the same dataset.")
484
+ base_experiment = next(
485
+ experiment for experiment in experiments if experiment.id == base_experiment_rowid
362
486
  )
487
+
488
+ # Use ExperimentDatasetExample to pull down examples.
489
+ # Splits are mutable and should not be used for comparison.
490
+ # The comparison should only occur against examples which were assigned to the same
491
+ # splits at the time of execution of the ExperimentRun.
363
492
  examples_query = (
364
- select(OrmExample)
365
- .distinct(OrmExample.id)
366
- .join(
367
- OrmRevision,
368
- onclause=and_(
369
- OrmExample.id == OrmRevision.dataset_example_id,
370
- OrmRevision.id.in_(revision_ids),
371
- OrmRevision.revision_kind != "DELETE",
372
- ),
373
- )
374
- .order_by(OrmExample.id.desc())
493
+ select(models.DatasetExample)
494
+ .join(models.ExperimentDatasetExample)
495
+ .where(models.ExperimentDatasetExample.experiment_id == base_experiment_rowid)
496
+ .order_by(models.DatasetExample.id.desc())
497
+ .limit(page_size + 1)
375
498
  )
376
499
 
500
+ if cursor is not None:
501
+ examples_query = examples_query.where(models.DatasetExample.id < cursor.rowid)
502
+
377
503
  if filter_condition:
378
504
  examples_query = update_examples_query_with_filter_condition(
379
505
  query=examples_query,
380
506
  filter_condition=filter_condition,
381
- experiment_ids=experiment_ids_,
507
+ experiment_ids=experiment_rowids,
382
508
  )
383
509
 
384
510
  examples = (await session.scalars(examples_query)).all()
511
+ has_next_page = len(examples) > page_size
512
+ examples = examples[:page_size]
385
513
 
386
514
  ExampleID: TypeAlias = int
387
515
  ExperimentID: TypeAlias = int
388
- runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[OrmExperimentRun]]] = (
516
+ runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[models.ExperimentRun]]] = (
389
517
  defaultdict(lambda: defaultdict(list))
390
518
  )
391
519
  async for run in await session.stream_scalars(
392
- select(OrmExperimentRun)
520
+ select(models.ExperimentRun)
393
521
  .where(
394
522
  and_(
395
- OrmExperimentRun.dataset_example_id.in_(example.id for example in examples),
396
- OrmExperimentRun.experiment_id.in_(experiment_ids_),
523
+ models.ExperimentRun.dataset_example_id.in_(
524
+ example.id for example in examples
525
+ ),
526
+ models.ExperimentRun.experiment_id.in_(experiment_rowids),
397
527
  )
398
528
  )
399
- .options(joinedload(OrmExperimentRun.trace).load_only(OrmTrace.trace_id))
529
+ .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
530
+ .order_by(
531
+ models.ExperimentRun.repetition_number.asc()
532
+ ) # repetitions are not currently implemented, but this ensures that the repetitions will be properly ordered once implemented # noqa: E501
400
533
  ):
401
534
  runs[run.dataset_example_id][run.experiment_id].append(run)
402
535
 
403
- experiment_comparisons = []
536
+ cursors_and_nodes = []
404
537
  for example in examples:
405
- run_comparison_items = []
406
- for experiment_id in experiment_ids_:
407
- run_comparison_items.append(
408
- RunComparisonItem(
409
- experiment_id=GlobalID(Experiment.__name__, str(experiment_id)),
410
- runs=[
411
- to_gql_experiment_run(run)
538
+ repeated_run_groups = []
539
+ for experiment_id in experiment_rowids:
540
+ repeated_run_groups.append(
541
+ ExperimentRepeatedRunGroup(
542
+ experiment_rowid=experiment_id,
543
+ dataset_example_rowid=example.id,
544
+ cached_runs=[
545
+ ExperimentRun(id=run.id, db_record=run)
412
546
  for run in sorted(
413
- runs[example.id][experiment_id], key=lambda run: run.id
547
+ runs[example.id][experiment_id],
548
+ key=lambda run: run.repetition_number,
414
549
  )
415
550
  ],
416
551
  )
417
552
  )
418
- experiment_comparisons.append(
419
- ExperimentComparison(
420
- example=DatasetExample(
421
- id_attr=example.id,
422
- created_at=example.created_at,
423
- version_id=version_id,
424
- ),
425
- run_comparison_items=run_comparison_items,
553
+ experiment_comparison = ExperimentComparison(
554
+ id_attr=example.id,
555
+ example=DatasetExample(
556
+ id=example.id,
557
+ db_record=example,
558
+ version_id=base_experiment.dataset_version_id,
559
+ ),
560
+ repeated_run_groups=repeated_run_groups,
561
+ )
562
+ cursors_and_nodes.append((Cursor(rowid=example.id), experiment_comparison))
563
+
564
+ return connection_from_cursors_and_nodes(
565
+ cursors_and_nodes=cursors_and_nodes,
566
+ has_previous_page=False, # set to false since we are only doing forward pagination (https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo.Fields) # noqa: E501
567
+ has_next_page=has_next_page,
568
+ )
569
+
570
+ @strawberry.field
571
+ async def experiment_run_metric_comparisons(
572
+ self,
573
+ info: Info[Context, None],
574
+ base_experiment_id: GlobalID,
575
+ compare_experiment_ids: list[GlobalID],
576
+ ) -> ExperimentRunMetricComparisons:
577
+ if base_experiment_id in compare_experiment_ids:
578
+ raise BadRequest("Compare experiment IDs cannot contain the base experiment ID")
579
+ if not compare_experiment_ids:
580
+ raise BadRequest("At least one compare experiment ID must be provided")
581
+ if len(set(compare_experiment_ids)) < len(compare_experiment_ids):
582
+ raise BadRequest("Compare experiment IDs must be unique")
583
+
584
+ try:
585
+ base_experiment_rowid = from_global_id_with_expected_type(
586
+ base_experiment_id, models.Experiment.__name__
587
+ )
588
+ except ValueError:
589
+ raise BadRequest(f"Invalid base experiment ID: {base_experiment_id}")
590
+
591
+ compare_experiment_rowids = []
592
+ for compare_experiment_id in compare_experiment_ids:
593
+ try:
594
+ compare_experiment_rowids.append(
595
+ from_global_id_with_expected_type(
596
+ compare_experiment_id, models.Experiment.__name__
597
+ )
426
598
  )
599
+ except ValueError:
600
+ raise BadRequest(f"Invalid compare experiment ID: {compare_experiment_id}")
601
+
602
+ base_experiment_runs = (
603
+ select(
604
+ models.ExperimentRun.dataset_example_id,
605
+ func.min(models.ExperimentRun.start_time).label("start_time"),
606
+ func.min(models.ExperimentRun.end_time).label("end_time"),
607
+ func.sum(models.SpanCost.total_tokens).label("total_tokens"),
608
+ func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
609
+ func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
610
+ func.sum(models.SpanCost.total_cost).label("total_cost"),
611
+ func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
612
+ func.sum(models.SpanCost.completion_cost).label("completion_cost"),
427
613
  )
428
- return experiment_comparisons
614
+ .select_from(models.ExperimentRun)
615
+ .join(
616
+ models.Trace,
617
+ onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
618
+ isouter=True,
619
+ )
620
+ .join(
621
+ models.SpanCost,
622
+ onclause=models.Trace.id == models.SpanCost.trace_rowid,
623
+ isouter=True,
624
+ )
625
+ .where(models.ExperimentRun.experiment_id == base_experiment_rowid)
626
+ .group_by(models.ExperimentRun.dataset_example_id)
627
+ .subquery()
628
+ .alias("base_experiment_runs")
629
+ )
630
+ compare_experiment_runs = (
631
+ select(
632
+ models.ExperimentRun.dataset_example_id,
633
+ func.min(
634
+ LatencyMs(models.ExperimentRun.start_time, models.ExperimentRun.end_time)
635
+ ).label("min_latency_ms"),
636
+ func.min(models.SpanCost.total_tokens).label("min_total_tokens"),
637
+ func.min(models.SpanCost.prompt_tokens).label("min_prompt_tokens"),
638
+ func.min(models.SpanCost.completion_tokens).label("min_completion_tokens"),
639
+ func.min(models.SpanCost.total_cost).label("min_total_cost"),
640
+ func.min(models.SpanCost.prompt_cost).label("min_prompt_cost"),
641
+ func.min(models.SpanCost.completion_cost).label("min_completion_cost"),
642
+ )
643
+ .select_from(models.ExperimentRun)
644
+ .join(
645
+ models.Trace,
646
+ onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
647
+ isouter=True,
648
+ )
649
+ .join(
650
+ models.SpanCost,
651
+ onclause=models.Trace.id == models.SpanCost.trace_rowid,
652
+ isouter=True,
653
+ )
654
+ .where(
655
+ models.ExperimentRun.experiment_id.in_(compare_experiment_rowids),
656
+ )
657
+ .group_by(models.ExperimentRun.dataset_example_id)
658
+ .subquery()
659
+ .alias("comp_exp_run_mins")
660
+ )
661
+
662
+ base_experiment_run_latency = LatencyMs(
663
+ base_experiment_runs.c.start_time, base_experiment_runs.c.end_time
664
+ ).label("base_experiment_run_latency_ms")
665
+
666
+ comparisons_query = (
667
+ select(
668
+ func.count().label("num_base_experiment_runs"),
669
+ _comparison_count_expression(
670
+ base_column=base_experiment_run_latency,
671
+ compare_column=compare_experiment_runs.c.min_latency_ms,
672
+ optimization_direction="minimize",
673
+ comparison_type="improvement",
674
+ ).label("num_latency_improved"),
675
+ _comparison_count_expression(
676
+ base_column=base_experiment_run_latency,
677
+ compare_column=compare_experiment_runs.c.min_latency_ms,
678
+ optimization_direction="minimize",
679
+ comparison_type="regression",
680
+ ).label("num_latency_regressed"),
681
+ _comparison_count_expression(
682
+ base_column=base_experiment_run_latency,
683
+ compare_column=compare_experiment_runs.c.min_latency_ms,
684
+ optimization_direction="minimize",
685
+ comparison_type="equality",
686
+ ).label("num_latency_is_equal"),
687
+ _comparison_count_expression(
688
+ base_column=base_experiment_runs.c.total_tokens,
689
+ compare_column=compare_experiment_runs.c.min_total_tokens,
690
+ optimization_direction="minimize",
691
+ comparison_type="improvement",
692
+ ).label("num_total_token_count_improved"),
693
+ _comparison_count_expression(
694
+ base_column=base_experiment_runs.c.total_tokens,
695
+ compare_column=compare_experiment_runs.c.min_total_tokens,
696
+ optimization_direction="minimize",
697
+ comparison_type="regression",
698
+ ).label("num_total_token_count_regressed"),
699
+ _comparison_count_expression(
700
+ base_column=base_experiment_runs.c.total_tokens,
701
+ compare_column=compare_experiment_runs.c.min_total_tokens,
702
+ optimization_direction="minimize",
703
+ comparison_type="equality",
704
+ ).label("num_total_token_count_is_equal"),
705
+ _comparison_count_expression(
706
+ base_column=base_experiment_runs.c.prompt_tokens,
707
+ compare_column=compare_experiment_runs.c.min_prompt_tokens,
708
+ optimization_direction="minimize",
709
+ comparison_type="improvement",
710
+ ).label("num_prompt_token_count_improved"),
711
+ _comparison_count_expression(
712
+ base_column=base_experiment_runs.c.prompt_tokens,
713
+ compare_column=compare_experiment_runs.c.min_prompt_tokens,
714
+ optimization_direction="minimize",
715
+ comparison_type="regression",
716
+ ).label("num_prompt_token_count_regressed"),
717
+ _comparison_count_expression(
718
+ base_column=base_experiment_runs.c.prompt_tokens,
719
+ compare_column=compare_experiment_runs.c.min_prompt_tokens,
720
+ optimization_direction="minimize",
721
+ comparison_type="equality",
722
+ ).label("num_prompt_token_count_is_equal"),
723
+ _comparison_count_expression(
724
+ base_column=base_experiment_runs.c.completion_tokens,
725
+ compare_column=compare_experiment_runs.c.min_completion_tokens,
726
+ optimization_direction="minimize",
727
+ comparison_type="improvement",
728
+ ).label("num_completion_token_count_improved"),
729
+ _comparison_count_expression(
730
+ base_column=base_experiment_runs.c.completion_tokens,
731
+ compare_column=compare_experiment_runs.c.min_completion_tokens,
732
+ optimization_direction="minimize",
733
+ comparison_type="regression",
734
+ ).label("num_completion_token_count_regressed"),
735
+ _comparison_count_expression(
736
+ base_column=base_experiment_runs.c.completion_tokens,
737
+ compare_column=compare_experiment_runs.c.min_completion_tokens,
738
+ optimization_direction="minimize",
739
+ comparison_type="equality",
740
+ ).label("num_completion_token_count_is_equal"),
741
+ _comparison_count_expression(
742
+ base_column=base_experiment_runs.c.total_cost,
743
+ compare_column=compare_experiment_runs.c.min_total_cost,
744
+ optimization_direction="minimize",
745
+ comparison_type="improvement",
746
+ ).label("num_total_cost_improved"),
747
+ _comparison_count_expression(
748
+ base_column=base_experiment_runs.c.total_cost,
749
+ compare_column=compare_experiment_runs.c.min_total_cost,
750
+ optimization_direction="minimize",
751
+ comparison_type="regression",
752
+ ).label("num_total_cost_regressed"),
753
+ _comparison_count_expression(
754
+ base_column=base_experiment_runs.c.total_cost,
755
+ compare_column=compare_experiment_runs.c.min_total_cost,
756
+ optimization_direction="minimize",
757
+ comparison_type="equality",
758
+ ).label("num_total_cost_is_equal"),
759
+ _comparison_count_expression(
760
+ base_column=base_experiment_runs.c.prompt_cost,
761
+ compare_column=compare_experiment_runs.c.min_prompt_cost,
762
+ optimization_direction="minimize",
763
+ comparison_type="improvement",
764
+ ).label("num_prompt_cost_improved"),
765
+ _comparison_count_expression(
766
+ base_column=base_experiment_runs.c.prompt_cost,
767
+ compare_column=compare_experiment_runs.c.min_prompt_cost,
768
+ optimization_direction="minimize",
769
+ comparison_type="regression",
770
+ ).label("num_prompt_cost_regressed"),
771
+ _comparison_count_expression(
772
+ base_column=base_experiment_runs.c.prompt_cost,
773
+ compare_column=compare_experiment_runs.c.min_prompt_cost,
774
+ optimization_direction="minimize",
775
+ comparison_type="equality",
776
+ ).label("num_prompt_cost_is_equal"),
777
+ _comparison_count_expression(
778
+ base_column=base_experiment_runs.c.completion_cost,
779
+ compare_column=compare_experiment_runs.c.min_completion_cost,
780
+ optimization_direction="minimize",
781
+ comparison_type="improvement",
782
+ ).label("num_completion_cost_improved"),
783
+ _comparison_count_expression(
784
+ base_column=base_experiment_runs.c.completion_cost,
785
+ compare_column=compare_experiment_runs.c.min_completion_cost,
786
+ optimization_direction="minimize",
787
+ comparison_type="regression",
788
+ ).label("num_completion_cost_regressed"),
789
+ _comparison_count_expression(
790
+ base_column=base_experiment_runs.c.completion_cost,
791
+ compare_column=compare_experiment_runs.c.min_completion_cost,
792
+ optimization_direction="minimize",
793
+ comparison_type="equality",
794
+ ).label("num_completion_cost_is_equal"),
795
+ )
796
+ .select_from(base_experiment_runs)
797
+ .join(
798
+ compare_experiment_runs,
799
+ onclause=base_experiment_runs.c.dataset_example_id
800
+ == compare_experiment_runs.c.dataset_example_id,
801
+ isouter=True,
802
+ )
803
+ )
804
+
805
+ async with info.context.db() as session:
806
+ result = (await session.execute(comparisons_query)).first()
807
+ assert result is not None
808
+
809
+ return ExperimentRunMetricComparisons(
810
+ latency=ExperimentRunMetricComparison(
811
+ num_runs_improved=result.num_latency_improved,
812
+ num_runs_regressed=result.num_latency_regressed,
813
+ num_runs_equal=result.num_latency_is_equal,
814
+ num_total_runs=result.num_base_experiment_runs,
815
+ ),
816
+ total_token_count=ExperimentRunMetricComparison(
817
+ num_runs_improved=result.num_total_token_count_improved,
818
+ num_runs_regressed=result.num_total_token_count_regressed,
819
+ num_runs_equal=result.num_total_token_count_is_equal,
820
+ num_total_runs=result.num_base_experiment_runs,
821
+ ),
822
+ prompt_token_count=ExperimentRunMetricComparison(
823
+ num_runs_improved=result.num_prompt_token_count_improved,
824
+ num_runs_regressed=result.num_prompt_token_count_regressed,
825
+ num_runs_equal=result.num_prompt_token_count_is_equal,
826
+ num_total_runs=result.num_base_experiment_runs,
827
+ ),
828
+ completion_token_count=ExperimentRunMetricComparison(
829
+ num_runs_improved=result.num_completion_token_count_improved,
830
+ num_runs_regressed=result.num_completion_token_count_regressed,
831
+ num_runs_equal=result.num_completion_token_count_is_equal,
832
+ num_total_runs=result.num_base_experiment_runs,
833
+ ),
834
+ total_cost=ExperimentRunMetricComparison(
835
+ num_runs_improved=result.num_total_cost_improved,
836
+ num_runs_regressed=result.num_total_cost_regressed,
837
+ num_runs_equal=result.num_total_cost_is_equal,
838
+ num_total_runs=result.num_base_experiment_runs,
839
+ ),
840
+ prompt_cost=ExperimentRunMetricComparison(
841
+ num_runs_improved=result.num_prompt_cost_improved,
842
+ num_runs_regressed=result.num_prompt_cost_regressed,
843
+ num_runs_equal=result.num_prompt_cost_is_equal,
844
+ num_total_runs=result.num_base_experiment_runs,
845
+ ),
846
+ completion_cost=ExperimentRunMetricComparison(
847
+ num_runs_improved=result.num_completion_cost_improved,
848
+ num_runs_regressed=result.num_completion_cost_regressed,
849
+ num_runs_equal=result.num_completion_cost_is_equal,
850
+ num_total_runs=result.num_base_experiment_runs,
851
+ ),
852
+ )
429
853
 
430
854
  @strawberry.field
431
855
  async def validate_experiment_run_filter_condition(
@@ -437,7 +861,7 @@ class Query:
437
861
  compile_sqlalchemy_filter_condition(
438
862
  filter_condition=condition,
439
863
  experiment_ids=[
440
- from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
864
+ from_global_id_with_expected_type(experiment_id, models.Experiment.__name__)
441
865
  for experiment_id in experiment_ids
442
866
  ],
443
867
  )
@@ -459,140 +883,55 @@ class Query:
459
883
  )
460
884
 
461
885
  @strawberry.field
462
- def model(self) -> Model:
463
- return Model()
886
+ def model(self) -> InferenceModel:
887
+ return InferenceModel()
464
888
 
465
889
  @strawberry.field
466
- async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
467
- type_name, node_id = from_global_id(id)
890
+ async def node(self, id: strawberry.ID, info: Info[Context, None]) -> Node:
891
+ if not is_global_id(id):
892
+ try:
893
+ experiment_rowid, dataset_example_rowid = (
894
+ parse_experiment_repeated_run_group_node_id(id)
895
+ )
896
+ except Exception:
897
+ raise NotFound(f"Unknown node: {id}")
898
+ return ExperimentRepeatedRunGroup(
899
+ experiment_rowid=experiment_rowid,
900
+ dataset_example_rowid=dataset_example_rowid,
901
+ )
902
+
903
+ global_id = GlobalID.from_id(id)
904
+ type_name, node_id = from_global_id(global_id)
468
905
  if type_name == "Dimension":
469
906
  dimension = info.context.model.scalar_dimensions[node_id]
470
907
  return to_gql_dimension(node_id, dimension)
471
908
  elif type_name == "EmbeddingDimension":
472
909
  embedding_dimension = info.context.model.embedding_dimensions[node_id]
473
910
  return to_gql_embedding_dimension(node_id, embedding_dimension)
474
- elif type_name == "Project":
475
- project_stmt = select(models.Project).filter_by(id=node_id)
476
- async with info.context.db() as session:
477
- project = await session.scalar(project_stmt)
478
- if project is None:
479
- raise NotFound(f"Unknown project: {id}")
480
- return Project(
481
- project_rowid=project.id,
482
- db_project=project,
483
- )
484
- elif type_name == "Trace":
485
- trace_stmt = select(models.Trace).filter_by(id=node_id)
486
- async with info.context.db() as session:
487
- trace = await session.scalar(trace_stmt)
488
- if trace is None:
489
- raise NotFound(f"Unknown trace: {id}")
490
- return Trace(trace_rowid=trace.id, db_trace=trace)
911
+ elif type_name == Project.__name__:
912
+ return Project(id=node_id)
913
+ elif type_name == Trace.__name__:
914
+ return Trace(id=node_id)
491
915
  elif type_name == Span.__name__:
492
- span_stmt = (
493
- select(models.Span)
494
- .options(
495
- joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
496
- )
497
- .where(models.Span.id == node_id)
498
- )
499
- async with info.context.db() as session:
500
- span = await session.scalar(span_stmt)
501
- if span is None:
502
- raise NotFound(f"Unknown span: {id}")
503
- return Span(span_rowid=span.id, db_span=span)
916
+ return Span(id=node_id)
504
917
  elif type_name == Dataset.__name__:
505
- dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
506
- async with info.context.db() as session:
507
- if (dataset := await session.scalar(dataset_stmt)) is None:
508
- raise NotFound(f"Unknown dataset: {id}")
509
- return to_gql_dataset(dataset)
918
+ return Dataset(id=node_id)
510
919
  elif type_name == DatasetExample.__name__:
511
- example_id = node_id
512
- latest_revision_id = (
513
- select(func.max(models.DatasetExampleRevision.id))
514
- .where(models.DatasetExampleRevision.dataset_example_id == example_id)
515
- .scalar_subquery()
516
- )
517
- async with info.context.db() as session:
518
- example = await session.scalar(
519
- select(models.DatasetExample)
520
- .join(
521
- models.DatasetExampleRevision,
522
- onclause=models.DatasetExampleRevision.dataset_example_id
523
- == models.DatasetExample.id,
524
- )
525
- .where(
526
- and_(
527
- models.DatasetExample.id == example_id,
528
- models.DatasetExampleRevision.id == latest_revision_id,
529
- models.DatasetExampleRevision.revision_kind != "DELETE",
530
- )
531
- )
532
- )
533
- if not example:
534
- raise NotFound(f"Unknown dataset example: {id}")
535
- return DatasetExample(
536
- id_attr=example.id,
537
- created_at=example.created_at,
538
- )
920
+ return DatasetExample(id=node_id)
921
+ elif type_name == DatasetSplit.__name__:
922
+ return DatasetSplit(id=node_id)
539
923
  elif type_name == Experiment.__name__:
540
- async with info.context.db() as session:
541
- experiment = await session.scalar(
542
- select(models.Experiment).where(models.Experiment.id == node_id)
543
- )
544
- if not experiment:
545
- raise NotFound(f"Unknown experiment: {id}")
546
- return Experiment(
547
- id_attr=experiment.id,
548
- name=experiment.name,
549
- project_name=experiment.project_name,
550
- description=experiment.description,
551
- created_at=experiment.created_at,
552
- updated_at=experiment.updated_at,
553
- metadata=experiment.metadata_,
554
- )
924
+ return Experiment(id=node_id)
555
925
  elif type_name == ExperimentRun.__name__:
556
- async with info.context.db() as session:
557
- if not (
558
- run := await session.scalar(
559
- select(models.ExperimentRun)
560
- .where(models.ExperimentRun.id == node_id)
561
- .options(
562
- joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
563
- )
564
- )
565
- ):
566
- raise NotFound(f"Unknown experiment run: {id}")
567
- return to_gql_experiment_run(run)
926
+ return ExperimentRun(id=node_id)
568
927
  elif type_name == User.__name__:
569
928
  if int((user := info.context.user).identity) != node_id and not user.is_admin:
570
929
  raise Unauthorized(MSG_ADMIN_ONLY)
571
- async with info.context.db() as session:
572
- if not (
573
- user := await session.scalar(
574
- select(models.User).where(models.User.id == node_id)
575
- )
576
- ):
577
- raise NotFound(f"Unknown user: {id}")
578
- return to_gql_user(user)
930
+ return User(id=node_id)
579
931
  elif type_name == ProjectSession.__name__:
580
- async with info.context.db() as session:
581
- if not (
582
- project_session := await session.scalar(
583
- select(models.ProjectSession).filter_by(id=node_id)
584
- )
585
- ):
586
- raise NotFound(f"Unknown user: {id}")
587
- return to_gql_project_session(project_session)
932
+ return ProjectSession(id=node_id)
588
933
  elif type_name == Prompt.__name__:
589
- async with info.context.db() as session:
590
- if orm_prompt := await session.scalar(
591
- select(models.Prompt).where(models.Prompt.id == node_id)
592
- ):
593
- return to_gql_prompt_from_orm(orm_prompt)
594
- else:
595
- raise NotFound(f"Unknown prompt: {id}")
934
+ return Prompt(id=node_id)
596
935
  elif type_name == PromptVersion.__name__:
597
936
  async with info.context.db() as session:
598
937
  if orm_prompt_version := await session.scalar(
@@ -602,39 +941,17 @@ class Query:
602
941
  else:
603
942
  raise NotFound(f"Unknown prompt version: {id}")
604
943
  elif type_name == PromptLabel.__name__:
605
- async with info.context.db() as session:
606
- if not (
607
- prompt_label := await session.scalar(
608
- select(models.PromptLabel).where(models.PromptLabel.id == node_id)
609
- )
610
- ):
611
- raise NotFound(f"Unknown prompt label: {id}")
612
- return to_gql_prompt_label(prompt_label)
944
+ return PromptLabel(id=node_id)
613
945
  elif type_name == PromptVersionTag.__name__:
614
- async with info.context.db() as session:
615
- if not (prompt_version_tag := await session.get(models.PromptVersionTag, node_id)):
616
- raise NotFound(f"Unknown prompt version tag: {id}")
617
- return to_gql_prompt_version_tag(prompt_version_tag)
946
+ return PromptVersionTag(id=node_id)
618
947
  elif type_name == ProjectTraceRetentionPolicy.__name__:
619
- async with info.context.db() as session:
620
- db_policy = await session.scalar(
621
- select(models.ProjectTraceRetentionPolicy).filter_by(id=node_id)
622
- )
623
- if not db_policy:
624
- raise NotFound(f"Unknown project trace retention policy: {id}")
625
- return ProjectTraceRetentionPolicy(id=db_policy.id, db_policy=db_policy)
948
+ return ProjectTraceRetentionPolicy(id=node_id)
626
949
  elif type_name == SpanAnnotation.__name__:
627
- async with info.context.db() as session:
628
- span_annotation = await session.get(models.SpanAnnotation, node_id)
629
- if not span_annotation:
630
- raise NotFound(f"Unknown span annotation: {id}")
631
- return to_gql_span_annotation(span_annotation)
950
+ return SpanAnnotation(id=node_id)
632
951
  elif type_name == TraceAnnotation.__name__:
633
- async with info.context.db() as session:
634
- trace_annotation = await session.get(models.TraceAnnotation, node_id)
635
- if not trace_annotation:
636
- raise NotFound(f"Unknown trace annotation: {id}")
637
- return to_gql_trace_annotation(trace_annotation)
952
+ return TraceAnnotation(id=node_id)
953
+ elif type_name == GenerativeModel.__name__:
954
+ return GenerativeModel(id=node_id)
638
955
  raise NotFound(f"Unknown node type: {type_name}")
639
956
 
640
957
  @strawberry.field
@@ -646,16 +963,7 @@ class Query:
646
963
  return None
647
964
  if isinstance(user, UnauthenticatedUser):
648
965
  return None
649
- async with info.context.db() as session:
650
- if (
651
- user := await session.scalar(
652
- select(models.User)
653
- .where(models.User.id == int(user.identity))
654
- .options(joinedload(models.User.role))
655
- )
656
- ) is None:
657
- return None
658
- return to_gql_user(user)
966
+ return User(id=int(user.identity))
659
967
 
660
968
  @strawberry.field
661
969
  async def prompts(
@@ -665,6 +973,8 @@ class Query:
665
973
  last: Optional[int] = UNSET,
666
974
  after: Optional[CursorString] = UNSET,
667
975
  before: Optional[CursorString] = UNSET,
976
+ filter: Optional[PromptFilter] = UNSET,
977
+ labelIds: Optional[list[GlobalID]] = UNSET,
668
978
  ) -> Connection[Prompt]:
669
979
  args = ConnectionArgs(
670
980
  first=first,
@@ -673,9 +983,29 @@ class Query:
673
983
  before=before if isinstance(before, CursorString) else None,
674
984
  )
675
985
  stmt = select(models.Prompt)
986
+ if filter:
987
+ column = getattr(models.Prompt, filter.col.value)
988
+ # Cast Identifier columns to String for ilike operations
989
+ if filter.col.value == "name":
990
+ column = cast(column, String)
991
+ stmt = stmt.where(column.ilike(f"%{filter.value}%")).order_by(
992
+ models.Prompt.updated_at.desc()
993
+ )
994
+ if labelIds:
995
+ stmt = stmt.join(models.PromptPromptLabel).where(
996
+ models.PromptPromptLabel.prompt_label_id.in_(
997
+ from_global_id_with_expected_type(
998
+ global_id=label_id, expected_type_name="PromptLabel"
999
+ )
1000
+ for label_id in labelIds
1001
+ )
1002
+ )
1003
+ stmt = stmt.distinct()
676
1004
  async with info.context.db() as session:
677
1005
  orm_prompts = await session.stream_scalars(stmt)
678
- data = [to_gql_prompt_from_orm(orm_prompt) async for orm_prompt in orm_prompts]
1006
+ data = [
1007
+ Prompt(id=orm_prompt.id, db_record=orm_prompt) async for orm_prompt in orm_prompts
1008
+ ]
679
1009
  return connection_from_list(
680
1010
  data=data,
681
1011
  args=args,
@@ -698,7 +1028,58 @@ class Query:
698
1028
  )
699
1029
  async with info.context.db() as session:
700
1030
  prompt_labels = await session.stream_scalars(select(models.PromptLabel))
701
- data = [to_gql_prompt_label(prompt_label) async for prompt_label in prompt_labels]
1031
+ data = [
1032
+ PromptLabel(id=prompt_label.id, db_record=prompt_label)
1033
+ async for prompt_label in prompt_labels
1034
+ ]
1035
+ return connection_from_list(
1036
+ data=data,
1037
+ args=args,
1038
+ )
1039
+
1040
+ @strawberry.field
1041
+ async def dataset_labels(
1042
+ self,
1043
+ info: Info[Context, None],
1044
+ first: Optional[int] = 50,
1045
+ last: Optional[int] = UNSET,
1046
+ after: Optional[CursorString] = UNSET,
1047
+ before: Optional[CursorString] = UNSET,
1048
+ ) -> Connection[DatasetLabel]:
1049
+ args = ConnectionArgs(
1050
+ first=first,
1051
+ after=after if isinstance(after, CursorString) else None,
1052
+ last=last,
1053
+ before=before if isinstance(before, CursorString) else None,
1054
+ )
1055
+ async with info.context.db() as session:
1056
+ dataset_labels = await session.scalars(
1057
+ select(models.DatasetLabel).order_by(models.DatasetLabel.name.asc())
1058
+ )
1059
+ data = [
1060
+ DatasetLabel(id=dataset_label.id, db_record=dataset_label)
1061
+ for dataset_label in dataset_labels
1062
+ ]
1063
+ return connection_from_list(data=data, args=args)
1064
+
1065
+ @strawberry.field
1066
+ async def dataset_splits(
1067
+ self,
1068
+ info: Info[Context, None],
1069
+ first: Optional[int] = 50,
1070
+ last: Optional[int] = UNSET,
1071
+ after: Optional[CursorString] = UNSET,
1072
+ before: Optional[CursorString] = UNSET,
1073
+ ) -> Connection[DatasetSplit]:
1074
+ args = ConnectionArgs(
1075
+ first=first,
1076
+ after=after if isinstance(after, CursorString) else None,
1077
+ last=last,
1078
+ before=before if isinstance(before, CursorString) else None,
1079
+ )
1080
+ async with info.context.db() as session:
1081
+ splits = await session.stream_scalars(select(models.DatasetSplit))
1082
+ data = [DatasetSplit(id=split.id, db_record=split) async for split in splits]
702
1083
  return connection_from_list(
703
1084
  data=data,
704
1085
  args=args,
@@ -921,16 +1302,17 @@ class Query:
921
1302
  # stats = cast(Iterable[tuple[str, int]], await session.execute(stmt))
922
1303
  # stats = _consolidate_sqlite_db_table_stats(stats)
923
1304
  elif info.context.db.dialect is SupportedSQLDialect.POSTGRESQL:
924
- stmt = text(f"""\
1305
+ nspname = getenv(ENV_PHOENIX_SQL_DATABASE_SCHEMA) or "public"
1306
+ stmt = text("""\
925
1307
  SELECT c.relname, pg_total_relation_size(c.oid)
926
1308
  FROM pg_class as c
927
1309
  INNER JOIN pg_namespace as n ON n.oid = c.relnamespace
928
1310
  WHERE c.relkind = 'r'
929
- AND n.nspname = '{getenv(ENV_PHOENIX_SQL_DATABASE_SCHEMA) or "public"}';
930
- """)
1311
+ AND n.nspname = :nspname;
1312
+ """).bindparams(nspname=nspname)
931
1313
  try:
932
1314
  async with info.context.db() as session:
933
- stats = cast(Iterable[tuple[str, int]], await session.execute(stmt))
1315
+ stats = type_cast(Iterable[tuple[str, int]], await session.execute(stmt))
934
1316
  except Exception:
935
1317
  # TODO: temporary workaround until we can reproduce the error
936
1318
  return []
@@ -941,6 +1323,62 @@ class Query:
941
1323
  for table_name, num_bytes in stats
942
1324
  ]
943
1325
 
1326
+ @strawberry.field
1327
+ async def server_status(
1328
+ self,
1329
+ info: Info[Context, None],
1330
+ ) -> ServerStatus:
1331
+ return ServerStatus(
1332
+ insufficient_storage=info.context.db.should_not_insert_or_update,
1333
+ )
1334
+
1335
+ @strawberry.field
1336
+ def validate_regular_expression(self, regex: str) -> ValidationResult:
1337
+ try:
1338
+ re.compile(regex)
1339
+ return ValidationResult(is_valid=True, error_message=None)
1340
+ except re.error as error:
1341
+ return ValidationResult(is_valid=False, error_message=str(error))
1342
+
1343
+ @strawberry.field
1344
+ async def get_span_by_otel_id(
1345
+ self,
1346
+ info: Info[Context, None],
1347
+ span_id: str,
1348
+ ) -> Optional[Span]:
1349
+ stmt = select(models.Span.id).filter_by(span_id=span_id)
1350
+ async with info.context.db() as session:
1351
+ span_rowid = await session.scalar(stmt)
1352
+ if span_rowid:
1353
+ return Span(id=span_rowid)
1354
+ return None
1355
+
1356
+ @strawberry.field
1357
+ async def get_trace_by_otel_id(
1358
+ self,
1359
+ info: Info[Context, None],
1360
+ trace_id: str,
1361
+ ) -> Optional[Trace]:
1362
+ stmt = select(models.Trace.id).where(models.Trace.trace_id == trace_id)
1363
+ async with info.context.db() as session:
1364
+ trace_rowid = await session.scalar(stmt)
1365
+ if trace_rowid:
1366
+ return Trace(id=trace_rowid)
1367
+ return None
1368
+
1369
+ @strawberry.field
1370
+ async def get_project_session_by_id(
1371
+ self,
1372
+ info: Info[Context, None],
1373
+ session_id: str,
1374
+ ) -> Optional[ProjectSession]:
1375
+ stmt = select(models.ProjectSession).where(models.ProjectSession.session_id == session_id)
1376
+ async with info.context.db() as session:
1377
+ session_row = await session.scalar(stmt)
1378
+ if session_row:
1379
+ return ProjectSession(id=session_row.id, db_record=session_row)
1380
+ return None
1381
+
944
1382
 
945
1383
  def _consolidate_sqlite_db_table_stats(
946
1384
  stats: Iterable[tuple[str, int]],
@@ -974,3 +1412,40 @@ def _longest_matching_prefix(s: str, prefixes: Iterable[str]) -> str:
974
1412
  if s.startswith(prefix) and len(prefix) > len(longest):
975
1413
  longest = prefix
976
1414
  return longest
1415
+
1416
+
1417
+ def _comparison_count_expression(
1418
+ *,
1419
+ base_column: ColumnElement[Any],
1420
+ compare_column: ColumnElement[Any],
1421
+ optimization_direction: Literal["maximize", "minimize"],
1422
+ comparison_type: Literal["improvement", "regression", "equality"],
1423
+ ) -> ColumnElement[int]:
1424
+ """
1425
+ Given a base and compare column, returns an expression counting the number of
1426
+ improvements, regressions, or equalities given the optimization direction.
1427
+ """
1428
+ if optimization_direction == "maximize":
1429
+ raise NotImplementedError
1430
+
1431
+ if comparison_type == "improvement":
1432
+ condition = compare_column > base_column
1433
+ elif comparison_type == "regression":
1434
+ condition = compare_column < base_column
1435
+ elif comparison_type == "equality":
1436
+ condition = compare_column == base_column
1437
+ else:
1438
+ assert_never(comparison_type)
1439
+
1440
+ return func.coalesce(
1441
+ func.sum(
1442
+ case(
1443
+ (
1444
+ condition,
1445
+ 1,
1446
+ ),
1447
+ else_=0,
1448
+ )
1449
+ ),
1450
+ 0,
1451
+ )