arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.1.dist-info/METADATA +261 -0
- arize_phoenix-7.7.1.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -241
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +4 -112
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.1.dist-info/METADATA +0 -495
- arize_phoenix-3.16.1.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -619
- phoenix/core/traces.py +0 -96
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -0,0 +1,597 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from collections.abc import AsyncIterator, Iterator
|
|
4
|
+
from datetime import datetime, timedelta, timezone
|
|
5
|
+
from typing import (
|
|
6
|
+
Any,
|
|
7
|
+
AsyncGenerator,
|
|
8
|
+
Coroutine,
|
|
9
|
+
Iterable,
|
|
10
|
+
Mapping,
|
|
11
|
+
Optional,
|
|
12
|
+
Sequence,
|
|
13
|
+
TypeVar,
|
|
14
|
+
cast,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
import strawberry
|
|
18
|
+
from openinference.instrumentation import safe_json_dumps
|
|
19
|
+
from openinference.semconv.trace import SpanAttributes
|
|
20
|
+
from sqlalchemy import and_, func, insert, select
|
|
21
|
+
from sqlalchemy.orm import load_only
|
|
22
|
+
from strawberry.relay.types import GlobalID
|
|
23
|
+
from strawberry.types import Info
|
|
24
|
+
from typing_extensions import TypeAlias, assert_never
|
|
25
|
+
|
|
26
|
+
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
27
|
+
from phoenix.db import models
|
|
28
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
29
|
+
from phoenix.server.api.context import Context
|
|
30
|
+
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
31
|
+
from phoenix.server.api.helpers.playground_clients import (
|
|
32
|
+
PlaygroundStreamingClient,
|
|
33
|
+
initialize_playground_clients,
|
|
34
|
+
)
|
|
35
|
+
from phoenix.server.api.helpers.playground_registry import (
|
|
36
|
+
PLAYGROUND_CLIENT_REGISTRY,
|
|
37
|
+
)
|
|
38
|
+
from phoenix.server.api.helpers.playground_spans import (
|
|
39
|
+
get_db_experiment_run,
|
|
40
|
+
get_db_span,
|
|
41
|
+
get_db_trace,
|
|
42
|
+
streaming_llm_span,
|
|
43
|
+
)
|
|
44
|
+
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
45
|
+
ChatCompletionInput,
|
|
46
|
+
ChatCompletionOverDatasetInput,
|
|
47
|
+
)
|
|
48
|
+
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
49
|
+
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
|
|
50
|
+
ChatCompletionSubscriptionError,
|
|
51
|
+
ChatCompletionSubscriptionExperiment,
|
|
52
|
+
ChatCompletionSubscriptionPayload,
|
|
53
|
+
ChatCompletionSubscriptionResult,
|
|
54
|
+
)
|
|
55
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
56
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
57
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
58
|
+
from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
59
|
+
from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
|
|
60
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
61
|
+
from phoenix.server.api.types.Span import to_gql_span
|
|
62
|
+
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
63
|
+
from phoenix.server.dml_event import SpanInsertEvent
|
|
64
|
+
from phoenix.server.types import DbSessionFactory
|
|
65
|
+
from phoenix.utilities.template_formatters import (
|
|
66
|
+
FStringTemplateFormatter,
|
|
67
|
+
MustacheTemplateFormatter,
|
|
68
|
+
NoOpFormatter,
|
|
69
|
+
TemplateFormatter,
|
|
70
|
+
TemplateFormatterError,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
GenericType = TypeVar("GenericType")
|
|
74
|
+
|
|
75
|
+
logger = logging.getLogger(__name__)
|
|
76
|
+
|
|
77
|
+
initialize_playground_clients()
|
|
78
|
+
|
|
79
|
+
ChatCompletionMessage: TypeAlias = tuple[
|
|
80
|
+
ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
|
|
81
|
+
]
|
|
82
|
+
DatasetExampleID: TypeAlias = GlobalID
|
|
83
|
+
ChatCompletionResult: TypeAlias = tuple[
|
|
84
|
+
DatasetExampleID, Optional[models.Span], models.ExperimentRun
|
|
85
|
+
]
|
|
86
|
+
ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
|
|
87
|
+
PLAYGROUND_PROJECT_NAME = "playground"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@strawberry.type
|
|
91
|
+
class Subscription:
|
|
92
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
93
|
+
async def chat_completion(
|
|
94
|
+
self, info: Info[Context, None], input: ChatCompletionInput
|
|
95
|
+
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
96
|
+
provider_key = input.model.provider_key
|
|
97
|
+
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
98
|
+
if llm_client_class is None:
|
|
99
|
+
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
100
|
+
try:
|
|
101
|
+
llm_client = llm_client_class(
|
|
102
|
+
model=input.model,
|
|
103
|
+
api_key=input.api_key,
|
|
104
|
+
)
|
|
105
|
+
except CustomGraphQLError:
|
|
106
|
+
raise
|
|
107
|
+
except Exception as error:
|
|
108
|
+
raise BadRequest(
|
|
109
|
+
f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
|
|
110
|
+
f"{str(error)}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
messages = [
|
|
114
|
+
(
|
|
115
|
+
message.role,
|
|
116
|
+
message.content,
|
|
117
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
118
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
119
|
+
)
|
|
120
|
+
for message in input.messages
|
|
121
|
+
]
|
|
122
|
+
attributes = None
|
|
123
|
+
if template_options := input.template:
|
|
124
|
+
messages = list(
|
|
125
|
+
_formatted_messages(
|
|
126
|
+
messages=messages,
|
|
127
|
+
template_language=template_options.language,
|
|
128
|
+
template_variables=template_options.variables,
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
|
|
132
|
+
invocation_parameters = llm_client.construct_invocation_parameters(
|
|
133
|
+
input.invocation_parameters
|
|
134
|
+
)
|
|
135
|
+
async with streaming_llm_span(
|
|
136
|
+
input=input,
|
|
137
|
+
messages=messages,
|
|
138
|
+
invocation_parameters=invocation_parameters,
|
|
139
|
+
attributes=attributes,
|
|
140
|
+
) as span:
|
|
141
|
+
async for chunk in llm_client.chat_completion_create(
|
|
142
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
143
|
+
):
|
|
144
|
+
span.add_response_chunk(chunk)
|
|
145
|
+
yield chunk
|
|
146
|
+
span.set_attributes(llm_client.attributes)
|
|
147
|
+
if span.status_message is not None:
|
|
148
|
+
yield ChatCompletionSubscriptionError(message=span.status_message)
|
|
149
|
+
async with info.context.db() as session:
|
|
150
|
+
if (
|
|
151
|
+
playground_project_id := await session.scalar(
|
|
152
|
+
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
|
|
153
|
+
)
|
|
154
|
+
) is None:
|
|
155
|
+
playground_project_id = await session.scalar(
|
|
156
|
+
insert(models.Project)
|
|
157
|
+
.returning(models.Project.id)
|
|
158
|
+
.values(
|
|
159
|
+
name=PLAYGROUND_PROJECT_NAME,
|
|
160
|
+
description="Traces from prompt playground",
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
db_trace = get_db_trace(span, playground_project_id)
|
|
164
|
+
db_span = get_db_span(span, db_trace)
|
|
165
|
+
session.add(db_span)
|
|
166
|
+
await session.flush()
|
|
167
|
+
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
168
|
+
yield ChatCompletionSubscriptionResult(span=to_gql_span(db_span))
|
|
169
|
+
|
|
170
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
171
|
+
async def chat_completion_over_dataset(
|
|
172
|
+
self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
|
|
173
|
+
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
174
|
+
provider_key = input.model.provider_key
|
|
175
|
+
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
176
|
+
if llm_client_class is None:
|
|
177
|
+
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
178
|
+
try:
|
|
179
|
+
llm_client = llm_client_class(
|
|
180
|
+
model=input.model,
|
|
181
|
+
api_key=input.api_key,
|
|
182
|
+
)
|
|
183
|
+
except CustomGraphQLError:
|
|
184
|
+
raise
|
|
185
|
+
except Exception as error:
|
|
186
|
+
raise BadRequest(
|
|
187
|
+
f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
|
|
188
|
+
f"{str(error)}"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
|
|
192
|
+
version_id = (
|
|
193
|
+
from_global_id_with_expected_type(
|
|
194
|
+
global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
195
|
+
)
|
|
196
|
+
if input.dataset_version_id
|
|
197
|
+
else None
|
|
198
|
+
)
|
|
199
|
+
async with info.context.db() as session:
|
|
200
|
+
if (
|
|
201
|
+
dataset := await session.scalar(
|
|
202
|
+
select(models.Dataset).where(models.Dataset.id == dataset_id)
|
|
203
|
+
)
|
|
204
|
+
) is None:
|
|
205
|
+
raise NotFound(f"Could not find dataset with ID {dataset_id}")
|
|
206
|
+
if version_id is None:
|
|
207
|
+
if (
|
|
208
|
+
resolved_version_id := await session.scalar(
|
|
209
|
+
select(models.DatasetVersion.id)
|
|
210
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
211
|
+
.order_by(models.DatasetVersion.id.desc())
|
|
212
|
+
.limit(1)
|
|
213
|
+
)
|
|
214
|
+
) is None:
|
|
215
|
+
raise NotFound(f"No versions found for dataset with ID {dataset_id}")
|
|
216
|
+
else:
|
|
217
|
+
if (
|
|
218
|
+
resolved_version_id := await session.scalar(
|
|
219
|
+
select(models.DatasetVersion.id).where(
|
|
220
|
+
and_(
|
|
221
|
+
models.DatasetVersion.dataset_id == dataset_id,
|
|
222
|
+
models.DatasetVersion.id == version_id,
|
|
223
|
+
)
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
) is None:
|
|
227
|
+
raise NotFound(f"Could not find dataset version with ID {version_id}")
|
|
228
|
+
revision_ids = (
|
|
229
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
230
|
+
.join(models.DatasetExample)
|
|
231
|
+
.where(
|
|
232
|
+
and_(
|
|
233
|
+
models.DatasetExample.dataset_id == dataset_id,
|
|
234
|
+
models.DatasetExampleRevision.dataset_version_id <= resolved_version_id,
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
238
|
+
)
|
|
239
|
+
if not (
|
|
240
|
+
revisions := [
|
|
241
|
+
rev
|
|
242
|
+
async for rev in await session.stream_scalars(
|
|
243
|
+
select(models.DatasetExampleRevision)
|
|
244
|
+
.where(
|
|
245
|
+
and_(
|
|
246
|
+
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
247
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id.asc())
|
|
251
|
+
.options(
|
|
252
|
+
load_only(
|
|
253
|
+
models.DatasetExampleRevision.dataset_example_id,
|
|
254
|
+
models.DatasetExampleRevision.input,
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
]
|
|
259
|
+
):
|
|
260
|
+
raise NotFound("No examples found for the given dataset and version")
|
|
261
|
+
if (
|
|
262
|
+
playground_project_id := await session.scalar(
|
|
263
|
+
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
|
|
264
|
+
)
|
|
265
|
+
) is None:
|
|
266
|
+
playground_project_id = await session.scalar(
|
|
267
|
+
insert(models.Project)
|
|
268
|
+
.returning(models.Project.id)
|
|
269
|
+
.values(
|
|
270
|
+
name=PLAYGROUND_PROJECT_NAME,
|
|
271
|
+
description="Traces from prompt playground",
|
|
272
|
+
)
|
|
273
|
+
)
|
|
274
|
+
experiment = models.Experiment(
|
|
275
|
+
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
276
|
+
dataset_version_id=resolved_version_id,
|
|
277
|
+
name=input.experiment_name or _default_playground_experiment_name(),
|
|
278
|
+
description=input.experiment_description
|
|
279
|
+
or _default_playground_experiment_description(dataset_name=dataset.name),
|
|
280
|
+
repetitions=1,
|
|
281
|
+
metadata_=input.experiment_metadata
|
|
282
|
+
or _default_playground_experiment_metadata(
|
|
283
|
+
dataset_name=dataset.name,
|
|
284
|
+
dataset_id=input.dataset_id,
|
|
285
|
+
version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
|
|
286
|
+
),
|
|
287
|
+
project_name=PLAYGROUND_PROJECT_NAME,
|
|
288
|
+
)
|
|
289
|
+
session.add(experiment)
|
|
290
|
+
await session.flush()
|
|
291
|
+
yield ChatCompletionSubscriptionExperiment(
|
|
292
|
+
experiment=to_gql_experiment(experiment)
|
|
293
|
+
) # eagerly yields experiment so it can be linked by consumers of the subscription
|
|
294
|
+
|
|
295
|
+
results: asyncio.Queue[ChatCompletionResult] = asyncio.Queue()
|
|
296
|
+
not_started: list[tuple[DatasetExampleID, ChatStream]] = [
|
|
297
|
+
(
|
|
298
|
+
GlobalID(DatasetExample.__name__, str(revision.dataset_example_id)),
|
|
299
|
+
_stream_chat_completion_over_dataset_example(
|
|
300
|
+
input=input,
|
|
301
|
+
llm_client=llm_client,
|
|
302
|
+
revision=revision,
|
|
303
|
+
results=results,
|
|
304
|
+
experiment_id=experiment.id,
|
|
305
|
+
project_id=playground_project_id,
|
|
306
|
+
),
|
|
307
|
+
)
|
|
308
|
+
for revision in revisions
|
|
309
|
+
]
|
|
310
|
+
in_progress: list[
|
|
311
|
+
tuple[
|
|
312
|
+
Optional[DatasetExampleID],
|
|
313
|
+
ChatStream,
|
|
314
|
+
asyncio.Task[ChatCompletionSubscriptionPayload],
|
|
315
|
+
]
|
|
316
|
+
] = []
|
|
317
|
+
max_in_progress = 3
|
|
318
|
+
write_batch_size = 10
|
|
319
|
+
write_interval = timedelta(seconds=10)
|
|
320
|
+
last_write_time = datetime.now()
|
|
321
|
+
while not_started or in_progress:
|
|
322
|
+
while not_started and len(in_progress) < max_in_progress:
|
|
323
|
+
ex_id, stream = not_started.pop()
|
|
324
|
+
task = _create_task_with_timeout(stream)
|
|
325
|
+
in_progress.append((ex_id, stream, task))
|
|
326
|
+
async_tasks_to_run = [task for _, _, task in in_progress]
|
|
327
|
+
completed_tasks, _ = await asyncio.wait(
|
|
328
|
+
async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED
|
|
329
|
+
)
|
|
330
|
+
for completed_task in completed_tasks:
|
|
331
|
+
idx = [task for _, _, task in in_progress].index(completed_task)
|
|
332
|
+
example_id, stream, _ = in_progress[idx]
|
|
333
|
+
try:
|
|
334
|
+
yield completed_task.result()
|
|
335
|
+
except StopAsyncIteration:
|
|
336
|
+
del in_progress[idx] # removes exhausted stream
|
|
337
|
+
except asyncio.TimeoutError:
|
|
338
|
+
del in_progress[idx] # removes timed-out stream
|
|
339
|
+
if example_id is not None:
|
|
340
|
+
yield ChatCompletionSubscriptionError(
|
|
341
|
+
message="Playground task timed out", dataset_example_id=example_id
|
|
342
|
+
)
|
|
343
|
+
except Exception as error:
|
|
344
|
+
del in_progress[idx] # removes failed stream
|
|
345
|
+
if example_id is not None:
|
|
346
|
+
yield ChatCompletionSubscriptionError(
|
|
347
|
+
message="An unexpected error occurred", dataset_example_id=example_id
|
|
348
|
+
)
|
|
349
|
+
logger.exception(error)
|
|
350
|
+
else:
|
|
351
|
+
task = _create_task_with_timeout(stream)
|
|
352
|
+
in_progress[idx] = (example_id, stream, task)
|
|
353
|
+
|
|
354
|
+
exceeded_write_batch_size = results.qsize() >= write_batch_size
|
|
355
|
+
exceeded_write_interval = datetime.now() - last_write_time > write_interval
|
|
356
|
+
write_already_in_progress = any(
|
|
357
|
+
_is_result_payloads_stream(stream) for _, stream, _ in in_progress
|
|
358
|
+
)
|
|
359
|
+
if (
|
|
360
|
+
not results.empty()
|
|
361
|
+
and (exceeded_write_batch_size or exceeded_write_interval)
|
|
362
|
+
and not write_already_in_progress
|
|
363
|
+
):
|
|
364
|
+
result_payloads_stream = _chat_completion_result_payloads(
|
|
365
|
+
db=info.context.db, results=_drain_no_wait(results)
|
|
366
|
+
)
|
|
367
|
+
task = _create_task_with_timeout(result_payloads_stream)
|
|
368
|
+
in_progress.append((None, result_payloads_stream, task))
|
|
369
|
+
last_write_time = datetime.now()
|
|
370
|
+
if remaining_results := await _drain(results):
|
|
371
|
+
async for result_payload in _chat_completion_result_payloads(
|
|
372
|
+
db=info.context.db, results=remaining_results
|
|
373
|
+
):
|
|
374
|
+
yield result_payload
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
async def _stream_chat_completion_over_dataset_example(
|
|
378
|
+
*,
|
|
379
|
+
input: ChatCompletionOverDatasetInput,
|
|
380
|
+
llm_client: PlaygroundStreamingClient,
|
|
381
|
+
revision: models.DatasetExampleRevision,
|
|
382
|
+
results: asyncio.Queue[ChatCompletionResult],
|
|
383
|
+
experiment_id: int,
|
|
384
|
+
project_id: int,
|
|
385
|
+
) -> ChatStream:
|
|
386
|
+
example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
|
|
387
|
+
invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
|
|
388
|
+
messages = [
|
|
389
|
+
(
|
|
390
|
+
message.role,
|
|
391
|
+
message.content,
|
|
392
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
393
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
394
|
+
)
|
|
395
|
+
for message in input.messages
|
|
396
|
+
]
|
|
397
|
+
try:
|
|
398
|
+
format_start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
399
|
+
messages = list(
|
|
400
|
+
_formatted_messages(
|
|
401
|
+
messages=messages,
|
|
402
|
+
template_language=input.template_language,
|
|
403
|
+
template_variables=revision.input,
|
|
404
|
+
)
|
|
405
|
+
)
|
|
406
|
+
except TemplateFormatterError as error:
|
|
407
|
+
format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
408
|
+
yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
|
|
409
|
+
await results.put(
|
|
410
|
+
(
|
|
411
|
+
example_id,
|
|
412
|
+
None,
|
|
413
|
+
models.ExperimentRun(
|
|
414
|
+
experiment_id=experiment_id,
|
|
415
|
+
dataset_example_id=revision.dataset_example_id,
|
|
416
|
+
trace_id=None,
|
|
417
|
+
output={},
|
|
418
|
+
repetition_number=1,
|
|
419
|
+
start_time=format_start_time,
|
|
420
|
+
end_time=format_end_time,
|
|
421
|
+
error=str(error),
|
|
422
|
+
trace=None,
|
|
423
|
+
),
|
|
424
|
+
)
|
|
425
|
+
)
|
|
426
|
+
return
|
|
427
|
+
async with streaming_llm_span(
|
|
428
|
+
input=input,
|
|
429
|
+
messages=messages,
|
|
430
|
+
invocation_parameters=invocation_parameters,
|
|
431
|
+
attributes={PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(revision.input)},
|
|
432
|
+
) as span:
|
|
433
|
+
async for chunk in llm_client.chat_completion_create(
|
|
434
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
435
|
+
):
|
|
436
|
+
span.add_response_chunk(chunk)
|
|
437
|
+
chunk.dataset_example_id = example_id
|
|
438
|
+
yield chunk
|
|
439
|
+
span.set_attributes(llm_client.attributes)
|
|
440
|
+
db_trace = get_db_trace(span, project_id)
|
|
441
|
+
db_span = get_db_span(span, db_trace)
|
|
442
|
+
db_run = get_db_experiment_run(
|
|
443
|
+
db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
|
|
444
|
+
)
|
|
445
|
+
await results.put((example_id, db_span, db_run))
|
|
446
|
+
if span.status_message is not None:
|
|
447
|
+
yield ChatCompletionSubscriptionError(
|
|
448
|
+
message=span.status_message, dataset_example_id=example_id
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
async def _chat_completion_result_payloads(
|
|
453
|
+
*,
|
|
454
|
+
db: DbSessionFactory,
|
|
455
|
+
results: Sequence[ChatCompletionResult],
|
|
456
|
+
) -> ChatStream:
|
|
457
|
+
if not results:
|
|
458
|
+
return
|
|
459
|
+
async with db() as session:
|
|
460
|
+
for _, span, run in results:
|
|
461
|
+
if span:
|
|
462
|
+
session.add(span)
|
|
463
|
+
session.add(run)
|
|
464
|
+
await session.flush()
|
|
465
|
+
for example_id, span, run in results:
|
|
466
|
+
yield ChatCompletionSubscriptionResult(
|
|
467
|
+
span=to_gql_span(span) if span else None,
|
|
468
|
+
experiment_run=to_gql_experiment_run(run),
|
|
469
|
+
dataset_example_id=example_id,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _is_result_payloads_stream(
|
|
474
|
+
stream: ChatStream,
|
|
475
|
+
) -> bool:
|
|
476
|
+
"""
|
|
477
|
+
Checks if the given generator was instantiated from
|
|
478
|
+
`_chat_completion_result_payloads`
|
|
479
|
+
"""
|
|
480
|
+
return stream.ag_code == _chat_completion_result_payloads.__code__
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _create_task_with_timeout(
|
|
484
|
+
iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 90
|
|
485
|
+
) -> asyncio.Task[GenericType]:
|
|
486
|
+
return asyncio.create_task(
|
|
487
|
+
_wait_for(
|
|
488
|
+
_as_coroutine(iterable),
|
|
489
|
+
timeout=timeout_in_seconds,
|
|
490
|
+
timeout_message="Playground task timed out",
|
|
491
|
+
)
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
async def _wait_for(
|
|
496
|
+
coro: Coroutine[None, None, GenericType],
|
|
497
|
+
timeout: float,
|
|
498
|
+
timeout_message: Optional[str] = None,
|
|
499
|
+
) -> GenericType:
|
|
500
|
+
"""
|
|
501
|
+
A function that imitates asyncio.wait_for, but allows the task to be
|
|
502
|
+
cancelled with a custom message.
|
|
503
|
+
"""
|
|
504
|
+
task = asyncio.create_task(coro)
|
|
505
|
+
done, pending = await asyncio.wait([task], timeout=timeout)
|
|
506
|
+
assert len(done) + len(pending) == 1
|
|
507
|
+
if done:
|
|
508
|
+
task = done.pop()
|
|
509
|
+
return task.result()
|
|
510
|
+
task = pending.pop()
|
|
511
|
+
task.cancel(msg=timeout_message)
|
|
512
|
+
try:
|
|
513
|
+
return await task
|
|
514
|
+
except asyncio.CancelledError:
|
|
515
|
+
raise asyncio.TimeoutError()
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
async def _drain(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
|
|
519
|
+
values: list[GenericType] = []
|
|
520
|
+
while not queue.empty():
|
|
521
|
+
values.append(await queue.get())
|
|
522
|
+
return values
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _drain_no_wait(queue: asyncio.Queue[GenericType]) -> list[GenericType]:
|
|
526
|
+
values: list[GenericType] = []
|
|
527
|
+
while True:
|
|
528
|
+
try:
|
|
529
|
+
values.append(queue.get_nowait())
|
|
530
|
+
except asyncio.QueueEmpty:
|
|
531
|
+
break
|
|
532
|
+
return values
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
|
|
536
|
+
return await iterable.__anext__()
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def _formatted_messages(
|
|
540
|
+
*,
|
|
541
|
+
messages: Iterable[ChatCompletionMessage],
|
|
542
|
+
template_language: TemplateLanguage,
|
|
543
|
+
template_variables: Mapping[str, Any],
|
|
544
|
+
) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
|
|
545
|
+
"""
|
|
546
|
+
Formats the messages using the given template options.
|
|
547
|
+
"""
|
|
548
|
+
template_formatter = _template_formatter(template_language=template_language)
|
|
549
|
+
(
|
|
550
|
+
roles,
|
|
551
|
+
templates,
|
|
552
|
+
tool_call_id,
|
|
553
|
+
tool_calls,
|
|
554
|
+
) = zip(*messages)
|
|
555
|
+
formatted_templates = map(
|
|
556
|
+
lambda template: template_formatter.format(template, **template_variables),
|
|
557
|
+
templates,
|
|
558
|
+
)
|
|
559
|
+
formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
|
|
560
|
+
return formatted_messages
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
|
|
564
|
+
"""
|
|
565
|
+
Instantiates the appropriate template formatter for the template language.
|
|
566
|
+
"""
|
|
567
|
+
if template_language is TemplateLanguage.MUSTACHE:
|
|
568
|
+
return MustacheTemplateFormatter()
|
|
569
|
+
if template_language is TemplateLanguage.F_STRING:
|
|
570
|
+
return FStringTemplateFormatter()
|
|
571
|
+
if template_language is TemplateLanguage.NONE:
|
|
572
|
+
return NoOpFormatter()
|
|
573
|
+
assert_never(template_language)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def _default_playground_experiment_name() -> str:
|
|
577
|
+
return "playground-experiment"
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def _default_playground_experiment_description(dataset_name: str) -> str:
|
|
581
|
+
return f'Playground experiment for dataset "{dataset_name}"'
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def _default_playground_experiment_metadata(
|
|
585
|
+
dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
|
|
586
|
+
) -> dict[str, Any]:
|
|
587
|
+
return {
|
|
588
|
+
"dataset_name": dataset_name,
|
|
589
|
+
"dataset_id": str(dataset_id),
|
|
590
|
+
"dataset_version_id": str(version_id),
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
595
|
+
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
596
|
+
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
597
|
+
PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@strawberry.interface
|
|
7
|
+
class Annotation:
|
|
8
|
+
name: str = strawberry.field(
|
|
9
|
+
description="Name of the annotation, e.g. 'helpfulness' or 'relevance'."
|
|
10
|
+
)
|
|
11
|
+
score: Optional[float] = strawberry.field(
|
|
12
|
+
description="Value of the annotation in the form of a numeric score."
|
|
13
|
+
)
|
|
14
|
+
label: Optional[str] = strawberry.field(
|
|
15
|
+
description="Value of the annotation in the form of a string, e.g. "
|
|
16
|
+
"'helpful' or 'not helpful'. Note that the label is not necessarily binary."
|
|
17
|
+
)
|
|
18
|
+
explanation: Optional[str] = strawberry.field(
|
|
19
|
+
description="The annotator's explanation for the annotation result (i.e. "
|
|
20
|
+
"score or label, or both) given to the subject."
|
|
21
|
+
)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Optional, Union, cast
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry import Private
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.api.types.LabelFraction import LabelFraction
|
|
9
|
+
|
|
10
|
+
AnnotationType = Union[models.SpanAnnotation, models.TraceAnnotation]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@strawberry.type
|
|
14
|
+
class AnnotationSummary:
|
|
15
|
+
df: Private[pd.DataFrame]
|
|
16
|
+
|
|
17
|
+
def __init__(self, dataframe: pd.DataFrame) -> None:
|
|
18
|
+
self.df = dataframe
|
|
19
|
+
|
|
20
|
+
@strawberry.field
|
|
21
|
+
def count(self) -> int:
|
|
22
|
+
return cast(int, self.df.record_count.sum())
|
|
23
|
+
|
|
24
|
+
@strawberry.field
|
|
25
|
+
def labels(self) -> list[str]:
|
|
26
|
+
return self.df.label.dropna().tolist()
|
|
27
|
+
|
|
28
|
+
@strawberry.field
|
|
29
|
+
def label_fractions(self) -> list[LabelFraction]:
|
|
30
|
+
if not (n := self.df.label_count.sum()):
|
|
31
|
+
return []
|
|
32
|
+
return [
|
|
33
|
+
LabelFraction(
|
|
34
|
+
label=cast(str, row.label),
|
|
35
|
+
fraction=row.label_count / n,
|
|
36
|
+
)
|
|
37
|
+
for row in self.df.loc[
|
|
38
|
+
self.df.label.notna(),
|
|
39
|
+
["label", "label_count"],
|
|
40
|
+
].itertuples()
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
@strawberry.field
|
|
44
|
+
def mean_score(self) -> Optional[float]:
|
|
45
|
+
if not (n := self.df.score_count.sum()):
|
|
46
|
+
return None
|
|
47
|
+
return cast(float, self.df.score_sum.sum() / n)
|
|
48
|
+
|
|
49
|
+
@strawberry.field
|
|
50
|
+
def score_count(self) -> int:
|
|
51
|
+
return cast(int, self.df.score_count.sum())
|
|
52
|
+
|
|
53
|
+
@strawberry.field
|
|
54
|
+
def label_count(self) -> int:
|
|
55
|
+
return cast(int, self.df.label_count.sum())
|