arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,14 @@
1
1
  import re
2
2
  from collections import defaultdict
3
3
  from datetime import datetime
4
- from typing import Any, Iterable, Iterator, Optional, Union
4
+ from typing import Any, Iterable, Iterator, Literal, Optional, Union
5
5
  from typing import cast as type_cast
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
9
  import strawberry
10
10
  from sqlalchemy import ColumnElement, String, and_, case, cast, func, select, text
11
- from sqlalchemy.orm import aliased, joinedload, load_only
11
+ from sqlalchemy.orm import joinedload, load_only
12
12
  from starlette.authentication import UnauthenticatedUser
13
13
  from strawberry import ID, UNSET
14
14
  from strawberry.relay import Connection, GlobalID, Node
@@ -22,7 +22,10 @@ from phoenix.config import (
22
22
  )
23
23
  from phoenix.db import models
24
24
  from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
25
- from phoenix.db.helpers import SupportedSQLDialect, exclude_experiment_projects
25
+ from phoenix.db.helpers import (
26
+ SupportedSQLDialect,
27
+ exclude_experiment_projects,
28
+ )
26
29
  from phoenix.db.models import LatencyMs
27
30
  from phoenix.pointcloud.clustering import Hdbscan
28
31
  from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
@@ -46,8 +49,10 @@ from phoenix.server.api.input_types.ProjectSort import ProjectColumn, ProjectSor
46
49
  from phoenix.server.api.input_types.PromptFilter import PromptFilter
47
50
  from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_annotation_config
48
51
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
49
- from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
52
+ from phoenix.server.api.types.Dataset import Dataset
50
53
  from phoenix.server.api.types.DatasetExample import DatasetExample
54
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
55
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
51
56
  from phoenix.server.api.types.Dimension import to_gql_dimension
52
57
  from phoenix.server.api.types.EmbeddingDimension import (
53
58
  DEFAULT_CLUSTER_SELECTION_EPSILON,
@@ -57,14 +62,24 @@ from phoenix.server.api.types.EmbeddingDimension import (
57
62
  )
58
63
  from phoenix.server.api.types.Event import create_event_id, unpack_event_id
59
64
  from phoenix.server.api.types.Experiment import Experiment
60
- from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
61
- from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
65
+ from phoenix.server.api.types.ExperimentComparison import (
66
+ ExperimentComparison,
67
+ )
68
+ from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
69
+ ExperimentRepeatedRunGroup,
70
+ parse_experiment_repeated_run_group_node_id,
71
+ )
72
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
62
73
  from phoenix.server.api.types.Functionality import Functionality
63
- from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
74
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel
64
75
  from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
65
76
  from phoenix.server.api.types.InferenceModel import InferenceModel
66
77
  from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
67
- from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
78
+ from phoenix.server.api.types.node import (
79
+ from_global_id,
80
+ from_global_id_with_expected_type,
81
+ is_global_id,
82
+ )
68
83
  from phoenix.server.api.types.pagination import (
69
84
  ConnectionArgs,
70
85
  Cursor,
@@ -74,21 +89,21 @@ from phoenix.server.api.types.pagination import (
74
89
  )
75
90
  from phoenix.server.api.types.PlaygroundModel import PlaygroundModel
76
91
  from phoenix.server.api.types.Project import Project
77
- from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
92
+ from phoenix.server.api.types.ProjectSession import ProjectSession
78
93
  from phoenix.server.api.types.ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
79
- from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
80
- from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
94
+ from phoenix.server.api.types.Prompt import Prompt
95
+ from phoenix.server.api.types.PromptLabel import PromptLabel
81
96
  from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
82
- from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
97
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
83
98
  from phoenix.server.api.types.ServerStatus import ServerStatus
84
99
  from phoenix.server.api.types.SortDir import SortDir
85
100
  from phoenix.server.api.types.Span import Span
86
- from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
101
+ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
87
102
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
88
103
  from phoenix.server.api.types.Trace import Trace
89
- from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
90
- from phoenix.server.api.types.User import User, to_gql_user
91
- from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
104
+ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
105
+ from phoenix.server.api.types.User import User
106
+ from phoenix.server.api.types.UserApiKey import UserApiKey
92
107
  from phoenix.server.api.types.UserRole import UserRole
93
108
  from phoenix.server.api.types.ValidationResult import ValidationResult
94
109
 
@@ -108,29 +123,52 @@ class DbTableStats:
108
123
 
109
124
 
110
125
  @strawberry.type
111
- class MetricCounts:
112
- num_increases: int
113
- num_decreases: int
114
- num_equal: int
115
-
126
+ class ExperimentRunMetricComparison:
127
+ num_runs_improved: int = strawberry.field(
128
+ description=(
129
+ "The number of runs in which the base experiment improved "
130
+ "on the best run in any compare experiment."
131
+ )
132
+ )
133
+ num_runs_regressed: int = strawberry.field(
134
+ description=(
135
+ "The number of runs in which the base experiment regressed "
136
+ "on the best run in any compare experiment."
137
+ )
138
+ )
139
+ num_runs_equal: int = strawberry.field(
140
+ description=(
141
+ "The number of runs in which the base experiment is equal to the best run "
142
+ "in any compare experiment."
143
+ )
144
+ )
145
+ num_total_runs: strawberry.Private[int]
116
146
 
117
- @strawberry.type
118
- class CompareExperimentRunMetricCounts:
119
- compare_experiment_id: GlobalID
120
- latency: MetricCounts
121
- prompt_token_count: MetricCounts
122
- completion_token_count: MetricCounts
123
- total_token_count: MetricCounts
124
- total_cost: MetricCounts
147
+ @strawberry.field(
148
+ description=(
149
+ "The number of runs in the base experiment that could not be compared, either because "
150
+ "the base experiment run was missing a value or because all compare experiment runs "
151
+ "were missing values."
152
+ )
153
+ ) # type: ignore[misc]
154
+ def num_runs_without_comparison(self) -> int:
155
+ return (
156
+ self.num_total_runs
157
+ - self.num_runs_improved
158
+ - self.num_runs_regressed
159
+ - self.num_runs_equal
160
+ )
125
161
 
126
162
 
127
163
  @strawberry.type
128
- class CompareExperimentRunAnnotationMetricCounts:
129
- annotation_name: str
130
- compare_experiment_id: GlobalID
131
- num_increases: int
132
- num_decreases: int
133
- num_equal: int
164
+ class ExperimentRunMetricComparisons:
165
+ latency: ExperimentRunMetricComparison
166
+ total_token_count: ExperimentRunMetricComparison
167
+ prompt_token_count: ExperimentRunMetricComparison
168
+ completion_token_count: ExperimentRunMetricComparison
169
+ total_cost: ExperimentRunMetricComparison
170
+ prompt_cost: ExperimentRunMetricComparison
171
+ completion_cost: ExperimentRunMetricComparison
134
172
 
135
173
 
136
174
  @strawberry.type
@@ -150,7 +188,17 @@ class Query:
150
188
  async def generative_models(
151
189
  self,
152
190
  info: Info[Context, None],
153
- ) -> list[GenerativeModel]:
191
+ first: Optional[int] = 50,
192
+ last: Optional[int] = UNSET,
193
+ after: Optional[CursorString] = UNSET,
194
+ before: Optional[CursorString] = UNSET,
195
+ ) -> Connection[GenerativeModel]:
196
+ args = ConnectionArgs(
197
+ first=first,
198
+ after=after if isinstance(after, CursorString) else None,
199
+ last=last,
200
+ before=before if isinstance(before, CursorString) else None,
201
+ )
154
202
  async with info.context.db() as session:
155
203
  result = await session.scalars(
156
204
  select(models.GenerativeModel)
@@ -160,17 +208,16 @@ class Query:
160
208
  models.GenerativeModel.provider.nullslast(),
161
209
  models.GenerativeModel.name,
162
210
  )
163
- .options(joinedload(models.GenerativeModel.token_prices))
164
211
  )
165
-
166
- return [to_gql_generative_model(model) for model in result.unique()]
212
+ data = [GenerativeModel(id=model.id, db_record=model) for model in result.unique()]
213
+ return connection_from_list(data=data, args=args)
167
214
 
168
215
  @strawberry.field
169
216
  async def playground_models(self, input: Optional[ModelsInput] = None) -> list[PlaygroundModel]:
170
217
  if input is not None and input.provider_key is not None:
171
218
  supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
172
219
  supported_models = [
173
- PlaygroundModel(name=model_name, provider_key=input.provider_key)
220
+ PlaygroundModel(name_value=model_name, provider_key_value=input.provider_key)
174
221
  for model_name in supported_model_names
175
222
  ]
176
223
  return supported_models
@@ -179,7 +226,9 @@ class Query:
179
226
  all_models: list[PlaygroundModel] = []
180
227
  for provider_key, model_name in registered_models:
181
228
  if model_name is not None and provider_key is not None:
182
- all_models.append(PlaygroundModel(name=model_name, provider_key=provider_key))
229
+ all_models.append(
230
+ PlaygroundModel(name_value=model_name, provider_key_value=provider_key)
231
+ )
183
232
  return all_models
184
233
 
185
234
  @strawberry.field
@@ -223,7 +272,7 @@ class Query:
223
272
  )
224
273
  async with info.context.db() as session:
225
274
  users = await session.stream_scalars(stmt)
226
- data = [to_gql_user(user) async for user in users]
275
+ data = [User(id=user.id, db_record=user) async for user in users]
227
276
  return connection_from_list(data=data, args=args)
228
277
 
229
278
  @strawberry.field
@@ -253,7 +302,7 @@ class Query:
253
302
  )
254
303
  async with info.context.db() as session:
255
304
  api_keys = await session.scalars(stmt)
256
- return [to_gql_api_key(api_key) for api_key in api_keys]
305
+ return [UserApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
257
306
 
258
307
  @strawberry.field(permission_classes=[IsAdmin]) # type: ignore
259
308
  async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
@@ -265,16 +314,7 @@ class Query:
265
314
  )
266
315
  async with info.context.db() as session:
267
316
  api_keys = await session.scalars(stmt)
268
- return [
269
- SystemApiKey(
270
- id_attr=api_key.id,
271
- name=api_key.name,
272
- description=api_key.description,
273
- created_at=api_key.created_at,
274
- expires_at=api_key.expires_at,
275
- )
276
- for api_key in api_keys
277
- ]
317
+ return [SystemApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
278
318
 
279
319
  @strawberry.field
280
320
  async def projects(
@@ -315,13 +355,7 @@ class Query:
315
355
  stmt = exclude_experiment_projects(stmt)
316
356
  async with info.context.db() as session:
317
357
  projects = await session.stream_scalars(stmt)
318
- data = [
319
- Project(
320
- project_rowid=project.id,
321
- db_project=project,
322
- )
323
- async for project in projects
324
- ]
358
+ data = [Project(id=project.id, db_record=project) async for project in projects]
325
359
  return connection_from_list(data=data, args=args)
326
360
 
327
361
  @strawberry.field
@@ -350,11 +384,39 @@ class Query:
350
384
  sort_col = getattr(models.Dataset, sort.col.value)
351
385
  stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
352
386
  if filter:
353
- stmt = stmt.where(getattr(models.Dataset, filter.col.value).ilike(f"%{filter.value}%"))
387
+ # Apply name filter
388
+ if filter.col and filter.value:
389
+ stmt = stmt.where(
390
+ getattr(models.Dataset, filter.col.value).ilike(f"%{filter.value}%")
391
+ )
392
+
393
+ # Apply label filter
394
+ if filter.filter_labels and filter.filter_labels is not UNSET:
395
+ label_rowids = []
396
+ for label_id in filter.filter_labels:
397
+ try:
398
+ label_rowid = from_global_id_with_expected_type(
399
+ global_id=GlobalID.from_id(label_id),
400
+ expected_type_name="DatasetLabel",
401
+ )
402
+ label_rowids.append(label_rowid)
403
+ except ValueError:
404
+ continue # Skip invalid label IDs
405
+
406
+ if label_rowids:
407
+ # Join with the junction table to filter by labels
408
+ stmt = (
409
+ stmt.join(
410
+ models.DatasetsDatasetLabel,
411
+ models.Dataset.id == models.DatasetsDatasetLabel.dataset_id,
412
+ )
413
+ .where(models.DatasetsDatasetLabel.dataset_label_id.in_(label_rowids))
414
+ .distinct()
415
+ )
354
416
  async with info.context.db() as session:
355
417
  datasets = await session.scalars(stmt)
356
418
  return connection_from_list(
357
- data=[to_gql_dataset(dataset) for dataset in datasets], args=args
419
+ data=[Dataset(id=dataset.id, db_record=dataset) for dataset in datasets], args=args
358
420
  )
359
421
 
360
422
  @strawberry.field
@@ -413,6 +475,7 @@ class Query:
413
475
  )
414
476
  )
415
477
  ).all()
478
+
416
479
  if not experiments or len(experiments) < len(experiment_rowids):
417
480
  raise NotFound("Unable to resolve one or more experiment IDs.")
418
481
  num_datasets = len(set(experiment.dataset_id for experiment in experiments))
@@ -421,37 +484,19 @@ class Query:
421
484
  base_experiment = next(
422
485
  experiment for experiment in experiments if experiment.id == base_experiment_rowid
423
486
  )
424
- revision_ids = (
425
- select(func.max(models.DatasetExampleRevision.id))
426
- .join(
427
- models.DatasetExample,
428
- models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
429
- )
430
- .where(
431
- and_(
432
- models.DatasetExampleRevision.dataset_version_id
433
- <= base_experiment.dataset_version_id,
434
- models.DatasetExample.dataset_id == base_experiment.dataset_id,
435
- )
436
- )
437
- .group_by(models.DatasetExampleRevision.dataset_example_id)
438
- .scalar_subquery()
439
- )
487
+
488
+ # Use ExperimentDatasetExample to pull down examples.
489
+ # Splits are mutable and should not be used for comparison.
490
+ # The comparison should only occur against examples which were assigned to the same
491
+ # splits at the time of execution of the ExperimentRun.
440
492
  examples_query = (
441
493
  select(models.DatasetExample)
442
- .distinct(models.DatasetExample.id)
443
- .join(
444
- models.DatasetExampleRevision,
445
- onclause=and_(
446
- models.DatasetExample.id
447
- == models.DatasetExampleRevision.dataset_example_id,
448
- models.DatasetExampleRevision.id.in_(revision_ids),
449
- models.DatasetExampleRevision.revision_kind != "DELETE",
450
- ),
451
- )
494
+ .join(models.ExperimentDatasetExample)
495
+ .where(models.ExperimentDatasetExample.experiment_id == base_experiment_rowid)
452
496
  .order_by(models.DatasetExample.id.desc())
453
497
  .limit(page_size + 1)
454
498
  )
499
+
455
500
  if cursor is not None:
456
501
  examples_query = examples_query.where(models.DatasetExample.id < cursor.rowid)
457
502
 
@@ -490,15 +535,17 @@ class Query:
490
535
 
491
536
  cursors_and_nodes = []
492
537
  for example in examples:
493
- run_comparison_items = []
538
+ repeated_run_groups = []
494
539
  for experiment_id in experiment_rowids:
495
- run_comparison_items.append(
496
- RunComparisonItem(
497
- experiment_id=GlobalID(Experiment.__name__, str(experiment_id)),
498
- runs=[
499
- to_gql_experiment_run(run)
540
+ repeated_run_groups.append(
541
+ ExperimentRepeatedRunGroup(
542
+ experiment_rowid=experiment_id,
543
+ dataset_example_rowid=example.id,
544
+ cached_runs=[
545
+ ExperimentRun(id=run.id, db_record=run)
500
546
  for run in sorted(
501
- runs[example.id][experiment_id], key=lambda run: run.id
547
+ runs[example.id][experiment_id],
548
+ key=lambda run: run.repetition_number,
502
549
  )
503
550
  ],
504
551
  )
@@ -506,11 +553,11 @@ class Query:
506
553
  experiment_comparison = ExperimentComparison(
507
554
  id_attr=example.id,
508
555
  example=DatasetExample(
509
- id_attr=example.id,
510
- created_at=example.created_at,
556
+ id=example.id,
557
+ db_record=example,
511
558
  version_id=base_experiment.dataset_version_id,
512
559
  ),
513
- run_comparison_items=run_comparison_items,
560
+ repeated_run_groups=repeated_run_groups,
514
561
  )
515
562
  cursors_and_nodes.append((Cursor(rowid=example.id), experiment_comparison))
516
563
 
@@ -521,12 +568,12 @@ class Query:
521
568
  )
522
569
 
523
570
  @strawberry.field
524
- async def compare_experiment_run_metric_counts(
571
+ async def experiment_run_metric_comparisons(
525
572
  self,
526
573
  info: Info[Context, None],
527
574
  base_experiment_id: GlobalID,
528
575
  compare_experiment_ids: list[GlobalID],
529
- ) -> list[CompareExperimentRunMetricCounts]:
576
+ ) -> ExperimentRunMetricComparisons:
530
577
  if base_experiment_id in compare_experiment_ids:
531
578
  raise BadRequest("Compare experiment IDs cannot contain the base experiment ID")
532
579
  if not compare_experiment_ids:
@@ -553,375 +600,256 @@ class Query:
553
600
  raise BadRequest(f"Invalid compare experiment ID: {compare_experiment_id}")
554
601
 
555
602
  base_experiment_runs = (
556
- select(models.ExperimentRun)
603
+ select(
604
+ models.ExperimentRun.dataset_example_id,
605
+ func.min(models.ExperimentRun.start_time).label("start_time"),
606
+ func.min(models.ExperimentRun.end_time).label("end_time"),
607
+ func.sum(models.SpanCost.total_tokens).label("total_tokens"),
608
+ func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
609
+ func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
610
+ func.sum(models.SpanCost.total_cost).label("total_cost"),
611
+ func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
612
+ func.sum(models.SpanCost.completion_cost).label("completion_cost"),
613
+ )
614
+ .select_from(models.ExperimentRun)
615
+ .join(
616
+ models.Trace,
617
+ onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
618
+ isouter=True,
619
+ )
620
+ .join(
621
+ models.SpanCost,
622
+ onclause=models.Trace.id == models.SpanCost.trace_rowid,
623
+ isouter=True,
624
+ )
557
625
  .where(models.ExperimentRun.experiment_id == base_experiment_rowid)
626
+ .group_by(models.ExperimentRun.dataset_example_id)
558
627
  .subquery()
559
628
  .alias("base_experiment_runs")
560
629
  )
561
- base_experiment_traces = aliased(models.Trace, name="base_experiment_traces")
562
- base_experiment_span_costs = (
630
+ compare_experiment_runs = (
563
631
  select(
564
- models.SpanCost.trace_rowid,
565
- func.coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
566
- func.coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
567
- func.coalesce(func.sum(models.SpanCost.completion_tokens), 0).label(
568
- "completion_tokens"
569
- ),
570
- func.coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
632
+ models.ExperimentRun.dataset_example_id,
633
+ func.min(
634
+ LatencyMs(models.ExperimentRun.start_time, models.ExperimentRun.end_time)
635
+ ).label("min_latency_ms"),
636
+ func.min(models.SpanCost.total_tokens).label("min_total_tokens"),
637
+ func.min(models.SpanCost.prompt_tokens).label("min_prompt_tokens"),
638
+ func.min(models.SpanCost.completion_tokens).label("min_completion_tokens"),
639
+ func.min(models.SpanCost.total_cost).label("min_total_cost"),
640
+ func.min(models.SpanCost.prompt_cost).label("min_prompt_cost"),
641
+ func.min(models.SpanCost.completion_cost).label("min_completion_cost"),
571
642
  )
572
- .select_from(models.SpanCost)
573
- .group_by(
574
- models.SpanCost.trace_rowid,
575
- )
576
- .subquery()
577
- .alias("base_experiment_span_costs")
578
- )
579
-
580
- query = (
581
- select() # add selected columns below
582
- .select_from(base_experiment_runs)
643
+ .select_from(models.ExperimentRun)
583
644
  .join(
584
- base_experiment_traces,
585
- onclause=base_experiment_runs.c.trace_id == base_experiment_traces.trace_id,
645
+ models.Trace,
646
+ onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
586
647
  isouter=True,
587
648
  )
588
649
  .join(
589
- base_experiment_span_costs,
590
- onclause=base_experiment_traces.id == base_experiment_span_costs.c.trace_rowid,
650
+ models.SpanCost,
651
+ onclause=models.Trace.id == models.SpanCost.trace_rowid,
591
652
  isouter=True,
592
653
  )
654
+ .where(
655
+ models.ExperimentRun.experiment_id.in_(compare_experiment_rowids),
656
+ )
657
+ .group_by(models.ExperimentRun.dataset_example_id)
658
+ .subquery()
659
+ .alias("comp_exp_run_mins")
593
660
  )
594
661
 
595
662
  base_experiment_run_latency = LatencyMs(
596
663
  base_experiment_runs.c.start_time, base_experiment_runs.c.end_time
597
664
  ).label("base_experiment_run_latency_ms")
598
- base_experiment_run_prompt_token_count = base_experiment_span_costs.c.prompt_tokens
599
- base_experiment_run_completion_token_count = base_experiment_span_costs.c.completion_tokens
600
- base_experiment_run_total_token_count = base_experiment_span_costs.c.total_tokens
601
- base_experiment_run_total_cost = base_experiment_span_costs.c.total_cost
602
-
603
- for compare_experiment_index, compare_experiment_rowid in enumerate(
604
- compare_experiment_rowids
605
- ):
606
- compare_experiment_runs = (
607
- select(models.ExperimentRun)
608
- .where(models.ExperimentRun.experiment_id == compare_experiment_rowid)
609
- .subquery()
610
- .alias(f"comp_exp_{compare_experiment_index}_runs")
611
- )
612
- compare_experiment_traces = aliased(
613
- models.Trace, name=f"comp_exp_{compare_experiment_index}_traces"
614
- )
615
- compare_experiment_span_costs = (
616
- select(
617
- models.SpanCost.trace_rowid,
618
- func.coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
619
- func.coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label(
620
- "prompt_tokens"
621
- ),
622
- func.coalesce(func.sum(models.SpanCost.completion_tokens), 0).label(
623
- "completion_tokens"
624
- ),
625
- func.coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
626
- )
627
- .select_from(models.SpanCost)
628
- .group_by(models.SpanCost.trace_rowid)
629
- .subquery()
630
- .alias(f"comp_exp_{compare_experiment_index}_span_costs")
631
- )
632
- compare_experiment_run_latency = LatencyMs(
633
- compare_experiment_runs.c.start_time, compare_experiment_runs.c.end_time
634
- ).label(f"comp_exp_{compare_experiment_index}_run_latency_ms")
635
- compare_experiment_run_prompt_token_count = (
636
- compare_experiment_span_costs.c.prompt_tokens
637
- )
638
- compare_experiment_run_completion_token_count = (
639
- compare_experiment_span_costs.c.completion_tokens
640
- )
641
- compare_experiment_run_total_token_count = compare_experiment_span_costs.c.total_tokens
642
- compare_experiment_run_total_cost = compare_experiment_span_costs.c.total_cost
643
-
644
- query = (
645
- query.add_columns(
646
- _count_rows(
647
- base_experiment_run_latency < compare_experiment_run_latency,
648
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_increased_latency"),
649
- _count_rows(
650
- base_experiment_run_latency > compare_experiment_run_latency,
651
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_decreased_latency"),
652
- _count_rows(
653
- base_experiment_run_latency == compare_experiment_run_latency,
654
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_equal_latency"),
655
- _count_rows(
656
- base_experiment_run_prompt_token_count
657
- < compare_experiment_run_prompt_token_count,
658
- ).label(
659
- f"comp_exp_{compare_experiment_index}_num_runs_increased_prompt_token_count"
660
- ),
661
- _count_rows(
662
- base_experiment_run_prompt_token_count
663
- > compare_experiment_run_prompt_token_count,
664
- ).label(
665
- f"comp_exp_{compare_experiment_index}_num_runs_decreased_prompt_token_count"
666
- ),
667
- _count_rows(
668
- base_experiment_run_prompt_token_count
669
- == compare_experiment_run_prompt_token_count,
670
- ).label(
671
- f"comp_exp_{compare_experiment_index}_num_runs_equal_prompt_token_count"
672
- ),
673
- _count_rows(
674
- base_experiment_run_completion_token_count
675
- < compare_experiment_run_completion_token_count,
676
- ).label(
677
- f"comp_exp_{compare_experiment_index}_num_runs_increased_completion_token_count"
678
- ),
679
- _count_rows(
680
- base_experiment_run_completion_token_count
681
- > compare_experiment_run_completion_token_count,
682
- ).label(
683
- f"comp_exp_{compare_experiment_index}_num_runs_decreased_completion_token_count"
684
- ),
685
- _count_rows(
686
- base_experiment_run_completion_token_count
687
- == compare_experiment_run_completion_token_count,
688
- ).label(
689
- f"comp_exp_{compare_experiment_index}_num_runs_equal_completion_token_count"
690
- ),
691
- _count_rows(
692
- base_experiment_run_total_token_count
693
- < compare_experiment_run_total_token_count,
694
- ).label(
695
- f"comp_exp_{compare_experiment_index}_num_runs_increased_total_token_count"
696
- ),
697
- _count_rows(
698
- base_experiment_run_total_token_count
699
- > compare_experiment_run_total_token_count,
700
- ).label(
701
- f"comp_exp_{compare_experiment_index}_num_runs_decreased_total_token_count"
702
- ),
703
- _count_rows(
704
- base_experiment_run_total_token_count
705
- == compare_experiment_run_total_token_count,
706
- ).label(
707
- f"comp_exp_{compare_experiment_index}_num_runs_equal_total_token_count"
708
- ),
709
- _count_rows(
710
- base_experiment_run_total_cost < compare_experiment_run_total_cost,
711
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_increased_total_cost"),
712
- _count_rows(
713
- base_experiment_run_total_cost > compare_experiment_run_total_cost,
714
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_decreased_total_cost"),
715
- _count_rows(
716
- base_experiment_run_total_cost == compare_experiment_run_total_cost,
717
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_equal_total_cost"),
718
- )
719
- .join(
720
- compare_experiment_runs,
721
- onclause=base_experiment_runs.c.dataset_example_id
722
- == compare_experiment_runs.c.dataset_example_id,
723
- isouter=True,
724
- )
725
- .join(
726
- compare_experiment_traces,
727
- onclause=compare_experiment_runs.c.trace_id
728
- == compare_experiment_traces.trace_id,
729
- isouter=True,
730
- )
731
- .join(
732
- compare_experiment_span_costs,
733
- onclause=compare_experiment_traces.id
734
- == compare_experiment_span_costs.c.trace_rowid,
735
- isouter=True,
736
- )
737
- )
738
665
 
739
- async with info.context.db() as session:
740
- result = (await session.execute(query)).first()
741
- assert result is not None
742
-
743
- num_columns_per_compare_experiment = len(query.columns) // len(compare_experiment_ids)
744
- counts = []
745
- for compare_experiment_index, compare_experiment_id in enumerate(compare_experiment_ids):
746
- start_index = compare_experiment_index * num_columns_per_compare_experiment
747
- end_index = start_index + num_columns_per_compare_experiment
748
- (
749
- num_runs_with_increased_latency,
750
- num_runs_with_decreased_latency,
751
- num_runs_with_equal_latency,
752
- num_runs_with_increased_prompt_token_count,
753
- num_runs_with_decreased_prompt_token_count,
754
- num_runs_with_equal_prompt_token_count,
755
- num_runs_with_increased_completion_token_count,
756
- num_runs_with_decreased_completion_token_count,
757
- num_runs_with_equal_completion_token_count,
758
- num_runs_with_increased_total_token_count,
759
- num_runs_with_decreased_total_token_count,
760
- num_runs_with_equal_total_token_count,
761
- num_runs_with_increased_total_cost,
762
- num_runs_with_decreased_total_cost,
763
- num_runs_with_equal_total_cost,
764
- ) = result[start_index:end_index]
765
- counts.append(
766
- CompareExperimentRunMetricCounts(
767
- compare_experiment_id=compare_experiment_id,
768
- latency=MetricCounts(
769
- num_increases=num_runs_with_increased_latency,
770
- num_decreases=num_runs_with_decreased_latency,
771
- num_equal=num_runs_with_equal_latency,
772
- ),
773
- prompt_token_count=MetricCounts(
774
- num_increases=num_runs_with_increased_prompt_token_count,
775
- num_decreases=num_runs_with_decreased_prompt_token_count,
776
- num_equal=num_runs_with_equal_prompt_token_count,
777
- ),
778
- completion_token_count=MetricCounts(
779
- num_increases=num_runs_with_increased_completion_token_count,
780
- num_decreases=num_runs_with_decreased_completion_token_count,
781
- num_equal=num_runs_with_equal_completion_token_count,
782
- ),
783
- total_token_count=MetricCounts(
784
- num_increases=num_runs_with_increased_total_token_count,
785
- num_decreases=num_runs_with_decreased_total_token_count,
786
- num_equal=num_runs_with_equal_total_token_count,
787
- ),
788
- total_cost=MetricCounts(
789
- num_increases=num_runs_with_increased_total_cost,
790
- num_decreases=num_runs_with_decreased_total_cost,
791
- num_equal=num_runs_with_equal_total_cost,
792
- ),
793
- )
794
- )
795
- return counts
796
-
797
- @strawberry.field
798
- async def compare_experiment_run_annotation_metric_counts(
799
- self,
800
- info: Info[Context, None],
801
- base_experiment_id: GlobalID,
802
- compare_experiment_ids: list[GlobalID],
803
- ) -> list[CompareExperimentRunAnnotationMetricCounts]:
804
- if base_experiment_id in compare_experiment_ids:
805
- raise BadRequest("Compare experiment IDs cannot contain the base experiment ID")
806
- if not compare_experiment_ids:
807
- raise BadRequest("At least one compare experiment ID must be provided")
808
- if len(set(compare_experiment_ids)) < len(compare_experiment_ids):
809
- raise BadRequest("Compare experiment IDs must be unique")
810
-
811
- try:
812
- base_experiment_rowid = from_global_id_with_expected_type(
813
- base_experiment_id, models.Experiment.__name__
814
- )
815
- except ValueError:
816
- raise BadRequest(f"Invalid base experiment ID: {base_experiment_id}")
817
-
818
- compare_experiment_rowids = []
819
- for compare_experiment_id in compare_experiment_ids:
820
- try:
821
- compare_experiment_rowids.append(
822
- from_global_id_with_expected_type(
823
- compare_experiment_id, models.Experiment.__name__
824
- )
825
- )
826
- except ValueError:
827
- raise BadRequest(f"Invalid compare experiment ID: {compare_experiment_id}")
828
-
829
- base_experiment_runs = (
830
- select(models.ExperimentRun)
831
- .where(
832
- models.ExperimentRun.experiment_id == base_experiment_rowid,
666
+ comparisons_query = (
667
+ select(
668
+ func.count().label("num_base_experiment_runs"),
669
+ _comparison_count_expression(
670
+ base_column=base_experiment_run_latency,
671
+ compare_column=compare_experiment_runs.c.min_latency_ms,
672
+ optimization_direction="minimize",
673
+ comparison_type="improvement",
674
+ ).label("num_latency_improved"),
675
+ _comparison_count_expression(
676
+ base_column=base_experiment_run_latency,
677
+ compare_column=compare_experiment_runs.c.min_latency_ms,
678
+ optimization_direction="minimize",
679
+ comparison_type="regression",
680
+ ).label("num_latency_regressed"),
681
+ _comparison_count_expression(
682
+ base_column=base_experiment_run_latency,
683
+ compare_column=compare_experiment_runs.c.min_latency_ms,
684
+ optimization_direction="minimize",
685
+ comparison_type="equality",
686
+ ).label("num_latency_is_equal"),
687
+ _comparison_count_expression(
688
+ base_column=base_experiment_runs.c.total_tokens,
689
+ compare_column=compare_experiment_runs.c.min_total_tokens,
690
+ optimization_direction="minimize",
691
+ comparison_type="improvement",
692
+ ).label("num_total_token_count_improved"),
693
+ _comparison_count_expression(
694
+ base_column=base_experiment_runs.c.total_tokens,
695
+ compare_column=compare_experiment_runs.c.min_total_tokens,
696
+ optimization_direction="minimize",
697
+ comparison_type="regression",
698
+ ).label("num_total_token_count_regressed"),
699
+ _comparison_count_expression(
700
+ base_column=base_experiment_runs.c.total_tokens,
701
+ compare_column=compare_experiment_runs.c.min_total_tokens,
702
+ optimization_direction="minimize",
703
+ comparison_type="equality",
704
+ ).label("num_total_token_count_is_equal"),
705
+ _comparison_count_expression(
706
+ base_column=base_experiment_runs.c.prompt_tokens,
707
+ compare_column=compare_experiment_runs.c.min_prompt_tokens,
708
+ optimization_direction="minimize",
709
+ comparison_type="improvement",
710
+ ).label("num_prompt_token_count_improved"),
711
+ _comparison_count_expression(
712
+ base_column=base_experiment_runs.c.prompt_tokens,
713
+ compare_column=compare_experiment_runs.c.min_prompt_tokens,
714
+ optimization_direction="minimize",
715
+ comparison_type="regression",
716
+ ).label("num_prompt_token_count_regressed"),
717
+ _comparison_count_expression(
718
+ base_column=base_experiment_runs.c.prompt_tokens,
719
+ compare_column=compare_experiment_runs.c.min_prompt_tokens,
720
+ optimization_direction="minimize",
721
+ comparison_type="equality",
722
+ ).label("num_prompt_token_count_is_equal"),
723
+ _comparison_count_expression(
724
+ base_column=base_experiment_runs.c.completion_tokens,
725
+ compare_column=compare_experiment_runs.c.min_completion_tokens,
726
+ optimization_direction="minimize",
727
+ comparison_type="improvement",
728
+ ).label("num_completion_token_count_improved"),
729
+ _comparison_count_expression(
730
+ base_column=base_experiment_runs.c.completion_tokens,
731
+ compare_column=compare_experiment_runs.c.min_completion_tokens,
732
+ optimization_direction="minimize",
733
+ comparison_type="regression",
734
+ ).label("num_completion_token_count_regressed"),
735
+ _comparison_count_expression(
736
+ base_column=base_experiment_runs.c.completion_tokens,
737
+ compare_column=compare_experiment_runs.c.min_completion_tokens,
738
+ optimization_direction="minimize",
739
+ comparison_type="equality",
740
+ ).label("num_completion_token_count_is_equal"),
741
+ _comparison_count_expression(
742
+ base_column=base_experiment_runs.c.total_cost,
743
+ compare_column=compare_experiment_runs.c.min_total_cost,
744
+ optimization_direction="minimize",
745
+ comparison_type="improvement",
746
+ ).label("num_total_cost_improved"),
747
+ _comparison_count_expression(
748
+ base_column=base_experiment_runs.c.total_cost,
749
+ compare_column=compare_experiment_runs.c.min_total_cost,
750
+ optimization_direction="minimize",
751
+ comparison_type="regression",
752
+ ).label("num_total_cost_regressed"),
753
+ _comparison_count_expression(
754
+ base_column=base_experiment_runs.c.total_cost,
755
+ compare_column=compare_experiment_runs.c.min_total_cost,
756
+ optimization_direction="minimize",
757
+ comparison_type="equality",
758
+ ).label("num_total_cost_is_equal"),
759
+ _comparison_count_expression(
760
+ base_column=base_experiment_runs.c.prompt_cost,
761
+ compare_column=compare_experiment_runs.c.min_prompt_cost,
762
+ optimization_direction="minimize",
763
+ comparison_type="improvement",
764
+ ).label("num_prompt_cost_improved"),
765
+ _comparison_count_expression(
766
+ base_column=base_experiment_runs.c.prompt_cost,
767
+ compare_column=compare_experiment_runs.c.min_prompt_cost,
768
+ optimization_direction="minimize",
769
+ comparison_type="regression",
770
+ ).label("num_prompt_cost_regressed"),
771
+ _comparison_count_expression(
772
+ base_column=base_experiment_runs.c.prompt_cost,
773
+ compare_column=compare_experiment_runs.c.min_prompt_cost,
774
+ optimization_direction="minimize",
775
+ comparison_type="equality",
776
+ ).label("num_prompt_cost_is_equal"),
777
+ _comparison_count_expression(
778
+ base_column=base_experiment_runs.c.completion_cost,
779
+ compare_column=compare_experiment_runs.c.min_completion_cost,
780
+ optimization_direction="minimize",
781
+ comparison_type="improvement",
782
+ ).label("num_completion_cost_improved"),
783
+ _comparison_count_expression(
784
+ base_column=base_experiment_runs.c.completion_cost,
785
+ compare_column=compare_experiment_runs.c.min_completion_cost,
786
+ optimization_direction="minimize",
787
+ comparison_type="regression",
788
+ ).label("num_completion_cost_regressed"),
789
+ _comparison_count_expression(
790
+ base_column=base_experiment_runs.c.completion_cost,
791
+ compare_column=compare_experiment_runs.c.min_completion_cost,
792
+ optimization_direction="minimize",
793
+ comparison_type="equality",
794
+ ).label("num_completion_cost_is_equal"),
833
795
  )
834
- .subquery()
835
- .alias("base_experiment_runs")
836
- )
837
- base_experiment_run_annotations = aliased(
838
- models.ExperimentRunAnnotation, name="base_experiment_run_annotations"
839
- )
840
- query = (
841
- select(base_experiment_run_annotations.name)
842
796
  .select_from(base_experiment_runs)
843
797
  .join(
844
- base_experiment_run_annotations,
845
- onclause=base_experiment_runs.c.id
846
- == base_experiment_run_annotations.experiment_run_id,
798
+ compare_experiment_runs,
799
+ onclause=base_experiment_runs.c.dataset_example_id
800
+ == compare_experiment_runs.c.dataset_example_id,
847
801
  isouter=True,
848
802
  )
849
- .group_by(base_experiment_run_annotations.name)
850
- .order_by(base_experiment_run_annotations.name)
851
803
  )
852
- for compare_experiment_index, compare_experiment_rowid in enumerate(
853
- compare_experiment_rowids
854
- ):
855
- compare_experiment_runs = (
856
- select(models.ExperimentRun)
857
- .where(
858
- models.ExperimentRun.experiment_id == compare_experiment_rowid,
859
- )
860
- .subquery()
861
- .alias(f"comp_exp_{compare_experiment_index}_runs")
862
- )
863
- compare_experiment_run_annotations = aliased(
864
- models.ExperimentRunAnnotation,
865
- name=f"comp_exp_{compare_experiment_index}_run_annotations",
866
- )
867
- query = (
868
- query.add_columns(
869
- _count_rows(
870
- base_experiment_run_annotations.score
871
- < compare_experiment_run_annotations.score,
872
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_increased_score"),
873
- _count_rows(
874
- base_experiment_run_annotations.score
875
- > compare_experiment_run_annotations.score,
876
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_decreased_score"),
877
- _count_rows(
878
- base_experiment_run_annotations.score
879
- == compare_experiment_run_annotations.score,
880
- ).label(f"comp_exp_{compare_experiment_index}_num_runs_equal_score"),
881
- )
882
- .join(
883
- compare_experiment_runs,
884
- onclause=base_experiment_runs.c.dataset_example_id
885
- == compare_experiment_runs.c.dataset_example_id,
886
- isouter=True,
887
- )
888
- .join(
889
- compare_experiment_run_annotations,
890
- onclause=compare_experiment_runs.c.id
891
- == compare_experiment_run_annotations.experiment_run_id,
892
- isouter=True,
893
- )
894
- .where(
895
- base_experiment_run_annotations.name == compare_experiment_run_annotations.name
896
- )
897
- )
804
+
898
805
  async with info.context.db() as session:
899
- result = (await session.execute(query)).all()
806
+ result = (await session.execute(comparisons_query)).first()
900
807
  assert result is not None
901
- num_columns_per_compare_experiment = (len(query.columns) - 1) // len(compare_experiment_ids)
902
- metric_counts = []
903
- for record in result:
904
- annotation_name, *counts = record
905
- for compare_experiment_index, compare_experiment_id in enumerate(
906
- compare_experiment_ids
907
- ):
908
- start_index = compare_experiment_index * num_columns_per_compare_experiment
909
- end_index = start_index + num_columns_per_compare_experiment
910
- (
911
- num_runs_with_increased_score,
912
- num_runs_with_decreased_score,
913
- num_runs_with_equal_score,
914
- ) = counts[start_index:end_index]
915
- metric_counts.append(
916
- CompareExperimentRunAnnotationMetricCounts(
917
- annotation_name=annotation_name,
918
- compare_experiment_id=compare_experiment_id,
919
- num_increases=num_runs_with_increased_score,
920
- num_decreases=num_runs_with_decreased_score,
921
- num_equal=num_runs_with_equal_score,
922
- )
923
- )
924
- return metric_counts
808
+
809
+ return ExperimentRunMetricComparisons(
810
+ latency=ExperimentRunMetricComparison(
811
+ num_runs_improved=result.num_latency_improved,
812
+ num_runs_regressed=result.num_latency_regressed,
813
+ num_runs_equal=result.num_latency_is_equal,
814
+ num_total_runs=result.num_base_experiment_runs,
815
+ ),
816
+ total_token_count=ExperimentRunMetricComparison(
817
+ num_runs_improved=result.num_total_token_count_improved,
818
+ num_runs_regressed=result.num_total_token_count_regressed,
819
+ num_runs_equal=result.num_total_token_count_is_equal,
820
+ num_total_runs=result.num_base_experiment_runs,
821
+ ),
822
+ prompt_token_count=ExperimentRunMetricComparison(
823
+ num_runs_improved=result.num_prompt_token_count_improved,
824
+ num_runs_regressed=result.num_prompt_token_count_regressed,
825
+ num_runs_equal=result.num_prompt_token_count_is_equal,
826
+ num_total_runs=result.num_base_experiment_runs,
827
+ ),
828
+ completion_token_count=ExperimentRunMetricComparison(
829
+ num_runs_improved=result.num_completion_token_count_improved,
830
+ num_runs_regressed=result.num_completion_token_count_regressed,
831
+ num_runs_equal=result.num_completion_token_count_is_equal,
832
+ num_total_runs=result.num_base_experiment_runs,
833
+ ),
834
+ total_cost=ExperimentRunMetricComparison(
835
+ num_runs_improved=result.num_total_cost_improved,
836
+ num_runs_regressed=result.num_total_cost_regressed,
837
+ num_runs_equal=result.num_total_cost_is_equal,
838
+ num_total_runs=result.num_base_experiment_runs,
839
+ ),
840
+ prompt_cost=ExperimentRunMetricComparison(
841
+ num_runs_improved=result.num_prompt_cost_improved,
842
+ num_runs_regressed=result.num_prompt_cost_regressed,
843
+ num_runs_equal=result.num_prompt_cost_is_equal,
844
+ num_total_runs=result.num_base_experiment_runs,
845
+ ),
846
+ completion_cost=ExperimentRunMetricComparison(
847
+ num_runs_improved=result.num_completion_cost_improved,
848
+ num_runs_regressed=result.num_completion_cost_regressed,
849
+ num_runs_equal=result.num_completion_cost_is_equal,
850
+ num_total_runs=result.num_base_experiment_runs,
851
+ ),
852
+ )
925
853
 
926
854
  @strawberry.field
927
855
  async def validate_experiment_run_filter_condition(
@@ -959,136 +887,51 @@ class Query:
959
887
  return InferenceModel()
960
888
 
961
889
  @strawberry.field
962
- async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
963
- type_name, node_id = from_global_id(id)
890
+ async def node(self, id: strawberry.ID, info: Info[Context, None]) -> Node:
891
+ if not is_global_id(id):
892
+ try:
893
+ experiment_rowid, dataset_example_rowid = (
894
+ parse_experiment_repeated_run_group_node_id(id)
895
+ )
896
+ except Exception:
897
+ raise NotFound(f"Unknown node: {id}")
898
+ return ExperimentRepeatedRunGroup(
899
+ experiment_rowid=experiment_rowid,
900
+ dataset_example_rowid=dataset_example_rowid,
901
+ )
902
+
903
+ global_id = GlobalID.from_id(id)
904
+ type_name, node_id = from_global_id(global_id)
964
905
  if type_name == "Dimension":
965
906
  dimension = info.context.model.scalar_dimensions[node_id]
966
907
  return to_gql_dimension(node_id, dimension)
967
908
  elif type_name == "EmbeddingDimension":
968
909
  embedding_dimension = info.context.model.embedding_dimensions[node_id]
969
910
  return to_gql_embedding_dimension(node_id, embedding_dimension)
970
- elif type_name == "Project":
971
- project_stmt = select(models.Project).filter_by(id=node_id)
972
- async with info.context.db() as session:
973
- project = await session.scalar(project_stmt)
974
- if project is None:
975
- raise NotFound(f"Unknown project: {id}")
976
- return Project(
977
- project_rowid=project.id,
978
- db_project=project,
979
- )
980
- elif type_name == "Trace":
981
- trace_stmt = select(models.Trace).filter_by(id=node_id)
982
- async with info.context.db() as session:
983
- trace = await session.scalar(trace_stmt)
984
- if trace is None:
985
- raise NotFound(f"Unknown trace: {id}")
986
- return Trace(trace_rowid=trace.id, db_trace=trace)
911
+ elif type_name == Project.__name__:
912
+ return Project(id=node_id)
913
+ elif type_name == Trace.__name__:
914
+ return Trace(id=node_id)
987
915
  elif type_name == Span.__name__:
988
- span_stmt = (
989
- select(models.Span)
990
- .options(
991
- joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
992
- )
993
- .where(models.Span.id == node_id)
994
- )
995
- async with info.context.db() as session:
996
- span = await session.scalar(span_stmt)
997
- if span is None:
998
- raise NotFound(f"Unknown span: {id}")
999
- return Span(span_rowid=span.id, db_span=span)
916
+ return Span(id=node_id)
1000
917
  elif type_name == Dataset.__name__:
1001
- dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
1002
- async with info.context.db() as session:
1003
- if (dataset := await session.scalar(dataset_stmt)) is None:
1004
- raise NotFound(f"Unknown dataset: {id}")
1005
- return to_gql_dataset(dataset)
918
+ return Dataset(id=node_id)
1006
919
  elif type_name == DatasetExample.__name__:
1007
- example_id = node_id
1008
- latest_revision_id = (
1009
- select(func.max(models.DatasetExampleRevision.id))
1010
- .where(models.DatasetExampleRevision.dataset_example_id == example_id)
1011
- .scalar_subquery()
1012
- )
1013
- async with info.context.db() as session:
1014
- example = await session.scalar(
1015
- select(models.DatasetExample)
1016
- .join(
1017
- models.DatasetExampleRevision,
1018
- onclause=models.DatasetExampleRevision.dataset_example_id
1019
- == models.DatasetExample.id,
1020
- )
1021
- .where(
1022
- and_(
1023
- models.DatasetExample.id == example_id,
1024
- models.DatasetExampleRevision.id == latest_revision_id,
1025
- models.DatasetExampleRevision.revision_kind != "DELETE",
1026
- )
1027
- )
1028
- )
1029
- if not example:
1030
- raise NotFound(f"Unknown dataset example: {id}")
1031
- return DatasetExample(
1032
- id_attr=example.id,
1033
- created_at=example.created_at,
1034
- )
920
+ return DatasetExample(id=node_id)
921
+ elif type_name == DatasetSplit.__name__:
922
+ return DatasetSplit(id=node_id)
1035
923
  elif type_name == Experiment.__name__:
1036
- async with info.context.db() as session:
1037
- experiment = await session.scalar(
1038
- select(models.Experiment).where(models.Experiment.id == node_id)
1039
- )
1040
- if not experiment:
1041
- raise NotFound(f"Unknown experiment: {id}")
1042
- return Experiment(
1043
- id_attr=experiment.id,
1044
- name=experiment.name,
1045
- project_name=experiment.project_name,
1046
- description=experiment.description,
1047
- created_at=experiment.created_at,
1048
- updated_at=experiment.updated_at,
1049
- metadata=experiment.metadata_,
1050
- )
924
+ return Experiment(id=node_id)
1051
925
  elif type_name == ExperimentRun.__name__:
1052
- async with info.context.db() as session:
1053
- if not (
1054
- run := await session.scalar(
1055
- select(models.ExperimentRun)
1056
- .where(models.ExperimentRun.id == node_id)
1057
- .options(
1058
- joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
1059
- )
1060
- )
1061
- ):
1062
- raise NotFound(f"Unknown experiment run: {id}")
1063
- return to_gql_experiment_run(run)
926
+ return ExperimentRun(id=node_id)
1064
927
  elif type_name == User.__name__:
1065
928
  if int((user := info.context.user).identity) != node_id and not user.is_admin:
1066
929
  raise Unauthorized(MSG_ADMIN_ONLY)
1067
- async with info.context.db() as session:
1068
- if not (
1069
- user := await session.scalar(
1070
- select(models.User).where(models.User.id == node_id)
1071
- )
1072
- ):
1073
- raise NotFound(f"Unknown user: {id}")
1074
- return to_gql_user(user)
930
+ return User(id=node_id)
1075
931
  elif type_name == ProjectSession.__name__:
1076
- async with info.context.db() as session:
1077
- if not (
1078
- project_session := await session.scalar(
1079
- select(models.ProjectSession).filter_by(id=node_id)
1080
- )
1081
- ):
1082
- raise NotFound(f"Unknown user: {id}")
1083
- return to_gql_project_session(project_session)
932
+ return ProjectSession(id=node_id)
1084
933
  elif type_name == Prompt.__name__:
1085
- async with info.context.db() as session:
1086
- if orm_prompt := await session.scalar(
1087
- select(models.Prompt).where(models.Prompt.id == node_id)
1088
- ):
1089
- return to_gql_prompt_from_orm(orm_prompt)
1090
- else:
1091
- raise NotFound(f"Unknown prompt: {id}")
934
+ return Prompt(id=node_id)
1092
935
  elif type_name == PromptVersion.__name__:
1093
936
  async with info.context.db() as session:
1094
937
  if orm_prompt_version := await session.scalar(
@@ -1098,51 +941,17 @@ class Query:
1098
941
  else:
1099
942
  raise NotFound(f"Unknown prompt version: {id}")
1100
943
  elif type_name == PromptLabel.__name__:
1101
- async with info.context.db() as session:
1102
- if not (
1103
- prompt_label := await session.scalar(
1104
- select(models.PromptLabel).where(models.PromptLabel.id == node_id)
1105
- )
1106
- ):
1107
- raise NotFound(f"Unknown prompt label: {id}")
1108
- return to_gql_prompt_label(prompt_label)
944
+ return PromptLabel(id=node_id)
1109
945
  elif type_name == PromptVersionTag.__name__:
1110
- async with info.context.db() as session:
1111
- if not (prompt_version_tag := await session.get(models.PromptVersionTag, node_id)):
1112
- raise NotFound(f"Unknown prompt version tag: {id}")
1113
- return to_gql_prompt_version_tag(prompt_version_tag)
946
+ return PromptVersionTag(id=node_id)
1114
947
  elif type_name == ProjectTraceRetentionPolicy.__name__:
1115
- async with info.context.db() as session:
1116
- db_policy = await session.scalar(
1117
- select(models.ProjectTraceRetentionPolicy).filter_by(id=node_id)
1118
- )
1119
- if not db_policy:
1120
- raise NotFound(f"Unknown project trace retention policy: {id}")
1121
- return ProjectTraceRetentionPolicy(id=db_policy.id, db_policy=db_policy)
948
+ return ProjectTraceRetentionPolicy(id=node_id)
1122
949
  elif type_name == SpanAnnotation.__name__:
1123
- async with info.context.db() as session:
1124
- span_annotation = await session.get(models.SpanAnnotation, node_id)
1125
- if not span_annotation:
1126
- raise NotFound(f"Unknown span annotation: {id}")
1127
- return to_gql_span_annotation(span_annotation)
950
+ return SpanAnnotation(id=node_id)
1128
951
  elif type_name == TraceAnnotation.__name__:
1129
- async with info.context.db() as session:
1130
- trace_annotation = await session.get(models.TraceAnnotation, node_id)
1131
- if not trace_annotation:
1132
- raise NotFound(f"Unknown trace annotation: {id}")
1133
- return to_gql_trace_annotation(trace_annotation)
952
+ return TraceAnnotation(id=node_id)
1134
953
  elif type_name == GenerativeModel.__name__:
1135
- async with info.context.db() as session:
1136
- stmt = (
1137
- select(models.GenerativeModel)
1138
- .where(models.GenerativeModel.deleted_at.is_(None))
1139
- .where(models.GenerativeModel.id == node_id)
1140
- .options(joinedload(models.GenerativeModel.token_prices))
1141
- )
1142
- model = await session.scalar(stmt)
1143
- if not model:
1144
- raise NotFound(f"Unknown model: {id}")
1145
- return to_gql_generative_model(model)
954
+ return GenerativeModel(id=node_id)
1146
955
  raise NotFound(f"Unknown node type: {type_name}")
1147
956
 
1148
957
  @strawberry.field
@@ -1154,16 +963,7 @@ class Query:
1154
963
  return None
1155
964
  if isinstance(user, UnauthenticatedUser):
1156
965
  return None
1157
- async with info.context.db() as session:
1158
- if (
1159
- user := await session.scalar(
1160
- select(models.User)
1161
- .where(models.User.id == int(user.identity))
1162
- .options(joinedload(models.User.role))
1163
- )
1164
- ) is None:
1165
- return None
1166
- return to_gql_user(user)
966
+ return User(id=int(user.identity))
1167
967
 
1168
968
  @strawberry.field
1169
969
  async def prompts(
@@ -1174,6 +974,7 @@ class Query:
1174
974
  after: Optional[CursorString] = UNSET,
1175
975
  before: Optional[CursorString] = UNSET,
1176
976
  filter: Optional[PromptFilter] = UNSET,
977
+ labelIds: Optional[list[GlobalID]] = UNSET,
1177
978
  ) -> Connection[Prompt]:
1178
979
  args = ConnectionArgs(
1179
980
  first=first,
@@ -1190,9 +991,21 @@ class Query:
1190
991
  stmt = stmt.where(column.ilike(f"%{filter.value}%")).order_by(
1191
992
  models.Prompt.updated_at.desc()
1192
993
  )
994
+ if labelIds:
995
+ stmt = stmt.join(models.PromptPromptLabel).where(
996
+ models.PromptPromptLabel.prompt_label_id.in_(
997
+ from_global_id_with_expected_type(
998
+ global_id=label_id, expected_type_name="PromptLabel"
999
+ )
1000
+ for label_id in labelIds
1001
+ )
1002
+ )
1003
+ stmt = stmt.distinct()
1193
1004
  async with info.context.db() as session:
1194
1005
  orm_prompts = await session.stream_scalars(stmt)
1195
- data = [to_gql_prompt_from_orm(orm_prompt) async for orm_prompt in orm_prompts]
1006
+ data = [
1007
+ Prompt(id=orm_prompt.id, db_record=orm_prompt) async for orm_prompt in orm_prompts
1008
+ ]
1196
1009
  return connection_from_list(
1197
1010
  data=data,
1198
1011
  args=args,
@@ -1215,7 +1028,58 @@ class Query:
1215
1028
  )
1216
1029
  async with info.context.db() as session:
1217
1030
  prompt_labels = await session.stream_scalars(select(models.PromptLabel))
1218
- data = [to_gql_prompt_label(prompt_label) async for prompt_label in prompt_labels]
1031
+ data = [
1032
+ PromptLabel(id=prompt_label.id, db_record=prompt_label)
1033
+ async for prompt_label in prompt_labels
1034
+ ]
1035
+ return connection_from_list(
1036
+ data=data,
1037
+ args=args,
1038
+ )
1039
+
1040
+ @strawberry.field
1041
+ async def dataset_labels(
1042
+ self,
1043
+ info: Info[Context, None],
1044
+ first: Optional[int] = 50,
1045
+ last: Optional[int] = UNSET,
1046
+ after: Optional[CursorString] = UNSET,
1047
+ before: Optional[CursorString] = UNSET,
1048
+ ) -> Connection[DatasetLabel]:
1049
+ args = ConnectionArgs(
1050
+ first=first,
1051
+ after=after if isinstance(after, CursorString) else None,
1052
+ last=last,
1053
+ before=before if isinstance(before, CursorString) else None,
1054
+ )
1055
+ async with info.context.db() as session:
1056
+ dataset_labels = await session.scalars(
1057
+ select(models.DatasetLabel).order_by(models.DatasetLabel.name.asc())
1058
+ )
1059
+ data = [
1060
+ DatasetLabel(id=dataset_label.id, db_record=dataset_label)
1061
+ for dataset_label in dataset_labels
1062
+ ]
1063
+ return connection_from_list(data=data, args=args)
1064
+
1065
+ @strawberry.field
1066
+ async def dataset_splits(
1067
+ self,
1068
+ info: Info[Context, None],
1069
+ first: Optional[int] = 50,
1070
+ last: Optional[int] = UNSET,
1071
+ after: Optional[CursorString] = UNSET,
1072
+ before: Optional[CursorString] = UNSET,
1073
+ ) -> Connection[DatasetSplit]:
1074
+ args = ConnectionArgs(
1075
+ first=first,
1076
+ after=after if isinstance(after, CursorString) else None,
1077
+ last=last,
1078
+ before=before if isinstance(before, CursorString) else None,
1079
+ )
1080
+ async with info.context.db() as session:
1081
+ splits = await session.stream_scalars(select(models.DatasetSplit))
1082
+ data = [DatasetSplit(id=split.id, db_record=split) async for split in splits]
1219
1083
  return connection_from_list(
1220
1084
  data=data,
1221
1085
  args=args,
@@ -1486,7 +1350,7 @@ class Query:
1486
1350
  async with info.context.db() as session:
1487
1351
  span_rowid = await session.scalar(stmt)
1488
1352
  if span_rowid:
1489
- return Span(span_rowid=span_rowid)
1353
+ return Span(id=span_rowid)
1490
1354
  return None
1491
1355
 
1492
1356
  @strawberry.field
@@ -1499,7 +1363,7 @@ class Query:
1499
1363
  async with info.context.db() as session:
1500
1364
  trace_rowid = await session.scalar(stmt)
1501
1365
  if trace_rowid:
1502
- return Trace(trace_rowid=trace_rowid)
1366
+ return Trace(id=trace_rowid)
1503
1367
  return None
1504
1368
 
1505
1369
  @strawberry.field
@@ -1512,7 +1376,7 @@ class Query:
1512
1376
  async with info.context.db() as session:
1513
1377
  session_row = await session.scalar(stmt)
1514
1378
  if session_row:
1515
- return to_gql_project_session(session_row)
1379
+ return ProjectSession(id=session_row.id, db_record=session_row)
1516
1380
  return None
1517
1381
 
1518
1382
 
@@ -1550,16 +1414,36 @@ def _longest_matching_prefix(s: str, prefixes: Iterable[str]) -> str:
1550
1414
  return longest
1551
1415
 
1552
1416
 
1553
- def _count_rows(
1554
- condition: ColumnElement[Any],
1555
- ) -> ColumnElement[Any]:
1417
+ def _comparison_count_expression(
1418
+ *,
1419
+ base_column: ColumnElement[Any],
1420
+ compare_column: ColumnElement[Any],
1421
+ optimization_direction: Literal["maximize", "minimize"],
1422
+ comparison_type: Literal["improvement", "regression", "equality"],
1423
+ ) -> ColumnElement[int]:
1556
1424
  """
1557
- Returns an expression that counts the number of rows satisfying the condition.
1425
+ Given a base and compare column, returns an expression counting the number of
1426
+ improvements, regressions, or equalities given the optimization direction.
1558
1427
  """
1428
+ if optimization_direction == "maximize":
1429
+ raise NotImplementedError
1430
+
1431
+ if comparison_type == "improvement":
1432
+ condition = compare_column > base_column
1433
+ elif comparison_type == "regression":
1434
+ condition = compare_column < base_column
1435
+ elif comparison_type == "equality":
1436
+ condition = compare_column == base_column
1437
+ else:
1438
+ assert_never(comparison_type)
1439
+
1559
1440
  return func.coalesce(
1560
1441
  func.sum(
1561
1442
  case(
1562
- (condition, 1),
1443
+ (
1444
+ condition,
1445
+ 1,
1446
+ ),
1563
1447
  else_=0,
1564
1448
  )
1565
1449
  ),