arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
phoenix/server/api/queries.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from collections import defaultdict
|
|
3
3
|
from datetime import datetime
|
|
4
|
-
from typing import Any, Iterable, Iterator, Optional, Union
|
|
4
|
+
from typing import Any, Iterable, Iterator, Literal, Optional, Union
|
|
5
5
|
from typing import cast as type_cast
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import numpy.typing as npt
|
|
9
9
|
import strawberry
|
|
10
10
|
from sqlalchemy import ColumnElement, String, and_, case, cast, func, select, text
|
|
11
|
-
from sqlalchemy.orm import
|
|
11
|
+
from sqlalchemy.orm import joinedload, load_only
|
|
12
12
|
from starlette.authentication import UnauthenticatedUser
|
|
13
13
|
from strawberry import ID, UNSET
|
|
14
14
|
from strawberry.relay import Connection, GlobalID, Node
|
|
@@ -22,7 +22,10 @@ from phoenix.config import (
|
|
|
22
22
|
)
|
|
23
23
|
from phoenix.db import models
|
|
24
24
|
from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
|
|
25
|
-
from phoenix.db.helpers import
|
|
25
|
+
from phoenix.db.helpers import (
|
|
26
|
+
SupportedSQLDialect,
|
|
27
|
+
exclude_experiment_projects,
|
|
28
|
+
)
|
|
26
29
|
from phoenix.db.models import LatencyMs
|
|
27
30
|
from phoenix.pointcloud.clustering import Hdbscan
|
|
28
31
|
from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
|
|
@@ -46,8 +49,10 @@ from phoenix.server.api.input_types.ProjectSort import ProjectColumn, ProjectSor
|
|
|
46
49
|
from phoenix.server.api.input_types.PromptFilter import PromptFilter
|
|
47
50
|
from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_annotation_config
|
|
48
51
|
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
49
|
-
from phoenix.server.api.types.Dataset import Dataset
|
|
52
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
50
53
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
54
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel
|
|
55
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
51
56
|
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
52
57
|
from phoenix.server.api.types.EmbeddingDimension import (
|
|
53
58
|
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
@@ -57,14 +62,24 @@ from phoenix.server.api.types.EmbeddingDimension import (
|
|
|
57
62
|
)
|
|
58
63
|
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
59
64
|
from phoenix.server.api.types.Experiment import Experiment
|
|
60
|
-
from phoenix.server.api.types.ExperimentComparison import
|
|
61
|
-
|
|
65
|
+
from phoenix.server.api.types.ExperimentComparison import (
|
|
66
|
+
ExperimentComparison,
|
|
67
|
+
)
|
|
68
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
69
|
+
ExperimentRepeatedRunGroup,
|
|
70
|
+
parse_experiment_repeated_run_group_node_id,
|
|
71
|
+
)
|
|
72
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
62
73
|
from phoenix.server.api.types.Functionality import Functionality
|
|
63
|
-
from phoenix.server.api.types.GenerativeModel import GenerativeModel
|
|
74
|
+
from phoenix.server.api.types.GenerativeModel import GenerativeModel
|
|
64
75
|
from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
|
|
65
76
|
from phoenix.server.api.types.InferenceModel import InferenceModel
|
|
66
77
|
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
67
|
-
from phoenix.server.api.types.node import
|
|
78
|
+
from phoenix.server.api.types.node import (
|
|
79
|
+
from_global_id,
|
|
80
|
+
from_global_id_with_expected_type,
|
|
81
|
+
is_global_id,
|
|
82
|
+
)
|
|
68
83
|
from phoenix.server.api.types.pagination import (
|
|
69
84
|
ConnectionArgs,
|
|
70
85
|
Cursor,
|
|
@@ -74,21 +89,21 @@ from phoenix.server.api.types.pagination import (
|
|
|
74
89
|
)
|
|
75
90
|
from phoenix.server.api.types.PlaygroundModel import PlaygroundModel
|
|
76
91
|
from phoenix.server.api.types.Project import Project
|
|
77
|
-
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
92
|
+
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
78
93
|
from phoenix.server.api.types.ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
|
|
79
|
-
from phoenix.server.api.types.Prompt import Prompt
|
|
80
|
-
from phoenix.server.api.types.PromptLabel import PromptLabel
|
|
94
|
+
from phoenix.server.api.types.Prompt import Prompt
|
|
95
|
+
from phoenix.server.api.types.PromptLabel import PromptLabel
|
|
81
96
|
from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
|
|
82
|
-
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
97
|
+
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
83
98
|
from phoenix.server.api.types.ServerStatus import ServerStatus
|
|
84
99
|
from phoenix.server.api.types.SortDir import SortDir
|
|
85
100
|
from phoenix.server.api.types.Span import Span
|
|
86
|
-
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
|
|
101
|
+
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
|
|
87
102
|
from phoenix.server.api.types.SystemApiKey import SystemApiKey
|
|
88
103
|
from phoenix.server.api.types.Trace import Trace
|
|
89
|
-
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
|
|
90
|
-
from phoenix.server.api.types.User import User
|
|
91
|
-
from phoenix.server.api.types.UserApiKey import UserApiKey
|
|
104
|
+
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
|
|
105
|
+
from phoenix.server.api.types.User import User
|
|
106
|
+
from phoenix.server.api.types.UserApiKey import UserApiKey
|
|
92
107
|
from phoenix.server.api.types.UserRole import UserRole
|
|
93
108
|
from phoenix.server.api.types.ValidationResult import ValidationResult
|
|
94
109
|
|
|
@@ -108,29 +123,52 @@ class DbTableStats:
|
|
|
108
123
|
|
|
109
124
|
|
|
110
125
|
@strawberry.type
|
|
111
|
-
class
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
126
|
+
class ExperimentRunMetricComparison:
|
|
127
|
+
num_runs_improved: int = strawberry.field(
|
|
128
|
+
description=(
|
|
129
|
+
"The number of runs in which the base experiment improved "
|
|
130
|
+
"on the best run in any compare experiment."
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
num_runs_regressed: int = strawberry.field(
|
|
134
|
+
description=(
|
|
135
|
+
"The number of runs in which the base experiment regressed "
|
|
136
|
+
"on the best run in any compare experiment."
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
num_runs_equal: int = strawberry.field(
|
|
140
|
+
description=(
|
|
141
|
+
"The number of runs in which the base experiment is equal to the best run "
|
|
142
|
+
"in any compare experiment."
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
num_total_runs: strawberry.Private[int]
|
|
116
146
|
|
|
117
|
-
@strawberry.
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
147
|
+
@strawberry.field(
|
|
148
|
+
description=(
|
|
149
|
+
"The number of runs in the base experiment that could not be compared, either because "
|
|
150
|
+
"the base experiment run was missing a value or because all compare experiment runs "
|
|
151
|
+
"were missing values."
|
|
152
|
+
)
|
|
153
|
+
) # type: ignore[misc]
|
|
154
|
+
def num_runs_without_comparison(self) -> int:
|
|
155
|
+
return (
|
|
156
|
+
self.num_total_runs
|
|
157
|
+
- self.num_runs_improved
|
|
158
|
+
- self.num_runs_regressed
|
|
159
|
+
- self.num_runs_equal
|
|
160
|
+
)
|
|
125
161
|
|
|
126
162
|
|
|
127
163
|
@strawberry.type
|
|
128
|
-
class
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
164
|
+
class ExperimentRunMetricComparisons:
|
|
165
|
+
latency: ExperimentRunMetricComparison
|
|
166
|
+
total_token_count: ExperimentRunMetricComparison
|
|
167
|
+
prompt_token_count: ExperimentRunMetricComparison
|
|
168
|
+
completion_token_count: ExperimentRunMetricComparison
|
|
169
|
+
total_cost: ExperimentRunMetricComparison
|
|
170
|
+
prompt_cost: ExperimentRunMetricComparison
|
|
171
|
+
completion_cost: ExperimentRunMetricComparison
|
|
134
172
|
|
|
135
173
|
|
|
136
174
|
@strawberry.type
|
|
@@ -150,7 +188,17 @@ class Query:
|
|
|
150
188
|
async def generative_models(
|
|
151
189
|
self,
|
|
152
190
|
info: Info[Context, None],
|
|
153
|
-
|
|
191
|
+
first: Optional[int] = 50,
|
|
192
|
+
last: Optional[int] = UNSET,
|
|
193
|
+
after: Optional[CursorString] = UNSET,
|
|
194
|
+
before: Optional[CursorString] = UNSET,
|
|
195
|
+
) -> Connection[GenerativeModel]:
|
|
196
|
+
args = ConnectionArgs(
|
|
197
|
+
first=first,
|
|
198
|
+
after=after if isinstance(after, CursorString) else None,
|
|
199
|
+
last=last,
|
|
200
|
+
before=before if isinstance(before, CursorString) else None,
|
|
201
|
+
)
|
|
154
202
|
async with info.context.db() as session:
|
|
155
203
|
result = await session.scalars(
|
|
156
204
|
select(models.GenerativeModel)
|
|
@@ -160,17 +208,16 @@ class Query:
|
|
|
160
208
|
models.GenerativeModel.provider.nullslast(),
|
|
161
209
|
models.GenerativeModel.name,
|
|
162
210
|
)
|
|
163
|
-
.options(joinedload(models.GenerativeModel.token_prices))
|
|
164
211
|
)
|
|
165
|
-
|
|
166
|
-
return
|
|
212
|
+
data = [GenerativeModel(id=model.id, db_record=model) for model in result.unique()]
|
|
213
|
+
return connection_from_list(data=data, args=args)
|
|
167
214
|
|
|
168
215
|
@strawberry.field
|
|
169
216
|
async def playground_models(self, input: Optional[ModelsInput] = None) -> list[PlaygroundModel]:
|
|
170
217
|
if input is not None and input.provider_key is not None:
|
|
171
218
|
supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
|
|
172
219
|
supported_models = [
|
|
173
|
-
PlaygroundModel(
|
|
220
|
+
PlaygroundModel(name_value=model_name, provider_key_value=input.provider_key)
|
|
174
221
|
for model_name in supported_model_names
|
|
175
222
|
]
|
|
176
223
|
return supported_models
|
|
@@ -179,7 +226,9 @@ class Query:
|
|
|
179
226
|
all_models: list[PlaygroundModel] = []
|
|
180
227
|
for provider_key, model_name in registered_models:
|
|
181
228
|
if model_name is not None and provider_key is not None:
|
|
182
|
-
all_models.append(
|
|
229
|
+
all_models.append(
|
|
230
|
+
PlaygroundModel(name_value=model_name, provider_key_value=provider_key)
|
|
231
|
+
)
|
|
183
232
|
return all_models
|
|
184
233
|
|
|
185
234
|
@strawberry.field
|
|
@@ -223,7 +272,7 @@ class Query:
|
|
|
223
272
|
)
|
|
224
273
|
async with info.context.db() as session:
|
|
225
274
|
users = await session.stream_scalars(stmt)
|
|
226
|
-
data = [
|
|
275
|
+
data = [User(id=user.id, db_record=user) async for user in users]
|
|
227
276
|
return connection_from_list(data=data, args=args)
|
|
228
277
|
|
|
229
278
|
@strawberry.field
|
|
@@ -253,7 +302,7 @@ class Query:
|
|
|
253
302
|
)
|
|
254
303
|
async with info.context.db() as session:
|
|
255
304
|
api_keys = await session.scalars(stmt)
|
|
256
|
-
return [
|
|
305
|
+
return [UserApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
|
|
257
306
|
|
|
258
307
|
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
|
|
259
308
|
async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
|
|
@@ -265,16 +314,7 @@ class Query:
|
|
|
265
314
|
)
|
|
266
315
|
async with info.context.db() as session:
|
|
267
316
|
api_keys = await session.scalars(stmt)
|
|
268
|
-
return [
|
|
269
|
-
SystemApiKey(
|
|
270
|
-
id_attr=api_key.id,
|
|
271
|
-
name=api_key.name,
|
|
272
|
-
description=api_key.description,
|
|
273
|
-
created_at=api_key.created_at,
|
|
274
|
-
expires_at=api_key.expires_at,
|
|
275
|
-
)
|
|
276
|
-
for api_key in api_keys
|
|
277
|
-
]
|
|
317
|
+
return [SystemApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
|
|
278
318
|
|
|
279
319
|
@strawberry.field
|
|
280
320
|
async def projects(
|
|
@@ -315,13 +355,7 @@ class Query:
|
|
|
315
355
|
stmt = exclude_experiment_projects(stmt)
|
|
316
356
|
async with info.context.db() as session:
|
|
317
357
|
projects = await session.stream_scalars(stmt)
|
|
318
|
-
data = [
|
|
319
|
-
Project(
|
|
320
|
-
project_rowid=project.id,
|
|
321
|
-
db_project=project,
|
|
322
|
-
)
|
|
323
|
-
async for project in projects
|
|
324
|
-
]
|
|
358
|
+
data = [Project(id=project.id, db_record=project) async for project in projects]
|
|
325
359
|
return connection_from_list(data=data, args=args)
|
|
326
360
|
|
|
327
361
|
@strawberry.field
|
|
@@ -350,11 +384,39 @@ class Query:
|
|
|
350
384
|
sort_col = getattr(models.Dataset, sort.col.value)
|
|
351
385
|
stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
|
|
352
386
|
if filter:
|
|
353
|
-
|
|
387
|
+
# Apply name filter
|
|
388
|
+
if filter.col and filter.value:
|
|
389
|
+
stmt = stmt.where(
|
|
390
|
+
getattr(models.Dataset, filter.col.value).ilike(f"%{filter.value}%")
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Apply label filter
|
|
394
|
+
if filter.filter_labels and filter.filter_labels is not UNSET:
|
|
395
|
+
label_rowids = []
|
|
396
|
+
for label_id in filter.filter_labels:
|
|
397
|
+
try:
|
|
398
|
+
label_rowid = from_global_id_with_expected_type(
|
|
399
|
+
global_id=GlobalID.from_id(label_id),
|
|
400
|
+
expected_type_name="DatasetLabel",
|
|
401
|
+
)
|
|
402
|
+
label_rowids.append(label_rowid)
|
|
403
|
+
except ValueError:
|
|
404
|
+
continue # Skip invalid label IDs
|
|
405
|
+
|
|
406
|
+
if label_rowids:
|
|
407
|
+
# Join with the junction table to filter by labels
|
|
408
|
+
stmt = (
|
|
409
|
+
stmt.join(
|
|
410
|
+
models.DatasetsDatasetLabel,
|
|
411
|
+
models.Dataset.id == models.DatasetsDatasetLabel.dataset_id,
|
|
412
|
+
)
|
|
413
|
+
.where(models.DatasetsDatasetLabel.dataset_label_id.in_(label_rowids))
|
|
414
|
+
.distinct()
|
|
415
|
+
)
|
|
354
416
|
async with info.context.db() as session:
|
|
355
417
|
datasets = await session.scalars(stmt)
|
|
356
418
|
return connection_from_list(
|
|
357
|
-
data=[
|
|
419
|
+
data=[Dataset(id=dataset.id, db_record=dataset) for dataset in datasets], args=args
|
|
358
420
|
)
|
|
359
421
|
|
|
360
422
|
@strawberry.field
|
|
@@ -413,6 +475,7 @@ class Query:
|
|
|
413
475
|
)
|
|
414
476
|
)
|
|
415
477
|
).all()
|
|
478
|
+
|
|
416
479
|
if not experiments or len(experiments) < len(experiment_rowids):
|
|
417
480
|
raise NotFound("Unable to resolve one or more experiment IDs.")
|
|
418
481
|
num_datasets = len(set(experiment.dataset_id for experiment in experiments))
|
|
@@ -421,37 +484,19 @@ class Query:
|
|
|
421
484
|
base_experiment = next(
|
|
422
485
|
experiment for experiment in experiments if experiment.id == base_experiment_rowid
|
|
423
486
|
)
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
)
|
|
430
|
-
.where(
|
|
431
|
-
and_(
|
|
432
|
-
models.DatasetExampleRevision.dataset_version_id
|
|
433
|
-
<= base_experiment.dataset_version_id,
|
|
434
|
-
models.DatasetExample.dataset_id == base_experiment.dataset_id,
|
|
435
|
-
)
|
|
436
|
-
)
|
|
437
|
-
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
438
|
-
.scalar_subquery()
|
|
439
|
-
)
|
|
487
|
+
|
|
488
|
+
# Use ExperimentDatasetExample to pull down examples.
|
|
489
|
+
# Splits are mutable and should not be used for comparison.
|
|
490
|
+
# The comparison should only occur against examples which were assigned to the same
|
|
491
|
+
# splits at the time of execution of the ExperimentRun.
|
|
440
492
|
examples_query = (
|
|
441
493
|
select(models.DatasetExample)
|
|
442
|
-
.
|
|
443
|
-
.
|
|
444
|
-
models.DatasetExampleRevision,
|
|
445
|
-
onclause=and_(
|
|
446
|
-
models.DatasetExample.id
|
|
447
|
-
== models.DatasetExampleRevision.dataset_example_id,
|
|
448
|
-
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
449
|
-
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
450
|
-
),
|
|
451
|
-
)
|
|
494
|
+
.join(models.ExperimentDatasetExample)
|
|
495
|
+
.where(models.ExperimentDatasetExample.experiment_id == base_experiment_rowid)
|
|
452
496
|
.order_by(models.DatasetExample.id.desc())
|
|
453
497
|
.limit(page_size + 1)
|
|
454
498
|
)
|
|
499
|
+
|
|
455
500
|
if cursor is not None:
|
|
456
501
|
examples_query = examples_query.where(models.DatasetExample.id < cursor.rowid)
|
|
457
502
|
|
|
@@ -490,15 +535,17 @@ class Query:
|
|
|
490
535
|
|
|
491
536
|
cursors_and_nodes = []
|
|
492
537
|
for example in examples:
|
|
493
|
-
|
|
538
|
+
repeated_run_groups = []
|
|
494
539
|
for experiment_id in experiment_rowids:
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
540
|
+
repeated_run_groups.append(
|
|
541
|
+
ExperimentRepeatedRunGroup(
|
|
542
|
+
experiment_rowid=experiment_id,
|
|
543
|
+
dataset_example_rowid=example.id,
|
|
544
|
+
cached_runs=[
|
|
545
|
+
ExperimentRun(id=run.id, db_record=run)
|
|
500
546
|
for run in sorted(
|
|
501
|
-
runs[example.id][experiment_id],
|
|
547
|
+
runs[example.id][experiment_id],
|
|
548
|
+
key=lambda run: run.repetition_number,
|
|
502
549
|
)
|
|
503
550
|
],
|
|
504
551
|
)
|
|
@@ -506,11 +553,11 @@ class Query:
|
|
|
506
553
|
experiment_comparison = ExperimentComparison(
|
|
507
554
|
id_attr=example.id,
|
|
508
555
|
example=DatasetExample(
|
|
509
|
-
|
|
510
|
-
|
|
556
|
+
id=example.id,
|
|
557
|
+
db_record=example,
|
|
511
558
|
version_id=base_experiment.dataset_version_id,
|
|
512
559
|
),
|
|
513
|
-
|
|
560
|
+
repeated_run_groups=repeated_run_groups,
|
|
514
561
|
)
|
|
515
562
|
cursors_and_nodes.append((Cursor(rowid=example.id), experiment_comparison))
|
|
516
563
|
|
|
@@ -521,12 +568,12 @@ class Query:
|
|
|
521
568
|
)
|
|
522
569
|
|
|
523
570
|
@strawberry.field
|
|
524
|
-
async def
|
|
571
|
+
async def experiment_run_metric_comparisons(
|
|
525
572
|
self,
|
|
526
573
|
info: Info[Context, None],
|
|
527
574
|
base_experiment_id: GlobalID,
|
|
528
575
|
compare_experiment_ids: list[GlobalID],
|
|
529
|
-
) ->
|
|
576
|
+
) -> ExperimentRunMetricComparisons:
|
|
530
577
|
if base_experiment_id in compare_experiment_ids:
|
|
531
578
|
raise BadRequest("Compare experiment IDs cannot contain the base experiment ID")
|
|
532
579
|
if not compare_experiment_ids:
|
|
@@ -553,375 +600,256 @@ class Query:
|
|
|
553
600
|
raise BadRequest(f"Invalid compare experiment ID: {compare_experiment_id}")
|
|
554
601
|
|
|
555
602
|
base_experiment_runs = (
|
|
556
|
-
select(
|
|
603
|
+
select(
|
|
604
|
+
models.ExperimentRun.dataset_example_id,
|
|
605
|
+
func.min(models.ExperimentRun.start_time).label("start_time"),
|
|
606
|
+
func.min(models.ExperimentRun.end_time).label("end_time"),
|
|
607
|
+
func.sum(models.SpanCost.total_tokens).label("total_tokens"),
|
|
608
|
+
func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
|
|
609
|
+
func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
|
|
610
|
+
func.sum(models.SpanCost.total_cost).label("total_cost"),
|
|
611
|
+
func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
|
|
612
|
+
func.sum(models.SpanCost.completion_cost).label("completion_cost"),
|
|
613
|
+
)
|
|
614
|
+
.select_from(models.ExperimentRun)
|
|
615
|
+
.join(
|
|
616
|
+
models.Trace,
|
|
617
|
+
onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
|
|
618
|
+
isouter=True,
|
|
619
|
+
)
|
|
620
|
+
.join(
|
|
621
|
+
models.SpanCost,
|
|
622
|
+
onclause=models.Trace.id == models.SpanCost.trace_rowid,
|
|
623
|
+
isouter=True,
|
|
624
|
+
)
|
|
557
625
|
.where(models.ExperimentRun.experiment_id == base_experiment_rowid)
|
|
626
|
+
.group_by(models.ExperimentRun.dataset_example_id)
|
|
558
627
|
.subquery()
|
|
559
628
|
.alias("base_experiment_runs")
|
|
560
629
|
)
|
|
561
|
-
|
|
562
|
-
base_experiment_span_costs = (
|
|
630
|
+
compare_experiment_runs = (
|
|
563
631
|
select(
|
|
564
|
-
models.
|
|
565
|
-
func.
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
),
|
|
570
|
-
func.
|
|
632
|
+
models.ExperimentRun.dataset_example_id,
|
|
633
|
+
func.min(
|
|
634
|
+
LatencyMs(models.ExperimentRun.start_time, models.ExperimentRun.end_time)
|
|
635
|
+
).label("min_latency_ms"),
|
|
636
|
+
func.min(models.SpanCost.total_tokens).label("min_total_tokens"),
|
|
637
|
+
func.min(models.SpanCost.prompt_tokens).label("min_prompt_tokens"),
|
|
638
|
+
func.min(models.SpanCost.completion_tokens).label("min_completion_tokens"),
|
|
639
|
+
func.min(models.SpanCost.total_cost).label("min_total_cost"),
|
|
640
|
+
func.min(models.SpanCost.prompt_cost).label("min_prompt_cost"),
|
|
641
|
+
func.min(models.SpanCost.completion_cost).label("min_completion_cost"),
|
|
571
642
|
)
|
|
572
|
-
.select_from(models.
|
|
573
|
-
.group_by(
|
|
574
|
-
models.SpanCost.trace_rowid,
|
|
575
|
-
)
|
|
576
|
-
.subquery()
|
|
577
|
-
.alias("base_experiment_span_costs")
|
|
578
|
-
)
|
|
579
|
-
|
|
580
|
-
query = (
|
|
581
|
-
select() # add selected columns below
|
|
582
|
-
.select_from(base_experiment_runs)
|
|
643
|
+
.select_from(models.ExperimentRun)
|
|
583
644
|
.join(
|
|
584
|
-
|
|
585
|
-
onclause=
|
|
645
|
+
models.Trace,
|
|
646
|
+
onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
|
|
586
647
|
isouter=True,
|
|
587
648
|
)
|
|
588
649
|
.join(
|
|
589
|
-
|
|
590
|
-
onclause=
|
|
650
|
+
models.SpanCost,
|
|
651
|
+
onclause=models.Trace.id == models.SpanCost.trace_rowid,
|
|
591
652
|
isouter=True,
|
|
592
653
|
)
|
|
654
|
+
.where(
|
|
655
|
+
models.ExperimentRun.experiment_id.in_(compare_experiment_rowids),
|
|
656
|
+
)
|
|
657
|
+
.group_by(models.ExperimentRun.dataset_example_id)
|
|
658
|
+
.subquery()
|
|
659
|
+
.alias("comp_exp_run_mins")
|
|
593
660
|
)
|
|
594
661
|
|
|
595
662
|
base_experiment_run_latency = LatencyMs(
|
|
596
663
|
base_experiment_runs.c.start_time, base_experiment_runs.c.end_time
|
|
597
664
|
).label("base_experiment_run_latency_ms")
|
|
598
|
-
base_experiment_run_prompt_token_count = base_experiment_span_costs.c.prompt_tokens
|
|
599
|
-
base_experiment_run_completion_token_count = base_experiment_span_costs.c.completion_tokens
|
|
600
|
-
base_experiment_run_total_token_count = base_experiment_span_costs.c.total_tokens
|
|
601
|
-
base_experiment_run_total_cost = base_experiment_span_costs.c.total_cost
|
|
602
|
-
|
|
603
|
-
for compare_experiment_index, compare_experiment_rowid in enumerate(
|
|
604
|
-
compare_experiment_rowids
|
|
605
|
-
):
|
|
606
|
-
compare_experiment_runs = (
|
|
607
|
-
select(models.ExperimentRun)
|
|
608
|
-
.where(models.ExperimentRun.experiment_id == compare_experiment_rowid)
|
|
609
|
-
.subquery()
|
|
610
|
-
.alias(f"comp_exp_{compare_experiment_index}_runs")
|
|
611
|
-
)
|
|
612
|
-
compare_experiment_traces = aliased(
|
|
613
|
-
models.Trace, name=f"comp_exp_{compare_experiment_index}_traces"
|
|
614
|
-
)
|
|
615
|
-
compare_experiment_span_costs = (
|
|
616
|
-
select(
|
|
617
|
-
models.SpanCost.trace_rowid,
|
|
618
|
-
func.coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
|
|
619
|
-
func.coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label(
|
|
620
|
-
"prompt_tokens"
|
|
621
|
-
),
|
|
622
|
-
func.coalesce(func.sum(models.SpanCost.completion_tokens), 0).label(
|
|
623
|
-
"completion_tokens"
|
|
624
|
-
),
|
|
625
|
-
func.coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
|
|
626
|
-
)
|
|
627
|
-
.select_from(models.SpanCost)
|
|
628
|
-
.group_by(models.SpanCost.trace_rowid)
|
|
629
|
-
.subquery()
|
|
630
|
-
.alias(f"comp_exp_{compare_experiment_index}_span_costs")
|
|
631
|
-
)
|
|
632
|
-
compare_experiment_run_latency = LatencyMs(
|
|
633
|
-
compare_experiment_runs.c.start_time, compare_experiment_runs.c.end_time
|
|
634
|
-
).label(f"comp_exp_{compare_experiment_index}_run_latency_ms")
|
|
635
|
-
compare_experiment_run_prompt_token_count = (
|
|
636
|
-
compare_experiment_span_costs.c.prompt_tokens
|
|
637
|
-
)
|
|
638
|
-
compare_experiment_run_completion_token_count = (
|
|
639
|
-
compare_experiment_span_costs.c.completion_tokens
|
|
640
|
-
)
|
|
641
|
-
compare_experiment_run_total_token_count = compare_experiment_span_costs.c.total_tokens
|
|
642
|
-
compare_experiment_run_total_cost = compare_experiment_span_costs.c.total_cost
|
|
643
|
-
|
|
644
|
-
query = (
|
|
645
|
-
query.add_columns(
|
|
646
|
-
_count_rows(
|
|
647
|
-
base_experiment_run_latency < compare_experiment_run_latency,
|
|
648
|
-
).label(f"comp_exp_{compare_experiment_index}_num_runs_increased_latency"),
|
|
649
|
-
_count_rows(
|
|
650
|
-
base_experiment_run_latency > compare_experiment_run_latency,
|
|
651
|
-
).label(f"comp_exp_{compare_experiment_index}_num_runs_decreased_latency"),
|
|
652
|
-
_count_rows(
|
|
653
|
-
base_experiment_run_latency == compare_experiment_run_latency,
|
|
654
|
-
).label(f"comp_exp_{compare_experiment_index}_num_runs_equal_latency"),
|
|
655
|
-
_count_rows(
|
|
656
|
-
base_experiment_run_prompt_token_count
|
|
657
|
-
< compare_experiment_run_prompt_token_count,
|
|
658
|
-
).label(
|
|
659
|
-
f"comp_exp_{compare_experiment_index}_num_runs_increased_prompt_token_count"
|
|
660
|
-
),
|
|
661
|
-
_count_rows(
|
|
662
|
-
base_experiment_run_prompt_token_count
|
|
663
|
-
> compare_experiment_run_prompt_token_count,
|
|
664
|
-
).label(
|
|
665
|
-
f"comp_exp_{compare_experiment_index}_num_runs_decreased_prompt_token_count"
|
|
666
|
-
),
|
|
667
|
-
_count_rows(
|
|
668
|
-
base_experiment_run_prompt_token_count
|
|
669
|
-
== compare_experiment_run_prompt_token_count,
|
|
670
|
-
).label(
|
|
671
|
-
f"comp_exp_{compare_experiment_index}_num_runs_equal_prompt_token_count"
|
|
672
|
-
),
|
|
673
|
-
_count_rows(
|
|
674
|
-
base_experiment_run_completion_token_count
|
|
675
|
-
< compare_experiment_run_completion_token_count,
|
|
676
|
-
).label(
|
|
677
|
-
f"comp_exp_{compare_experiment_index}_num_runs_increased_completion_token_count"
|
|
678
|
-
),
|
|
679
|
-
_count_rows(
|
|
680
|
-
base_experiment_run_completion_token_count
|
|
681
|
-
> compare_experiment_run_completion_token_count,
|
|
682
|
-
).label(
|
|
683
|
-
f"comp_exp_{compare_experiment_index}_num_runs_decreased_completion_token_count"
|
|
684
|
-
),
|
|
685
|
-
_count_rows(
|
|
686
|
-
base_experiment_run_completion_token_count
|
|
687
|
-
== compare_experiment_run_completion_token_count,
|
|
688
|
-
).label(
|
|
689
|
-
f"comp_exp_{compare_experiment_index}_num_runs_equal_completion_token_count"
|
|
690
|
-
),
|
|
691
|
-
_count_rows(
|
|
692
|
-
base_experiment_run_total_token_count
|
|
693
|
-
< compare_experiment_run_total_token_count,
|
|
694
|
-
).label(
|
|
695
|
-
f"comp_exp_{compare_experiment_index}_num_runs_increased_total_token_count"
|
|
696
|
-
),
|
|
697
|
-
_count_rows(
|
|
698
|
-
base_experiment_run_total_token_count
|
|
699
|
-
> compare_experiment_run_total_token_count,
|
|
700
|
-
).label(
|
|
701
|
-
f"comp_exp_{compare_experiment_index}_num_runs_decreased_total_token_count"
|
|
702
|
-
),
|
|
703
|
-
_count_rows(
|
|
704
|
-
base_experiment_run_total_token_count
|
|
705
|
-
== compare_experiment_run_total_token_count,
|
|
706
|
-
).label(
|
|
707
|
-
f"comp_exp_{compare_experiment_index}_num_runs_equal_total_token_count"
|
|
708
|
-
),
|
|
709
|
-
_count_rows(
|
|
710
|
-
base_experiment_run_total_cost < compare_experiment_run_total_cost,
|
|
711
|
-
).label(f"comp_exp_{compare_experiment_index}_num_runs_increased_total_cost"),
|
|
712
|
-
_count_rows(
|
|
713
|
-
base_experiment_run_total_cost > compare_experiment_run_total_cost,
|
|
714
|
-
).label(f"comp_exp_{compare_experiment_index}_num_runs_decreased_total_cost"),
|
|
715
|
-
_count_rows(
|
|
716
|
-
base_experiment_run_total_cost == compare_experiment_run_total_cost,
|
|
717
|
-
).label(f"comp_exp_{compare_experiment_index}_num_runs_equal_total_cost"),
|
|
718
|
-
)
|
|
719
|
-
.join(
|
|
720
|
-
compare_experiment_runs,
|
|
721
|
-
onclause=base_experiment_runs.c.dataset_example_id
|
|
722
|
-
== compare_experiment_runs.c.dataset_example_id,
|
|
723
|
-
isouter=True,
|
|
724
|
-
)
|
|
725
|
-
.join(
|
|
726
|
-
compare_experiment_traces,
|
|
727
|
-
onclause=compare_experiment_runs.c.trace_id
|
|
728
|
-
== compare_experiment_traces.trace_id,
|
|
729
|
-
isouter=True,
|
|
730
|
-
)
|
|
731
|
-
.join(
|
|
732
|
-
compare_experiment_span_costs,
|
|
733
|
-
onclause=compare_experiment_traces.id
|
|
734
|
-
== compare_experiment_span_costs.c.trace_rowid,
|
|
735
|
-
isouter=True,
|
|
736
|
-
)
|
|
737
|
-
)
|
|
738
665
|
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
)
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
666
|
+
comparisons_query = (
|
|
667
|
+
select(
|
|
668
|
+
func.count().label("num_base_experiment_runs"),
|
|
669
|
+
_comparison_count_expression(
|
|
670
|
+
base_column=base_experiment_run_latency,
|
|
671
|
+
compare_column=compare_experiment_runs.c.min_latency_ms,
|
|
672
|
+
optimization_direction="minimize",
|
|
673
|
+
comparison_type="improvement",
|
|
674
|
+
).label("num_latency_improved"),
|
|
675
|
+
_comparison_count_expression(
|
|
676
|
+
base_column=base_experiment_run_latency,
|
|
677
|
+
compare_column=compare_experiment_runs.c.min_latency_ms,
|
|
678
|
+
optimization_direction="minimize",
|
|
679
|
+
comparison_type="regression",
|
|
680
|
+
).label("num_latency_regressed"),
|
|
681
|
+
_comparison_count_expression(
|
|
682
|
+
base_column=base_experiment_run_latency,
|
|
683
|
+
compare_column=compare_experiment_runs.c.min_latency_ms,
|
|
684
|
+
optimization_direction="minimize",
|
|
685
|
+
comparison_type="equality",
|
|
686
|
+
).label("num_latency_is_equal"),
|
|
687
|
+
_comparison_count_expression(
|
|
688
|
+
base_column=base_experiment_runs.c.total_tokens,
|
|
689
|
+
compare_column=compare_experiment_runs.c.min_total_tokens,
|
|
690
|
+
optimization_direction="minimize",
|
|
691
|
+
comparison_type="improvement",
|
|
692
|
+
).label("num_total_token_count_improved"),
|
|
693
|
+
_comparison_count_expression(
|
|
694
|
+
base_column=base_experiment_runs.c.total_tokens,
|
|
695
|
+
compare_column=compare_experiment_runs.c.min_total_tokens,
|
|
696
|
+
optimization_direction="minimize",
|
|
697
|
+
comparison_type="regression",
|
|
698
|
+
).label("num_total_token_count_regressed"),
|
|
699
|
+
_comparison_count_expression(
|
|
700
|
+
base_column=base_experiment_runs.c.total_tokens,
|
|
701
|
+
compare_column=compare_experiment_runs.c.min_total_tokens,
|
|
702
|
+
optimization_direction="minimize",
|
|
703
|
+
comparison_type="equality",
|
|
704
|
+
).label("num_total_token_count_is_equal"),
|
|
705
|
+
_comparison_count_expression(
|
|
706
|
+
base_column=base_experiment_runs.c.prompt_tokens,
|
|
707
|
+
compare_column=compare_experiment_runs.c.min_prompt_tokens,
|
|
708
|
+
optimization_direction="minimize",
|
|
709
|
+
comparison_type="improvement",
|
|
710
|
+
).label("num_prompt_token_count_improved"),
|
|
711
|
+
_comparison_count_expression(
|
|
712
|
+
base_column=base_experiment_runs.c.prompt_tokens,
|
|
713
|
+
compare_column=compare_experiment_runs.c.min_prompt_tokens,
|
|
714
|
+
optimization_direction="minimize",
|
|
715
|
+
comparison_type="regression",
|
|
716
|
+
).label("num_prompt_token_count_regressed"),
|
|
717
|
+
_comparison_count_expression(
|
|
718
|
+
base_column=base_experiment_runs.c.prompt_tokens,
|
|
719
|
+
compare_column=compare_experiment_runs.c.min_prompt_tokens,
|
|
720
|
+
optimization_direction="minimize",
|
|
721
|
+
comparison_type="equality",
|
|
722
|
+
).label("num_prompt_token_count_is_equal"),
|
|
723
|
+
_comparison_count_expression(
|
|
724
|
+
base_column=base_experiment_runs.c.completion_tokens,
|
|
725
|
+
compare_column=compare_experiment_runs.c.min_completion_tokens,
|
|
726
|
+
optimization_direction="minimize",
|
|
727
|
+
comparison_type="improvement",
|
|
728
|
+
).label("num_completion_token_count_improved"),
|
|
729
|
+
_comparison_count_expression(
|
|
730
|
+
base_column=base_experiment_runs.c.completion_tokens,
|
|
731
|
+
compare_column=compare_experiment_runs.c.min_completion_tokens,
|
|
732
|
+
optimization_direction="minimize",
|
|
733
|
+
comparison_type="regression",
|
|
734
|
+
).label("num_completion_token_count_regressed"),
|
|
735
|
+
_comparison_count_expression(
|
|
736
|
+
base_column=base_experiment_runs.c.completion_tokens,
|
|
737
|
+
compare_column=compare_experiment_runs.c.min_completion_tokens,
|
|
738
|
+
optimization_direction="minimize",
|
|
739
|
+
comparison_type="equality",
|
|
740
|
+
).label("num_completion_token_count_is_equal"),
|
|
741
|
+
_comparison_count_expression(
|
|
742
|
+
base_column=base_experiment_runs.c.total_cost,
|
|
743
|
+
compare_column=compare_experiment_runs.c.min_total_cost,
|
|
744
|
+
optimization_direction="minimize",
|
|
745
|
+
comparison_type="improvement",
|
|
746
|
+
).label("num_total_cost_improved"),
|
|
747
|
+
_comparison_count_expression(
|
|
748
|
+
base_column=base_experiment_runs.c.total_cost,
|
|
749
|
+
compare_column=compare_experiment_runs.c.min_total_cost,
|
|
750
|
+
optimization_direction="minimize",
|
|
751
|
+
comparison_type="regression",
|
|
752
|
+
).label("num_total_cost_regressed"),
|
|
753
|
+
_comparison_count_expression(
|
|
754
|
+
base_column=base_experiment_runs.c.total_cost,
|
|
755
|
+
compare_column=compare_experiment_runs.c.min_total_cost,
|
|
756
|
+
optimization_direction="minimize",
|
|
757
|
+
comparison_type="equality",
|
|
758
|
+
).label("num_total_cost_is_equal"),
|
|
759
|
+
_comparison_count_expression(
|
|
760
|
+
base_column=base_experiment_runs.c.prompt_cost,
|
|
761
|
+
compare_column=compare_experiment_runs.c.min_prompt_cost,
|
|
762
|
+
optimization_direction="minimize",
|
|
763
|
+
comparison_type="improvement",
|
|
764
|
+
).label("num_prompt_cost_improved"),
|
|
765
|
+
_comparison_count_expression(
|
|
766
|
+
base_column=base_experiment_runs.c.prompt_cost,
|
|
767
|
+
compare_column=compare_experiment_runs.c.min_prompt_cost,
|
|
768
|
+
optimization_direction="minimize",
|
|
769
|
+
comparison_type="regression",
|
|
770
|
+
).label("num_prompt_cost_regressed"),
|
|
771
|
+
_comparison_count_expression(
|
|
772
|
+
base_column=base_experiment_runs.c.prompt_cost,
|
|
773
|
+
compare_column=compare_experiment_runs.c.min_prompt_cost,
|
|
774
|
+
optimization_direction="minimize",
|
|
775
|
+
comparison_type="equality",
|
|
776
|
+
).label("num_prompt_cost_is_equal"),
|
|
777
|
+
_comparison_count_expression(
|
|
778
|
+
base_column=base_experiment_runs.c.completion_cost,
|
|
779
|
+
compare_column=compare_experiment_runs.c.min_completion_cost,
|
|
780
|
+
optimization_direction="minimize",
|
|
781
|
+
comparison_type="improvement",
|
|
782
|
+
).label("num_completion_cost_improved"),
|
|
783
|
+
_comparison_count_expression(
|
|
784
|
+
base_column=base_experiment_runs.c.completion_cost,
|
|
785
|
+
compare_column=compare_experiment_runs.c.min_completion_cost,
|
|
786
|
+
optimization_direction="minimize",
|
|
787
|
+
comparison_type="regression",
|
|
788
|
+
).label("num_completion_cost_regressed"),
|
|
789
|
+
_comparison_count_expression(
|
|
790
|
+
base_column=base_experiment_runs.c.completion_cost,
|
|
791
|
+
compare_column=compare_experiment_runs.c.min_completion_cost,
|
|
792
|
+
optimization_direction="minimize",
|
|
793
|
+
comparison_type="equality",
|
|
794
|
+
).label("num_completion_cost_is_equal"),
|
|
833
795
|
)
|
|
834
|
-
.subquery()
|
|
835
|
-
.alias("base_experiment_runs")
|
|
836
|
-
)
|
|
837
|
-
base_experiment_run_annotations = aliased(
|
|
838
|
-
models.ExperimentRunAnnotation, name="base_experiment_run_annotations"
|
|
839
|
-
)
|
|
840
|
-
query = (
|
|
841
|
-
select(base_experiment_run_annotations.name)
|
|
842
796
|
.select_from(base_experiment_runs)
|
|
843
797
|
.join(
|
|
844
|
-
|
|
845
|
-
onclause=base_experiment_runs.c.
|
|
846
|
-
==
|
|
798
|
+
compare_experiment_runs,
|
|
799
|
+
onclause=base_experiment_runs.c.dataset_example_id
|
|
800
|
+
== compare_experiment_runs.c.dataset_example_id,
|
|
847
801
|
isouter=True,
|
|
848
802
|
)
|
|
849
|
-
.group_by(base_experiment_run_annotations.name)
|
|
850
|
-
.order_by(base_experiment_run_annotations.name)
|
|
851
803
|
)
|
|
852
|
-
|
|
853
|
-
compare_experiment_rowids
|
|
854
|
-
):
|
|
855
|
-
compare_experiment_runs = (
|
|
856
|
-
select(models.ExperimentRun)
|
|
857
|
-
.where(
|
|
858
|
-
models.ExperimentRun.experiment_id == compare_experiment_rowid,
|
|
859
|
-
)
|
|
860
|
-
.subquery()
|
|
861
|
-
.alias(f"comp_exp_{compare_experiment_index}_runs")
|
|
862
|
-
)
|
|
863
|
-
compare_experiment_run_annotations = aliased(
|
|
864
|
-
models.ExperimentRunAnnotation,
|
|
865
|
-
name=f"comp_exp_{compare_experiment_index}_run_annotations",
|
|
866
|
-
)
|
|
867
|
-
query = (
|
|
868
|
-
query.add_columns(
|
|
869
|
-
_count_rows(
|
|
870
|
-
base_experiment_run_annotations.score
|
|
871
|
-
< compare_experiment_run_annotations.score,
|
|
872
|
-
).label(f"comp_exp_{compare_experiment_index}_num_runs_increased_score"),
|
|
873
|
-
_count_rows(
|
|
874
|
-
base_experiment_run_annotations.score
|
|
875
|
-
> compare_experiment_run_annotations.score,
|
|
876
|
-
).label(f"comp_exp_{compare_experiment_index}_num_runs_decreased_score"),
|
|
877
|
-
_count_rows(
|
|
878
|
-
base_experiment_run_annotations.score
|
|
879
|
-
== compare_experiment_run_annotations.score,
|
|
880
|
-
).label(f"comp_exp_{compare_experiment_index}_num_runs_equal_score"),
|
|
881
|
-
)
|
|
882
|
-
.join(
|
|
883
|
-
compare_experiment_runs,
|
|
884
|
-
onclause=base_experiment_runs.c.dataset_example_id
|
|
885
|
-
== compare_experiment_runs.c.dataset_example_id,
|
|
886
|
-
isouter=True,
|
|
887
|
-
)
|
|
888
|
-
.join(
|
|
889
|
-
compare_experiment_run_annotations,
|
|
890
|
-
onclause=compare_experiment_runs.c.id
|
|
891
|
-
== compare_experiment_run_annotations.experiment_run_id,
|
|
892
|
-
isouter=True,
|
|
893
|
-
)
|
|
894
|
-
.where(
|
|
895
|
-
base_experiment_run_annotations.name == compare_experiment_run_annotations.name
|
|
896
|
-
)
|
|
897
|
-
)
|
|
804
|
+
|
|
898
805
|
async with info.context.db() as session:
|
|
899
|
-
result = (await session.execute(
|
|
806
|
+
result = (await session.execute(comparisons_query)).first()
|
|
900
807
|
assert result is not None
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
808
|
+
|
|
809
|
+
return ExperimentRunMetricComparisons(
|
|
810
|
+
latency=ExperimentRunMetricComparison(
|
|
811
|
+
num_runs_improved=result.num_latency_improved,
|
|
812
|
+
num_runs_regressed=result.num_latency_regressed,
|
|
813
|
+
num_runs_equal=result.num_latency_is_equal,
|
|
814
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
815
|
+
),
|
|
816
|
+
total_token_count=ExperimentRunMetricComparison(
|
|
817
|
+
num_runs_improved=result.num_total_token_count_improved,
|
|
818
|
+
num_runs_regressed=result.num_total_token_count_regressed,
|
|
819
|
+
num_runs_equal=result.num_total_token_count_is_equal,
|
|
820
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
821
|
+
),
|
|
822
|
+
prompt_token_count=ExperimentRunMetricComparison(
|
|
823
|
+
num_runs_improved=result.num_prompt_token_count_improved,
|
|
824
|
+
num_runs_regressed=result.num_prompt_token_count_regressed,
|
|
825
|
+
num_runs_equal=result.num_prompt_token_count_is_equal,
|
|
826
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
827
|
+
),
|
|
828
|
+
completion_token_count=ExperimentRunMetricComparison(
|
|
829
|
+
num_runs_improved=result.num_completion_token_count_improved,
|
|
830
|
+
num_runs_regressed=result.num_completion_token_count_regressed,
|
|
831
|
+
num_runs_equal=result.num_completion_token_count_is_equal,
|
|
832
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
833
|
+
),
|
|
834
|
+
total_cost=ExperimentRunMetricComparison(
|
|
835
|
+
num_runs_improved=result.num_total_cost_improved,
|
|
836
|
+
num_runs_regressed=result.num_total_cost_regressed,
|
|
837
|
+
num_runs_equal=result.num_total_cost_is_equal,
|
|
838
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
839
|
+
),
|
|
840
|
+
prompt_cost=ExperimentRunMetricComparison(
|
|
841
|
+
num_runs_improved=result.num_prompt_cost_improved,
|
|
842
|
+
num_runs_regressed=result.num_prompt_cost_regressed,
|
|
843
|
+
num_runs_equal=result.num_prompt_cost_is_equal,
|
|
844
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
845
|
+
),
|
|
846
|
+
completion_cost=ExperimentRunMetricComparison(
|
|
847
|
+
num_runs_improved=result.num_completion_cost_improved,
|
|
848
|
+
num_runs_regressed=result.num_completion_cost_regressed,
|
|
849
|
+
num_runs_equal=result.num_completion_cost_is_equal,
|
|
850
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
851
|
+
),
|
|
852
|
+
)
|
|
925
853
|
|
|
926
854
|
@strawberry.field
|
|
927
855
|
async def validate_experiment_run_filter_condition(
|
|
@@ -959,136 +887,51 @@ class Query:
|
|
|
959
887
|
return InferenceModel()
|
|
960
888
|
|
|
961
889
|
@strawberry.field
|
|
962
|
-
async def node(self, id:
|
|
963
|
-
|
|
890
|
+
async def node(self, id: strawberry.ID, info: Info[Context, None]) -> Node:
|
|
891
|
+
if not is_global_id(id):
|
|
892
|
+
try:
|
|
893
|
+
experiment_rowid, dataset_example_rowid = (
|
|
894
|
+
parse_experiment_repeated_run_group_node_id(id)
|
|
895
|
+
)
|
|
896
|
+
except Exception:
|
|
897
|
+
raise NotFound(f"Unknown node: {id}")
|
|
898
|
+
return ExperimentRepeatedRunGroup(
|
|
899
|
+
experiment_rowid=experiment_rowid,
|
|
900
|
+
dataset_example_rowid=dataset_example_rowid,
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
global_id = GlobalID.from_id(id)
|
|
904
|
+
type_name, node_id = from_global_id(global_id)
|
|
964
905
|
if type_name == "Dimension":
|
|
965
906
|
dimension = info.context.model.scalar_dimensions[node_id]
|
|
966
907
|
return to_gql_dimension(node_id, dimension)
|
|
967
908
|
elif type_name == "EmbeddingDimension":
|
|
968
909
|
embedding_dimension = info.context.model.embedding_dimensions[node_id]
|
|
969
910
|
return to_gql_embedding_dimension(node_id, embedding_dimension)
|
|
970
|
-
elif type_name ==
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
if project is None:
|
|
975
|
-
raise NotFound(f"Unknown project: {id}")
|
|
976
|
-
return Project(
|
|
977
|
-
project_rowid=project.id,
|
|
978
|
-
db_project=project,
|
|
979
|
-
)
|
|
980
|
-
elif type_name == "Trace":
|
|
981
|
-
trace_stmt = select(models.Trace).filter_by(id=node_id)
|
|
982
|
-
async with info.context.db() as session:
|
|
983
|
-
trace = await session.scalar(trace_stmt)
|
|
984
|
-
if trace is None:
|
|
985
|
-
raise NotFound(f"Unknown trace: {id}")
|
|
986
|
-
return Trace(trace_rowid=trace.id, db_trace=trace)
|
|
911
|
+
elif type_name == Project.__name__:
|
|
912
|
+
return Project(id=node_id)
|
|
913
|
+
elif type_name == Trace.__name__:
|
|
914
|
+
return Trace(id=node_id)
|
|
987
915
|
elif type_name == Span.__name__:
|
|
988
|
-
|
|
989
|
-
select(models.Span)
|
|
990
|
-
.options(
|
|
991
|
-
joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
|
|
992
|
-
)
|
|
993
|
-
.where(models.Span.id == node_id)
|
|
994
|
-
)
|
|
995
|
-
async with info.context.db() as session:
|
|
996
|
-
span = await session.scalar(span_stmt)
|
|
997
|
-
if span is None:
|
|
998
|
-
raise NotFound(f"Unknown span: {id}")
|
|
999
|
-
return Span(span_rowid=span.id, db_span=span)
|
|
916
|
+
return Span(id=node_id)
|
|
1000
917
|
elif type_name == Dataset.__name__:
|
|
1001
|
-
|
|
1002
|
-
async with info.context.db() as session:
|
|
1003
|
-
if (dataset := await session.scalar(dataset_stmt)) is None:
|
|
1004
|
-
raise NotFound(f"Unknown dataset: {id}")
|
|
1005
|
-
return to_gql_dataset(dataset)
|
|
918
|
+
return Dataset(id=node_id)
|
|
1006
919
|
elif type_name == DatasetExample.__name__:
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
.where(models.DatasetExampleRevision.dataset_example_id == example_id)
|
|
1011
|
-
.scalar_subquery()
|
|
1012
|
-
)
|
|
1013
|
-
async with info.context.db() as session:
|
|
1014
|
-
example = await session.scalar(
|
|
1015
|
-
select(models.DatasetExample)
|
|
1016
|
-
.join(
|
|
1017
|
-
models.DatasetExampleRevision,
|
|
1018
|
-
onclause=models.DatasetExampleRevision.dataset_example_id
|
|
1019
|
-
== models.DatasetExample.id,
|
|
1020
|
-
)
|
|
1021
|
-
.where(
|
|
1022
|
-
and_(
|
|
1023
|
-
models.DatasetExample.id == example_id,
|
|
1024
|
-
models.DatasetExampleRevision.id == latest_revision_id,
|
|
1025
|
-
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
1026
|
-
)
|
|
1027
|
-
)
|
|
1028
|
-
)
|
|
1029
|
-
if not example:
|
|
1030
|
-
raise NotFound(f"Unknown dataset example: {id}")
|
|
1031
|
-
return DatasetExample(
|
|
1032
|
-
id_attr=example.id,
|
|
1033
|
-
created_at=example.created_at,
|
|
1034
|
-
)
|
|
920
|
+
return DatasetExample(id=node_id)
|
|
921
|
+
elif type_name == DatasetSplit.__name__:
|
|
922
|
+
return DatasetSplit(id=node_id)
|
|
1035
923
|
elif type_name == Experiment.__name__:
|
|
1036
|
-
|
|
1037
|
-
experiment = await session.scalar(
|
|
1038
|
-
select(models.Experiment).where(models.Experiment.id == node_id)
|
|
1039
|
-
)
|
|
1040
|
-
if not experiment:
|
|
1041
|
-
raise NotFound(f"Unknown experiment: {id}")
|
|
1042
|
-
return Experiment(
|
|
1043
|
-
id_attr=experiment.id,
|
|
1044
|
-
name=experiment.name,
|
|
1045
|
-
project_name=experiment.project_name,
|
|
1046
|
-
description=experiment.description,
|
|
1047
|
-
created_at=experiment.created_at,
|
|
1048
|
-
updated_at=experiment.updated_at,
|
|
1049
|
-
metadata=experiment.metadata_,
|
|
1050
|
-
)
|
|
924
|
+
return Experiment(id=node_id)
|
|
1051
925
|
elif type_name == ExperimentRun.__name__:
|
|
1052
|
-
|
|
1053
|
-
if not (
|
|
1054
|
-
run := await session.scalar(
|
|
1055
|
-
select(models.ExperimentRun)
|
|
1056
|
-
.where(models.ExperimentRun.id == node_id)
|
|
1057
|
-
.options(
|
|
1058
|
-
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
1059
|
-
)
|
|
1060
|
-
)
|
|
1061
|
-
):
|
|
1062
|
-
raise NotFound(f"Unknown experiment run: {id}")
|
|
1063
|
-
return to_gql_experiment_run(run)
|
|
926
|
+
return ExperimentRun(id=node_id)
|
|
1064
927
|
elif type_name == User.__name__:
|
|
1065
928
|
if int((user := info.context.user).identity) != node_id and not user.is_admin:
|
|
1066
929
|
raise Unauthorized(MSG_ADMIN_ONLY)
|
|
1067
|
-
|
|
1068
|
-
if not (
|
|
1069
|
-
user := await session.scalar(
|
|
1070
|
-
select(models.User).where(models.User.id == node_id)
|
|
1071
|
-
)
|
|
1072
|
-
):
|
|
1073
|
-
raise NotFound(f"Unknown user: {id}")
|
|
1074
|
-
return to_gql_user(user)
|
|
930
|
+
return User(id=node_id)
|
|
1075
931
|
elif type_name == ProjectSession.__name__:
|
|
1076
|
-
|
|
1077
|
-
if not (
|
|
1078
|
-
project_session := await session.scalar(
|
|
1079
|
-
select(models.ProjectSession).filter_by(id=node_id)
|
|
1080
|
-
)
|
|
1081
|
-
):
|
|
1082
|
-
raise NotFound(f"Unknown user: {id}")
|
|
1083
|
-
return to_gql_project_session(project_session)
|
|
932
|
+
return ProjectSession(id=node_id)
|
|
1084
933
|
elif type_name == Prompt.__name__:
|
|
1085
|
-
|
|
1086
|
-
if orm_prompt := await session.scalar(
|
|
1087
|
-
select(models.Prompt).where(models.Prompt.id == node_id)
|
|
1088
|
-
):
|
|
1089
|
-
return to_gql_prompt_from_orm(orm_prompt)
|
|
1090
|
-
else:
|
|
1091
|
-
raise NotFound(f"Unknown prompt: {id}")
|
|
934
|
+
return Prompt(id=node_id)
|
|
1092
935
|
elif type_name == PromptVersion.__name__:
|
|
1093
936
|
async with info.context.db() as session:
|
|
1094
937
|
if orm_prompt_version := await session.scalar(
|
|
@@ -1098,51 +941,17 @@ class Query:
|
|
|
1098
941
|
else:
|
|
1099
942
|
raise NotFound(f"Unknown prompt version: {id}")
|
|
1100
943
|
elif type_name == PromptLabel.__name__:
|
|
1101
|
-
|
|
1102
|
-
if not (
|
|
1103
|
-
prompt_label := await session.scalar(
|
|
1104
|
-
select(models.PromptLabel).where(models.PromptLabel.id == node_id)
|
|
1105
|
-
)
|
|
1106
|
-
):
|
|
1107
|
-
raise NotFound(f"Unknown prompt label: {id}")
|
|
1108
|
-
return to_gql_prompt_label(prompt_label)
|
|
944
|
+
return PromptLabel(id=node_id)
|
|
1109
945
|
elif type_name == PromptVersionTag.__name__:
|
|
1110
|
-
|
|
1111
|
-
if not (prompt_version_tag := await session.get(models.PromptVersionTag, node_id)):
|
|
1112
|
-
raise NotFound(f"Unknown prompt version tag: {id}")
|
|
1113
|
-
return to_gql_prompt_version_tag(prompt_version_tag)
|
|
946
|
+
return PromptVersionTag(id=node_id)
|
|
1114
947
|
elif type_name == ProjectTraceRetentionPolicy.__name__:
|
|
1115
|
-
|
|
1116
|
-
db_policy = await session.scalar(
|
|
1117
|
-
select(models.ProjectTraceRetentionPolicy).filter_by(id=node_id)
|
|
1118
|
-
)
|
|
1119
|
-
if not db_policy:
|
|
1120
|
-
raise NotFound(f"Unknown project trace retention policy: {id}")
|
|
1121
|
-
return ProjectTraceRetentionPolicy(id=db_policy.id, db_policy=db_policy)
|
|
948
|
+
return ProjectTraceRetentionPolicy(id=node_id)
|
|
1122
949
|
elif type_name == SpanAnnotation.__name__:
|
|
1123
|
-
|
|
1124
|
-
span_annotation = await session.get(models.SpanAnnotation, node_id)
|
|
1125
|
-
if not span_annotation:
|
|
1126
|
-
raise NotFound(f"Unknown span annotation: {id}")
|
|
1127
|
-
return to_gql_span_annotation(span_annotation)
|
|
950
|
+
return SpanAnnotation(id=node_id)
|
|
1128
951
|
elif type_name == TraceAnnotation.__name__:
|
|
1129
|
-
|
|
1130
|
-
trace_annotation = await session.get(models.TraceAnnotation, node_id)
|
|
1131
|
-
if not trace_annotation:
|
|
1132
|
-
raise NotFound(f"Unknown trace annotation: {id}")
|
|
1133
|
-
return to_gql_trace_annotation(trace_annotation)
|
|
952
|
+
return TraceAnnotation(id=node_id)
|
|
1134
953
|
elif type_name == GenerativeModel.__name__:
|
|
1135
|
-
|
|
1136
|
-
stmt = (
|
|
1137
|
-
select(models.GenerativeModel)
|
|
1138
|
-
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
1139
|
-
.where(models.GenerativeModel.id == node_id)
|
|
1140
|
-
.options(joinedload(models.GenerativeModel.token_prices))
|
|
1141
|
-
)
|
|
1142
|
-
model = await session.scalar(stmt)
|
|
1143
|
-
if not model:
|
|
1144
|
-
raise NotFound(f"Unknown model: {id}")
|
|
1145
|
-
return to_gql_generative_model(model)
|
|
954
|
+
return GenerativeModel(id=node_id)
|
|
1146
955
|
raise NotFound(f"Unknown node type: {type_name}")
|
|
1147
956
|
|
|
1148
957
|
@strawberry.field
|
|
@@ -1154,16 +963,7 @@ class Query:
|
|
|
1154
963
|
return None
|
|
1155
964
|
if isinstance(user, UnauthenticatedUser):
|
|
1156
965
|
return None
|
|
1157
|
-
|
|
1158
|
-
if (
|
|
1159
|
-
user := await session.scalar(
|
|
1160
|
-
select(models.User)
|
|
1161
|
-
.where(models.User.id == int(user.identity))
|
|
1162
|
-
.options(joinedload(models.User.role))
|
|
1163
|
-
)
|
|
1164
|
-
) is None:
|
|
1165
|
-
return None
|
|
1166
|
-
return to_gql_user(user)
|
|
966
|
+
return User(id=int(user.identity))
|
|
1167
967
|
|
|
1168
968
|
@strawberry.field
|
|
1169
969
|
async def prompts(
|
|
@@ -1174,6 +974,7 @@ class Query:
|
|
|
1174
974
|
after: Optional[CursorString] = UNSET,
|
|
1175
975
|
before: Optional[CursorString] = UNSET,
|
|
1176
976
|
filter: Optional[PromptFilter] = UNSET,
|
|
977
|
+
labelIds: Optional[list[GlobalID]] = UNSET,
|
|
1177
978
|
) -> Connection[Prompt]:
|
|
1178
979
|
args = ConnectionArgs(
|
|
1179
980
|
first=first,
|
|
@@ -1190,9 +991,21 @@ class Query:
|
|
|
1190
991
|
stmt = stmt.where(column.ilike(f"%{filter.value}%")).order_by(
|
|
1191
992
|
models.Prompt.updated_at.desc()
|
|
1192
993
|
)
|
|
994
|
+
if labelIds:
|
|
995
|
+
stmt = stmt.join(models.PromptPromptLabel).where(
|
|
996
|
+
models.PromptPromptLabel.prompt_label_id.in_(
|
|
997
|
+
from_global_id_with_expected_type(
|
|
998
|
+
global_id=label_id, expected_type_name="PromptLabel"
|
|
999
|
+
)
|
|
1000
|
+
for label_id in labelIds
|
|
1001
|
+
)
|
|
1002
|
+
)
|
|
1003
|
+
stmt = stmt.distinct()
|
|
1193
1004
|
async with info.context.db() as session:
|
|
1194
1005
|
orm_prompts = await session.stream_scalars(stmt)
|
|
1195
|
-
data = [
|
|
1006
|
+
data = [
|
|
1007
|
+
Prompt(id=orm_prompt.id, db_record=orm_prompt) async for orm_prompt in orm_prompts
|
|
1008
|
+
]
|
|
1196
1009
|
return connection_from_list(
|
|
1197
1010
|
data=data,
|
|
1198
1011
|
args=args,
|
|
@@ -1215,7 +1028,58 @@ class Query:
|
|
|
1215
1028
|
)
|
|
1216
1029
|
async with info.context.db() as session:
|
|
1217
1030
|
prompt_labels = await session.stream_scalars(select(models.PromptLabel))
|
|
1218
|
-
data = [
|
|
1031
|
+
data = [
|
|
1032
|
+
PromptLabel(id=prompt_label.id, db_record=prompt_label)
|
|
1033
|
+
async for prompt_label in prompt_labels
|
|
1034
|
+
]
|
|
1035
|
+
return connection_from_list(
|
|
1036
|
+
data=data,
|
|
1037
|
+
args=args,
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
@strawberry.field
|
|
1041
|
+
async def dataset_labels(
|
|
1042
|
+
self,
|
|
1043
|
+
info: Info[Context, None],
|
|
1044
|
+
first: Optional[int] = 50,
|
|
1045
|
+
last: Optional[int] = UNSET,
|
|
1046
|
+
after: Optional[CursorString] = UNSET,
|
|
1047
|
+
before: Optional[CursorString] = UNSET,
|
|
1048
|
+
) -> Connection[DatasetLabel]:
|
|
1049
|
+
args = ConnectionArgs(
|
|
1050
|
+
first=first,
|
|
1051
|
+
after=after if isinstance(after, CursorString) else None,
|
|
1052
|
+
last=last,
|
|
1053
|
+
before=before if isinstance(before, CursorString) else None,
|
|
1054
|
+
)
|
|
1055
|
+
async with info.context.db() as session:
|
|
1056
|
+
dataset_labels = await session.scalars(
|
|
1057
|
+
select(models.DatasetLabel).order_by(models.DatasetLabel.name.asc())
|
|
1058
|
+
)
|
|
1059
|
+
data = [
|
|
1060
|
+
DatasetLabel(id=dataset_label.id, db_record=dataset_label)
|
|
1061
|
+
for dataset_label in dataset_labels
|
|
1062
|
+
]
|
|
1063
|
+
return connection_from_list(data=data, args=args)
|
|
1064
|
+
|
|
1065
|
+
@strawberry.field
|
|
1066
|
+
async def dataset_splits(
|
|
1067
|
+
self,
|
|
1068
|
+
info: Info[Context, None],
|
|
1069
|
+
first: Optional[int] = 50,
|
|
1070
|
+
last: Optional[int] = UNSET,
|
|
1071
|
+
after: Optional[CursorString] = UNSET,
|
|
1072
|
+
before: Optional[CursorString] = UNSET,
|
|
1073
|
+
) -> Connection[DatasetSplit]:
|
|
1074
|
+
args = ConnectionArgs(
|
|
1075
|
+
first=first,
|
|
1076
|
+
after=after if isinstance(after, CursorString) else None,
|
|
1077
|
+
last=last,
|
|
1078
|
+
before=before if isinstance(before, CursorString) else None,
|
|
1079
|
+
)
|
|
1080
|
+
async with info.context.db() as session:
|
|
1081
|
+
splits = await session.stream_scalars(select(models.DatasetSplit))
|
|
1082
|
+
data = [DatasetSplit(id=split.id, db_record=split) async for split in splits]
|
|
1219
1083
|
return connection_from_list(
|
|
1220
1084
|
data=data,
|
|
1221
1085
|
args=args,
|
|
@@ -1486,7 +1350,7 @@ class Query:
|
|
|
1486
1350
|
async with info.context.db() as session:
|
|
1487
1351
|
span_rowid = await session.scalar(stmt)
|
|
1488
1352
|
if span_rowid:
|
|
1489
|
-
return Span(
|
|
1353
|
+
return Span(id=span_rowid)
|
|
1490
1354
|
return None
|
|
1491
1355
|
|
|
1492
1356
|
@strawberry.field
|
|
@@ -1499,7 +1363,7 @@ class Query:
|
|
|
1499
1363
|
async with info.context.db() as session:
|
|
1500
1364
|
trace_rowid = await session.scalar(stmt)
|
|
1501
1365
|
if trace_rowid:
|
|
1502
|
-
return Trace(
|
|
1366
|
+
return Trace(id=trace_rowid)
|
|
1503
1367
|
return None
|
|
1504
1368
|
|
|
1505
1369
|
@strawberry.field
|
|
@@ -1512,7 +1376,7 @@ class Query:
|
|
|
1512
1376
|
async with info.context.db() as session:
|
|
1513
1377
|
session_row = await session.scalar(stmt)
|
|
1514
1378
|
if session_row:
|
|
1515
|
-
return
|
|
1379
|
+
return ProjectSession(id=session_row.id, db_record=session_row)
|
|
1516
1380
|
return None
|
|
1517
1381
|
|
|
1518
1382
|
|
|
@@ -1550,16 +1414,36 @@ def _longest_matching_prefix(s: str, prefixes: Iterable[str]) -> str:
|
|
|
1550
1414
|
return longest
|
|
1551
1415
|
|
|
1552
1416
|
|
|
1553
|
-
def
|
|
1554
|
-
|
|
1555
|
-
|
|
1417
|
+
def _comparison_count_expression(
|
|
1418
|
+
*,
|
|
1419
|
+
base_column: ColumnElement[Any],
|
|
1420
|
+
compare_column: ColumnElement[Any],
|
|
1421
|
+
optimization_direction: Literal["maximize", "minimize"],
|
|
1422
|
+
comparison_type: Literal["improvement", "regression", "equality"],
|
|
1423
|
+
) -> ColumnElement[int]:
|
|
1556
1424
|
"""
|
|
1557
|
-
|
|
1425
|
+
Given a base and compare column, returns an expression counting the number of
|
|
1426
|
+
improvements, regressions, or equalities given the optimization direction.
|
|
1558
1427
|
"""
|
|
1428
|
+
if optimization_direction == "maximize":
|
|
1429
|
+
raise NotImplementedError
|
|
1430
|
+
|
|
1431
|
+
if comparison_type == "improvement":
|
|
1432
|
+
condition = compare_column > base_column
|
|
1433
|
+
elif comparison_type == "regression":
|
|
1434
|
+
condition = compare_column < base_column
|
|
1435
|
+
elif comparison_type == "equality":
|
|
1436
|
+
condition = compare_column == base_column
|
|
1437
|
+
else:
|
|
1438
|
+
assert_never(comparison_type)
|
|
1439
|
+
|
|
1559
1440
|
return func.coalesce(
|
|
1560
1441
|
func.sum(
|
|
1561
1442
|
case(
|
|
1562
|
-
(
|
|
1443
|
+
(
|
|
1444
|
+
condition,
|
|
1445
|
+
1,
|
|
1446
|
+
),
|
|
1563
1447
|
else_=0,
|
|
1564
1448
|
)
|
|
1565
1449
|
),
|