arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.1.dist-info/METADATA +261 -0
- arize_phoenix-7.7.1.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -241
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +4 -112
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.1.dist-info/METADATA +0 -495
- arize_phoenix-3.16.1.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -619
- phoenix/core/traces.py +0 -96
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -0,0 +1,763 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import operator
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from hashlib import sha256
|
|
7
|
+
from typing import Any, Callable, Literal, Optional, Union, get_args
|
|
8
|
+
|
|
9
|
+
from sqlalchemy import (
|
|
10
|
+
BinaryExpression,
|
|
11
|
+
Boolean,
|
|
12
|
+
Float,
|
|
13
|
+
Integer,
|
|
14
|
+
Null,
|
|
15
|
+
Select,
|
|
16
|
+
String,
|
|
17
|
+
and_,
|
|
18
|
+
cast,
|
|
19
|
+
literal,
|
|
20
|
+
or_,
|
|
21
|
+
)
|
|
22
|
+
from sqlalchemy.orm import aliased
|
|
23
|
+
from sqlalchemy.sql import operators as sqlalchemy_operators
|
|
24
|
+
from typing_extensions import TypeAlias, TypeGuard, assert_never
|
|
25
|
+
|
|
26
|
+
from phoenix.db import models
|
|
27
|
+
|
|
28
|
+
SupportedComparisonOperator: TypeAlias = Union[
|
|
29
|
+
ast.Is,
|
|
30
|
+
ast.IsNot,
|
|
31
|
+
ast.In,
|
|
32
|
+
ast.NotIn,
|
|
33
|
+
ast.Eq,
|
|
34
|
+
ast.NotEq,
|
|
35
|
+
ast.Lt,
|
|
36
|
+
ast.LtE,
|
|
37
|
+
ast.Gt,
|
|
38
|
+
ast.GtE,
|
|
39
|
+
]
|
|
40
|
+
SupportedConstantType: TypeAlias = Union[bool, int, float, str, None]
|
|
41
|
+
SQLAlchemyDataType: TypeAlias = Union[Boolean, Integer, Float[float], String]
|
|
42
|
+
ExperimentID: TypeAlias = int
|
|
43
|
+
SupportedUnaryBooleanOperator: TypeAlias = ast.Not
|
|
44
|
+
SupportedUnaryTermOperator: TypeAlias = ast.USub
|
|
45
|
+
SupportedDatasetExampleAttributeName: TypeAlias = Literal["input", "reference_output", "metadata"]
|
|
46
|
+
SupportedExperimentRunAttributeName: TypeAlias = Literal["output", "error", "latency_ms", "evals"]
|
|
47
|
+
SupportedExperimentRunEvalAttributeName: TypeAlias = Literal["score", "explanation", "label"]
|
|
48
|
+
EvalName: TypeAlias = str
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def update_examples_query_with_filter_condition(
|
|
52
|
+
query: Select[Any], filter_condition: str, experiment_ids: list[int]
|
|
53
|
+
) -> Select[Any]:
|
|
54
|
+
orm_filter_condition, transformer = compile_sqlalchemy_filter_condition(
|
|
55
|
+
filter_condition=filter_condition, experiment_ids=experiment_ids
|
|
56
|
+
)
|
|
57
|
+
for experiment_id in experiment_ids:
|
|
58
|
+
experiment_runs = transformer.get_experiment_runs_alias(experiment_id)
|
|
59
|
+
if experiment_runs is None:
|
|
60
|
+
continue
|
|
61
|
+
query = query.join(
|
|
62
|
+
experiment_runs,
|
|
63
|
+
onclause=and_(
|
|
64
|
+
experiment_runs.dataset_example_id == models.DatasetExample.id,
|
|
65
|
+
experiment_runs.experiment_id == experiment_id,
|
|
66
|
+
),
|
|
67
|
+
isouter=True,
|
|
68
|
+
)
|
|
69
|
+
experiment_run_annotations_aliases = transformer.get_experiment_run_annotations_aliases(
|
|
70
|
+
experiment_id
|
|
71
|
+
)
|
|
72
|
+
for eval_name, experiment_run_annotations in experiment_run_annotations_aliases.items():
|
|
73
|
+
query = query.join(
|
|
74
|
+
experiment_run_annotations,
|
|
75
|
+
onclause=(
|
|
76
|
+
and_(
|
|
77
|
+
experiment_run_annotations.experiment_run_id == experiment_runs.id,
|
|
78
|
+
experiment_run_annotations.name == eval_name,
|
|
79
|
+
)
|
|
80
|
+
),
|
|
81
|
+
isouter=True,
|
|
82
|
+
)
|
|
83
|
+
query = query.where(orm_filter_condition)
|
|
84
|
+
return query
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def compile_sqlalchemy_filter_condition(
|
|
88
|
+
filter_condition: str, experiment_ids: list[int]
|
|
89
|
+
) -> tuple[Any, "SQLAlchemyTransformer"]:
|
|
90
|
+
try:
|
|
91
|
+
original_tree = ast.parse(filter_condition, mode="eval")
|
|
92
|
+
except SyntaxError as error:
|
|
93
|
+
raise ExperimentRunFilterConditionSyntaxError(str(error))
|
|
94
|
+
|
|
95
|
+
trees_with_bound_attribute_names = _bind_free_attribute_names(original_tree, experiment_ids)
|
|
96
|
+
has_free_attribute_names = bool(trees_with_bound_attribute_names)
|
|
97
|
+
if has_free_attribute_names:
|
|
98
|
+
# compile the filter condition once for each experiment and return the disjunction
|
|
99
|
+
sqlalchemy_transformer = SQLAlchemyTransformer(experiment_ids=experiment_ids)
|
|
100
|
+
compiled_filter_conditions: dict[ExperimentID, BinaryExpression[Any]] = {}
|
|
101
|
+
for experiment_id, tree in trees_with_bound_attribute_names.items():
|
|
102
|
+
sqlalchemy_tree = sqlalchemy_transformer.visit(tree)
|
|
103
|
+
node = sqlalchemy_tree.body
|
|
104
|
+
if not isinstance(node, BooleanExpression):
|
|
105
|
+
raise ExperimentRunFilterConditionSyntaxError(
|
|
106
|
+
"Filter condition must be a boolean expression"
|
|
107
|
+
)
|
|
108
|
+
compiled_filter_conditions[experiment_id] = node.compile()
|
|
109
|
+
return or_(*compiled_filter_conditions.values()), sqlalchemy_transformer
|
|
110
|
+
|
|
111
|
+
# compile the filter condition once for all experiments
|
|
112
|
+
sqlalchemy_transformer = SQLAlchemyTransformer(experiment_ids)
|
|
113
|
+
sqlalchemy_tree = sqlalchemy_transformer.visit(original_tree)
|
|
114
|
+
node = sqlalchemy_tree.body
|
|
115
|
+
if not isinstance(node, BooleanExpression):
|
|
116
|
+
raise ExperimentRunFilterConditionSyntaxError(
|
|
117
|
+
"Filter condition must be a boolean expression"
|
|
118
|
+
)
|
|
119
|
+
compiled_filter_condition = node.compile()
|
|
120
|
+
return compiled_filter_condition, sqlalchemy_transformer
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _bind_free_attribute_names(
|
|
124
|
+
tree: ast.AST, experiment_ids: list[ExperimentID]
|
|
125
|
+
) -> dict[ExperimentID, ast.AST]:
|
|
126
|
+
trees_with_bound_attribute_names: dict[ExperimentID, ast.AST] = {}
|
|
127
|
+
for experiment_index, experiment_id in enumerate(experiment_ids):
|
|
128
|
+
binder = FreeAttributeNameBinder(experiment_index=experiment_index)
|
|
129
|
+
trees_with_bound_attribute_names[experiment_id] = binder.visit(deepcopy(tree))
|
|
130
|
+
has_free_attribute_names = binder.binds_free_attribute_name
|
|
131
|
+
if not has_free_attribute_names:
|
|
132
|
+
return {} # return early since there are no free attribute names
|
|
133
|
+
return trees_with_bound_attribute_names
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class FreeAttributeNameBinder(ast.NodeTransformer):
|
|
137
|
+
def __init__(self, *, experiment_index: int) -> None:
|
|
138
|
+
super().__init__()
|
|
139
|
+
self._experiment_index = experiment_index
|
|
140
|
+
self._binds_free_attribute_name = False
|
|
141
|
+
|
|
142
|
+
def visit_Name(self, node: ast.Name) -> Any:
|
|
143
|
+
name = node.id
|
|
144
|
+
if _is_supported_experiment_run_attribute_name(name):
|
|
145
|
+
self._binds_free_attribute_name = True
|
|
146
|
+
return ast.Attribute(
|
|
147
|
+
value=ast.Subscript(
|
|
148
|
+
value=ast.Name(id="experiments", ctx=ast.Load()),
|
|
149
|
+
slice=ast.Constant(value=self._experiment_index),
|
|
150
|
+
ctx=ast.Load(),
|
|
151
|
+
),
|
|
152
|
+
attr=name,
|
|
153
|
+
ctx=node.ctx,
|
|
154
|
+
)
|
|
155
|
+
return node
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def binds_free_attribute_name(self) -> bool:
|
|
159
|
+
return self._binds_free_attribute_name
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class ExperimentRunFilterConditionSyntaxError(Exception):
|
|
163
|
+
pass
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(frozen=True)
|
|
167
|
+
class ExperimentRunFilterConditionNode(ABC):
|
|
168
|
+
"""
|
|
169
|
+
A node in a tree representing a SQLAlchemy expression.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
ast_node: ast.AST
|
|
173
|
+
|
|
174
|
+
@abstractmethod
|
|
175
|
+
def compile(self) -> Any:
|
|
176
|
+
"""
|
|
177
|
+
Compiles the node into a SQLAlchemy expression.
|
|
178
|
+
"""
|
|
179
|
+
raise NotImplementedError
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@dataclass(frozen=True)
|
|
183
|
+
class Term(ExperimentRunFilterConditionNode):
|
|
184
|
+
@property
|
|
185
|
+
def data_type(self) -> Optional[SQLAlchemyDataType]:
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@dataclass(frozen=True)
|
|
190
|
+
class Constant(Term):
|
|
191
|
+
value: SupportedConstantType
|
|
192
|
+
|
|
193
|
+
def compile(self) -> Any:
|
|
194
|
+
value = self.value
|
|
195
|
+
if value is None:
|
|
196
|
+
return Null()
|
|
197
|
+
return literal(value)
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def data_type(self) -> Optional[SQLAlchemyDataType]:
|
|
201
|
+
value = self.value
|
|
202
|
+
if isinstance(value, bool):
|
|
203
|
+
return Boolean()
|
|
204
|
+
elif isinstance(value, int):
|
|
205
|
+
return Integer()
|
|
206
|
+
elif isinstance(value, float):
|
|
207
|
+
return Float()
|
|
208
|
+
elif isinstance(value, str):
|
|
209
|
+
return String()
|
|
210
|
+
elif value is None:
|
|
211
|
+
return None
|
|
212
|
+
assert_never(value)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ExperimentsName(ExperimentRunFilterConditionNode):
|
|
216
|
+
def compile(self) -> Any:
|
|
217
|
+
raise ExperimentRunFilterConditionSyntaxError("Select an experiment with [<index>]")
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@dataclass(frozen=True)
|
|
221
|
+
class ExperimentRun(ExperimentRunFilterConditionNode):
|
|
222
|
+
slice: Constant
|
|
223
|
+
experiment_ids: list[int]
|
|
224
|
+
experiment_id: int = field(init=False)
|
|
225
|
+
|
|
226
|
+
def __post_init__(self) -> None:
|
|
227
|
+
experiment_index = self.slice.value
|
|
228
|
+
if not isinstance(experiment_index, int):
|
|
229
|
+
raise ExperimentRunFilterConditionSyntaxError("Index to experiments must be an integer")
|
|
230
|
+
if not (0 <= experiment_index < len(self.experiment_ids)):
|
|
231
|
+
raise ExperimentRunFilterConditionSyntaxError("Select an experiment with [<index>]")
|
|
232
|
+
object.__setattr__(self, "experiment_id", self.experiment_ids[experiment_index])
|
|
233
|
+
|
|
234
|
+
def compile(self) -> Any:
|
|
235
|
+
raise ExperimentRunFilterConditionSyntaxError("Add an attribute")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@dataclass(frozen=True)
|
|
239
|
+
class Attribute(Term):
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@dataclass(frozen=True)
|
|
244
|
+
class HasAliasedTables:
|
|
245
|
+
transformer: "SQLAlchemyTransformer"
|
|
246
|
+
|
|
247
|
+
def experiment_run_alias(self, experiment_id: ExperimentID) -> Any:
|
|
248
|
+
return self.transformer.get_experiment_runs_alias(
|
|
249
|
+
experiment_id
|
|
250
|
+
) or self.transformer.create_experiment_runs_alias(experiment_id)
|
|
251
|
+
|
|
252
|
+
def experiment_run_annotation_alias(
|
|
253
|
+
self, experiment_id: ExperimentID, eval_name: EvalName
|
|
254
|
+
) -> Any:
|
|
255
|
+
return self.transformer.get_experiment_run_annotations_alias(
|
|
256
|
+
experiment_id, eval_name
|
|
257
|
+
) or self.transformer.create_experiment_run_annotations_alias(experiment_id, eval_name)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@dataclass(frozen=True)
|
|
261
|
+
class DatasetExampleAttribute(HasAliasedTables, Attribute):
|
|
262
|
+
attribute_name: str
|
|
263
|
+
_attribute_name: SupportedDatasetExampleAttributeName = field(init=False)
|
|
264
|
+
|
|
265
|
+
def __post_init__(self) -> None:
|
|
266
|
+
if not _is_supported_dataset_example_attribute(self.attribute_name):
|
|
267
|
+
raise ExperimentRunFilterConditionSyntaxError("Unknown name")
|
|
268
|
+
object.__setattr__(self, "_attribute_name", self.attribute_name)
|
|
269
|
+
|
|
270
|
+
def compile(self) -> Any:
|
|
271
|
+
attribute_name = self._attribute_name
|
|
272
|
+
if attribute_name == "input":
|
|
273
|
+
return models.DatasetExampleRevision.input
|
|
274
|
+
elif attribute_name == "reference_output":
|
|
275
|
+
return models.DatasetExampleRevision.output
|
|
276
|
+
elif attribute_name == "metadata":
|
|
277
|
+
return models.DatasetExampleRevision.metadata_
|
|
278
|
+
assert_never(attribute_name)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@dataclass(frozen=True)
|
|
282
|
+
class ExperimentRunAttribute(HasAliasedTables, Attribute):
|
|
283
|
+
attribute_name: str
|
|
284
|
+
experiment_id: int
|
|
285
|
+
_attribute_name: SupportedExperimentRunAttributeName = field(init=False)
|
|
286
|
+
|
|
287
|
+
def __post_init__(self) -> None:
|
|
288
|
+
if not _is_supported_experiment_run_attribute_name(self.attribute_name):
|
|
289
|
+
raise ExperimentRunFilterConditionSyntaxError("Unknown name")
|
|
290
|
+
object.__setattr__(self, "_attribute_name", self.attribute_name)
|
|
291
|
+
|
|
292
|
+
def compile(self) -> Any:
|
|
293
|
+
attribute_name = self._attribute_name
|
|
294
|
+
experiment_id = self.experiment_id
|
|
295
|
+
if attribute_name == "evals":
|
|
296
|
+
raise ExperimentRunFilterConditionSyntaxError("Select an eval with [<eval-name>]")
|
|
297
|
+
elif attribute_name == "output":
|
|
298
|
+
aliased_experiment_run = self.experiment_run_alias(experiment_id)
|
|
299
|
+
return aliased_experiment_run.output["task_output"]
|
|
300
|
+
elif attribute_name == "error":
|
|
301
|
+
aliased_experiment_run = self.experiment_run_alias(experiment_id)
|
|
302
|
+
return aliased_experiment_run.error
|
|
303
|
+
elif attribute_name == "latency_ms":
|
|
304
|
+
aliased_experiment_run = self.experiment_run_alias(experiment_id)
|
|
305
|
+
return aliased_experiment_run.latency_ms
|
|
306
|
+
assert_never(attribute_name)
|
|
307
|
+
|
|
308
|
+
@property
|
|
309
|
+
def is_eval_attribute(self) -> bool:
|
|
310
|
+
return self.attribute_name == "evals"
|
|
311
|
+
|
|
312
|
+
@property
|
|
313
|
+
def is_json_attribute(self) -> bool:
|
|
314
|
+
return self.attribute_name in ("input", "reference_output", "output")
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def data_type(self) -> Optional[SQLAlchemyDataType]:
|
|
318
|
+
attribute_name = self._attribute_name
|
|
319
|
+
if attribute_name == "evals":
|
|
320
|
+
return None
|
|
321
|
+
elif attribute_name == "output":
|
|
322
|
+
return None
|
|
323
|
+
elif attribute_name == "error":
|
|
324
|
+
return String()
|
|
325
|
+
elif attribute_name == "latency_ms":
|
|
326
|
+
return Float()
|
|
327
|
+
assert_never(attribute_name)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
@dataclass(frozen=True)
|
|
331
|
+
class JSONAttribute(Attribute):
|
|
332
|
+
attribute: Attribute
|
|
333
|
+
index_constant: Constant
|
|
334
|
+
_index_value: Union[int, str] = field(init=False)
|
|
335
|
+
|
|
336
|
+
def __post_init__(self) -> None:
|
|
337
|
+
index_value = self.index_constant.value
|
|
338
|
+
if not isinstance(index_value, (int, str)):
|
|
339
|
+
raise ExperimentRunFilterConditionSyntaxError("Index must be an integer or string")
|
|
340
|
+
object.__setattr__(self, "_index_value", index_value)
|
|
341
|
+
|
|
342
|
+
def compile(self) -> Any:
|
|
343
|
+
compiled_attribute = self.attribute.compile()
|
|
344
|
+
return compiled_attribute[self._index_value]
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@dataclass(frozen=True)
|
|
348
|
+
class ExperimentRunEval(ExperimentRunFilterConditionNode):
|
|
349
|
+
experiment_run_attribute: ExperimentRunAttribute
|
|
350
|
+
eval_name: str
|
|
351
|
+
experiment_id: int = field(init=False)
|
|
352
|
+
|
|
353
|
+
def __post_init__(self) -> None:
|
|
354
|
+
if not isinstance(self.eval_name, str):
|
|
355
|
+
raise ExperimentRunFilterConditionSyntaxError("Eval must be indexed by string")
|
|
356
|
+
object.__setattr__(self, "experiment_id", self.experiment_run_attribute.experiment_id)
|
|
357
|
+
|
|
358
|
+
def compile(self) -> Any:
|
|
359
|
+
raise ExperimentRunFilterConditionSyntaxError(
|
|
360
|
+
"Choose an attribute for your eval (label, score, etc.)"
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@dataclass(frozen=True)
|
|
365
|
+
class ExperimentRunEvalAttribute(HasAliasedTables, Attribute):
|
|
366
|
+
experiment_run_eval: ExperimentRunEval
|
|
367
|
+
attribute_name: str
|
|
368
|
+
experiment_id: int = field(init=False)
|
|
369
|
+
_attribute_name: SupportedExperimentRunEvalAttributeName = field(init=False)
|
|
370
|
+
_eval_name: str = field(init=False)
|
|
371
|
+
|
|
372
|
+
def __post_init__(self) -> None:
|
|
373
|
+
if not _is_supported_experiment_run_eval_attribute_name(self.attribute_name):
|
|
374
|
+
raise ExperimentRunFilterConditionSyntaxError("Unknown eval attribute")
|
|
375
|
+
object.__setattr__(self, "experiment_id", self.experiment_run_eval.experiment_id)
|
|
376
|
+
object.__setattr__(self, "_attribute_name", self.attribute_name)
|
|
377
|
+
object.__setattr__(self, "_eval_name", self.experiment_run_eval.eval_name)
|
|
378
|
+
|
|
379
|
+
def compile(self) -> Any:
|
|
380
|
+
experiment_id = self.experiment_id
|
|
381
|
+
eval_name = self._eval_name
|
|
382
|
+
attribute_name = self._attribute_name
|
|
383
|
+
experiment_run_annotations = self.experiment_run_annotation_alias(experiment_id, eval_name)
|
|
384
|
+
return getattr(experiment_run_annotations, attribute_name)
|
|
385
|
+
|
|
386
|
+
@property
|
|
387
|
+
def data_type(self) -> Optional[SQLAlchemyDataType]:
|
|
388
|
+
attribute_name = self._attribute_name
|
|
389
|
+
if attribute_name == "label":
|
|
390
|
+
return String()
|
|
391
|
+
elif attribute_name == "score":
|
|
392
|
+
return Float()
|
|
393
|
+
elif attribute_name == "explanation":
|
|
394
|
+
return String()
|
|
395
|
+
assert_never(attribute_name)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@dataclass(frozen=True)
|
|
399
|
+
class UnaryTermOperation(Term):
|
|
400
|
+
operand: Term
|
|
401
|
+
operator: SupportedUnaryTermOperator
|
|
402
|
+
|
|
403
|
+
def compile(self) -> Any:
|
|
404
|
+
operator = self.operator
|
|
405
|
+
operand = self.operand
|
|
406
|
+
sqlalchemy_operator: Callable[[Any], Any]
|
|
407
|
+
if isinstance(operator, ast.USub):
|
|
408
|
+
sqlalchemy_operator = sqlalchemy_operators.neg
|
|
409
|
+
else:
|
|
410
|
+
assert_never(operator)
|
|
411
|
+
compiled_operand = operand.compile()
|
|
412
|
+
return sqlalchemy_operator(compiled_operand)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
@dataclass(frozen=True)
|
|
416
|
+
class BooleanExpression(ExperimentRunFilterConditionNode):
|
|
417
|
+
pass
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@dataclass(frozen=True)
|
|
421
|
+
class ComparisonOperation(BooleanExpression):
|
|
422
|
+
left_operand: Term
|
|
423
|
+
right_operand: Term
|
|
424
|
+
operator: ast.cmpop
|
|
425
|
+
_operator: SupportedComparisonOperator = field(init=False)
|
|
426
|
+
|
|
427
|
+
def __post_init__(self) -> None:
|
|
428
|
+
operator = self.operator
|
|
429
|
+
if not _is_supported_comparison_operator(operator):
|
|
430
|
+
raise ExperimentRunFilterConditionSyntaxError("Unsupported comparison operator")
|
|
431
|
+
object.__setattr__(self, "_operator", operator)
|
|
432
|
+
|
|
433
|
+
def compile(self) -> Any:
|
|
434
|
+
left_operand = self.left_operand
|
|
435
|
+
right_operand = self.right_operand
|
|
436
|
+
operator = self._operator
|
|
437
|
+
compiled_left_operand = left_operand.compile()
|
|
438
|
+
compiled_right_operand = right_operand.compile()
|
|
439
|
+
cast_type = _get_cast_type_for_comparison(
|
|
440
|
+
operator=operator,
|
|
441
|
+
left_operand=left_operand,
|
|
442
|
+
right_operand=right_operand,
|
|
443
|
+
)
|
|
444
|
+
if cast_type is not None:
|
|
445
|
+
if left_operand.data_type is None:
|
|
446
|
+
compiled_left_operand = cast(compiled_left_operand, cast_type)
|
|
447
|
+
if right_operand.data_type is None:
|
|
448
|
+
compiled_right_operand = cast(compiled_right_operand, cast_type)
|
|
449
|
+
sqlalchemy_operator = _get_sqlalchemy_comparison_operator(operator)
|
|
450
|
+
return sqlalchemy_operator(compiled_left_operand, compiled_right_operand)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
@dataclass(frozen=True)
|
|
454
|
+
class UnaryBooleanOperation(BooleanExpression):
|
|
455
|
+
operand: ExperimentRunFilterConditionNode
|
|
456
|
+
operator: SupportedUnaryBooleanOperator
|
|
457
|
+
|
|
458
|
+
def __post_init__(self) -> None:
|
|
459
|
+
if not isinstance(self.operand, BooleanExpression):
|
|
460
|
+
raise ExperimentRunFilterConditionSyntaxError("Operand must be a boolean expression")
|
|
461
|
+
|
|
462
|
+
def compile(self) -> Any:
|
|
463
|
+
operator = self.operator
|
|
464
|
+
sqlalchemy_operator: Callable[[Any], Any]
|
|
465
|
+
if isinstance(operator, ast.Not):
|
|
466
|
+
sqlalchemy_operator = sqlalchemy_operators.inv
|
|
467
|
+
else:
|
|
468
|
+
assert_never(operator)
|
|
469
|
+
compiled_operand = self.operand.compile()
|
|
470
|
+
return sqlalchemy_operator(compiled_operand)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
@dataclass(frozen=True)
|
|
474
|
+
class BooleanOperation(BooleanExpression):
|
|
475
|
+
operator: ast.boolop
|
|
476
|
+
operands: list[BooleanExpression]
|
|
477
|
+
|
|
478
|
+
def __post_init__(self) -> None:
|
|
479
|
+
if len(self.operands) < 2:
|
|
480
|
+
raise ExperimentRunFilterConditionSyntaxError(
|
|
481
|
+
"Boolean operators require at least two operands"
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
def compile(self) -> Any:
|
|
485
|
+
ast_operator = self.operator
|
|
486
|
+
operands = [operand.compile() for operand in self.operands]
|
|
487
|
+
if isinstance(ast_operator, ast.And):
|
|
488
|
+
return and_(*operands)
|
|
489
|
+
elif isinstance(ast_operator, ast.Or):
|
|
490
|
+
return or_(*operands)
|
|
491
|
+
raise ExperimentRunFilterConditionSyntaxError("Unsupported boolean operator")
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
class SQLAlchemyTransformer(ast.NodeTransformer):
|
|
495
|
+
def __init__(self, experiment_ids: list[int]) -> None:
|
|
496
|
+
if not experiment_ids:
|
|
497
|
+
raise ValueError("Must provide one or more experiments")
|
|
498
|
+
self._experiment_ids = experiment_ids
|
|
499
|
+
self._aliased_experiment_runs: dict[ExperimentID, Any] = {}
|
|
500
|
+
self._aliased_experiment_run_annotations: dict[ExperimentID, dict[EvalName, Any]] = {}
|
|
501
|
+
|
|
502
|
+
def visit_Constant(self, node: ast.Constant) -> Constant:
|
|
503
|
+
return Constant(value=node.value, ast_node=node)
|
|
504
|
+
|
|
505
|
+
def visit_Name(self, node: ast.Name) -> ExperimentRunFilterConditionNode:
|
|
506
|
+
name = node.id
|
|
507
|
+
if name == "experiments":
|
|
508
|
+
return ExperimentsName(ast_node=node)
|
|
509
|
+
elif _is_supported_dataset_example_attribute(name):
|
|
510
|
+
return DatasetExampleAttribute(
|
|
511
|
+
attribute_name=name,
|
|
512
|
+
transformer=self,
|
|
513
|
+
ast_node=node,
|
|
514
|
+
)
|
|
515
|
+
raise ExperimentRunFilterConditionSyntaxError("Unknown name")
|
|
516
|
+
|
|
517
|
+
def visit_UnaryOp(self, node: ast.UnaryOp) -> Union[UnaryBooleanOperation, UnaryTermOperation]:
|
|
518
|
+
operator = node.op
|
|
519
|
+
operand = self.visit(node.operand)
|
|
520
|
+
if _is_supported_unary_boolean_operator(operator):
|
|
521
|
+
return UnaryBooleanOperation(operand=operand, operator=operator, ast_node=node)
|
|
522
|
+
if _is_supported_unary_term_operator(operator):
|
|
523
|
+
return UnaryTermOperation(operand=operand, operator=operator, ast_node=node)
|
|
524
|
+
raise ExperimentRunFilterConditionSyntaxError("Unsupported unary operator")
|
|
525
|
+
|
|
526
|
+
def visit_BoolOp(self, node: ast.BoolOp) -> BooleanOperation:
|
|
527
|
+
operator = node.op
|
|
528
|
+
operands = [self.visit(value) for value in node.values]
|
|
529
|
+
return BooleanOperation(operator=operator, operands=operands, ast_node=node)
|
|
530
|
+
|
|
531
|
+
def visit_Compare(self, node: ast.Compare) -> ExperimentRunFilterConditionNode:
|
|
532
|
+
if not (len(node.ops) == 1 and len(node.comparators) == 1):
|
|
533
|
+
raise ExperimentRunFilterConditionSyntaxError("Only binary comparisons are supported")
|
|
534
|
+
left_operand = self.visit(node.left)
|
|
535
|
+
right_operand = self.visit(node.comparators[0])
|
|
536
|
+
operator = node.ops[0]
|
|
537
|
+
return ComparisonOperation(
|
|
538
|
+
left_operand=left_operand,
|
|
539
|
+
right_operand=right_operand,
|
|
540
|
+
operator=operator,
|
|
541
|
+
ast_node=node,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def visit_Subscript(self, node: ast.Subscript) -> ExperimentRunFilterConditionNode:
|
|
545
|
+
container = self.visit(node.value)
|
|
546
|
+
key = self.visit(node.slice)
|
|
547
|
+
if isinstance(container, ExperimentsName):
|
|
548
|
+
if not isinstance(key, Constant):
|
|
549
|
+
raise ExperimentRunFilterConditionSyntaxError("Index must be a constant")
|
|
550
|
+
return ExperimentRun(
|
|
551
|
+
slice=key,
|
|
552
|
+
experiment_ids=self._experiment_ids,
|
|
553
|
+
ast_node=node,
|
|
554
|
+
)
|
|
555
|
+
if isinstance(container, ExperimentRunAttribute):
|
|
556
|
+
if container.is_eval_attribute:
|
|
557
|
+
return ExperimentRunEval(
|
|
558
|
+
experiment_run_attribute=container,
|
|
559
|
+
eval_name=key.value,
|
|
560
|
+
ast_node=node,
|
|
561
|
+
)
|
|
562
|
+
if isinstance(container, (JSONAttribute, DatasetExampleAttribute)) or (
|
|
563
|
+
isinstance(container, ExperimentRunAttribute) and container.is_json_attribute
|
|
564
|
+
):
|
|
565
|
+
return JSONAttribute(
|
|
566
|
+
attribute=container,
|
|
567
|
+
index_constant=key,
|
|
568
|
+
ast_node=node,
|
|
569
|
+
)
|
|
570
|
+
raise ExperimentRunFilterConditionSyntaxError("Invalid subscript")
|
|
571
|
+
|
|
572
|
+
def visit_Attribute(self, node: ast.Attribute) -> ExperimentRunFilterConditionNode:
|
|
573
|
+
parent = self.visit(node.value)
|
|
574
|
+
attribute_name = node.attr
|
|
575
|
+
if isinstance(parent, ExperimentRun):
|
|
576
|
+
if _is_supported_experiment_run_attribute_name(attribute_name):
|
|
577
|
+
return ExperimentRunAttribute(
|
|
578
|
+
attribute_name=attribute_name,
|
|
579
|
+
experiment_id=parent.experiment_id,
|
|
580
|
+
transformer=self,
|
|
581
|
+
ast_node=node,
|
|
582
|
+
)
|
|
583
|
+
elif _is_supported_dataset_example_attribute(attribute_name):
|
|
584
|
+
return DatasetExampleAttribute(
|
|
585
|
+
attribute_name=attribute_name,
|
|
586
|
+
transformer=self,
|
|
587
|
+
ast_node=node,
|
|
588
|
+
)
|
|
589
|
+
raise ExperimentRunFilterConditionSyntaxError("Unknown attribute")
|
|
590
|
+
if isinstance(parent, ExperimentRunEval):
|
|
591
|
+
return ExperimentRunEvalAttribute(
|
|
592
|
+
attribute_name=attribute_name,
|
|
593
|
+
experiment_run_eval=parent,
|
|
594
|
+
transformer=self,
|
|
595
|
+
ast_node=node,
|
|
596
|
+
)
|
|
597
|
+
raise ExperimentRunFilterConditionSyntaxError("Unknown attribute")
|
|
598
|
+
|
|
599
|
+
def create_experiment_runs_alias(self, experiment_id: ExperimentID) -> Any:
|
|
600
|
+
if self.get_experiment_runs_alias(experiment_id) is not None:
|
|
601
|
+
raise ValueError(f"Alias already exists for experiment ID: {experiment_id}")
|
|
602
|
+
experiment_index = self.get_experiment_index(experiment_id)
|
|
603
|
+
alias_name = f"experiment_runs_{experiment_index}"
|
|
604
|
+
aliased_table = aliased(models.ExperimentRun, name=alias_name)
|
|
605
|
+
self._aliased_experiment_runs[experiment_id] = aliased_table
|
|
606
|
+
return aliased_table
|
|
607
|
+
|
|
608
|
+
def get_experiment_runs_alias(self, experiment_id: ExperimentID) -> Any:
|
|
609
|
+
return self._aliased_experiment_runs.get(experiment_id)
|
|
610
|
+
|
|
611
|
+
def create_experiment_run_annotations_alias(
|
|
612
|
+
self, experiment_id: ExperimentID, eval_name: EvalName
|
|
613
|
+
) -> Any:
|
|
614
|
+
if self.get_experiment_run_annotations_alias(experiment_id, eval_name) is not None:
|
|
615
|
+
raise ValueError(
|
|
616
|
+
f"Alias exists for experiment ID and eval name: {(experiment_id, eval_name)}"
|
|
617
|
+
)
|
|
618
|
+
self._ensure_experiment_runs_alias_exists(
|
|
619
|
+
experiment_id
|
|
620
|
+
) # experiment_runs are needed so we have something to join experiment_run_annotations to
|
|
621
|
+
experiment_index = self.get_experiment_index(experiment_id)
|
|
622
|
+
eval_name_hash = sha256(eval_name.encode()).hexdigest()[:9]
|
|
623
|
+
alias_name = ( # postgres truncates identifiers at 63 chars, so cap the length
|
|
624
|
+
f"experiment_run_annotations_{experiment_index}_{eval_name_hash}"
|
|
625
|
+
)
|
|
626
|
+
aliased_table = aliased(models.ExperimentRunAnnotation, name=alias_name)
|
|
627
|
+
if experiment_id not in self._aliased_experiment_run_annotations:
|
|
628
|
+
self._aliased_experiment_run_annotations[experiment_id] = {}
|
|
629
|
+
self._aliased_experiment_run_annotations[experiment_id][eval_name] = aliased_table
|
|
630
|
+
return aliased_table
|
|
631
|
+
|
|
632
|
+
def get_experiment_run_annotations_alias(
|
|
633
|
+
self, experiment_id: ExperimentID, eval_name: EvalName
|
|
634
|
+
) -> Any:
|
|
635
|
+
return self._aliased_experiment_run_annotations.get(experiment_id, {}).get(eval_name)
|
|
636
|
+
|
|
637
|
+
def get_experiment_run_annotations_aliases(
|
|
638
|
+
self, experiment_id: ExperimentID
|
|
639
|
+
) -> dict[EvalName, Any]:
|
|
640
|
+
return self._aliased_experiment_run_annotations.get(experiment_id, {})
|
|
641
|
+
|
|
642
|
+
def get_experiment_index(self, experiment_id: ExperimentID) -> int:
|
|
643
|
+
return self._experiment_ids.index(experiment_id)
|
|
644
|
+
|
|
645
|
+
def _ensure_experiment_runs_alias_exists(self, experiment_id: ExperimentID) -> None:
|
|
646
|
+
if self.get_experiment_runs_alias(experiment_id) is None:
|
|
647
|
+
self.create_experiment_runs_alias(experiment_id)
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def _get_sqlalchemy_comparison_operator(
|
|
651
|
+
ast_operator: SupportedComparisonOperator,
|
|
652
|
+
) -> Callable[[Any, Any], Any]:
|
|
653
|
+
if isinstance(ast_operator, ast.Eq):
|
|
654
|
+
return operator.eq
|
|
655
|
+
elif isinstance(ast_operator, ast.NotEq):
|
|
656
|
+
return operator.ne
|
|
657
|
+
elif isinstance(ast_operator, ast.Lt):
|
|
658
|
+
return sqlalchemy_operators.lt
|
|
659
|
+
elif isinstance(ast_operator, ast.LtE):
|
|
660
|
+
return sqlalchemy_operators.le
|
|
661
|
+
elif isinstance(ast_operator, ast.Gt):
|
|
662
|
+
return sqlalchemy_operators.gt
|
|
663
|
+
elif isinstance(ast_operator, ast.GtE):
|
|
664
|
+
return sqlalchemy_operators.ge
|
|
665
|
+
elif isinstance(ast_operator, ast.Is):
|
|
666
|
+
return sqlalchemy_operators.is_
|
|
667
|
+
elif isinstance(ast_operator, ast.IsNot):
|
|
668
|
+
return sqlalchemy_operators.is_not
|
|
669
|
+
elif isinstance(ast_operator, ast.In):
|
|
670
|
+
return lambda left, right: models.TextContains(right, left)
|
|
671
|
+
elif isinstance(ast_operator, ast.NotIn):
|
|
672
|
+
return lambda left, right: ~models.TextContains(right, left)
|
|
673
|
+
assert_never(ast_operator)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def _get_cast_type_for_comparison(
|
|
677
|
+
*,
|
|
678
|
+
operator: SupportedComparisonOperator,
|
|
679
|
+
left_operand: Term,
|
|
680
|
+
right_operand: Term,
|
|
681
|
+
) -> Optional[SQLAlchemyDataType]:
|
|
682
|
+
"""
|
|
683
|
+
Some column types (e.g., JSON columns) require an explicit cast before
|
|
684
|
+
comparing with non-null values. We don't know the true type of the value in
|
|
685
|
+
the JSON column, so we use heuristics to cast to a reasonable type given the
|
|
686
|
+
operator and operands. There are three cases:
|
|
687
|
+
|
|
688
|
+
1. Both operands have known types.
|
|
689
|
+
2. One operand has a known type and the other does not.
|
|
690
|
+
3. Neither operand has a known type, e.g., both are JSON attributes.
|
|
691
|
+
|
|
692
|
+
In the first case, a cast is not needed. In the second case, we cast the
|
|
693
|
+
operand with the unknown type to the type of the operand being compared. In
|
|
694
|
+
the third case, we cast both operands to the same type using heuristics
|
|
695
|
+
based on the operator.
|
|
696
|
+
"""
|
|
697
|
+
|
|
698
|
+
left_operand_data_type = left_operand.data_type
|
|
699
|
+
right_operand_data_type = right_operand.data_type
|
|
700
|
+
if left_operand_data_type is not None and right_operand_data_type is not None:
|
|
701
|
+
return None # Both operands have known data types, so no cast is needed.
|
|
702
|
+
|
|
703
|
+
if isinstance(operator, (ast.Gt, ast.GtE, ast.Lt, ast.LtE)):
|
|
704
|
+
# These operations should always cast to float, even if a comparison is
|
|
705
|
+
# being made to an integer.
|
|
706
|
+
return Float()
|
|
707
|
+
|
|
708
|
+
if isinstance(operator, (ast.In, ast.NotIn)):
|
|
709
|
+
# These operations are performed on strings.
|
|
710
|
+
return String()
|
|
711
|
+
|
|
712
|
+
# If one operand is None, don't cast.
|
|
713
|
+
left_operand_is_null = isinstance(left_operand, Constant) and left_operand.value is None
|
|
714
|
+
right_operand_is_null = isinstance(right_operand, Constant) and right_operand.value is None
|
|
715
|
+
if left_operand_is_null or right_operand_is_null:
|
|
716
|
+
return None
|
|
717
|
+
|
|
718
|
+
# If one operand has a known type and the other does not, cast to the known type.
|
|
719
|
+
if left_operand_data_type is not None and right_operand_data_type is None:
|
|
720
|
+
return left_operand_data_type
|
|
721
|
+
elif left_operand_data_type is None and right_operand_data_type is not None:
|
|
722
|
+
return right_operand_data_type
|
|
723
|
+
|
|
724
|
+
# If neither operand has a known type, we infer a cast type from the comparison operator.
|
|
725
|
+
if isinstance(operator, (ast.Eq, ast.NotEq, ast.Is, ast.IsNot)):
|
|
726
|
+
return String()
|
|
727
|
+
assert_never(operator)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def _is_supported_comparison_operator(
|
|
731
|
+
operator: ast.cmpop,
|
|
732
|
+
) -> TypeGuard[SupportedComparisonOperator]:
|
|
733
|
+
return isinstance(operator, get_args(SupportedComparisonOperator))
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
def _is_supported_dataset_example_attribute(
|
|
737
|
+
name: str,
|
|
738
|
+
) -> TypeGuard[SupportedDatasetExampleAttributeName]:
|
|
739
|
+
return name in get_args(SupportedDatasetExampleAttributeName)
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def _is_supported_experiment_run_attribute_name(
|
|
743
|
+
name: str,
|
|
744
|
+
) -> TypeGuard[SupportedExperimentRunAttributeName]:
|
|
745
|
+
return name in get_args(SupportedExperimentRunAttributeName)
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def _is_supported_experiment_run_eval_attribute_name(
|
|
749
|
+
name: str,
|
|
750
|
+
) -> TypeGuard[SupportedExperimentRunEvalAttributeName]:
|
|
751
|
+
return name in get_args(SupportedExperimentRunEvalAttributeName)
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def _is_supported_unary_boolean_operator(
|
|
755
|
+
operator: ast.unaryop,
|
|
756
|
+
) -> TypeGuard[SupportedUnaryBooleanOperator]:
|
|
757
|
+
return isinstance(operator, SupportedUnaryBooleanOperator)
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def _is_supported_unary_term_operator(
|
|
761
|
+
operator: ast.unaryop,
|
|
762
|
+
) -> TypeGuard[SupportedUnaryTermOperator]:
|
|
763
|
+
return isinstance(operator, SupportedUnaryTermOperator)
|