arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.1.dist-info/METADATA +261 -0
- arize_phoenix-7.7.1.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -241
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +4 -112
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.1.dist-info/METADATA +0 -495
- arize_phoenix-3.16.1.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -619
- phoenix/core/traces.py +0 -96
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from sqlalchemy.orm import joinedload
|
|
7
|
+
from strawberry import UNSET
|
|
8
|
+
from strawberry.relay.types import Connection, GlobalID, Node, NodeID
|
|
9
|
+
from strawberry.types import Info
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
14
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
15
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
16
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
17
|
+
from phoenix.server.api.types.pagination import (
|
|
18
|
+
ConnectionArgs,
|
|
19
|
+
CursorString,
|
|
20
|
+
connection_from_list,
|
|
21
|
+
)
|
|
22
|
+
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@strawberry.type
|
|
26
|
+
class DatasetExample(Node):
|
|
27
|
+
id_attr: NodeID[int]
|
|
28
|
+
created_at: datetime
|
|
29
|
+
version_id: strawberry.Private[Optional[int]] = None
|
|
30
|
+
|
|
31
|
+
@strawberry.field
|
|
32
|
+
async def revision(
|
|
33
|
+
self,
|
|
34
|
+
info: Info[Context, None],
|
|
35
|
+
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
36
|
+
) -> DatasetExampleRevision:
|
|
37
|
+
example_id = self.id_attr
|
|
38
|
+
version_id: Optional[int] = None
|
|
39
|
+
if dataset_version_id:
|
|
40
|
+
version_id = from_global_id_with_expected_type(
|
|
41
|
+
global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
42
|
+
)
|
|
43
|
+
elif self.version_id is not None:
|
|
44
|
+
version_id = self.version_id
|
|
45
|
+
return await info.context.data_loaders.dataset_example_revisions.load(
|
|
46
|
+
(example_id, version_id)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@strawberry.field
|
|
50
|
+
async def span(
|
|
51
|
+
self,
|
|
52
|
+
info: Info[Context, None],
|
|
53
|
+
) -> Optional[Span]:
|
|
54
|
+
return (
|
|
55
|
+
to_gql_span(span)
|
|
56
|
+
if (span := await info.context.data_loaders.dataset_example_spans.load(self.id_attr))
|
|
57
|
+
else None
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
@strawberry.field
|
|
61
|
+
async def experiment_runs(
|
|
62
|
+
self,
|
|
63
|
+
info: Info[Context, None],
|
|
64
|
+
first: Optional[int] = 50,
|
|
65
|
+
last: Optional[int] = UNSET,
|
|
66
|
+
after: Optional[CursorString] = UNSET,
|
|
67
|
+
before: Optional[CursorString] = UNSET,
|
|
68
|
+
) -> Connection[ExperimentRun]:
|
|
69
|
+
args = ConnectionArgs(
|
|
70
|
+
first=first,
|
|
71
|
+
after=after if isinstance(after, CursorString) else None,
|
|
72
|
+
last=last,
|
|
73
|
+
before=before if isinstance(before, CursorString) else None,
|
|
74
|
+
)
|
|
75
|
+
example_id = self.id_attr
|
|
76
|
+
query = (
|
|
77
|
+
select(models.ExperimentRun)
|
|
78
|
+
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
79
|
+
.join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
|
|
80
|
+
.where(models.ExperimentRun.dataset_example_id == example_id)
|
|
81
|
+
.order_by(models.Experiment.id.desc())
|
|
82
|
+
)
|
|
83
|
+
async with info.context.db() as session:
|
|
84
|
+
runs = (await session.scalars(query)).all()
|
|
85
|
+
return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
|
|
6
|
+
from phoenix.db import models
|
|
7
|
+
from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@strawberry.enum
|
|
11
|
+
class RevisionKind(Enum):
|
|
12
|
+
CREATE = "CREATE"
|
|
13
|
+
PATCH = "PATCH"
|
|
14
|
+
DELETE = "DELETE"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@strawberry.type
|
|
18
|
+
class DatasetExampleRevision(ExampleRevision):
|
|
19
|
+
"""
|
|
20
|
+
Represents a revision (i.e., update or alteration) of a dataset example.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
revision_kind: RevisionKind
|
|
24
|
+
created_at: datetime
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_orm_revision(cls, revision: models.DatasetExampleRevision) -> "DatasetExampleRevision":
|
|
28
|
+
return cls(
|
|
29
|
+
input=revision.input,
|
|
30
|
+
output=revision.output,
|
|
31
|
+
metadata=revision.metadata_,
|
|
32
|
+
revision_kind=RevisionKind(revision.revision_kind),
|
|
33
|
+
created_at=revision.created_at,
|
|
34
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry.relay import Node, NodeID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@strawberry.type
|
|
10
|
+
class DatasetVersion(Node):
|
|
11
|
+
id_attr: NodeID[int]
|
|
12
|
+
description: Optional[str]
|
|
13
|
+
metadata: JSON
|
|
14
|
+
created_at: datetime
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Optional
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import strawberry
|
|
6
6
|
from strawberry import UNSET
|
|
7
|
+
from strawberry.relay import Node, NodeID
|
|
7
8
|
from strawberry.types import Info
|
|
8
9
|
from typing_extensions import Annotated
|
|
9
10
|
|
|
@@ -17,12 +18,11 @@ from ..context import Context
|
|
|
17
18
|
from ..input_types.Granularity import Granularity
|
|
18
19
|
from ..input_types.TimeRange import TimeRange
|
|
19
20
|
from .DataQualityMetric import DataQualityMetric
|
|
20
|
-
from .DatasetRole import DatasetRole
|
|
21
21
|
from .DatasetValues import DatasetValues
|
|
22
22
|
from .DimensionDataType import DimensionDataType
|
|
23
23
|
from .DimensionShape import DimensionShape
|
|
24
24
|
from .DimensionType import DimensionType
|
|
25
|
-
from .
|
|
25
|
+
from .InferencesRole import InferencesRole
|
|
26
26
|
from .ScalarDriftMetricEnum import ScalarDriftMetric
|
|
27
27
|
from .Segments import (
|
|
28
28
|
GqlBinFactory,
|
|
@@ -40,6 +40,7 @@ from .TimeSeries import (
|
|
|
40
40
|
|
|
41
41
|
@strawberry.type
|
|
42
42
|
class Dimension(Node):
|
|
43
|
+
id_attr: NodeID[int]
|
|
43
44
|
name: str = strawberry.field(description="The name of the dimension (a.k.a. the column name)")
|
|
44
45
|
type: DimensionType = strawberry.field(
|
|
45
46
|
description="Whether the dimension represents a feature, tag, prediction, or actual."
|
|
@@ -62,16 +63,16 @@ class Dimension(Node):
|
|
|
62
63
|
"""
|
|
63
64
|
Computes a drift metric between all reference data and the primary data
|
|
64
65
|
belonging to the input time range (inclusive of the time range start and
|
|
65
|
-
exclusive of the time range end). Returns None if no reference
|
|
66
|
-
|
|
66
|
+
exclusive of the time range end). Returns None if no reference inferences
|
|
67
|
+
exist, if no primary data exists in the input time range, or if the
|
|
67
68
|
input time range is invalid.
|
|
68
69
|
"""
|
|
69
70
|
model = info.context.model
|
|
70
71
|
if model[REFERENCE].empty:
|
|
71
72
|
return None
|
|
72
|
-
|
|
73
|
+
inferences = model[PRIMARY]
|
|
73
74
|
time_range, granularity = ensure_timeseries_parameters(
|
|
74
|
-
|
|
75
|
+
inferences,
|
|
75
76
|
time_range,
|
|
76
77
|
)
|
|
77
78
|
data = get_drift_timeseries_data(
|
|
@@ -92,18 +93,18 @@ class Dimension(Node):
|
|
|
92
93
|
info: Info[Context, None],
|
|
93
94
|
metric: DataQualityMetric,
|
|
94
95
|
time_range: Optional[TimeRange] = UNSET,
|
|
95
|
-
|
|
96
|
-
Optional[
|
|
96
|
+
inferences_role: Annotated[
|
|
97
|
+
Optional[InferencesRole],
|
|
97
98
|
strawberry.argument(
|
|
98
|
-
description="The
|
|
99
|
+
description="The inferences (primary or reference) to query",
|
|
99
100
|
),
|
|
100
|
-
] =
|
|
101
|
+
] = InferencesRole.primary,
|
|
101
102
|
) -> Optional[float]:
|
|
102
|
-
if not isinstance(
|
|
103
|
-
|
|
104
|
-
|
|
103
|
+
if not isinstance(inferences_role, InferencesRole):
|
|
104
|
+
inferences_role = InferencesRole.primary
|
|
105
|
+
inferences = info.context.model[inferences_role.value]
|
|
105
106
|
time_range, granularity = ensure_timeseries_parameters(
|
|
106
|
-
|
|
107
|
+
inferences,
|
|
107
108
|
time_range,
|
|
108
109
|
)
|
|
109
110
|
data = get_data_quality_timeseries_data(
|
|
@@ -111,7 +112,7 @@ class Dimension(Node):
|
|
|
111
112
|
metric,
|
|
112
113
|
time_range,
|
|
113
114
|
granularity,
|
|
114
|
-
|
|
115
|
+
inferences_role,
|
|
115
116
|
)
|
|
116
117
|
return data[0].value if len(data) else None
|
|
117
118
|
|
|
@@ -122,7 +123,7 @@ class Dimension(Node):
|
|
|
122
123
|
" Missing values are excluded. Non-categorical dimensions return an empty list."
|
|
123
124
|
)
|
|
124
125
|
) # type: ignore # https://github.com/strawberry-graphql/strawberry/issues/1929
|
|
125
|
-
def categories(self) ->
|
|
126
|
+
def categories(self) -> list[str]:
|
|
126
127
|
return list(self.dimension.categories)
|
|
127
128
|
|
|
128
129
|
@strawberry.field(
|
|
@@ -139,18 +140,18 @@ class Dimension(Node):
|
|
|
139
140
|
metric: DataQualityMetric,
|
|
140
141
|
time_range: TimeRange,
|
|
141
142
|
granularity: Granularity,
|
|
142
|
-
|
|
143
|
-
Optional[
|
|
143
|
+
inferences_role: Annotated[
|
|
144
|
+
Optional[InferencesRole],
|
|
144
145
|
strawberry.argument(
|
|
145
|
-
description="The
|
|
146
|
+
description="The inferences (primary or reference) to query",
|
|
146
147
|
),
|
|
147
|
-
] =
|
|
148
|
+
] = InferencesRole.primary,
|
|
148
149
|
) -> DataQualityTimeSeries:
|
|
149
|
-
if not isinstance(
|
|
150
|
-
|
|
151
|
-
|
|
150
|
+
if not isinstance(inferences_role, InferencesRole):
|
|
151
|
+
inferences_role = InferencesRole.primary
|
|
152
|
+
inferences = info.context.model[inferences_role.value]
|
|
152
153
|
time_range, granularity = ensure_timeseries_parameters(
|
|
153
|
-
|
|
154
|
+
inferences,
|
|
154
155
|
time_range,
|
|
155
156
|
granularity,
|
|
156
157
|
)
|
|
@@ -160,7 +161,7 @@ class Dimension(Node):
|
|
|
160
161
|
metric,
|
|
161
162
|
time_range,
|
|
162
163
|
granularity,
|
|
163
|
-
|
|
164
|
+
inferences_role,
|
|
164
165
|
)
|
|
165
166
|
)
|
|
166
167
|
|
|
@@ -182,9 +183,9 @@ class Dimension(Node):
|
|
|
182
183
|
model = info.context.model
|
|
183
184
|
if model[REFERENCE].empty:
|
|
184
185
|
return DriftTimeSeries(data=[])
|
|
185
|
-
|
|
186
|
+
inferences = model[PRIMARY]
|
|
186
187
|
time_range, granularity = ensure_timeseries_parameters(
|
|
187
|
-
|
|
188
|
+
inferences,
|
|
188
189
|
time_range,
|
|
189
190
|
granularity,
|
|
190
191
|
)
|
|
@@ -202,7 +203,7 @@ class Dimension(Node):
|
|
|
202
203
|
)
|
|
203
204
|
|
|
204
205
|
@strawberry.field(
|
|
205
|
-
description="
|
|
206
|
+
description="The segments across both inference sets and returns the counts per segment",
|
|
206
207
|
) # type: ignore
|
|
207
208
|
def segments_comparison(
|
|
208
209
|
self,
|
|
@@ -249,8 +250,8 @@ class Dimension(Node):
|
|
|
249
250
|
if isinstance(binning_method, binning.IntervalBinning) and binning_method.bins is not None:
|
|
250
251
|
all_bins = all_bins.union(binning_method.bins)
|
|
251
252
|
for bin in all_bins:
|
|
252
|
-
values:
|
|
253
|
-
for role in ms.
|
|
253
|
+
values: dict[ms.InferencesRole, Any] = defaultdict(lambda: None)
|
|
254
|
+
for role in ms.InferencesRole:
|
|
254
255
|
if model[role].empty:
|
|
255
256
|
continue
|
|
256
257
|
try:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from collections.abc import Iterable
|
|
2
3
|
from functools import cached_property
|
|
3
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Optional
|
|
4
5
|
|
|
5
6
|
import pandas as pd
|
|
6
7
|
import strawberry
|
|
@@ -24,8 +25,8 @@ class DocumentEvaluationSummary:
|
|
|
24
25
|
) -> None:
|
|
25
26
|
self.evaluation_name = evaluation_name
|
|
26
27
|
self.metrics_collection = pd.Series(metrics_collection, dtype=object)
|
|
27
|
-
self._cached_average_ndcg_results:
|
|
28
|
-
self._cached_average_precision_results:
|
|
28
|
+
self._cached_average_ndcg_results: dict[Optional[int], tuple[float, int]] = {}
|
|
29
|
+
self._cached_average_precision_results: dict[Optional[int], tuple[float, int]] = {}
|
|
29
30
|
|
|
30
31
|
@strawberry.field
|
|
31
32
|
def average_ndcg(self, k: Optional[int] = UNSET) -> Optional[float]:
|
|
@@ -67,7 +68,7 @@ class DocumentEvaluationSummary:
|
|
|
67
68
|
_, count = self._average_hit
|
|
68
69
|
return count
|
|
69
70
|
|
|
70
|
-
def _average_ndcg(self, k: Optional[int] = None) ->
|
|
71
|
+
def _average_ndcg(self, k: Optional[int] = None) -> tuple[float, int]:
|
|
71
72
|
if (result := self._cached_average_ndcg_results.get(k)) is not None:
|
|
72
73
|
return result
|
|
73
74
|
values = self.metrics_collection.apply(lambda m: m.ndcg(k))
|
|
@@ -75,20 +76,20 @@ class DocumentEvaluationSummary:
|
|
|
75
76
|
self._cached_average_ndcg_results[k] = result
|
|
76
77
|
return result
|
|
77
78
|
|
|
78
|
-
def _average_precision(self, k: Optional[int] = None) ->
|
|
79
|
+
def _average_precision(self, k: Optional[int] = None) -> tuple[float, int]:
|
|
79
80
|
if (result := self._cached_average_precision_results.get(k)) is not None:
|
|
80
81
|
return result
|
|
81
82
|
values = self.metrics_collection.apply(lambda m: m.precision(k))
|
|
82
83
|
result = (values.mean(), values.count())
|
|
83
|
-
self.
|
|
84
|
+
self._cached_average_precision_results[k] = result
|
|
84
85
|
return result
|
|
85
86
|
|
|
86
87
|
@cached_property
|
|
87
|
-
def _average_reciprocal_rank(self) ->
|
|
88
|
+
def _average_reciprocal_rank(self) -> tuple[float, int]:
|
|
88
89
|
values = self.metrics_collection.apply(lambda m: m.reciprocal_rank())
|
|
89
90
|
return values.mean(), values.count()
|
|
90
91
|
|
|
91
92
|
@cached_property
|
|
92
|
-
def _average_hit(self) ->
|
|
93
|
+
def _average_hit(self) -> tuple[float, int]:
|
|
93
94
|
values = self.metrics_collection.apply(lambda m: m.hit())
|
|
94
95
|
return values.mean(), values.count()
|
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
|
+
from collections.abc import Iterable, Iterator
|
|
2
3
|
from datetime import timedelta
|
|
3
4
|
from itertools import chain, repeat
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Optional, Union, cast
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import numpy.typing as npt
|
|
8
9
|
import pandas as pd
|
|
9
10
|
import strawberry
|
|
10
11
|
from strawberry import UNSET
|
|
12
|
+
from strawberry.relay import GlobalID, Node, NodeID
|
|
11
13
|
from strawberry.scalars import ID
|
|
12
14
|
from strawberry.types import Info
|
|
13
15
|
from typing_extensions import Annotated
|
|
@@ -22,7 +24,7 @@ from phoenix.core.model_schema import (
|
|
|
22
24
|
PRIMARY,
|
|
23
25
|
PROMPT,
|
|
24
26
|
REFERENCE,
|
|
25
|
-
|
|
27
|
+
Inferences,
|
|
26
28
|
)
|
|
27
29
|
from phoenix.metrics.timeseries import row_interval_from_sorted_time_index
|
|
28
30
|
from phoenix.pointcloud.clustering import Hdbscan
|
|
@@ -31,7 +33,7 @@ from phoenix.pointcloud.projectors import Umap
|
|
|
31
33
|
from phoenix.server.api.context import Context
|
|
32
34
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
33
35
|
from phoenix.server.api.types.Cluster import to_gql_clusters
|
|
34
|
-
from phoenix.server.api.types.
|
|
36
|
+
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
35
37
|
from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
|
|
36
38
|
|
|
37
39
|
from ..input_types.Granularity import Granularity
|
|
@@ -39,7 +41,6 @@ from .DataQualityMetric import DataQualityMetric
|
|
|
39
41
|
from .EmbeddingMetadata import EmbeddingMetadata
|
|
40
42
|
from .Event import create_event_id, unpack_event_id
|
|
41
43
|
from .EventMetadata import EventMetadata
|
|
42
|
-
from .node import GlobalID, Node
|
|
43
44
|
from .Retrieval import Retrieval
|
|
44
45
|
from .TimeSeries import (
|
|
45
46
|
DataQualityTimeSeries,
|
|
@@ -70,6 +71,7 @@ CORPUS = "CORPUS"
|
|
|
70
71
|
class EmbeddingDimension(Node):
|
|
71
72
|
"""A embedding dimension of a model. Represents unstructured data"""
|
|
72
73
|
|
|
74
|
+
id_attr: NodeID[int]
|
|
73
75
|
name: str
|
|
74
76
|
dimension: strawberry.Private[ms.EmbeddingDimension]
|
|
75
77
|
|
|
@@ -155,16 +157,16 @@ class EmbeddingDimension(Node):
|
|
|
155
157
|
metric: DataQualityMetric,
|
|
156
158
|
time_range: TimeRange,
|
|
157
159
|
granularity: Granularity,
|
|
158
|
-
|
|
159
|
-
Optional[
|
|
160
|
+
inferences_role: Annotated[
|
|
161
|
+
Optional[InferencesRole],
|
|
160
162
|
strawberry.argument(
|
|
161
163
|
description="The dataset (primary or reference) to query",
|
|
162
164
|
),
|
|
163
|
-
] =
|
|
165
|
+
] = InferencesRole.primary,
|
|
164
166
|
) -> DataQualityTimeSeries:
|
|
165
|
-
if not isinstance(
|
|
166
|
-
|
|
167
|
-
dataset = info.context.model[
|
|
167
|
+
if not isinstance(inferences_role, InferencesRole):
|
|
168
|
+
inferences_role = InferencesRole.primary
|
|
169
|
+
dataset = info.context.model[inferences_role.value]
|
|
168
170
|
time_range, granularity = ensure_timeseries_parameters(
|
|
169
171
|
dataset,
|
|
170
172
|
time_range,
|
|
@@ -176,7 +178,7 @@ class EmbeddingDimension(Node):
|
|
|
176
178
|
metric,
|
|
177
179
|
time_range,
|
|
178
180
|
granularity,
|
|
179
|
-
|
|
181
|
+
inferences_role,
|
|
180
182
|
)
|
|
181
183
|
)
|
|
182
184
|
|
|
@@ -312,18 +314,18 @@ class EmbeddingDimension(Node):
|
|
|
312
314
|
] = DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
313
315
|
) -> UMAPPoints:
|
|
314
316
|
model = info.context.model
|
|
315
|
-
data:
|
|
316
|
-
retrievals:
|
|
317
|
-
for
|
|
318
|
-
|
|
319
|
-
row_id_start, row_id_stop = 0, len(
|
|
320
|
-
if
|
|
317
|
+
data: dict[ID, npt.NDArray[np.float64]] = {}
|
|
318
|
+
retrievals: list[tuple[ID, Any, Any]] = []
|
|
319
|
+
for inferences in model[Inferences]:
|
|
320
|
+
inferences_id = inferences.role
|
|
321
|
+
row_id_start, row_id_stop = 0, len(inferences)
|
|
322
|
+
if inferences_id is PRIMARY:
|
|
321
323
|
row_id_start, row_id_stop = row_interval_from_sorted_time_index(
|
|
322
|
-
time_index=cast(pd.DatetimeIndex,
|
|
324
|
+
time_index=cast(pd.DatetimeIndex, inferences.index),
|
|
323
325
|
time_start=time_range.start,
|
|
324
326
|
time_stop=time_range.end,
|
|
325
327
|
)
|
|
326
|
-
vector_column = self.dimension[
|
|
328
|
+
vector_column = self.dimension[inferences_id]
|
|
327
329
|
samples_collected = 0
|
|
328
330
|
for row_id in _row_indices(
|
|
329
331
|
row_id_start,
|
|
@@ -337,7 +339,7 @@ class EmbeddingDimension(Node):
|
|
|
337
339
|
# of dunder method __len__.
|
|
338
340
|
if not hasattr(embedding_vector, "__len__"):
|
|
339
341
|
continue
|
|
340
|
-
event_id = create_event_id(row_id,
|
|
342
|
+
event_id = create_event_id(row_id, inferences_id)
|
|
341
343
|
data[event_id] = embedding_vector
|
|
342
344
|
samples_collected += 1
|
|
343
345
|
if isinstance(
|
|
@@ -347,23 +349,23 @@ class EmbeddingDimension(Node):
|
|
|
347
349
|
retrievals.append(
|
|
348
350
|
(
|
|
349
351
|
event_id,
|
|
350
|
-
self.dimension.context_retrieval_ids(
|
|
351
|
-
self.dimension.context_retrieval_scores(
|
|
352
|
+
self.dimension.context_retrieval_ids(inferences).iloc[row_id],
|
|
353
|
+
self.dimension.context_retrieval_scores(inferences).iloc[row_id],
|
|
352
354
|
)
|
|
353
355
|
)
|
|
354
356
|
|
|
355
|
-
context_retrievals:
|
|
357
|
+
context_retrievals: list[Retrieval] = []
|
|
356
358
|
if isinstance(
|
|
357
359
|
self.dimension,
|
|
358
360
|
ms.RetrievalEmbeddingDimension,
|
|
359
361
|
) and (corpus := info.context.corpus):
|
|
360
|
-
|
|
361
|
-
for row_id, document_embedding_vector in enumerate(
|
|
362
|
+
corpus_inferences = corpus[PRIMARY]
|
|
363
|
+
for row_id, document_embedding_vector in enumerate(corpus_inferences[PROMPT]):
|
|
362
364
|
if not hasattr(document_embedding_vector, "__len__"):
|
|
363
365
|
continue
|
|
364
|
-
event_id = create_event_id(row_id,
|
|
366
|
+
event_id = create_event_id(row_id, AncillaryInferencesRole.corpus)
|
|
365
367
|
data[event_id] = document_embedding_vector
|
|
366
|
-
corpus_primary_key =
|
|
368
|
+
corpus_primary_key = corpus_inferences.primary_key
|
|
367
369
|
for event_id, retrieval_ids, retrieval_scores in retrievals:
|
|
368
370
|
if not isinstance(retrieval_ids, Iterable):
|
|
369
371
|
continue
|
|
@@ -385,7 +387,7 @@ class EmbeddingDimension(Node):
|
|
|
385
387
|
)
|
|
386
388
|
except KeyError:
|
|
387
389
|
continue
|
|
388
|
-
document_embedding_vector =
|
|
390
|
+
document_embedding_vector = corpus_inferences[PROMPT].iloc[document_row_id]
|
|
389
391
|
if not hasattr(document_embedding_vector, "__len__"):
|
|
390
392
|
continue
|
|
391
393
|
context_retrievals.append(
|
|
@@ -393,7 +395,7 @@ class EmbeddingDimension(Node):
|
|
|
393
395
|
query_id=event_id,
|
|
394
396
|
document_id=create_event_id(
|
|
395
397
|
document_row_id,
|
|
396
|
-
|
|
398
|
+
AncillaryInferencesRole.corpus,
|
|
397
399
|
),
|
|
398
400
|
relevance=document_score,
|
|
399
401
|
)
|
|
@@ -413,48 +415,53 @@ class EmbeddingDimension(Node):
|
|
|
413
415
|
),
|
|
414
416
|
).generate(data, n_components=n_components)
|
|
415
417
|
|
|
416
|
-
points:
|
|
418
|
+
points: dict[Union[InferencesRole, AncillaryInferencesRole], list[UMAPPoint]] = defaultdict(
|
|
419
|
+
list
|
|
420
|
+
)
|
|
417
421
|
for event_id, vector in vectors.items():
|
|
418
|
-
row_id,
|
|
419
|
-
if isinstance(
|
|
420
|
-
dataset = model[
|
|
422
|
+
row_id, inferences_role = unpack_event_id(event_id)
|
|
423
|
+
if isinstance(inferences_role, InferencesRole):
|
|
424
|
+
dataset = model[inferences_role.value]
|
|
421
425
|
embedding_metadata = EmbeddingMetadata(
|
|
422
|
-
prediction_id=dataset[PREDICTION_ID][row_id],
|
|
423
|
-
link_to_data=dataset[self.dimension.link_to_data][row_id],
|
|
424
|
-
raw_data=dataset[self.dimension.raw_data][row_id],
|
|
426
|
+
prediction_id=dataset[PREDICTION_ID].iloc[row_id],
|
|
427
|
+
link_to_data=dataset[self.dimension.link_to_data].iloc[row_id],
|
|
428
|
+
raw_data=dataset[self.dimension.raw_data].iloc[row_id],
|
|
425
429
|
)
|
|
426
430
|
elif (corpus := info.context.corpus) is not None:
|
|
427
431
|
dataset = corpus[PRIMARY]
|
|
428
432
|
dimension = cast(ms.EmbeddingDimension, corpus[PROMPT])
|
|
429
433
|
embedding_metadata = EmbeddingMetadata(
|
|
430
|
-
prediction_id=dataset[PREDICTION_ID][row_id],
|
|
431
|
-
link_to_data=dataset[dimension.link_to_data][row_id],
|
|
432
|
-
raw_data=dataset[dimension.raw_data][row_id],
|
|
434
|
+
prediction_id=dataset[PREDICTION_ID].iloc[row_id],
|
|
435
|
+
link_to_data=dataset[dimension.link_to_data].iloc[row_id],
|
|
436
|
+
raw_data=dataset[dimension.raw_data].iloc[row_id],
|
|
433
437
|
)
|
|
434
438
|
else:
|
|
435
439
|
continue
|
|
436
|
-
points[
|
|
440
|
+
points[inferences_role].append(
|
|
437
441
|
UMAPPoint(
|
|
438
|
-
id=GlobalID(
|
|
442
|
+
id=GlobalID(
|
|
443
|
+
type_name=f"{type(self).__name__}:{str(inferences_role)}",
|
|
444
|
+
node_id=str(row_id),
|
|
445
|
+
),
|
|
439
446
|
event_id=event_id,
|
|
440
447
|
coordinates=to_gql_coordinates(vector),
|
|
441
448
|
event_metadata=EventMetadata(
|
|
442
|
-
prediction_label=dataset[PREDICTION_LABEL][row_id],
|
|
443
|
-
prediction_score=dataset[PREDICTION_SCORE][row_id],
|
|
444
|
-
actual_label=dataset[ACTUAL_LABEL][row_id],
|
|
445
|
-
actual_score=dataset[ACTUAL_SCORE][row_id],
|
|
449
|
+
prediction_label=dataset[PREDICTION_LABEL].iloc[row_id],
|
|
450
|
+
prediction_score=dataset[PREDICTION_SCORE].iloc[row_id],
|
|
451
|
+
actual_label=dataset[ACTUAL_LABEL].iloc[row_id],
|
|
452
|
+
actual_score=dataset[ACTUAL_SCORE].iloc[row_id],
|
|
446
453
|
),
|
|
447
454
|
embedding_metadata=embedding_metadata,
|
|
448
455
|
)
|
|
449
456
|
)
|
|
450
457
|
|
|
451
458
|
return UMAPPoints(
|
|
452
|
-
data=points[
|
|
453
|
-
reference_data=points[
|
|
459
|
+
data=points[InferencesRole.primary],
|
|
460
|
+
reference_data=points[InferencesRole.reference],
|
|
454
461
|
clusters=to_gql_clusters(
|
|
455
462
|
clustered_events=clustered_events,
|
|
456
463
|
),
|
|
457
|
-
corpus_data=points[
|
|
464
|
+
corpus_data=points[AncillaryInferencesRole.corpus],
|
|
458
465
|
context_retrievals=context_retrievals,
|
|
459
466
|
)
|
|
460
467
|
|
|
@@ -470,7 +477,7 @@ def _row_indices(
|
|
|
470
477
|
return
|
|
471
478
|
shuffled_indices = np.arange(start, stop)
|
|
472
479
|
np.random.shuffle(shuffled_indices)
|
|
473
|
-
yield from shuffled_indices
|
|
480
|
+
yield from shuffled_indices # type: ignore[misc,unused-ignore]
|
|
474
481
|
|
|
475
482
|
|
|
476
483
|
def to_gql_embedding_dimension(
|