arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
phoenix/server/api/queries.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from collections import defaultdict
|
|
2
3
|
from datetime import datetime
|
|
3
|
-
from typing import Iterable, Iterator, Optional, Union
|
|
4
|
+
from typing import Any, Iterable, Iterator, Literal, Optional, Union
|
|
5
|
+
from typing import cast as type_cast
|
|
4
6
|
|
|
5
7
|
import numpy as np
|
|
6
8
|
import numpy.typing as npt
|
|
7
9
|
import strawberry
|
|
8
|
-
from sqlalchemy import and_,
|
|
9
|
-
from sqlalchemy.orm import joinedload
|
|
10
|
+
from sqlalchemy import ColumnElement, String, and_, case, cast, func, select, text
|
|
11
|
+
from sqlalchemy.orm import joinedload, load_only
|
|
10
12
|
from starlette.authentication import UnauthenticatedUser
|
|
11
13
|
from strawberry import ID, UNSET
|
|
12
14
|
from strawberry.relay import Connection, GlobalID, Node
|
|
@@ -18,19 +20,17 @@ from phoenix.config import (
|
|
|
18
20
|
get_env_database_allocated_storage_capacity_gibibytes,
|
|
19
21
|
getenv,
|
|
20
22
|
)
|
|
21
|
-
from phoenix.db import
|
|
23
|
+
from phoenix.db import models
|
|
22
24
|
from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
|
|
23
|
-
from phoenix.db.helpers import
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
from phoenix.db.models import
|
|
28
|
-
from phoenix.db.models import ExperimentRun as OrmExperimentRun
|
|
29
|
-
from phoenix.db.models import Trace as OrmTrace
|
|
25
|
+
from phoenix.db.helpers import (
|
|
26
|
+
SupportedSQLDialect,
|
|
27
|
+
exclude_experiment_projects,
|
|
28
|
+
)
|
|
29
|
+
from phoenix.db.models import LatencyMs
|
|
30
30
|
from phoenix.pointcloud.clustering import Hdbscan
|
|
31
31
|
from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
|
|
32
32
|
from phoenix.server.api.context import Context
|
|
33
|
-
from phoenix.server.api.exceptions import NotFound, Unauthorized
|
|
33
|
+
from phoenix.server.api.exceptions import BadRequest, NotFound, Unauthorized
|
|
34
34
|
from phoenix.server.api.helpers import ensure_list
|
|
35
35
|
from phoenix.server.api.helpers.experiment_run_filters import (
|
|
36
36
|
ExperimentRunFilterConditionSyntaxError,
|
|
@@ -41,14 +41,18 @@ from phoenix.server.api.helpers.playground_clients import initialize_playground_
|
|
|
41
41
|
from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
|
|
42
42
|
from phoenix.server.api.input_types.ClusterInput import ClusterInput
|
|
43
43
|
from phoenix.server.api.input_types.Coordinates import InputCoordinate2D, InputCoordinate3D
|
|
44
|
+
from phoenix.server.api.input_types.DatasetFilter import DatasetFilter
|
|
44
45
|
from phoenix.server.api.input_types.DatasetSort import DatasetSort
|
|
45
46
|
from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
|
|
46
47
|
from phoenix.server.api.input_types.ProjectFilter import ProjectFilter
|
|
47
48
|
from phoenix.server.api.input_types.ProjectSort import ProjectColumn, ProjectSort
|
|
49
|
+
from phoenix.server.api.input_types.PromptFilter import PromptFilter
|
|
48
50
|
from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_annotation_config
|
|
49
51
|
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
50
|
-
from phoenix.server.api.types.Dataset import Dataset
|
|
52
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
51
53
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
54
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel
|
|
55
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit
|
|
52
56
|
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
53
57
|
from phoenix.server.api.types.EmbeddingDimension import (
|
|
54
58
|
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
@@ -58,30 +62,48 @@ from phoenix.server.api.types.EmbeddingDimension import (
|
|
|
58
62
|
)
|
|
59
63
|
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
60
64
|
from phoenix.server.api.types.Experiment import Experiment
|
|
61
|
-
from phoenix.server.api.types.ExperimentComparison import
|
|
62
|
-
|
|
65
|
+
from phoenix.server.api.types.ExperimentComparison import (
|
|
66
|
+
ExperimentComparison,
|
|
67
|
+
)
|
|
68
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
69
|
+
ExperimentRepeatedRunGroup,
|
|
70
|
+
parse_experiment_repeated_run_group_node_id,
|
|
71
|
+
)
|
|
72
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
63
73
|
from phoenix.server.api.types.Functionality import Functionality
|
|
64
74
|
from phoenix.server.api.types.GenerativeModel import GenerativeModel
|
|
65
75
|
from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
|
|
76
|
+
from phoenix.server.api.types.InferenceModel import InferenceModel
|
|
66
77
|
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
67
|
-
from phoenix.server.api.types.
|
|
68
|
-
|
|
69
|
-
|
|
78
|
+
from phoenix.server.api.types.node import (
|
|
79
|
+
from_global_id,
|
|
80
|
+
from_global_id_with_expected_type,
|
|
81
|
+
is_global_id,
|
|
82
|
+
)
|
|
83
|
+
from phoenix.server.api.types.pagination import (
|
|
84
|
+
ConnectionArgs,
|
|
85
|
+
Cursor,
|
|
86
|
+
CursorString,
|
|
87
|
+
connection_from_cursors_and_nodes,
|
|
88
|
+
connection_from_list,
|
|
89
|
+
)
|
|
90
|
+
from phoenix.server.api.types.PlaygroundModel import PlaygroundModel
|
|
70
91
|
from phoenix.server.api.types.Project import Project
|
|
71
|
-
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
92
|
+
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
72
93
|
from phoenix.server.api.types.ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
|
|
73
|
-
from phoenix.server.api.types.Prompt import Prompt
|
|
74
|
-
from phoenix.server.api.types.PromptLabel import PromptLabel
|
|
94
|
+
from phoenix.server.api.types.Prompt import Prompt
|
|
95
|
+
from phoenix.server.api.types.PromptLabel import PromptLabel
|
|
75
96
|
from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_version
|
|
76
|
-
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
97
|
+
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
98
|
+
from phoenix.server.api.types.ServerStatus import ServerStatus
|
|
77
99
|
from phoenix.server.api.types.SortDir import SortDir
|
|
78
100
|
from phoenix.server.api.types.Span import Span
|
|
79
|
-
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
|
|
101
|
+
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
|
|
80
102
|
from phoenix.server.api.types.SystemApiKey import SystemApiKey
|
|
81
103
|
from phoenix.server.api.types.Trace import Trace
|
|
82
|
-
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
|
|
83
|
-
from phoenix.server.api.types.User import User
|
|
84
|
-
from phoenix.server.api.types.UserApiKey import UserApiKey
|
|
104
|
+
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
|
|
105
|
+
from phoenix.server.api.types.User import User
|
|
106
|
+
from phoenix.server.api.types.UserApiKey import UserApiKey
|
|
85
107
|
from phoenix.server.api.types.UserRole import UserRole
|
|
86
108
|
from phoenix.server.api.types.ValidationResult import ValidationResult
|
|
87
109
|
|
|
@@ -100,6 +122,55 @@ class DbTableStats:
|
|
|
100
122
|
num_bytes: float
|
|
101
123
|
|
|
102
124
|
|
|
125
|
+
@strawberry.type
|
|
126
|
+
class ExperimentRunMetricComparison:
|
|
127
|
+
num_runs_improved: int = strawberry.field(
|
|
128
|
+
description=(
|
|
129
|
+
"The number of runs in which the base experiment improved "
|
|
130
|
+
"on the best run in any compare experiment."
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
num_runs_regressed: int = strawberry.field(
|
|
134
|
+
description=(
|
|
135
|
+
"The number of runs in which the base experiment regressed "
|
|
136
|
+
"on the best run in any compare experiment."
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
num_runs_equal: int = strawberry.field(
|
|
140
|
+
description=(
|
|
141
|
+
"The number of runs in which the base experiment is equal to the best run "
|
|
142
|
+
"in any compare experiment."
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
num_total_runs: strawberry.Private[int]
|
|
146
|
+
|
|
147
|
+
@strawberry.field(
|
|
148
|
+
description=(
|
|
149
|
+
"The number of runs in the base experiment that could not be compared, either because "
|
|
150
|
+
"the base experiment run was missing a value or because all compare experiment runs "
|
|
151
|
+
"were missing values."
|
|
152
|
+
)
|
|
153
|
+
) # type: ignore[misc]
|
|
154
|
+
def num_runs_without_comparison(self) -> int:
|
|
155
|
+
return (
|
|
156
|
+
self.num_total_runs
|
|
157
|
+
- self.num_runs_improved
|
|
158
|
+
- self.num_runs_regressed
|
|
159
|
+
- self.num_runs_equal
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@strawberry.type
|
|
164
|
+
class ExperimentRunMetricComparisons:
|
|
165
|
+
latency: ExperimentRunMetricComparison
|
|
166
|
+
total_token_count: ExperimentRunMetricComparison
|
|
167
|
+
prompt_token_count: ExperimentRunMetricComparison
|
|
168
|
+
completion_token_count: ExperimentRunMetricComparison
|
|
169
|
+
total_cost: ExperimentRunMetricComparison
|
|
170
|
+
prompt_cost: ExperimentRunMetricComparison
|
|
171
|
+
completion_cost: ExperimentRunMetricComparison
|
|
172
|
+
|
|
173
|
+
|
|
103
174
|
@strawberry.type
|
|
104
175
|
class Query:
|
|
105
176
|
@strawberry.field
|
|
@@ -114,20 +185,50 @@ class Query:
|
|
|
114
185
|
]
|
|
115
186
|
|
|
116
187
|
@strawberry.field
|
|
117
|
-
async def
|
|
188
|
+
async def generative_models(
|
|
189
|
+
self,
|
|
190
|
+
info: Info[Context, None],
|
|
191
|
+
first: Optional[int] = 50,
|
|
192
|
+
last: Optional[int] = UNSET,
|
|
193
|
+
after: Optional[CursorString] = UNSET,
|
|
194
|
+
before: Optional[CursorString] = UNSET,
|
|
195
|
+
) -> Connection[GenerativeModel]:
|
|
196
|
+
args = ConnectionArgs(
|
|
197
|
+
first=first,
|
|
198
|
+
after=after if isinstance(after, CursorString) else None,
|
|
199
|
+
last=last,
|
|
200
|
+
before=before if isinstance(before, CursorString) else None,
|
|
201
|
+
)
|
|
202
|
+
async with info.context.db() as session:
|
|
203
|
+
result = await session.scalars(
|
|
204
|
+
select(models.GenerativeModel)
|
|
205
|
+
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
206
|
+
.order_by(
|
|
207
|
+
models.GenerativeModel.is_built_in.asc(), # display custom models first
|
|
208
|
+
models.GenerativeModel.provider.nullslast(),
|
|
209
|
+
models.GenerativeModel.name,
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
data = [GenerativeModel(id=model.id, db_record=model) for model in result.unique()]
|
|
213
|
+
return connection_from_list(data=data, args=args)
|
|
214
|
+
|
|
215
|
+
@strawberry.field
|
|
216
|
+
async def playground_models(self, input: Optional[ModelsInput] = None) -> list[PlaygroundModel]:
|
|
118
217
|
if input is not None and input.provider_key is not None:
|
|
119
218
|
supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
|
|
120
219
|
supported_models = [
|
|
121
|
-
|
|
220
|
+
PlaygroundModel(name_value=model_name, provider_key_value=input.provider_key)
|
|
122
221
|
for model_name in supported_model_names
|
|
123
222
|
]
|
|
124
223
|
return supported_models
|
|
125
224
|
|
|
126
225
|
registered_models = PLAYGROUND_CLIENT_REGISTRY.list_all_models()
|
|
127
|
-
all_models: list[
|
|
226
|
+
all_models: list[PlaygroundModel] = []
|
|
128
227
|
for provider_key, model_name in registered_models:
|
|
129
228
|
if model_name is not None and provider_key is not None:
|
|
130
|
-
all_models.append(
|
|
229
|
+
all_models.append(
|
|
230
|
+
PlaygroundModel(name_value=model_name, provider_key_value=provider_key)
|
|
231
|
+
)
|
|
131
232
|
return all_models
|
|
132
233
|
|
|
133
234
|
@strawberry.field
|
|
@@ -165,13 +266,13 @@ class Query:
|
|
|
165
266
|
stmt = (
|
|
166
267
|
select(models.User)
|
|
167
268
|
.join(models.UserRole)
|
|
168
|
-
.where(models.UserRole.name !=
|
|
269
|
+
.where(models.UserRole.name != "SYSTEM")
|
|
169
270
|
.order_by(models.User.email)
|
|
170
271
|
.options(joinedload(models.User.role))
|
|
171
272
|
)
|
|
172
273
|
async with info.context.db() as session:
|
|
173
274
|
users = await session.stream_scalars(stmt)
|
|
174
|
-
data = [
|
|
275
|
+
data = [User(id=user.id, db_record=user) async for user in users]
|
|
175
276
|
return connection_from_list(data=data, args=args)
|
|
176
277
|
|
|
177
278
|
@strawberry.field
|
|
@@ -181,7 +282,7 @@ class Query:
|
|
|
181
282
|
) -> list[UserRole]:
|
|
182
283
|
async with info.context.db() as session:
|
|
183
284
|
roles = await session.scalars(
|
|
184
|
-
select(models.UserRole).where(models.UserRole.name !=
|
|
285
|
+
select(models.UserRole).where(models.UserRole.name != "SYSTEM")
|
|
185
286
|
)
|
|
186
287
|
return [
|
|
187
288
|
UserRole(
|
|
@@ -197,11 +298,11 @@ class Query:
|
|
|
197
298
|
select(models.ApiKey)
|
|
198
299
|
.join(models.User)
|
|
199
300
|
.join(models.UserRole)
|
|
200
|
-
.where(models.UserRole.name !=
|
|
301
|
+
.where(models.UserRole.name != "SYSTEM")
|
|
201
302
|
)
|
|
202
303
|
async with info.context.db() as session:
|
|
203
304
|
api_keys = await session.scalars(stmt)
|
|
204
|
-
return [
|
|
305
|
+
return [UserApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
|
|
205
306
|
|
|
206
307
|
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
|
|
207
308
|
async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
|
|
@@ -209,20 +310,11 @@ class Query:
|
|
|
209
310
|
select(models.ApiKey)
|
|
210
311
|
.join(models.User)
|
|
211
312
|
.join(models.UserRole)
|
|
212
|
-
.where(models.UserRole.name ==
|
|
313
|
+
.where(models.UserRole.name == "SYSTEM")
|
|
213
314
|
)
|
|
214
315
|
async with info.context.db() as session:
|
|
215
316
|
api_keys = await session.scalars(stmt)
|
|
216
|
-
return [
|
|
217
|
-
SystemApiKey(
|
|
218
|
-
id_attr=api_key.id,
|
|
219
|
-
name=api_key.name,
|
|
220
|
-
description=api_key.description,
|
|
221
|
-
created_at=api_key.created_at,
|
|
222
|
-
expires_at=api_key.expires_at,
|
|
223
|
-
)
|
|
224
|
-
for api_key in api_keys
|
|
225
|
-
]
|
|
317
|
+
return [SystemApiKey(id=api_key.id, db_record=api_key) for api_key in api_keys]
|
|
226
318
|
|
|
227
319
|
@strawberry.field
|
|
228
320
|
async def projects(
|
|
@@ -263,13 +355,7 @@ class Query:
|
|
|
263
355
|
stmt = exclude_experiment_projects(stmt)
|
|
264
356
|
async with info.context.db() as session:
|
|
265
357
|
projects = await session.stream_scalars(stmt)
|
|
266
|
-
data = [
|
|
267
|
-
Project(
|
|
268
|
-
project_rowid=project.id,
|
|
269
|
-
db_project=project,
|
|
270
|
-
)
|
|
271
|
-
async for project in projects
|
|
272
|
-
]
|
|
358
|
+
data = [Project(id=project.id, db_record=project) async for project in projects]
|
|
273
359
|
return connection_from_list(data=data, args=args)
|
|
274
360
|
|
|
275
361
|
@strawberry.field
|
|
@@ -285,6 +371,7 @@ class Query:
|
|
|
285
371
|
after: Optional[CursorString] = UNSET,
|
|
286
372
|
before: Optional[CursorString] = UNSET,
|
|
287
373
|
sort: Optional[DatasetSort] = UNSET,
|
|
374
|
+
filter: Optional[DatasetFilter] = UNSET,
|
|
288
375
|
) -> Connection[Dataset]:
|
|
289
376
|
args = ConnectionArgs(
|
|
290
377
|
first=first,
|
|
@@ -296,10 +383,40 @@ class Query:
|
|
|
296
383
|
if sort:
|
|
297
384
|
sort_col = getattr(models.Dataset, sort.col.value)
|
|
298
385
|
stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
|
|
386
|
+
if filter:
|
|
387
|
+
# Apply name filter
|
|
388
|
+
if filter.col and filter.value:
|
|
389
|
+
stmt = stmt.where(
|
|
390
|
+
getattr(models.Dataset, filter.col.value).ilike(f"%{filter.value}%")
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Apply label filter
|
|
394
|
+
if filter.filter_labels and filter.filter_labels is not UNSET:
|
|
395
|
+
label_rowids = []
|
|
396
|
+
for label_id in filter.filter_labels:
|
|
397
|
+
try:
|
|
398
|
+
label_rowid = from_global_id_with_expected_type(
|
|
399
|
+
global_id=GlobalID.from_id(label_id),
|
|
400
|
+
expected_type_name="DatasetLabel",
|
|
401
|
+
)
|
|
402
|
+
label_rowids.append(label_rowid)
|
|
403
|
+
except ValueError:
|
|
404
|
+
continue # Skip invalid label IDs
|
|
405
|
+
|
|
406
|
+
if label_rowids:
|
|
407
|
+
# Join with the junction table to filter by labels
|
|
408
|
+
stmt = (
|
|
409
|
+
stmt.join(
|
|
410
|
+
models.DatasetsDatasetLabel,
|
|
411
|
+
models.Dataset.id == models.DatasetsDatasetLabel.dataset_id,
|
|
412
|
+
)
|
|
413
|
+
.where(models.DatasetsDatasetLabel.dataset_label_id.in_(label_rowids))
|
|
414
|
+
.distinct()
|
|
415
|
+
)
|
|
299
416
|
async with info.context.db() as session:
|
|
300
417
|
datasets = await session.scalars(stmt)
|
|
301
418
|
return connection_from_list(
|
|
302
|
-
data=[
|
|
419
|
+
data=[Dataset(id=dataset.id, db_record=dataset) for dataset in datasets], args=args
|
|
303
420
|
)
|
|
304
421
|
|
|
305
422
|
@strawberry.field
|
|
@@ -310,122 +427,429 @@ class Query:
|
|
|
310
427
|
async def compare_experiments(
|
|
311
428
|
self,
|
|
312
429
|
info: Info[Context, None],
|
|
313
|
-
|
|
430
|
+
base_experiment_id: GlobalID,
|
|
431
|
+
compare_experiment_ids: list[GlobalID],
|
|
432
|
+
first: Optional[int] = 50,
|
|
433
|
+
after: Optional[CursorString] = UNSET,
|
|
314
434
|
filter_condition: Optional[str] = UNSET,
|
|
315
|
-
) ->
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
435
|
+
) -> Connection[ExperimentComparison]:
|
|
436
|
+
if base_experiment_id in compare_experiment_ids:
|
|
437
|
+
raise BadRequest("Compare experiment IDs cannot contain the base experiment ID")
|
|
438
|
+
if len(set(compare_experiment_ids)) < len(compare_experiment_ids):
|
|
439
|
+
raise BadRequest("Compare experiment IDs must be unique")
|
|
440
|
+
|
|
441
|
+
try:
|
|
442
|
+
base_experiment_rowid = from_global_id_with_expected_type(
|
|
443
|
+
base_experiment_id, models.Experiment.__name__
|
|
444
|
+
)
|
|
445
|
+
except ValueError:
|
|
446
|
+
raise BadRequest(f"Invalid base experiment ID: {base_experiment_id}")
|
|
447
|
+
|
|
448
|
+
compare_experiment_rowids = []
|
|
449
|
+
for compare_experiment_id in compare_experiment_ids:
|
|
450
|
+
try:
|
|
451
|
+
compare_experiment_rowids.append(
|
|
452
|
+
from_global_id_with_expected_type(
|
|
453
|
+
compare_experiment_id, models.Experiment.__name__
|
|
454
|
+
)
|
|
455
|
+
)
|
|
456
|
+
except ValueError:
|
|
457
|
+
raise BadRequest(f"Invalid compare experiment ID: {compare_experiment_id}")
|
|
458
|
+
|
|
459
|
+
experiment_rowids = [base_experiment_rowid, *compare_experiment_rowids]
|
|
460
|
+
|
|
461
|
+
cursor = Cursor.from_string(after) if after else None
|
|
462
|
+
page_size = first or 50
|
|
322
463
|
|
|
323
464
|
async with info.context.db() as session:
|
|
324
|
-
|
|
325
|
-
await session.
|
|
465
|
+
experiments = (
|
|
466
|
+
await session.scalars(
|
|
326
467
|
select(
|
|
327
|
-
|
|
328
|
-
func.max(OrmVersion.dataset_id),
|
|
329
|
-
func.max(OrmVersion.id),
|
|
330
|
-
func.count(OrmExperiment.id),
|
|
331
|
-
)
|
|
332
|
-
.select_from(OrmVersion)
|
|
333
|
-
.join(
|
|
334
|
-
OrmExperiment,
|
|
335
|
-
OrmExperiment.dataset_version_id == OrmVersion.id,
|
|
336
|
-
)
|
|
337
|
-
.where(
|
|
338
|
-
OrmExperiment.id.in_(experiment_ids_),
|
|
468
|
+
models.Experiment,
|
|
339
469
|
)
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
num_datasets, dataset_id, version_id, num_resolved_experiment_ids = validation_result
|
|
346
|
-
if num_datasets != 1:
|
|
347
|
-
raise ValueError("Experiments must belong to the same dataset.")
|
|
348
|
-
if num_resolved_experiment_ids != len(experiment_ids_):
|
|
349
|
-
raise ValueError("Unable to resolve one or more experiment IDs.")
|
|
350
|
-
|
|
351
|
-
revision_ids = (
|
|
352
|
-
select(func.max(OrmRevision.id))
|
|
353
|
-
.join(OrmExample, OrmExample.id == OrmRevision.dataset_example_id)
|
|
354
|
-
.where(
|
|
355
|
-
and_(
|
|
356
|
-
OrmRevision.dataset_version_id <= version_id,
|
|
357
|
-
OrmExample.dataset_id == dataset_id,
|
|
470
|
+
.where(models.Experiment.id.in_(experiment_rowids))
|
|
471
|
+
.options(
|
|
472
|
+
load_only(
|
|
473
|
+
models.Experiment.dataset_id, models.Experiment.dataset_version_id
|
|
474
|
+
)
|
|
358
475
|
)
|
|
359
476
|
)
|
|
360
|
-
|
|
361
|
-
|
|
477
|
+
).all()
|
|
478
|
+
|
|
479
|
+
if not experiments or len(experiments) < len(experiment_rowids):
|
|
480
|
+
raise NotFound("Unable to resolve one or more experiment IDs.")
|
|
481
|
+
num_datasets = len(set(experiment.dataset_id for experiment in experiments))
|
|
482
|
+
if num_datasets > 1:
|
|
483
|
+
raise BadRequest("Experiments must belong to the same dataset.")
|
|
484
|
+
base_experiment = next(
|
|
485
|
+
experiment for experiment in experiments if experiment.id == base_experiment_rowid
|
|
362
486
|
)
|
|
487
|
+
|
|
488
|
+
# Use ExperimentDatasetExample to pull down examples.
|
|
489
|
+
# Splits are mutable and should not be used for comparison.
|
|
490
|
+
# The comparison should only occur against examples which were assigned to the same
|
|
491
|
+
# splits at the time of execution of the ExperimentRun.
|
|
363
492
|
examples_query = (
|
|
364
|
-
select(
|
|
365
|
-
.
|
|
366
|
-
.
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
OrmExample.id == OrmRevision.dataset_example_id,
|
|
370
|
-
OrmRevision.id.in_(revision_ids),
|
|
371
|
-
OrmRevision.revision_kind != "DELETE",
|
|
372
|
-
),
|
|
373
|
-
)
|
|
374
|
-
.order_by(OrmExample.id.desc())
|
|
493
|
+
select(models.DatasetExample)
|
|
494
|
+
.join(models.ExperimentDatasetExample)
|
|
495
|
+
.where(models.ExperimentDatasetExample.experiment_id == base_experiment_rowid)
|
|
496
|
+
.order_by(models.DatasetExample.id.desc())
|
|
497
|
+
.limit(page_size + 1)
|
|
375
498
|
)
|
|
376
499
|
|
|
500
|
+
if cursor is not None:
|
|
501
|
+
examples_query = examples_query.where(models.DatasetExample.id < cursor.rowid)
|
|
502
|
+
|
|
377
503
|
if filter_condition:
|
|
378
504
|
examples_query = update_examples_query_with_filter_condition(
|
|
379
505
|
query=examples_query,
|
|
380
506
|
filter_condition=filter_condition,
|
|
381
|
-
experiment_ids=
|
|
507
|
+
experiment_ids=experiment_rowids,
|
|
382
508
|
)
|
|
383
509
|
|
|
384
510
|
examples = (await session.scalars(examples_query)).all()
|
|
511
|
+
has_next_page = len(examples) > page_size
|
|
512
|
+
examples = examples[:page_size]
|
|
385
513
|
|
|
386
514
|
ExampleID: TypeAlias = int
|
|
387
515
|
ExperimentID: TypeAlias = int
|
|
388
|
-
runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[
|
|
516
|
+
runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[models.ExperimentRun]]] = (
|
|
389
517
|
defaultdict(lambda: defaultdict(list))
|
|
390
518
|
)
|
|
391
519
|
async for run in await session.stream_scalars(
|
|
392
|
-
select(
|
|
520
|
+
select(models.ExperimentRun)
|
|
393
521
|
.where(
|
|
394
522
|
and_(
|
|
395
|
-
|
|
396
|
-
|
|
523
|
+
models.ExperimentRun.dataset_example_id.in_(
|
|
524
|
+
example.id for example in examples
|
|
525
|
+
),
|
|
526
|
+
models.ExperimentRun.experiment_id.in_(experiment_rowids),
|
|
397
527
|
)
|
|
398
528
|
)
|
|
399
|
-
.options(joinedload(
|
|
529
|
+
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
530
|
+
.order_by(
|
|
531
|
+
models.ExperimentRun.repetition_number.asc()
|
|
532
|
+
) # repetitions are not currently implemented, but this ensures that the repetitions will be properly ordered once implemented # noqa: E501
|
|
400
533
|
):
|
|
401
534
|
runs[run.dataset_example_id][run.experiment_id].append(run)
|
|
402
535
|
|
|
403
|
-
|
|
536
|
+
cursors_and_nodes = []
|
|
404
537
|
for example in examples:
|
|
405
|
-
|
|
406
|
-
for experiment_id in
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
538
|
+
repeated_run_groups = []
|
|
539
|
+
for experiment_id in experiment_rowids:
|
|
540
|
+
repeated_run_groups.append(
|
|
541
|
+
ExperimentRepeatedRunGroup(
|
|
542
|
+
experiment_rowid=experiment_id,
|
|
543
|
+
dataset_example_rowid=example.id,
|
|
544
|
+
cached_runs=[
|
|
545
|
+
ExperimentRun(id=run.id, db_record=run)
|
|
412
546
|
for run in sorted(
|
|
413
|
-
runs[example.id][experiment_id],
|
|
547
|
+
runs[example.id][experiment_id],
|
|
548
|
+
key=lambda run: run.repetition_number,
|
|
414
549
|
)
|
|
415
550
|
],
|
|
416
551
|
)
|
|
417
552
|
)
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
553
|
+
experiment_comparison = ExperimentComparison(
|
|
554
|
+
id_attr=example.id,
|
|
555
|
+
example=DatasetExample(
|
|
556
|
+
id=example.id,
|
|
557
|
+
db_record=example,
|
|
558
|
+
version_id=base_experiment.dataset_version_id,
|
|
559
|
+
),
|
|
560
|
+
repeated_run_groups=repeated_run_groups,
|
|
561
|
+
)
|
|
562
|
+
cursors_and_nodes.append((Cursor(rowid=example.id), experiment_comparison))
|
|
563
|
+
|
|
564
|
+
return connection_from_cursors_and_nodes(
|
|
565
|
+
cursors_and_nodes=cursors_and_nodes,
|
|
566
|
+
has_previous_page=False, # set to false since we are only doing forward pagination (https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo.Fields) # noqa: E501
|
|
567
|
+
has_next_page=has_next_page,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
@strawberry.field
|
|
571
|
+
async def experiment_run_metric_comparisons(
|
|
572
|
+
self,
|
|
573
|
+
info: Info[Context, None],
|
|
574
|
+
base_experiment_id: GlobalID,
|
|
575
|
+
compare_experiment_ids: list[GlobalID],
|
|
576
|
+
) -> ExperimentRunMetricComparisons:
|
|
577
|
+
if base_experiment_id in compare_experiment_ids:
|
|
578
|
+
raise BadRequest("Compare experiment IDs cannot contain the base experiment ID")
|
|
579
|
+
if not compare_experiment_ids:
|
|
580
|
+
raise BadRequest("At least one compare experiment ID must be provided")
|
|
581
|
+
if len(set(compare_experiment_ids)) < len(compare_experiment_ids):
|
|
582
|
+
raise BadRequest("Compare experiment IDs must be unique")
|
|
583
|
+
|
|
584
|
+
try:
|
|
585
|
+
base_experiment_rowid = from_global_id_with_expected_type(
|
|
586
|
+
base_experiment_id, models.Experiment.__name__
|
|
587
|
+
)
|
|
588
|
+
except ValueError:
|
|
589
|
+
raise BadRequest(f"Invalid base experiment ID: {base_experiment_id}")
|
|
590
|
+
|
|
591
|
+
compare_experiment_rowids = []
|
|
592
|
+
for compare_experiment_id in compare_experiment_ids:
|
|
593
|
+
try:
|
|
594
|
+
compare_experiment_rowids.append(
|
|
595
|
+
from_global_id_with_expected_type(
|
|
596
|
+
compare_experiment_id, models.Experiment.__name__
|
|
597
|
+
)
|
|
426
598
|
)
|
|
599
|
+
except ValueError:
|
|
600
|
+
raise BadRequest(f"Invalid compare experiment ID: {compare_experiment_id}")
|
|
601
|
+
|
|
602
|
+
base_experiment_runs = (
|
|
603
|
+
select(
|
|
604
|
+
models.ExperimentRun.dataset_example_id,
|
|
605
|
+
func.min(models.ExperimentRun.start_time).label("start_time"),
|
|
606
|
+
func.min(models.ExperimentRun.end_time).label("end_time"),
|
|
607
|
+
func.sum(models.SpanCost.total_tokens).label("total_tokens"),
|
|
608
|
+
func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
|
|
609
|
+
func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
|
|
610
|
+
func.sum(models.SpanCost.total_cost).label("total_cost"),
|
|
611
|
+
func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
|
|
612
|
+
func.sum(models.SpanCost.completion_cost).label("completion_cost"),
|
|
427
613
|
)
|
|
428
|
-
|
|
614
|
+
.select_from(models.ExperimentRun)
|
|
615
|
+
.join(
|
|
616
|
+
models.Trace,
|
|
617
|
+
onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
|
|
618
|
+
isouter=True,
|
|
619
|
+
)
|
|
620
|
+
.join(
|
|
621
|
+
models.SpanCost,
|
|
622
|
+
onclause=models.Trace.id == models.SpanCost.trace_rowid,
|
|
623
|
+
isouter=True,
|
|
624
|
+
)
|
|
625
|
+
.where(models.ExperimentRun.experiment_id == base_experiment_rowid)
|
|
626
|
+
.group_by(models.ExperimentRun.dataset_example_id)
|
|
627
|
+
.subquery()
|
|
628
|
+
.alias("base_experiment_runs")
|
|
629
|
+
)
|
|
630
|
+
compare_experiment_runs = (
|
|
631
|
+
select(
|
|
632
|
+
models.ExperimentRun.dataset_example_id,
|
|
633
|
+
func.min(
|
|
634
|
+
LatencyMs(models.ExperimentRun.start_time, models.ExperimentRun.end_time)
|
|
635
|
+
).label("min_latency_ms"),
|
|
636
|
+
func.min(models.SpanCost.total_tokens).label("min_total_tokens"),
|
|
637
|
+
func.min(models.SpanCost.prompt_tokens).label("min_prompt_tokens"),
|
|
638
|
+
func.min(models.SpanCost.completion_tokens).label("min_completion_tokens"),
|
|
639
|
+
func.min(models.SpanCost.total_cost).label("min_total_cost"),
|
|
640
|
+
func.min(models.SpanCost.prompt_cost).label("min_prompt_cost"),
|
|
641
|
+
func.min(models.SpanCost.completion_cost).label("min_completion_cost"),
|
|
642
|
+
)
|
|
643
|
+
.select_from(models.ExperimentRun)
|
|
644
|
+
.join(
|
|
645
|
+
models.Trace,
|
|
646
|
+
onclause=models.ExperimentRun.trace_id == models.Trace.trace_id,
|
|
647
|
+
isouter=True,
|
|
648
|
+
)
|
|
649
|
+
.join(
|
|
650
|
+
models.SpanCost,
|
|
651
|
+
onclause=models.Trace.id == models.SpanCost.trace_rowid,
|
|
652
|
+
isouter=True,
|
|
653
|
+
)
|
|
654
|
+
.where(
|
|
655
|
+
models.ExperimentRun.experiment_id.in_(compare_experiment_rowids),
|
|
656
|
+
)
|
|
657
|
+
.group_by(models.ExperimentRun.dataset_example_id)
|
|
658
|
+
.subquery()
|
|
659
|
+
.alias("comp_exp_run_mins")
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
base_experiment_run_latency = LatencyMs(
|
|
663
|
+
base_experiment_runs.c.start_time, base_experiment_runs.c.end_time
|
|
664
|
+
).label("base_experiment_run_latency_ms")
|
|
665
|
+
|
|
666
|
+
comparisons_query = (
|
|
667
|
+
select(
|
|
668
|
+
func.count().label("num_base_experiment_runs"),
|
|
669
|
+
_comparison_count_expression(
|
|
670
|
+
base_column=base_experiment_run_latency,
|
|
671
|
+
compare_column=compare_experiment_runs.c.min_latency_ms,
|
|
672
|
+
optimization_direction="minimize",
|
|
673
|
+
comparison_type="improvement",
|
|
674
|
+
).label("num_latency_improved"),
|
|
675
|
+
_comparison_count_expression(
|
|
676
|
+
base_column=base_experiment_run_latency,
|
|
677
|
+
compare_column=compare_experiment_runs.c.min_latency_ms,
|
|
678
|
+
optimization_direction="minimize",
|
|
679
|
+
comparison_type="regression",
|
|
680
|
+
).label("num_latency_regressed"),
|
|
681
|
+
_comparison_count_expression(
|
|
682
|
+
base_column=base_experiment_run_latency,
|
|
683
|
+
compare_column=compare_experiment_runs.c.min_latency_ms,
|
|
684
|
+
optimization_direction="minimize",
|
|
685
|
+
comparison_type="equality",
|
|
686
|
+
).label("num_latency_is_equal"),
|
|
687
|
+
_comparison_count_expression(
|
|
688
|
+
base_column=base_experiment_runs.c.total_tokens,
|
|
689
|
+
compare_column=compare_experiment_runs.c.min_total_tokens,
|
|
690
|
+
optimization_direction="minimize",
|
|
691
|
+
comparison_type="improvement",
|
|
692
|
+
).label("num_total_token_count_improved"),
|
|
693
|
+
_comparison_count_expression(
|
|
694
|
+
base_column=base_experiment_runs.c.total_tokens,
|
|
695
|
+
compare_column=compare_experiment_runs.c.min_total_tokens,
|
|
696
|
+
optimization_direction="minimize",
|
|
697
|
+
comparison_type="regression",
|
|
698
|
+
).label("num_total_token_count_regressed"),
|
|
699
|
+
_comparison_count_expression(
|
|
700
|
+
base_column=base_experiment_runs.c.total_tokens,
|
|
701
|
+
compare_column=compare_experiment_runs.c.min_total_tokens,
|
|
702
|
+
optimization_direction="minimize",
|
|
703
|
+
comparison_type="equality",
|
|
704
|
+
).label("num_total_token_count_is_equal"),
|
|
705
|
+
_comparison_count_expression(
|
|
706
|
+
base_column=base_experiment_runs.c.prompt_tokens,
|
|
707
|
+
compare_column=compare_experiment_runs.c.min_prompt_tokens,
|
|
708
|
+
optimization_direction="minimize",
|
|
709
|
+
comparison_type="improvement",
|
|
710
|
+
).label("num_prompt_token_count_improved"),
|
|
711
|
+
_comparison_count_expression(
|
|
712
|
+
base_column=base_experiment_runs.c.prompt_tokens,
|
|
713
|
+
compare_column=compare_experiment_runs.c.min_prompt_tokens,
|
|
714
|
+
optimization_direction="minimize",
|
|
715
|
+
comparison_type="regression",
|
|
716
|
+
).label("num_prompt_token_count_regressed"),
|
|
717
|
+
_comparison_count_expression(
|
|
718
|
+
base_column=base_experiment_runs.c.prompt_tokens,
|
|
719
|
+
compare_column=compare_experiment_runs.c.min_prompt_tokens,
|
|
720
|
+
optimization_direction="minimize",
|
|
721
|
+
comparison_type="equality",
|
|
722
|
+
).label("num_prompt_token_count_is_equal"),
|
|
723
|
+
_comparison_count_expression(
|
|
724
|
+
base_column=base_experiment_runs.c.completion_tokens,
|
|
725
|
+
compare_column=compare_experiment_runs.c.min_completion_tokens,
|
|
726
|
+
optimization_direction="minimize",
|
|
727
|
+
comparison_type="improvement",
|
|
728
|
+
).label("num_completion_token_count_improved"),
|
|
729
|
+
_comparison_count_expression(
|
|
730
|
+
base_column=base_experiment_runs.c.completion_tokens,
|
|
731
|
+
compare_column=compare_experiment_runs.c.min_completion_tokens,
|
|
732
|
+
optimization_direction="minimize",
|
|
733
|
+
comparison_type="regression",
|
|
734
|
+
).label("num_completion_token_count_regressed"),
|
|
735
|
+
_comparison_count_expression(
|
|
736
|
+
base_column=base_experiment_runs.c.completion_tokens,
|
|
737
|
+
compare_column=compare_experiment_runs.c.min_completion_tokens,
|
|
738
|
+
optimization_direction="minimize",
|
|
739
|
+
comparison_type="equality",
|
|
740
|
+
).label("num_completion_token_count_is_equal"),
|
|
741
|
+
_comparison_count_expression(
|
|
742
|
+
base_column=base_experiment_runs.c.total_cost,
|
|
743
|
+
compare_column=compare_experiment_runs.c.min_total_cost,
|
|
744
|
+
optimization_direction="minimize",
|
|
745
|
+
comparison_type="improvement",
|
|
746
|
+
).label("num_total_cost_improved"),
|
|
747
|
+
_comparison_count_expression(
|
|
748
|
+
base_column=base_experiment_runs.c.total_cost,
|
|
749
|
+
compare_column=compare_experiment_runs.c.min_total_cost,
|
|
750
|
+
optimization_direction="minimize",
|
|
751
|
+
comparison_type="regression",
|
|
752
|
+
).label("num_total_cost_regressed"),
|
|
753
|
+
_comparison_count_expression(
|
|
754
|
+
base_column=base_experiment_runs.c.total_cost,
|
|
755
|
+
compare_column=compare_experiment_runs.c.min_total_cost,
|
|
756
|
+
optimization_direction="minimize",
|
|
757
|
+
comparison_type="equality",
|
|
758
|
+
).label("num_total_cost_is_equal"),
|
|
759
|
+
_comparison_count_expression(
|
|
760
|
+
base_column=base_experiment_runs.c.prompt_cost,
|
|
761
|
+
compare_column=compare_experiment_runs.c.min_prompt_cost,
|
|
762
|
+
optimization_direction="minimize",
|
|
763
|
+
comparison_type="improvement",
|
|
764
|
+
).label("num_prompt_cost_improved"),
|
|
765
|
+
_comparison_count_expression(
|
|
766
|
+
base_column=base_experiment_runs.c.prompt_cost,
|
|
767
|
+
compare_column=compare_experiment_runs.c.min_prompt_cost,
|
|
768
|
+
optimization_direction="minimize",
|
|
769
|
+
comparison_type="regression",
|
|
770
|
+
).label("num_prompt_cost_regressed"),
|
|
771
|
+
_comparison_count_expression(
|
|
772
|
+
base_column=base_experiment_runs.c.prompt_cost,
|
|
773
|
+
compare_column=compare_experiment_runs.c.min_prompt_cost,
|
|
774
|
+
optimization_direction="minimize",
|
|
775
|
+
comparison_type="equality",
|
|
776
|
+
).label("num_prompt_cost_is_equal"),
|
|
777
|
+
_comparison_count_expression(
|
|
778
|
+
base_column=base_experiment_runs.c.completion_cost,
|
|
779
|
+
compare_column=compare_experiment_runs.c.min_completion_cost,
|
|
780
|
+
optimization_direction="minimize",
|
|
781
|
+
comparison_type="improvement",
|
|
782
|
+
).label("num_completion_cost_improved"),
|
|
783
|
+
_comparison_count_expression(
|
|
784
|
+
base_column=base_experiment_runs.c.completion_cost,
|
|
785
|
+
compare_column=compare_experiment_runs.c.min_completion_cost,
|
|
786
|
+
optimization_direction="minimize",
|
|
787
|
+
comparison_type="regression",
|
|
788
|
+
).label("num_completion_cost_regressed"),
|
|
789
|
+
_comparison_count_expression(
|
|
790
|
+
base_column=base_experiment_runs.c.completion_cost,
|
|
791
|
+
compare_column=compare_experiment_runs.c.min_completion_cost,
|
|
792
|
+
optimization_direction="minimize",
|
|
793
|
+
comparison_type="equality",
|
|
794
|
+
).label("num_completion_cost_is_equal"),
|
|
795
|
+
)
|
|
796
|
+
.select_from(base_experiment_runs)
|
|
797
|
+
.join(
|
|
798
|
+
compare_experiment_runs,
|
|
799
|
+
onclause=base_experiment_runs.c.dataset_example_id
|
|
800
|
+
== compare_experiment_runs.c.dataset_example_id,
|
|
801
|
+
isouter=True,
|
|
802
|
+
)
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
async with info.context.db() as session:
|
|
806
|
+
result = (await session.execute(comparisons_query)).first()
|
|
807
|
+
assert result is not None
|
|
808
|
+
|
|
809
|
+
return ExperimentRunMetricComparisons(
|
|
810
|
+
latency=ExperimentRunMetricComparison(
|
|
811
|
+
num_runs_improved=result.num_latency_improved,
|
|
812
|
+
num_runs_regressed=result.num_latency_regressed,
|
|
813
|
+
num_runs_equal=result.num_latency_is_equal,
|
|
814
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
815
|
+
),
|
|
816
|
+
total_token_count=ExperimentRunMetricComparison(
|
|
817
|
+
num_runs_improved=result.num_total_token_count_improved,
|
|
818
|
+
num_runs_regressed=result.num_total_token_count_regressed,
|
|
819
|
+
num_runs_equal=result.num_total_token_count_is_equal,
|
|
820
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
821
|
+
),
|
|
822
|
+
prompt_token_count=ExperimentRunMetricComparison(
|
|
823
|
+
num_runs_improved=result.num_prompt_token_count_improved,
|
|
824
|
+
num_runs_regressed=result.num_prompt_token_count_regressed,
|
|
825
|
+
num_runs_equal=result.num_prompt_token_count_is_equal,
|
|
826
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
827
|
+
),
|
|
828
|
+
completion_token_count=ExperimentRunMetricComparison(
|
|
829
|
+
num_runs_improved=result.num_completion_token_count_improved,
|
|
830
|
+
num_runs_regressed=result.num_completion_token_count_regressed,
|
|
831
|
+
num_runs_equal=result.num_completion_token_count_is_equal,
|
|
832
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
833
|
+
),
|
|
834
|
+
total_cost=ExperimentRunMetricComparison(
|
|
835
|
+
num_runs_improved=result.num_total_cost_improved,
|
|
836
|
+
num_runs_regressed=result.num_total_cost_regressed,
|
|
837
|
+
num_runs_equal=result.num_total_cost_is_equal,
|
|
838
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
839
|
+
),
|
|
840
|
+
prompt_cost=ExperimentRunMetricComparison(
|
|
841
|
+
num_runs_improved=result.num_prompt_cost_improved,
|
|
842
|
+
num_runs_regressed=result.num_prompt_cost_regressed,
|
|
843
|
+
num_runs_equal=result.num_prompt_cost_is_equal,
|
|
844
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
845
|
+
),
|
|
846
|
+
completion_cost=ExperimentRunMetricComparison(
|
|
847
|
+
num_runs_improved=result.num_completion_cost_improved,
|
|
848
|
+
num_runs_regressed=result.num_completion_cost_regressed,
|
|
849
|
+
num_runs_equal=result.num_completion_cost_is_equal,
|
|
850
|
+
num_total_runs=result.num_base_experiment_runs,
|
|
851
|
+
),
|
|
852
|
+
)
|
|
429
853
|
|
|
430
854
|
@strawberry.field
|
|
431
855
|
async def validate_experiment_run_filter_condition(
|
|
@@ -437,7 +861,7 @@ class Query:
|
|
|
437
861
|
compile_sqlalchemy_filter_condition(
|
|
438
862
|
filter_condition=condition,
|
|
439
863
|
experiment_ids=[
|
|
440
|
-
from_global_id_with_expected_type(experiment_id,
|
|
864
|
+
from_global_id_with_expected_type(experiment_id, models.Experiment.__name__)
|
|
441
865
|
for experiment_id in experiment_ids
|
|
442
866
|
],
|
|
443
867
|
)
|
|
@@ -459,140 +883,55 @@ class Query:
|
|
|
459
883
|
)
|
|
460
884
|
|
|
461
885
|
@strawberry.field
|
|
462
|
-
def model(self) ->
|
|
463
|
-
return
|
|
886
|
+
def model(self) -> InferenceModel:
|
|
887
|
+
return InferenceModel()
|
|
464
888
|
|
|
465
889
|
@strawberry.field
|
|
466
|
-
async def node(self, id:
|
|
467
|
-
|
|
890
|
+
async def node(self, id: strawberry.ID, info: Info[Context, None]) -> Node:
|
|
891
|
+
if not is_global_id(id):
|
|
892
|
+
try:
|
|
893
|
+
experiment_rowid, dataset_example_rowid = (
|
|
894
|
+
parse_experiment_repeated_run_group_node_id(id)
|
|
895
|
+
)
|
|
896
|
+
except Exception:
|
|
897
|
+
raise NotFound(f"Unknown node: {id}")
|
|
898
|
+
return ExperimentRepeatedRunGroup(
|
|
899
|
+
experiment_rowid=experiment_rowid,
|
|
900
|
+
dataset_example_rowid=dataset_example_rowid,
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
global_id = GlobalID.from_id(id)
|
|
904
|
+
type_name, node_id = from_global_id(global_id)
|
|
468
905
|
if type_name == "Dimension":
|
|
469
906
|
dimension = info.context.model.scalar_dimensions[node_id]
|
|
470
907
|
return to_gql_dimension(node_id, dimension)
|
|
471
908
|
elif type_name == "EmbeddingDimension":
|
|
472
909
|
embedding_dimension = info.context.model.embedding_dimensions[node_id]
|
|
473
910
|
return to_gql_embedding_dimension(node_id, embedding_dimension)
|
|
474
|
-
elif type_name ==
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
if project is None:
|
|
479
|
-
raise NotFound(f"Unknown project: {id}")
|
|
480
|
-
return Project(
|
|
481
|
-
project_rowid=project.id,
|
|
482
|
-
db_project=project,
|
|
483
|
-
)
|
|
484
|
-
elif type_name == "Trace":
|
|
485
|
-
trace_stmt = select(models.Trace).filter_by(id=node_id)
|
|
486
|
-
async with info.context.db() as session:
|
|
487
|
-
trace = await session.scalar(trace_stmt)
|
|
488
|
-
if trace is None:
|
|
489
|
-
raise NotFound(f"Unknown trace: {id}")
|
|
490
|
-
return Trace(trace_rowid=trace.id, db_trace=trace)
|
|
911
|
+
elif type_name == Project.__name__:
|
|
912
|
+
return Project(id=node_id)
|
|
913
|
+
elif type_name == Trace.__name__:
|
|
914
|
+
return Trace(id=node_id)
|
|
491
915
|
elif type_name == Span.__name__:
|
|
492
|
-
|
|
493
|
-
select(models.Span)
|
|
494
|
-
.options(
|
|
495
|
-
joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
|
|
496
|
-
)
|
|
497
|
-
.where(models.Span.id == node_id)
|
|
498
|
-
)
|
|
499
|
-
async with info.context.db() as session:
|
|
500
|
-
span = await session.scalar(span_stmt)
|
|
501
|
-
if span is None:
|
|
502
|
-
raise NotFound(f"Unknown span: {id}")
|
|
503
|
-
return Span(span_rowid=span.id, db_span=span)
|
|
916
|
+
return Span(id=node_id)
|
|
504
917
|
elif type_name == Dataset.__name__:
|
|
505
|
-
|
|
506
|
-
async with info.context.db() as session:
|
|
507
|
-
if (dataset := await session.scalar(dataset_stmt)) is None:
|
|
508
|
-
raise NotFound(f"Unknown dataset: {id}")
|
|
509
|
-
return to_gql_dataset(dataset)
|
|
918
|
+
return Dataset(id=node_id)
|
|
510
919
|
elif type_name == DatasetExample.__name__:
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
.where(models.DatasetExampleRevision.dataset_example_id == example_id)
|
|
515
|
-
.scalar_subquery()
|
|
516
|
-
)
|
|
517
|
-
async with info.context.db() as session:
|
|
518
|
-
example = await session.scalar(
|
|
519
|
-
select(models.DatasetExample)
|
|
520
|
-
.join(
|
|
521
|
-
models.DatasetExampleRevision,
|
|
522
|
-
onclause=models.DatasetExampleRevision.dataset_example_id
|
|
523
|
-
== models.DatasetExample.id,
|
|
524
|
-
)
|
|
525
|
-
.where(
|
|
526
|
-
and_(
|
|
527
|
-
models.DatasetExample.id == example_id,
|
|
528
|
-
models.DatasetExampleRevision.id == latest_revision_id,
|
|
529
|
-
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
530
|
-
)
|
|
531
|
-
)
|
|
532
|
-
)
|
|
533
|
-
if not example:
|
|
534
|
-
raise NotFound(f"Unknown dataset example: {id}")
|
|
535
|
-
return DatasetExample(
|
|
536
|
-
id_attr=example.id,
|
|
537
|
-
created_at=example.created_at,
|
|
538
|
-
)
|
|
920
|
+
return DatasetExample(id=node_id)
|
|
921
|
+
elif type_name == DatasetSplit.__name__:
|
|
922
|
+
return DatasetSplit(id=node_id)
|
|
539
923
|
elif type_name == Experiment.__name__:
|
|
540
|
-
|
|
541
|
-
experiment = await session.scalar(
|
|
542
|
-
select(models.Experiment).where(models.Experiment.id == node_id)
|
|
543
|
-
)
|
|
544
|
-
if not experiment:
|
|
545
|
-
raise NotFound(f"Unknown experiment: {id}")
|
|
546
|
-
return Experiment(
|
|
547
|
-
id_attr=experiment.id,
|
|
548
|
-
name=experiment.name,
|
|
549
|
-
project_name=experiment.project_name,
|
|
550
|
-
description=experiment.description,
|
|
551
|
-
created_at=experiment.created_at,
|
|
552
|
-
updated_at=experiment.updated_at,
|
|
553
|
-
metadata=experiment.metadata_,
|
|
554
|
-
)
|
|
924
|
+
return Experiment(id=node_id)
|
|
555
925
|
elif type_name == ExperimentRun.__name__:
|
|
556
|
-
|
|
557
|
-
if not (
|
|
558
|
-
run := await session.scalar(
|
|
559
|
-
select(models.ExperimentRun)
|
|
560
|
-
.where(models.ExperimentRun.id == node_id)
|
|
561
|
-
.options(
|
|
562
|
-
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
563
|
-
)
|
|
564
|
-
)
|
|
565
|
-
):
|
|
566
|
-
raise NotFound(f"Unknown experiment run: {id}")
|
|
567
|
-
return to_gql_experiment_run(run)
|
|
926
|
+
return ExperimentRun(id=node_id)
|
|
568
927
|
elif type_name == User.__name__:
|
|
569
928
|
if int((user := info.context.user).identity) != node_id and not user.is_admin:
|
|
570
929
|
raise Unauthorized(MSG_ADMIN_ONLY)
|
|
571
|
-
|
|
572
|
-
if not (
|
|
573
|
-
user := await session.scalar(
|
|
574
|
-
select(models.User).where(models.User.id == node_id)
|
|
575
|
-
)
|
|
576
|
-
):
|
|
577
|
-
raise NotFound(f"Unknown user: {id}")
|
|
578
|
-
return to_gql_user(user)
|
|
930
|
+
return User(id=node_id)
|
|
579
931
|
elif type_name == ProjectSession.__name__:
|
|
580
|
-
|
|
581
|
-
if not (
|
|
582
|
-
project_session := await session.scalar(
|
|
583
|
-
select(models.ProjectSession).filter_by(id=node_id)
|
|
584
|
-
)
|
|
585
|
-
):
|
|
586
|
-
raise NotFound(f"Unknown user: {id}")
|
|
587
|
-
return to_gql_project_session(project_session)
|
|
932
|
+
return ProjectSession(id=node_id)
|
|
588
933
|
elif type_name == Prompt.__name__:
|
|
589
|
-
|
|
590
|
-
if orm_prompt := await session.scalar(
|
|
591
|
-
select(models.Prompt).where(models.Prompt.id == node_id)
|
|
592
|
-
):
|
|
593
|
-
return to_gql_prompt_from_orm(orm_prompt)
|
|
594
|
-
else:
|
|
595
|
-
raise NotFound(f"Unknown prompt: {id}")
|
|
934
|
+
return Prompt(id=node_id)
|
|
596
935
|
elif type_name == PromptVersion.__name__:
|
|
597
936
|
async with info.context.db() as session:
|
|
598
937
|
if orm_prompt_version := await session.scalar(
|
|
@@ -602,39 +941,17 @@ class Query:
|
|
|
602
941
|
else:
|
|
603
942
|
raise NotFound(f"Unknown prompt version: {id}")
|
|
604
943
|
elif type_name == PromptLabel.__name__:
|
|
605
|
-
|
|
606
|
-
if not (
|
|
607
|
-
prompt_label := await session.scalar(
|
|
608
|
-
select(models.PromptLabel).where(models.PromptLabel.id == node_id)
|
|
609
|
-
)
|
|
610
|
-
):
|
|
611
|
-
raise NotFound(f"Unknown prompt label: {id}")
|
|
612
|
-
return to_gql_prompt_label(prompt_label)
|
|
944
|
+
return PromptLabel(id=node_id)
|
|
613
945
|
elif type_name == PromptVersionTag.__name__:
|
|
614
|
-
|
|
615
|
-
if not (prompt_version_tag := await session.get(models.PromptVersionTag, node_id)):
|
|
616
|
-
raise NotFound(f"Unknown prompt version tag: {id}")
|
|
617
|
-
return to_gql_prompt_version_tag(prompt_version_tag)
|
|
946
|
+
return PromptVersionTag(id=node_id)
|
|
618
947
|
elif type_name == ProjectTraceRetentionPolicy.__name__:
|
|
619
|
-
|
|
620
|
-
db_policy = await session.scalar(
|
|
621
|
-
select(models.ProjectTraceRetentionPolicy).filter_by(id=node_id)
|
|
622
|
-
)
|
|
623
|
-
if not db_policy:
|
|
624
|
-
raise NotFound(f"Unknown project trace retention policy: {id}")
|
|
625
|
-
return ProjectTraceRetentionPolicy(id=db_policy.id, db_policy=db_policy)
|
|
948
|
+
return ProjectTraceRetentionPolicy(id=node_id)
|
|
626
949
|
elif type_name == SpanAnnotation.__name__:
|
|
627
|
-
|
|
628
|
-
span_annotation = await session.get(models.SpanAnnotation, node_id)
|
|
629
|
-
if not span_annotation:
|
|
630
|
-
raise NotFound(f"Unknown span annotation: {id}")
|
|
631
|
-
return to_gql_span_annotation(span_annotation)
|
|
950
|
+
return SpanAnnotation(id=node_id)
|
|
632
951
|
elif type_name == TraceAnnotation.__name__:
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
raise NotFound(f"Unknown trace annotation: {id}")
|
|
637
|
-
return to_gql_trace_annotation(trace_annotation)
|
|
952
|
+
return TraceAnnotation(id=node_id)
|
|
953
|
+
elif type_name == GenerativeModel.__name__:
|
|
954
|
+
return GenerativeModel(id=node_id)
|
|
638
955
|
raise NotFound(f"Unknown node type: {type_name}")
|
|
639
956
|
|
|
640
957
|
@strawberry.field
|
|
@@ -646,16 +963,7 @@ class Query:
|
|
|
646
963
|
return None
|
|
647
964
|
if isinstance(user, UnauthenticatedUser):
|
|
648
965
|
return None
|
|
649
|
-
|
|
650
|
-
if (
|
|
651
|
-
user := await session.scalar(
|
|
652
|
-
select(models.User)
|
|
653
|
-
.where(models.User.id == int(user.identity))
|
|
654
|
-
.options(joinedload(models.User.role))
|
|
655
|
-
)
|
|
656
|
-
) is None:
|
|
657
|
-
return None
|
|
658
|
-
return to_gql_user(user)
|
|
966
|
+
return User(id=int(user.identity))
|
|
659
967
|
|
|
660
968
|
@strawberry.field
|
|
661
969
|
async def prompts(
|
|
@@ -665,6 +973,8 @@ class Query:
|
|
|
665
973
|
last: Optional[int] = UNSET,
|
|
666
974
|
after: Optional[CursorString] = UNSET,
|
|
667
975
|
before: Optional[CursorString] = UNSET,
|
|
976
|
+
filter: Optional[PromptFilter] = UNSET,
|
|
977
|
+
labelIds: Optional[list[GlobalID]] = UNSET,
|
|
668
978
|
) -> Connection[Prompt]:
|
|
669
979
|
args = ConnectionArgs(
|
|
670
980
|
first=first,
|
|
@@ -673,9 +983,29 @@ class Query:
|
|
|
673
983
|
before=before if isinstance(before, CursorString) else None,
|
|
674
984
|
)
|
|
675
985
|
stmt = select(models.Prompt)
|
|
986
|
+
if filter:
|
|
987
|
+
column = getattr(models.Prompt, filter.col.value)
|
|
988
|
+
# Cast Identifier columns to String for ilike operations
|
|
989
|
+
if filter.col.value == "name":
|
|
990
|
+
column = cast(column, String)
|
|
991
|
+
stmt = stmt.where(column.ilike(f"%{filter.value}%")).order_by(
|
|
992
|
+
models.Prompt.updated_at.desc()
|
|
993
|
+
)
|
|
994
|
+
if labelIds:
|
|
995
|
+
stmt = stmt.join(models.PromptPromptLabel).where(
|
|
996
|
+
models.PromptPromptLabel.prompt_label_id.in_(
|
|
997
|
+
from_global_id_with_expected_type(
|
|
998
|
+
global_id=label_id, expected_type_name="PromptLabel"
|
|
999
|
+
)
|
|
1000
|
+
for label_id in labelIds
|
|
1001
|
+
)
|
|
1002
|
+
)
|
|
1003
|
+
stmt = stmt.distinct()
|
|
676
1004
|
async with info.context.db() as session:
|
|
677
1005
|
orm_prompts = await session.stream_scalars(stmt)
|
|
678
|
-
data = [
|
|
1006
|
+
data = [
|
|
1007
|
+
Prompt(id=orm_prompt.id, db_record=orm_prompt) async for orm_prompt in orm_prompts
|
|
1008
|
+
]
|
|
679
1009
|
return connection_from_list(
|
|
680
1010
|
data=data,
|
|
681
1011
|
args=args,
|
|
@@ -698,7 +1028,58 @@ class Query:
|
|
|
698
1028
|
)
|
|
699
1029
|
async with info.context.db() as session:
|
|
700
1030
|
prompt_labels = await session.stream_scalars(select(models.PromptLabel))
|
|
701
|
-
data = [
|
|
1031
|
+
data = [
|
|
1032
|
+
PromptLabel(id=prompt_label.id, db_record=prompt_label)
|
|
1033
|
+
async for prompt_label in prompt_labels
|
|
1034
|
+
]
|
|
1035
|
+
return connection_from_list(
|
|
1036
|
+
data=data,
|
|
1037
|
+
args=args,
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
@strawberry.field
|
|
1041
|
+
async def dataset_labels(
|
|
1042
|
+
self,
|
|
1043
|
+
info: Info[Context, None],
|
|
1044
|
+
first: Optional[int] = 50,
|
|
1045
|
+
last: Optional[int] = UNSET,
|
|
1046
|
+
after: Optional[CursorString] = UNSET,
|
|
1047
|
+
before: Optional[CursorString] = UNSET,
|
|
1048
|
+
) -> Connection[DatasetLabel]:
|
|
1049
|
+
args = ConnectionArgs(
|
|
1050
|
+
first=first,
|
|
1051
|
+
after=after if isinstance(after, CursorString) else None,
|
|
1052
|
+
last=last,
|
|
1053
|
+
before=before if isinstance(before, CursorString) else None,
|
|
1054
|
+
)
|
|
1055
|
+
async with info.context.db() as session:
|
|
1056
|
+
dataset_labels = await session.scalars(
|
|
1057
|
+
select(models.DatasetLabel).order_by(models.DatasetLabel.name.asc())
|
|
1058
|
+
)
|
|
1059
|
+
data = [
|
|
1060
|
+
DatasetLabel(id=dataset_label.id, db_record=dataset_label)
|
|
1061
|
+
for dataset_label in dataset_labels
|
|
1062
|
+
]
|
|
1063
|
+
return connection_from_list(data=data, args=args)
|
|
1064
|
+
|
|
1065
|
+
@strawberry.field
|
|
1066
|
+
async def dataset_splits(
|
|
1067
|
+
self,
|
|
1068
|
+
info: Info[Context, None],
|
|
1069
|
+
first: Optional[int] = 50,
|
|
1070
|
+
last: Optional[int] = UNSET,
|
|
1071
|
+
after: Optional[CursorString] = UNSET,
|
|
1072
|
+
before: Optional[CursorString] = UNSET,
|
|
1073
|
+
) -> Connection[DatasetSplit]:
|
|
1074
|
+
args = ConnectionArgs(
|
|
1075
|
+
first=first,
|
|
1076
|
+
after=after if isinstance(after, CursorString) else None,
|
|
1077
|
+
last=last,
|
|
1078
|
+
before=before if isinstance(before, CursorString) else None,
|
|
1079
|
+
)
|
|
1080
|
+
async with info.context.db() as session:
|
|
1081
|
+
splits = await session.stream_scalars(select(models.DatasetSplit))
|
|
1082
|
+
data = [DatasetSplit(id=split.id, db_record=split) async for split in splits]
|
|
702
1083
|
return connection_from_list(
|
|
703
1084
|
data=data,
|
|
704
1085
|
args=args,
|
|
@@ -921,16 +1302,17 @@ class Query:
|
|
|
921
1302
|
# stats = cast(Iterable[tuple[str, int]], await session.execute(stmt))
|
|
922
1303
|
# stats = _consolidate_sqlite_db_table_stats(stats)
|
|
923
1304
|
elif info.context.db.dialect is SupportedSQLDialect.POSTGRESQL:
|
|
924
|
-
|
|
1305
|
+
nspname = getenv(ENV_PHOENIX_SQL_DATABASE_SCHEMA) or "public"
|
|
1306
|
+
stmt = text("""\
|
|
925
1307
|
SELECT c.relname, pg_total_relation_size(c.oid)
|
|
926
1308
|
FROM pg_class as c
|
|
927
1309
|
INNER JOIN pg_namespace as n ON n.oid = c.relnamespace
|
|
928
1310
|
WHERE c.relkind = 'r'
|
|
929
|
-
AND n.nspname =
|
|
930
|
-
""")
|
|
1311
|
+
AND n.nspname = :nspname;
|
|
1312
|
+
""").bindparams(nspname=nspname)
|
|
931
1313
|
try:
|
|
932
1314
|
async with info.context.db() as session:
|
|
933
|
-
stats =
|
|
1315
|
+
stats = type_cast(Iterable[tuple[str, int]], await session.execute(stmt))
|
|
934
1316
|
except Exception:
|
|
935
1317
|
# TODO: temporary workaround until we can reproduce the error
|
|
936
1318
|
return []
|
|
@@ -941,6 +1323,62 @@ class Query:
|
|
|
941
1323
|
for table_name, num_bytes in stats
|
|
942
1324
|
]
|
|
943
1325
|
|
|
1326
|
+
@strawberry.field
|
|
1327
|
+
async def server_status(
|
|
1328
|
+
self,
|
|
1329
|
+
info: Info[Context, None],
|
|
1330
|
+
) -> ServerStatus:
|
|
1331
|
+
return ServerStatus(
|
|
1332
|
+
insufficient_storage=info.context.db.should_not_insert_or_update,
|
|
1333
|
+
)
|
|
1334
|
+
|
|
1335
|
+
@strawberry.field
|
|
1336
|
+
def validate_regular_expression(self, regex: str) -> ValidationResult:
|
|
1337
|
+
try:
|
|
1338
|
+
re.compile(regex)
|
|
1339
|
+
return ValidationResult(is_valid=True, error_message=None)
|
|
1340
|
+
except re.error as error:
|
|
1341
|
+
return ValidationResult(is_valid=False, error_message=str(error))
|
|
1342
|
+
|
|
1343
|
+
@strawberry.field
|
|
1344
|
+
async def get_span_by_otel_id(
|
|
1345
|
+
self,
|
|
1346
|
+
info: Info[Context, None],
|
|
1347
|
+
span_id: str,
|
|
1348
|
+
) -> Optional[Span]:
|
|
1349
|
+
stmt = select(models.Span.id).filter_by(span_id=span_id)
|
|
1350
|
+
async with info.context.db() as session:
|
|
1351
|
+
span_rowid = await session.scalar(stmt)
|
|
1352
|
+
if span_rowid:
|
|
1353
|
+
return Span(id=span_rowid)
|
|
1354
|
+
return None
|
|
1355
|
+
|
|
1356
|
+
@strawberry.field
|
|
1357
|
+
async def get_trace_by_otel_id(
|
|
1358
|
+
self,
|
|
1359
|
+
info: Info[Context, None],
|
|
1360
|
+
trace_id: str,
|
|
1361
|
+
) -> Optional[Trace]:
|
|
1362
|
+
stmt = select(models.Trace.id).where(models.Trace.trace_id == trace_id)
|
|
1363
|
+
async with info.context.db() as session:
|
|
1364
|
+
trace_rowid = await session.scalar(stmt)
|
|
1365
|
+
if trace_rowid:
|
|
1366
|
+
return Trace(id=trace_rowid)
|
|
1367
|
+
return None
|
|
1368
|
+
|
|
1369
|
+
@strawberry.field
|
|
1370
|
+
async def get_project_session_by_id(
|
|
1371
|
+
self,
|
|
1372
|
+
info: Info[Context, None],
|
|
1373
|
+
session_id: str,
|
|
1374
|
+
) -> Optional[ProjectSession]:
|
|
1375
|
+
stmt = select(models.ProjectSession).where(models.ProjectSession.session_id == session_id)
|
|
1376
|
+
async with info.context.db() as session:
|
|
1377
|
+
session_row = await session.scalar(stmt)
|
|
1378
|
+
if session_row:
|
|
1379
|
+
return ProjectSession(id=session_row.id, db_record=session_row)
|
|
1380
|
+
return None
|
|
1381
|
+
|
|
944
1382
|
|
|
945
1383
|
def _consolidate_sqlite_db_table_stats(
|
|
946
1384
|
stats: Iterable[tuple[str, int]],
|
|
@@ -974,3 +1412,40 @@ def _longest_matching_prefix(s: str, prefixes: Iterable[str]) -> str:
|
|
|
974
1412
|
if s.startswith(prefix) and len(prefix) > len(longest):
|
|
975
1413
|
longest = prefix
|
|
976
1414
|
return longest
|
|
1415
|
+
|
|
1416
|
+
|
|
1417
|
+
def _comparison_count_expression(
|
|
1418
|
+
*,
|
|
1419
|
+
base_column: ColumnElement[Any],
|
|
1420
|
+
compare_column: ColumnElement[Any],
|
|
1421
|
+
optimization_direction: Literal["maximize", "minimize"],
|
|
1422
|
+
comparison_type: Literal["improvement", "regression", "equality"],
|
|
1423
|
+
) -> ColumnElement[int]:
|
|
1424
|
+
"""
|
|
1425
|
+
Given a base and compare column, returns an expression counting the number of
|
|
1426
|
+
improvements, regressions, or equalities given the optimization direction.
|
|
1427
|
+
"""
|
|
1428
|
+
if optimization_direction == "maximize":
|
|
1429
|
+
raise NotImplementedError
|
|
1430
|
+
|
|
1431
|
+
if comparison_type == "improvement":
|
|
1432
|
+
condition = compare_column > base_column
|
|
1433
|
+
elif comparison_type == "regression":
|
|
1434
|
+
condition = compare_column < base_column
|
|
1435
|
+
elif comparison_type == "equality":
|
|
1436
|
+
condition = compare_column == base_column
|
|
1437
|
+
else:
|
|
1438
|
+
assert_never(comparison_type)
|
|
1439
|
+
|
|
1440
|
+
return func.coalesce(
|
|
1441
|
+
func.sum(
|
|
1442
|
+
case(
|
|
1443
|
+
(
|
|
1444
|
+
condition,
|
|
1445
|
+
1,
|
|
1446
|
+
),
|
|
1447
|
+
else_=0,
|
|
1448
|
+
)
|
|
1449
|
+
),
|
|
1450
|
+
0,
|
|
1451
|
+
)
|