arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.1.dist-info/METADATA +261 -0
- arize_phoenix-7.7.1.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -241
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +4 -112
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.1.dist-info/METADATA +0 -495
- arize_phoenix-3.16.1.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -619
- phoenix/core/traces.py +0 -96
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -0,0 +1,738 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import numpy.typing as npt
|
|
7
|
+
import strawberry
|
|
8
|
+
from sqlalchemy import and_, distinct, func, select
|
|
9
|
+
from sqlalchemy.orm import joinedload
|
|
10
|
+
from starlette.authentication import UnauthenticatedUser
|
|
11
|
+
from strawberry import ID, UNSET
|
|
12
|
+
from strawberry.relay import Connection, GlobalID, Node
|
|
13
|
+
from strawberry.types import Info
|
|
14
|
+
from typing_extensions import Annotated, TypeAlias
|
|
15
|
+
|
|
16
|
+
from phoenix.db import enums, models
|
|
17
|
+
from phoenix.db.models import (
|
|
18
|
+
DatasetExample as OrmExample,
|
|
19
|
+
)
|
|
20
|
+
from phoenix.db.models import (
|
|
21
|
+
DatasetExampleRevision as OrmRevision,
|
|
22
|
+
)
|
|
23
|
+
from phoenix.db.models import (
|
|
24
|
+
DatasetVersion as OrmVersion,
|
|
25
|
+
)
|
|
26
|
+
from phoenix.db.models import (
|
|
27
|
+
Experiment as OrmExperiment,
|
|
28
|
+
)
|
|
29
|
+
from phoenix.db.models import ExperimentRun as OrmExperimentRun
|
|
30
|
+
from phoenix.db.models import (
|
|
31
|
+
Trace as OrmTrace,
|
|
32
|
+
)
|
|
33
|
+
from phoenix.pointcloud.clustering import Hdbscan
|
|
34
|
+
from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
|
|
35
|
+
from phoenix.server.api.context import Context
|
|
36
|
+
from phoenix.server.api.exceptions import NotFound, Unauthorized
|
|
37
|
+
from phoenix.server.api.helpers import ensure_list
|
|
38
|
+
from phoenix.server.api.helpers.experiment_run_filters import (
|
|
39
|
+
ExperimentRunFilterConditionSyntaxError,
|
|
40
|
+
compile_sqlalchemy_filter_condition,
|
|
41
|
+
update_examples_query_with_filter_condition,
|
|
42
|
+
)
|
|
43
|
+
from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
|
|
44
|
+
from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
|
|
45
|
+
from phoenix.server.api.input_types.ClusterInput import ClusterInput
|
|
46
|
+
from phoenix.server.api.input_types.Coordinates import (
|
|
47
|
+
InputCoordinate2D,
|
|
48
|
+
InputCoordinate3D,
|
|
49
|
+
)
|
|
50
|
+
from phoenix.server.api.input_types.DatasetSort import DatasetSort
|
|
51
|
+
from phoenix.server.api.input_types.InvocationParameters import (
|
|
52
|
+
InvocationParameter,
|
|
53
|
+
)
|
|
54
|
+
from phoenix.server.api.subscriptions import PLAYGROUND_PROJECT_NAME
|
|
55
|
+
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
56
|
+
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
57
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
58
|
+
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
59
|
+
from phoenix.server.api.types.EmbeddingDimension import (
|
|
60
|
+
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
61
|
+
DEFAULT_MIN_CLUSTER_SIZE,
|
|
62
|
+
DEFAULT_MIN_SAMPLES,
|
|
63
|
+
to_gql_embedding_dimension,
|
|
64
|
+
)
|
|
65
|
+
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
66
|
+
from phoenix.server.api.types.Experiment import Experiment
|
|
67
|
+
from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
|
|
68
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
69
|
+
from phoenix.server.api.types.Functionality import Functionality
|
|
70
|
+
from phoenix.server.api.types.GenerativeModel import GenerativeModel
|
|
71
|
+
from phoenix.server.api.types.GenerativeProvider import (
|
|
72
|
+
GenerativeProvider,
|
|
73
|
+
GenerativeProviderKey,
|
|
74
|
+
)
|
|
75
|
+
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
76
|
+
from phoenix.server.api.types.Model import Model
|
|
77
|
+
from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
|
|
78
|
+
from phoenix.server.api.types.pagination import (
|
|
79
|
+
ConnectionArgs,
|
|
80
|
+
CursorString,
|
|
81
|
+
connection_from_list,
|
|
82
|
+
)
|
|
83
|
+
from phoenix.server.api.types.Project import Project
|
|
84
|
+
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
|
|
85
|
+
from phoenix.server.api.types.SortDir import SortDir
|
|
86
|
+
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
87
|
+
from phoenix.server.api.types.SystemApiKey import SystemApiKey
|
|
88
|
+
from phoenix.server.api.types.Trace import to_gql_trace
|
|
89
|
+
from phoenix.server.api.types.User import User, to_gql_user
|
|
90
|
+
from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
|
|
91
|
+
from phoenix.server.api.types.UserRole import UserRole
|
|
92
|
+
from phoenix.server.api.types.ValidationResult import ValidationResult
|
|
93
|
+
|
|
94
|
+
initialize_playground_clients()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@strawberry.input
|
|
98
|
+
class ModelsInput:
|
|
99
|
+
provider_key: Optional[GenerativeProviderKey]
|
|
100
|
+
model_name: Optional[str] = None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@strawberry.type
|
|
104
|
+
class Query:
|
|
105
|
+
@strawberry.field
|
|
106
|
+
async def model_providers(self) -> list[GenerativeProvider]:
|
|
107
|
+
available_providers = PLAYGROUND_CLIENT_REGISTRY.list_all_providers()
|
|
108
|
+
return [
|
|
109
|
+
GenerativeProvider(
|
|
110
|
+
name=provider_key.value,
|
|
111
|
+
key=provider_key,
|
|
112
|
+
)
|
|
113
|
+
for provider_key in available_providers
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
@strawberry.field
|
|
117
|
+
async def models(self, input: Optional[ModelsInput] = None) -> list[GenerativeModel]:
|
|
118
|
+
if input is not None and input.provider_key is not None:
|
|
119
|
+
supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
|
|
120
|
+
supported_models = [
|
|
121
|
+
GenerativeModel(name=model_name, provider_key=input.provider_key)
|
|
122
|
+
for model_name in supported_model_names
|
|
123
|
+
]
|
|
124
|
+
return supported_models
|
|
125
|
+
|
|
126
|
+
registered_models = PLAYGROUND_CLIENT_REGISTRY.list_all_models()
|
|
127
|
+
all_models: list[GenerativeModel] = []
|
|
128
|
+
for provider_key, model_name in registered_models:
|
|
129
|
+
if model_name is not None and provider_key is not None:
|
|
130
|
+
all_models.append(GenerativeModel(name=model_name, provider_key=provider_key))
|
|
131
|
+
return all_models
|
|
132
|
+
|
|
133
|
+
@strawberry.field
|
|
134
|
+
async def model_invocation_parameters(
|
|
135
|
+
self, input: Optional[ModelsInput] = None
|
|
136
|
+
) -> list[InvocationParameter]:
|
|
137
|
+
if input is None:
|
|
138
|
+
return []
|
|
139
|
+
provider_key = input.provider_key
|
|
140
|
+
model_name = input.model_name
|
|
141
|
+
if provider_key is not None:
|
|
142
|
+
client = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, model_name)
|
|
143
|
+
if client is None:
|
|
144
|
+
return []
|
|
145
|
+
invocation_parameters = client.supported_invocation_parameters()
|
|
146
|
+
return invocation_parameters
|
|
147
|
+
else:
|
|
148
|
+
return []
|
|
149
|
+
|
|
150
|
+
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
|
|
151
|
+
async def users(
|
|
152
|
+
self,
|
|
153
|
+
info: Info[Context, None],
|
|
154
|
+
first: Optional[int] = 50,
|
|
155
|
+
last: Optional[int] = UNSET,
|
|
156
|
+
after: Optional[CursorString] = UNSET,
|
|
157
|
+
before: Optional[CursorString] = UNSET,
|
|
158
|
+
) -> Connection[User]:
|
|
159
|
+
args = ConnectionArgs(
|
|
160
|
+
first=first,
|
|
161
|
+
after=after if isinstance(after, CursorString) else None,
|
|
162
|
+
last=last,
|
|
163
|
+
before=before if isinstance(before, CursorString) else None,
|
|
164
|
+
)
|
|
165
|
+
stmt = (
|
|
166
|
+
select(models.User)
|
|
167
|
+
.join(models.UserRole)
|
|
168
|
+
.where(models.UserRole.name != enums.UserRole.SYSTEM.value)
|
|
169
|
+
.order_by(models.User.email)
|
|
170
|
+
.options(joinedload(models.User.role))
|
|
171
|
+
)
|
|
172
|
+
async with info.context.db() as session:
|
|
173
|
+
users = await session.stream_scalars(stmt)
|
|
174
|
+
data = [to_gql_user(user) async for user in users]
|
|
175
|
+
return connection_from_list(data=data, args=args)
|
|
176
|
+
|
|
177
|
+
@strawberry.field
|
|
178
|
+
async def user_roles(
|
|
179
|
+
self,
|
|
180
|
+
info: Info[Context, None],
|
|
181
|
+
) -> list[UserRole]:
|
|
182
|
+
async with info.context.db() as session:
|
|
183
|
+
roles = await session.scalars(
|
|
184
|
+
select(models.UserRole).where(models.UserRole.name != enums.UserRole.SYSTEM.value)
|
|
185
|
+
)
|
|
186
|
+
return [
|
|
187
|
+
UserRole(
|
|
188
|
+
id_attr=role.id,
|
|
189
|
+
name=role.name,
|
|
190
|
+
)
|
|
191
|
+
for role in roles
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
|
|
195
|
+
async def user_api_keys(self, info: Info[Context, None]) -> list[UserApiKey]:
|
|
196
|
+
stmt = (
|
|
197
|
+
select(models.ApiKey)
|
|
198
|
+
.join(models.User)
|
|
199
|
+
.join(models.UserRole)
|
|
200
|
+
.where(models.UserRole.name != enums.UserRole.SYSTEM.value)
|
|
201
|
+
)
|
|
202
|
+
async with info.context.db() as session:
|
|
203
|
+
api_keys = await session.scalars(stmt)
|
|
204
|
+
return [to_gql_api_key(api_key) for api_key in api_keys]
|
|
205
|
+
|
|
206
|
+
@strawberry.field(permission_classes=[IsAdmin]) # type: ignore
|
|
207
|
+
async def system_api_keys(self, info: Info[Context, None]) -> list[SystemApiKey]:
|
|
208
|
+
stmt = (
|
|
209
|
+
select(models.ApiKey)
|
|
210
|
+
.join(models.User)
|
|
211
|
+
.join(models.UserRole)
|
|
212
|
+
.where(models.UserRole.name == enums.UserRole.SYSTEM.value)
|
|
213
|
+
)
|
|
214
|
+
async with info.context.db() as session:
|
|
215
|
+
api_keys = await session.scalars(stmt)
|
|
216
|
+
return [
|
|
217
|
+
SystemApiKey(
|
|
218
|
+
id_attr=api_key.id,
|
|
219
|
+
name=api_key.name,
|
|
220
|
+
description=api_key.description,
|
|
221
|
+
created_at=api_key.created_at,
|
|
222
|
+
expires_at=api_key.expires_at,
|
|
223
|
+
)
|
|
224
|
+
for api_key in api_keys
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
@strawberry.field
|
|
228
|
+
async def projects(
|
|
229
|
+
self,
|
|
230
|
+
info: Info[Context, None],
|
|
231
|
+
first: Optional[int] = 50,
|
|
232
|
+
last: Optional[int] = UNSET,
|
|
233
|
+
after: Optional[CursorString] = UNSET,
|
|
234
|
+
before: Optional[CursorString] = UNSET,
|
|
235
|
+
) -> Connection[Project]:
|
|
236
|
+
args = ConnectionArgs(
|
|
237
|
+
first=first,
|
|
238
|
+
after=after if isinstance(after, CursorString) else None,
|
|
239
|
+
last=last,
|
|
240
|
+
before=before if isinstance(before, CursorString) else None,
|
|
241
|
+
)
|
|
242
|
+
stmt = (
|
|
243
|
+
select(models.Project)
|
|
244
|
+
.outerjoin(
|
|
245
|
+
models.Experiment,
|
|
246
|
+
and_(
|
|
247
|
+
models.Project.name == models.Experiment.project_name,
|
|
248
|
+
models.Experiment.project_name != PLAYGROUND_PROJECT_NAME,
|
|
249
|
+
),
|
|
250
|
+
)
|
|
251
|
+
.where(models.Experiment.project_name.is_(None))
|
|
252
|
+
.order_by(models.Project.id)
|
|
253
|
+
)
|
|
254
|
+
async with info.context.db() as session:
|
|
255
|
+
projects = await session.stream_scalars(stmt)
|
|
256
|
+
data = [
|
|
257
|
+
Project(
|
|
258
|
+
id_attr=project.id,
|
|
259
|
+
name=project.name,
|
|
260
|
+
gradient_start_color=project.gradient_start_color,
|
|
261
|
+
gradient_end_color=project.gradient_end_color,
|
|
262
|
+
)
|
|
263
|
+
async for project in projects
|
|
264
|
+
]
|
|
265
|
+
return connection_from_list(data=data, args=args)
|
|
266
|
+
|
|
267
|
+
@strawberry.field
|
|
268
|
+
def projects_last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
269
|
+
return info.context.last_updated_at.get(models.Project)
|
|
270
|
+
|
|
271
|
+
@strawberry.field
|
|
272
|
+
async def datasets(
|
|
273
|
+
self,
|
|
274
|
+
info: Info[Context, None],
|
|
275
|
+
first: Optional[int] = 50,
|
|
276
|
+
last: Optional[int] = UNSET,
|
|
277
|
+
after: Optional[CursorString] = UNSET,
|
|
278
|
+
before: Optional[CursorString] = UNSET,
|
|
279
|
+
sort: Optional[DatasetSort] = UNSET,
|
|
280
|
+
) -> Connection[Dataset]:
|
|
281
|
+
args = ConnectionArgs(
|
|
282
|
+
first=first,
|
|
283
|
+
after=after if isinstance(after, CursorString) else None,
|
|
284
|
+
last=last,
|
|
285
|
+
before=before if isinstance(before, CursorString) else None,
|
|
286
|
+
)
|
|
287
|
+
stmt = select(models.Dataset)
|
|
288
|
+
if sort:
|
|
289
|
+
sort_col = getattr(models.Dataset, sort.col.value)
|
|
290
|
+
stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
|
|
291
|
+
async with info.context.db() as session:
|
|
292
|
+
datasets = await session.scalars(stmt)
|
|
293
|
+
return connection_from_list(
|
|
294
|
+
data=[to_gql_dataset(dataset) for dataset in datasets], args=args
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
@strawberry.field
|
|
298
|
+
def datasets_last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
299
|
+
return info.context.last_updated_at.get(models.Dataset)
|
|
300
|
+
|
|
301
|
+
@strawberry.field
|
|
302
|
+
async def compare_experiments(
|
|
303
|
+
self,
|
|
304
|
+
info: Info[Context, None],
|
|
305
|
+
experiment_ids: list[GlobalID],
|
|
306
|
+
filter_condition: Optional[str] = UNSET,
|
|
307
|
+
) -> list[ExperimentComparison]:
|
|
308
|
+
experiment_ids_ = [
|
|
309
|
+
from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
|
|
310
|
+
for experiment_id in experiment_ids
|
|
311
|
+
]
|
|
312
|
+
if len(set(experiment_ids_)) != len(experiment_ids_):
|
|
313
|
+
raise ValueError("Experiment IDs must be unique.")
|
|
314
|
+
|
|
315
|
+
async with info.context.db() as session:
|
|
316
|
+
validation_result = (
|
|
317
|
+
await session.execute(
|
|
318
|
+
select(
|
|
319
|
+
func.count(distinct(OrmVersion.dataset_id)),
|
|
320
|
+
func.max(OrmVersion.dataset_id),
|
|
321
|
+
func.max(OrmVersion.id),
|
|
322
|
+
func.count(OrmExperiment.id),
|
|
323
|
+
)
|
|
324
|
+
.select_from(OrmVersion)
|
|
325
|
+
.join(
|
|
326
|
+
OrmExperiment,
|
|
327
|
+
OrmExperiment.dataset_version_id == OrmVersion.id,
|
|
328
|
+
)
|
|
329
|
+
.where(
|
|
330
|
+
OrmExperiment.id.in_(experiment_ids_),
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
).first()
|
|
334
|
+
if validation_result is None:
|
|
335
|
+
raise ValueError("No experiments could be found for input IDs.")
|
|
336
|
+
|
|
337
|
+
num_datasets, dataset_id, version_id, num_resolved_experiment_ids = validation_result
|
|
338
|
+
if num_datasets != 1:
|
|
339
|
+
raise ValueError("Experiments must belong to the same dataset.")
|
|
340
|
+
if num_resolved_experiment_ids != len(experiment_ids_):
|
|
341
|
+
raise ValueError("Unable to resolve one or more experiment IDs.")
|
|
342
|
+
|
|
343
|
+
revision_ids = (
|
|
344
|
+
select(func.max(OrmRevision.id))
|
|
345
|
+
.join(OrmExample, OrmExample.id == OrmRevision.dataset_example_id)
|
|
346
|
+
.where(
|
|
347
|
+
and_(
|
|
348
|
+
OrmRevision.dataset_version_id <= version_id,
|
|
349
|
+
OrmExample.dataset_id == dataset_id,
|
|
350
|
+
)
|
|
351
|
+
)
|
|
352
|
+
.group_by(OrmRevision.dataset_example_id)
|
|
353
|
+
.scalar_subquery()
|
|
354
|
+
)
|
|
355
|
+
examples_query = (
|
|
356
|
+
select(OrmExample)
|
|
357
|
+
.distinct(OrmExample.id)
|
|
358
|
+
.join(
|
|
359
|
+
OrmRevision,
|
|
360
|
+
onclause=and_(
|
|
361
|
+
OrmExample.id == OrmRevision.dataset_example_id,
|
|
362
|
+
OrmRevision.id.in_(revision_ids),
|
|
363
|
+
OrmRevision.revision_kind != "DELETE",
|
|
364
|
+
),
|
|
365
|
+
)
|
|
366
|
+
.order_by(OrmExample.id.desc())
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
if filter_condition:
|
|
370
|
+
examples_query = update_examples_query_with_filter_condition(
|
|
371
|
+
query=examples_query,
|
|
372
|
+
filter_condition=filter_condition,
|
|
373
|
+
experiment_ids=experiment_ids_,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
examples = (await session.scalars(examples_query)).all()
|
|
377
|
+
|
|
378
|
+
ExampleID: TypeAlias = int
|
|
379
|
+
ExperimentID: TypeAlias = int
|
|
380
|
+
runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[OrmExperimentRun]]] = (
|
|
381
|
+
defaultdict(lambda: defaultdict(list))
|
|
382
|
+
)
|
|
383
|
+
async for run in await session.stream_scalars(
|
|
384
|
+
select(OrmExperimentRun)
|
|
385
|
+
.where(
|
|
386
|
+
and_(
|
|
387
|
+
OrmExperimentRun.dataset_example_id.in_(example.id for example in examples),
|
|
388
|
+
OrmExperimentRun.experiment_id.in_(experiment_ids_),
|
|
389
|
+
)
|
|
390
|
+
)
|
|
391
|
+
.options(joinedload(OrmExperimentRun.trace).load_only(OrmTrace.trace_id))
|
|
392
|
+
):
|
|
393
|
+
runs[run.dataset_example_id][run.experiment_id].append(run)
|
|
394
|
+
|
|
395
|
+
experiment_comparisons = []
|
|
396
|
+
for example in examples:
|
|
397
|
+
run_comparison_items = []
|
|
398
|
+
for experiment_id in experiment_ids_:
|
|
399
|
+
run_comparison_items.append(
|
|
400
|
+
RunComparisonItem(
|
|
401
|
+
experiment_id=GlobalID(Experiment.__name__, str(experiment_id)),
|
|
402
|
+
runs=[
|
|
403
|
+
to_gql_experiment_run(run)
|
|
404
|
+
for run in sorted(
|
|
405
|
+
runs[example.id][experiment_id], key=lambda run: run.id
|
|
406
|
+
)
|
|
407
|
+
],
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
experiment_comparisons.append(
|
|
411
|
+
ExperimentComparison(
|
|
412
|
+
example=DatasetExample(
|
|
413
|
+
id_attr=example.id,
|
|
414
|
+
created_at=example.created_at,
|
|
415
|
+
version_id=version_id,
|
|
416
|
+
),
|
|
417
|
+
run_comparison_items=run_comparison_items,
|
|
418
|
+
)
|
|
419
|
+
)
|
|
420
|
+
return experiment_comparisons
|
|
421
|
+
|
|
422
|
+
@strawberry.field
|
|
423
|
+
async def validate_experiment_run_filter_condition(
|
|
424
|
+
self,
|
|
425
|
+
condition: str,
|
|
426
|
+
experiment_ids: list[GlobalID],
|
|
427
|
+
) -> ValidationResult:
|
|
428
|
+
try:
|
|
429
|
+
compile_sqlalchemy_filter_condition(
|
|
430
|
+
filter_condition=condition,
|
|
431
|
+
experiment_ids=[
|
|
432
|
+
from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
|
|
433
|
+
for experiment_id in experiment_ids
|
|
434
|
+
],
|
|
435
|
+
)
|
|
436
|
+
return ValidationResult(
|
|
437
|
+
is_valid=True,
|
|
438
|
+
error_message=None,
|
|
439
|
+
)
|
|
440
|
+
except ExperimentRunFilterConditionSyntaxError as error:
|
|
441
|
+
return ValidationResult(
|
|
442
|
+
is_valid=False,
|
|
443
|
+
error_message=str(error),
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
@strawberry.field
|
|
447
|
+
async def functionality(self, info: Info[Context, None]) -> "Functionality":
|
|
448
|
+
has_model_inferences = not info.context.model.is_empty
|
|
449
|
+
async with info.context.db() as session:
|
|
450
|
+
has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
|
|
451
|
+
return Functionality(
|
|
452
|
+
model_inferences=has_model_inferences,
|
|
453
|
+
tracing=has_traces,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
@strawberry.field
|
|
457
|
+
def model(self) -> Model:
|
|
458
|
+
return Model()
|
|
459
|
+
|
|
460
|
+
@strawberry.field
|
|
461
|
+
async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
|
|
462
|
+
type_name, node_id = from_global_id(id)
|
|
463
|
+
if type_name == "Dimension":
|
|
464
|
+
dimension = info.context.model.scalar_dimensions[node_id]
|
|
465
|
+
return to_gql_dimension(node_id, dimension)
|
|
466
|
+
elif type_name == "EmbeddingDimension":
|
|
467
|
+
embedding_dimension = info.context.model.embedding_dimensions[node_id]
|
|
468
|
+
return to_gql_embedding_dimension(node_id, embedding_dimension)
|
|
469
|
+
elif type_name == "Project":
|
|
470
|
+
project_stmt = select(
|
|
471
|
+
models.Project.id,
|
|
472
|
+
models.Project.name,
|
|
473
|
+
models.Project.gradient_start_color,
|
|
474
|
+
models.Project.gradient_end_color,
|
|
475
|
+
).where(models.Project.id == node_id)
|
|
476
|
+
async with info.context.db() as session:
|
|
477
|
+
project = (await session.execute(project_stmt)).first()
|
|
478
|
+
if project is None:
|
|
479
|
+
raise NotFound(f"Unknown project: {id}")
|
|
480
|
+
return Project(
|
|
481
|
+
id_attr=project.id,
|
|
482
|
+
name=project.name,
|
|
483
|
+
gradient_start_color=project.gradient_start_color,
|
|
484
|
+
gradient_end_color=project.gradient_end_color,
|
|
485
|
+
)
|
|
486
|
+
elif type_name == "Trace":
|
|
487
|
+
trace_stmt = select(models.Trace).filter_by(id=node_id)
|
|
488
|
+
async with info.context.db() as session:
|
|
489
|
+
trace = await session.scalar(trace_stmt)
|
|
490
|
+
if trace is None:
|
|
491
|
+
raise NotFound(f"Unknown trace: {id}")
|
|
492
|
+
return to_gql_trace(trace)
|
|
493
|
+
elif type_name == Span.__name__:
|
|
494
|
+
span_stmt = (
|
|
495
|
+
select(models.Span)
|
|
496
|
+
.options(
|
|
497
|
+
joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
|
|
498
|
+
)
|
|
499
|
+
.where(models.Span.id == node_id)
|
|
500
|
+
)
|
|
501
|
+
async with info.context.db() as session:
|
|
502
|
+
span = await session.scalar(span_stmt)
|
|
503
|
+
if span is None:
|
|
504
|
+
raise NotFound(f"Unknown span: {id}")
|
|
505
|
+
return to_gql_span(span)
|
|
506
|
+
elif type_name == Dataset.__name__:
|
|
507
|
+
dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
|
|
508
|
+
async with info.context.db() as session:
|
|
509
|
+
if (dataset := await session.scalar(dataset_stmt)) is None:
|
|
510
|
+
raise NotFound(f"Unknown dataset: {id}")
|
|
511
|
+
return to_gql_dataset(dataset)
|
|
512
|
+
elif type_name == DatasetExample.__name__:
|
|
513
|
+
example_id = node_id
|
|
514
|
+
latest_revision_id = (
|
|
515
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
516
|
+
.where(models.DatasetExampleRevision.dataset_example_id == example_id)
|
|
517
|
+
.scalar_subquery()
|
|
518
|
+
)
|
|
519
|
+
async with info.context.db() as session:
|
|
520
|
+
example = await session.scalar(
|
|
521
|
+
select(models.DatasetExample)
|
|
522
|
+
.join(
|
|
523
|
+
models.DatasetExampleRevision,
|
|
524
|
+
onclause=models.DatasetExampleRevision.dataset_example_id
|
|
525
|
+
== models.DatasetExample.id,
|
|
526
|
+
)
|
|
527
|
+
.where(
|
|
528
|
+
and_(
|
|
529
|
+
models.DatasetExample.id == example_id,
|
|
530
|
+
models.DatasetExampleRevision.id == latest_revision_id,
|
|
531
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
532
|
+
)
|
|
533
|
+
)
|
|
534
|
+
)
|
|
535
|
+
if not example:
|
|
536
|
+
raise NotFound(f"Unknown dataset example: {id}")
|
|
537
|
+
return DatasetExample(
|
|
538
|
+
id_attr=example.id,
|
|
539
|
+
created_at=example.created_at,
|
|
540
|
+
)
|
|
541
|
+
elif type_name == Experiment.__name__:
|
|
542
|
+
async with info.context.db() as session:
|
|
543
|
+
experiment = await session.scalar(
|
|
544
|
+
select(models.Experiment).where(models.Experiment.id == node_id)
|
|
545
|
+
)
|
|
546
|
+
if not experiment:
|
|
547
|
+
raise NotFound(f"Unknown experiment: {id}")
|
|
548
|
+
return Experiment(
|
|
549
|
+
id_attr=experiment.id,
|
|
550
|
+
name=experiment.name,
|
|
551
|
+
project_name=experiment.project_name,
|
|
552
|
+
description=experiment.description,
|
|
553
|
+
created_at=experiment.created_at,
|
|
554
|
+
updated_at=experiment.updated_at,
|
|
555
|
+
metadata=experiment.metadata_,
|
|
556
|
+
)
|
|
557
|
+
elif type_name == ExperimentRun.__name__:
|
|
558
|
+
async with info.context.db() as session:
|
|
559
|
+
if not (
|
|
560
|
+
run := await session.scalar(
|
|
561
|
+
select(models.ExperimentRun)
|
|
562
|
+
.where(models.ExperimentRun.id == node_id)
|
|
563
|
+
.options(
|
|
564
|
+
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
565
|
+
)
|
|
566
|
+
)
|
|
567
|
+
):
|
|
568
|
+
raise NotFound(f"Unknown experiment run: {id}")
|
|
569
|
+
return to_gql_experiment_run(run)
|
|
570
|
+
elif type_name == User.__name__:
|
|
571
|
+
if int((user := info.context.user).identity) != node_id and not user.is_admin:
|
|
572
|
+
raise Unauthorized(MSG_ADMIN_ONLY)
|
|
573
|
+
async with info.context.db() as session:
|
|
574
|
+
if not (
|
|
575
|
+
user := await session.scalar(
|
|
576
|
+
select(models.User).where(models.User.id == node_id)
|
|
577
|
+
)
|
|
578
|
+
):
|
|
579
|
+
raise NotFound(f"Unknown user: {id}")
|
|
580
|
+
return to_gql_user(user)
|
|
581
|
+
elif type_name == ProjectSession.__name__:
|
|
582
|
+
async with info.context.db() as session:
|
|
583
|
+
if not (
|
|
584
|
+
project_session := await session.scalar(
|
|
585
|
+
select(models.ProjectSession).filter_by(id=node_id)
|
|
586
|
+
)
|
|
587
|
+
):
|
|
588
|
+
raise NotFound(f"Unknown user: {id}")
|
|
589
|
+
return to_gql_project_session(project_session)
|
|
590
|
+
raise NotFound(f"Unknown node type: {type_name}")
|
|
591
|
+
|
|
592
|
+
@strawberry.field
|
|
593
|
+
async def viewer(self, info: Info[Context, None]) -> Optional[User]:
|
|
594
|
+
request = info.context.get_request()
|
|
595
|
+
try:
|
|
596
|
+
user = request.user
|
|
597
|
+
except AssertionError:
|
|
598
|
+
return None
|
|
599
|
+
if isinstance(user, UnauthenticatedUser):
|
|
600
|
+
return None
|
|
601
|
+
async with info.context.db() as session:
|
|
602
|
+
if (
|
|
603
|
+
user := await session.scalar(
|
|
604
|
+
select(models.User)
|
|
605
|
+
.where(models.User.id == int(user.identity))
|
|
606
|
+
.options(joinedload(models.User.role))
|
|
607
|
+
)
|
|
608
|
+
) is None:
|
|
609
|
+
return None
|
|
610
|
+
return to_gql_user(user)
|
|
611
|
+
|
|
612
|
+
@strawberry.field
|
|
613
|
+
def clusters(
|
|
614
|
+
self,
|
|
615
|
+
clusters: list[ClusterInput],
|
|
616
|
+
) -> list[Cluster]:
|
|
617
|
+
clustered_events: dict[str, set[ID]] = defaultdict(set)
|
|
618
|
+
for i, cluster in enumerate(clusters):
|
|
619
|
+
clustered_events[cluster.id or str(i)].update(cluster.event_ids)
|
|
620
|
+
return to_gql_clusters(
|
|
621
|
+
clustered_events=clustered_events,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
@strawberry.field
|
|
625
|
+
def hdbscan_clustering(
|
|
626
|
+
self,
|
|
627
|
+
info: Info[Context, None],
|
|
628
|
+
event_ids: Annotated[
|
|
629
|
+
list[ID],
|
|
630
|
+
strawberry.argument(
|
|
631
|
+
description="Event ID of the coordinates",
|
|
632
|
+
),
|
|
633
|
+
],
|
|
634
|
+
coordinates_2d: Annotated[
|
|
635
|
+
Optional[list[InputCoordinate2D]],
|
|
636
|
+
strawberry.argument(
|
|
637
|
+
description="Point coordinates. Must be either 2D or 3D.",
|
|
638
|
+
),
|
|
639
|
+
] = UNSET,
|
|
640
|
+
coordinates_3d: Annotated[
|
|
641
|
+
Optional[list[InputCoordinate3D]],
|
|
642
|
+
strawberry.argument(
|
|
643
|
+
description="Point coordinates. Must be either 2D or 3D.",
|
|
644
|
+
),
|
|
645
|
+
] = UNSET,
|
|
646
|
+
min_cluster_size: Annotated[
|
|
647
|
+
int,
|
|
648
|
+
strawberry.argument(
|
|
649
|
+
description="HDBSCAN minimum cluster size",
|
|
650
|
+
),
|
|
651
|
+
] = DEFAULT_MIN_CLUSTER_SIZE,
|
|
652
|
+
cluster_min_samples: Annotated[
|
|
653
|
+
int,
|
|
654
|
+
strawberry.argument(
|
|
655
|
+
description="HDBSCAN minimum samples",
|
|
656
|
+
),
|
|
657
|
+
] = DEFAULT_MIN_SAMPLES,
|
|
658
|
+
cluster_selection_epsilon: Annotated[
|
|
659
|
+
float,
|
|
660
|
+
strawberry.argument(
|
|
661
|
+
description="HDBSCAN cluster selection epsilon",
|
|
662
|
+
),
|
|
663
|
+
] = DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
664
|
+
) -> list[Cluster]:
|
|
665
|
+
coordinates_3d = ensure_list(coordinates_3d)
|
|
666
|
+
coordinates_2d = ensure_list(coordinates_2d)
|
|
667
|
+
|
|
668
|
+
if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
|
|
669
|
+
raise ValueError("must specify only one of 2D or 3D coordinates")
|
|
670
|
+
|
|
671
|
+
if len(coordinates_3d) > 0:
|
|
672
|
+
coordinates = list(
|
|
673
|
+
map(
|
|
674
|
+
lambda coord: np.array(
|
|
675
|
+
[coord.x, coord.y, coord.z],
|
|
676
|
+
),
|
|
677
|
+
coordinates_3d,
|
|
678
|
+
)
|
|
679
|
+
)
|
|
680
|
+
else:
|
|
681
|
+
coordinates = list(
|
|
682
|
+
map(
|
|
683
|
+
lambda coord: np.array(
|
|
684
|
+
[coord.x, coord.y],
|
|
685
|
+
),
|
|
686
|
+
coordinates_2d,
|
|
687
|
+
)
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
if len(event_ids) != len(coordinates):
|
|
691
|
+
raise ValueError(
|
|
692
|
+
f"length mismatch between "
|
|
693
|
+
f"event_ids ({len(event_ids)}) "
|
|
694
|
+
f"and coordinates ({len(coordinates)})"
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
if len(event_ids) == 0:
|
|
698
|
+
return []
|
|
699
|
+
|
|
700
|
+
grouped_event_ids: dict[
|
|
701
|
+
Union[InferencesRole, AncillaryInferencesRole],
|
|
702
|
+
list[ID],
|
|
703
|
+
] = defaultdict(list)
|
|
704
|
+
grouped_coordinates: dict[
|
|
705
|
+
Union[InferencesRole, AncillaryInferencesRole],
|
|
706
|
+
list[npt.NDArray[np.float64]],
|
|
707
|
+
] = defaultdict(list)
|
|
708
|
+
|
|
709
|
+
for event_id, coordinate in zip(event_ids, coordinates):
|
|
710
|
+
row_id, inferences_role = unpack_event_id(event_id)
|
|
711
|
+
grouped_coordinates[inferences_role].append(coordinate)
|
|
712
|
+
grouped_event_ids[inferences_role].append(create_event_id(row_id, inferences_role))
|
|
713
|
+
|
|
714
|
+
stacked_event_ids = (
|
|
715
|
+
grouped_event_ids[InferencesRole.primary]
|
|
716
|
+
+ grouped_event_ids[InferencesRole.reference]
|
|
717
|
+
+ grouped_event_ids[AncillaryInferencesRole.corpus]
|
|
718
|
+
)
|
|
719
|
+
stacked_coordinates = np.stack(
|
|
720
|
+
grouped_coordinates[InferencesRole.primary]
|
|
721
|
+
+ grouped_coordinates[InferencesRole.reference]
|
|
722
|
+
+ grouped_coordinates[AncillaryInferencesRole.corpus]
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
clusters = Hdbscan(
|
|
726
|
+
min_cluster_size=min_cluster_size,
|
|
727
|
+
min_samples=cluster_min_samples,
|
|
728
|
+
cluster_selection_epsilon=cluster_selection_epsilon,
|
|
729
|
+
).find_clusters(stacked_coordinates)
|
|
730
|
+
|
|
731
|
+
clustered_events = {
|
|
732
|
+
str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
|
|
733
|
+
for i, cluster in enumerate(clusters)
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
return to_gql_clusters(
|
|
737
|
+
clustered_events=clustered_events,
|
|
738
|
+
)
|