arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,27 +1,36 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from dataclasses import asdict, dataclass
|
|
3
5
|
from datetime import datetime
|
|
4
6
|
from typing import TYPE_CHECKING, Annotated, Optional, Union
|
|
5
7
|
|
|
8
|
+
import pandas as pd
|
|
6
9
|
import strawberry
|
|
10
|
+
from aioitertools.itertools import islice
|
|
7
11
|
from openinference.semconv.trace import SpanAttributes
|
|
8
|
-
from sqlalchemy import desc, select
|
|
9
|
-
from strawberry import ID, UNSET,
|
|
12
|
+
from sqlalchemy import desc, or_, select
|
|
13
|
+
from strawberry import ID, UNSET, lazy
|
|
10
14
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
11
15
|
from strawberry.types import Info
|
|
12
16
|
from typing_extensions import TypeAlias
|
|
13
17
|
|
|
14
18
|
from phoenix.db import models
|
|
15
19
|
from phoenix.server.api.context import Context
|
|
20
|
+
from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
|
|
16
21
|
from phoenix.server.api.input_types.TraceAnnotationSort import TraceAnnotationSort
|
|
22
|
+
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
|
|
23
|
+
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
17
24
|
from phoenix.server.api.types.pagination import (
|
|
18
|
-
|
|
25
|
+
Cursor,
|
|
19
26
|
CursorString,
|
|
20
|
-
|
|
27
|
+
connection_from_cursors_and_nodes,
|
|
21
28
|
)
|
|
22
29
|
from phoenix.server.api.types.SortDir import SortDir
|
|
23
30
|
from phoenix.server.api.types.Span import Span
|
|
24
|
-
from phoenix.server.api.types.
|
|
31
|
+
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
32
|
+
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
33
|
+
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
|
|
25
34
|
|
|
26
35
|
if TYPE_CHECKING:
|
|
27
36
|
from phoenix.server.api.types.Project import Project
|
|
@@ -33,11 +42,11 @@ TraceRowId: TypeAlias = int
|
|
|
33
42
|
|
|
34
43
|
@strawberry.type
|
|
35
44
|
class Trace(Node):
|
|
36
|
-
|
|
37
|
-
|
|
45
|
+
id: NodeID[TraceRowId]
|
|
46
|
+
db_record: strawberry.Private[Optional[models.Trace]] = None
|
|
38
47
|
|
|
39
48
|
def __post_init__(self) -> None:
|
|
40
|
-
if self.
|
|
49
|
+
if self.db_record and self.id != self.db_record.id:
|
|
41
50
|
raise ValueError("Trace ID mismatch")
|
|
42
51
|
|
|
43
52
|
@strawberry.field
|
|
@@ -45,11 +54,11 @@ class Trace(Node):
|
|
|
45
54
|
self,
|
|
46
55
|
info: Info[Context, None],
|
|
47
56
|
) -> ID:
|
|
48
|
-
if self.
|
|
49
|
-
trace_id = self.
|
|
57
|
+
if self.db_record:
|
|
58
|
+
trace_id = self.db_record.trace_id
|
|
50
59
|
else:
|
|
51
60
|
trace_id = await info.context.data_loaders.trace_fields.load(
|
|
52
|
-
(self.
|
|
61
|
+
(self.id, models.Trace.trace_id),
|
|
53
62
|
)
|
|
54
63
|
return ID(trace_id)
|
|
55
64
|
|
|
@@ -58,11 +67,11 @@ class Trace(Node):
|
|
|
58
67
|
self,
|
|
59
68
|
info: Info[Context, None],
|
|
60
69
|
) -> datetime:
|
|
61
|
-
if self.
|
|
62
|
-
start_time = self.
|
|
70
|
+
if self.db_record:
|
|
71
|
+
start_time = self.db_record.start_time
|
|
63
72
|
else:
|
|
64
73
|
start_time = await info.context.data_loaders.trace_fields.load(
|
|
65
|
-
(self.
|
|
74
|
+
(self.id, models.Trace.start_time),
|
|
66
75
|
)
|
|
67
76
|
return start_time
|
|
68
77
|
|
|
@@ -71,11 +80,11 @@ class Trace(Node):
|
|
|
71
80
|
self,
|
|
72
81
|
info: Info[Context, None],
|
|
73
82
|
) -> datetime:
|
|
74
|
-
if self.
|
|
75
|
-
end_time = self.
|
|
83
|
+
if self.db_record:
|
|
84
|
+
end_time = self.db_record.end_time
|
|
76
85
|
else:
|
|
77
86
|
end_time = await info.context.data_loaders.trace_fields.load(
|
|
78
|
-
(self.
|
|
87
|
+
(self.id, models.Trace.end_time),
|
|
79
88
|
)
|
|
80
89
|
return end_time
|
|
81
90
|
|
|
@@ -84,11 +93,11 @@ class Trace(Node):
|
|
|
84
93
|
self,
|
|
85
94
|
info: Info[Context, None],
|
|
86
95
|
) -> Optional[float]:
|
|
87
|
-
if self.
|
|
88
|
-
latency_ms = self.
|
|
96
|
+
if self.db_record:
|
|
97
|
+
latency_ms = self.db_record.latency_ms
|
|
89
98
|
else:
|
|
90
99
|
latency_ms = await info.context.data_loaders.trace_fields.load(
|
|
91
|
-
(self.
|
|
100
|
+
(self.id, models.Trace.latency_ms),
|
|
92
101
|
)
|
|
93
102
|
return latency_ms
|
|
94
103
|
|
|
@@ -97,26 +106,26 @@ class Trace(Node):
|
|
|
97
106
|
self,
|
|
98
107
|
info: Info[Context, None],
|
|
99
108
|
) -> Annotated["Project", strawberry.lazy(".Project")]:
|
|
100
|
-
if self.
|
|
101
|
-
project_rowid = self.
|
|
109
|
+
if self.db_record:
|
|
110
|
+
project_rowid = self.db_record.project_rowid
|
|
102
111
|
else:
|
|
103
112
|
project_rowid = await info.context.data_loaders.trace_fields.load(
|
|
104
|
-
(self.
|
|
113
|
+
(self.id, models.Trace.project_rowid),
|
|
105
114
|
)
|
|
106
115
|
from phoenix.server.api.types.Project import Project
|
|
107
116
|
|
|
108
|
-
return Project(
|
|
117
|
+
return Project(id=project_rowid)
|
|
109
118
|
|
|
110
119
|
@strawberry.field
|
|
111
120
|
async def project_id(
|
|
112
121
|
self,
|
|
113
122
|
info: Info[Context, None],
|
|
114
123
|
) -> GlobalID:
|
|
115
|
-
if self.
|
|
116
|
-
project_rowid = self.
|
|
124
|
+
if self.db_record:
|
|
125
|
+
project_rowid = self.db_record.project_rowid
|
|
117
126
|
else:
|
|
118
127
|
project_rowid = await info.context.data_loaders.trace_fields.load(
|
|
119
|
-
(self.
|
|
128
|
+
(self.id, models.Trace.project_rowid),
|
|
120
129
|
)
|
|
121
130
|
from phoenix.server.api.types.Project import Project
|
|
122
131
|
|
|
@@ -127,11 +136,11 @@ class Trace(Node):
|
|
|
127
136
|
self,
|
|
128
137
|
info: Info[Context, None],
|
|
129
138
|
) -> Optional[GlobalID]:
|
|
130
|
-
if self.
|
|
131
|
-
project_session_rowid = self.
|
|
139
|
+
if self.db_record:
|
|
140
|
+
project_session_rowid = self.db_record.project_session_rowid
|
|
132
141
|
else:
|
|
133
142
|
project_session_rowid = await info.context.data_loaders.trace_fields.load(
|
|
134
|
-
(self.
|
|
143
|
+
(self.id, models.Trace.project_session_rowid),
|
|
135
144
|
)
|
|
136
145
|
if project_session_rowid is None:
|
|
137
146
|
return None
|
|
@@ -144,39 +153,40 @@ class Trace(Node):
|
|
|
144
153
|
self,
|
|
145
154
|
info: Info[Context, None],
|
|
146
155
|
) -> Union[Annotated["ProjectSession", lazy(".ProjectSession")], None]:
|
|
147
|
-
if self.
|
|
148
|
-
project_session_rowid = self.
|
|
156
|
+
if self.db_record:
|
|
157
|
+
project_session_rowid = self.db_record.project_session_rowid
|
|
149
158
|
else:
|
|
150
159
|
project_session_rowid = await info.context.data_loaders.trace_fields.load(
|
|
151
|
-
(self.
|
|
160
|
+
(self.id, models.Trace.project_session_rowid),
|
|
152
161
|
)
|
|
153
162
|
if project_session_rowid is None:
|
|
154
163
|
return None
|
|
155
|
-
from phoenix.server.api.types.ProjectSession import to_gql_project_session
|
|
156
164
|
|
|
157
165
|
stmt = select(models.ProjectSession).filter_by(id=project_session_rowid)
|
|
158
166
|
async with info.context.db() as session:
|
|
159
167
|
project_session = await session.scalar(stmt)
|
|
160
168
|
if project_session is None:
|
|
161
169
|
return None
|
|
162
|
-
|
|
170
|
+
from .ProjectSession import ProjectSession
|
|
171
|
+
|
|
172
|
+
return ProjectSession(id=project_session.id, db_record=project_session)
|
|
163
173
|
|
|
164
174
|
@strawberry.field
|
|
165
175
|
async def root_span(
|
|
166
176
|
self,
|
|
167
177
|
info: Info[Context, None],
|
|
168
178
|
) -> Optional[Span]:
|
|
169
|
-
span_rowid = await info.context.data_loaders.trace_root_spans.load(self.
|
|
179
|
+
span_rowid = await info.context.data_loaders.trace_root_spans.load(self.id)
|
|
170
180
|
if span_rowid is None:
|
|
171
181
|
return None
|
|
172
|
-
return Span(
|
|
182
|
+
return Span(id=span_rowid)
|
|
173
183
|
|
|
174
184
|
@strawberry.field
|
|
175
185
|
async def num_spans(
|
|
176
186
|
self,
|
|
177
187
|
info: Info[Context, None],
|
|
178
188
|
) -> int:
|
|
179
|
-
return await info.context.data_loaders.num_spans_per_trace.load(self.
|
|
189
|
+
return await info.context.data_loaders.num_spans_per_trace.load(self.id)
|
|
180
190
|
|
|
181
191
|
@strawberry.field
|
|
182
192
|
async def spans(
|
|
@@ -186,26 +196,94 @@ class Trace(Node):
|
|
|
186
196
|
last: Optional[int] = UNSET,
|
|
187
197
|
after: Optional[CursorString] = UNSET,
|
|
188
198
|
before: Optional[CursorString] = UNSET,
|
|
199
|
+
root_spans_only: Optional[bool] = UNSET,
|
|
200
|
+
orphan_span_as_root_span: Optional[bool] = True,
|
|
189
201
|
) -> Connection[Span]:
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
stmt = (
|
|
202
|
+
# Validate pagination arguments
|
|
203
|
+
if isinstance(first, int) and first <= 0:
|
|
204
|
+
raise ValueError('Argument "first" must be a positive int')
|
|
205
|
+
|
|
206
|
+
# Build base query for spans in this trace
|
|
207
|
+
base_query = (
|
|
197
208
|
select(models.Span.id)
|
|
198
209
|
.join(models.Trace)
|
|
199
|
-
.where(models.Trace.id == self.
|
|
210
|
+
.where(models.Trace.id == self.id)
|
|
200
211
|
# Sort descending because the root span tends to show up later
|
|
201
212
|
# in the ingestion process.
|
|
202
213
|
.order_by(desc(models.Span.id))
|
|
203
|
-
.limit(first)
|
|
204
214
|
)
|
|
215
|
+
# Handle cursor pagination (forward pagination only)
|
|
216
|
+
if after is not UNSET and after is not None:
|
|
217
|
+
# Type narrowing: after is guaranteed to be str at this point
|
|
218
|
+
assert after is not None # Type narrowing for mypy
|
|
219
|
+
try:
|
|
220
|
+
cursor = Cursor.from_string(after)
|
|
221
|
+
except Exception as e:
|
|
222
|
+
raise ValueError(f"Invalid cursor format: {after}") from e
|
|
223
|
+
# For descending order, "after" means we want spans with smaller IDs
|
|
224
|
+
# (going forward in descending order)
|
|
225
|
+
base_query = base_query.where(models.Span.id < cursor.rowid)
|
|
226
|
+
# Note: backward pagination (last/before) is not yet implemented
|
|
227
|
+
# as it requires more complex handling with reversed ordering
|
|
228
|
+
if before is not UNSET or (last is not UNSET and last is not None):
|
|
229
|
+
raise ValueError("Backward pagination (last/before) is not yet supported")
|
|
230
|
+
|
|
231
|
+
# Build final query based on filtering requirements
|
|
232
|
+
if root_spans_only:
|
|
233
|
+
if orphan_span_as_root_span:
|
|
234
|
+
# A root span is either a span with no parent_id or an orphan span
|
|
235
|
+
# (a span whose parent_id references a span that doesn't exist in the current trace)
|
|
236
|
+
# We need parent_id to check for orphan spans, so add it to the query
|
|
237
|
+
# and create a CTE
|
|
238
|
+
candidate_spans = base_query.add_columns(models.Span.parent_id).cte(
|
|
239
|
+
"candidate_spans"
|
|
240
|
+
)
|
|
241
|
+
# Subquery to get all span_ids that exist in this trace
|
|
242
|
+
parent_spans_in_trace = (
|
|
243
|
+
select(models.Span.span_id)
|
|
244
|
+
.where(models.Span.trace_rowid == self.id)
|
|
245
|
+
.alias("parent_spans")
|
|
246
|
+
)
|
|
247
|
+
# Filter candidates to only root spans (NULL parent_id or orphan spans)
|
|
248
|
+
stmt = (
|
|
249
|
+
select(candidate_spans.c.id)
|
|
250
|
+
.where(
|
|
251
|
+
or_(
|
|
252
|
+
candidate_spans.c.parent_id.is_(None),
|
|
253
|
+
~select(1)
|
|
254
|
+
.where(candidate_spans.c.parent_id == parent_spans_in_trace.c.span_id)
|
|
255
|
+
.exists(),
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
.order_by(desc(candidate_spans.c.id))
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
# Only include explicit root spans (spans with parent_id = NULL)
|
|
262
|
+
stmt = base_query.where(models.Span.parent_id.is_(None))
|
|
263
|
+
else:
|
|
264
|
+
# Return all spans (no root span filtering)
|
|
265
|
+
stmt = base_query
|
|
266
|
+
|
|
267
|
+
# Over-fetch by one to determine whether there's a next page
|
|
268
|
+
limit = first if isinstance(first, int) else 50
|
|
269
|
+
stmt = stmt.limit(limit + 1)
|
|
270
|
+
|
|
271
|
+
cursors_and_nodes = []
|
|
205
272
|
async with info.context.db() as session:
|
|
206
273
|
span_rowids = await session.stream_scalars(stmt)
|
|
207
|
-
|
|
208
|
-
|
|
274
|
+
async for span_rowid in islice(span_rowids, limit):
|
|
275
|
+
cursor = Cursor(rowid=span_rowid)
|
|
276
|
+
cursors_and_nodes.append((cursor, Span(id=span_rowid)))
|
|
277
|
+
has_next_page = True
|
|
278
|
+
try:
|
|
279
|
+
await span_rowids.__anext__()
|
|
280
|
+
except StopAsyncIteration:
|
|
281
|
+
has_next_page = False
|
|
282
|
+
return connection_from_cursors_and_nodes(
|
|
283
|
+
cursors_and_nodes,
|
|
284
|
+
has_previous_page=False,
|
|
285
|
+
has_next_page=has_next_page,
|
|
286
|
+
)
|
|
209
287
|
|
|
210
288
|
@strawberry.field(description="Annotations associated with the trace.") # type: ignore
|
|
211
289
|
async def trace_annotations(
|
|
@@ -214,7 +292,7 @@ class Trace(Node):
|
|
|
214
292
|
sort: Optional[TraceAnnotationSort] = None,
|
|
215
293
|
) -> list[TraceAnnotation]:
|
|
216
294
|
async with info.context.db() as session:
|
|
217
|
-
stmt = select(models.TraceAnnotation).filter_by(trace_rowid=self.
|
|
295
|
+
stmt = select(models.TraceAnnotation).filter_by(trace_rowid=self.id)
|
|
218
296
|
if sort:
|
|
219
297
|
sort_col = getattr(models.TraceAnnotation, sort.col.value)
|
|
220
298
|
if sort.dir is SortDir.desc:
|
|
@@ -224,7 +302,101 @@ class Trace(Node):
|
|
|
224
302
|
else:
|
|
225
303
|
stmt = stmt.order_by(models.TraceAnnotation.created_at.desc())
|
|
226
304
|
annotations = await session.scalars(stmt)
|
|
227
|
-
return [
|
|
305
|
+
return [
|
|
306
|
+
TraceAnnotation(id=annotation.id, db_record=annotation) for annotation in annotations
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
@strawberry.field(description="Summarizes each annotation (by name) associated with the trace") # type: ignore
|
|
310
|
+
async def trace_annotation_summaries(
|
|
311
|
+
self,
|
|
312
|
+
info: Info[Context, None],
|
|
313
|
+
filter: Optional[AnnotationFilter] = None,
|
|
314
|
+
) -> list[AnnotationSummary]:
|
|
315
|
+
"""
|
|
316
|
+
Retrieves and summarizes annotations associated with this span.
|
|
317
|
+
|
|
318
|
+
This method aggregates annotation data by name and label, calculating metrics
|
|
319
|
+
such as count of occurrences and sum of scores. The results are organized
|
|
320
|
+
into a structured format that can be easily converted to a DataFrame.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
info: GraphQL context information
|
|
324
|
+
filter: Optional filter to apply to annotations before processing
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
A list of AnnotationSummary objects, each containing:
|
|
328
|
+
- name: The name of the annotation
|
|
329
|
+
- data: A list of dictionaries with label statistics
|
|
330
|
+
"""
|
|
331
|
+
# Load all annotations for this span from the data loader
|
|
332
|
+
annotations = await info.context.data_loaders.trace_annotations_by_trace.load(self.id)
|
|
333
|
+
|
|
334
|
+
# Apply filter if provided to narrow down the annotations
|
|
335
|
+
if filter:
|
|
336
|
+
annotations = [
|
|
337
|
+
annotation for annotation in annotations if satisfies_filter(annotation, filter)
|
|
338
|
+
]
|
|
339
|
+
|
|
340
|
+
@dataclass
|
|
341
|
+
class Metrics:
|
|
342
|
+
record_count: int = 0
|
|
343
|
+
label_count: int = 0
|
|
344
|
+
score_sum: float = 0
|
|
345
|
+
score_count: int = 0
|
|
346
|
+
|
|
347
|
+
summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
|
|
348
|
+
lambda: defaultdict(Metrics)
|
|
349
|
+
)
|
|
350
|
+
for annotation in annotations:
|
|
351
|
+
metrics = summaries[annotation.name][annotation.label]
|
|
352
|
+
metrics.record_count += 1
|
|
353
|
+
metrics.label_count += int(annotation.label is not None)
|
|
354
|
+
metrics.score_sum += annotation.score or 0
|
|
355
|
+
metrics.score_count += int(annotation.score is not None)
|
|
356
|
+
|
|
357
|
+
result: list[AnnotationSummary] = []
|
|
358
|
+
for name, label_metrics in summaries.items():
|
|
359
|
+
rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
|
|
360
|
+
result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
|
|
361
|
+
return result
|
|
362
|
+
|
|
363
|
+
@strawberry.field
|
|
364
|
+
async def cost_summary(
|
|
365
|
+
self,
|
|
366
|
+
info: Info[Context, None],
|
|
367
|
+
) -> SpanCostSummary:
|
|
368
|
+
loader = info.context.data_loaders.span_cost_summary_by_trace
|
|
369
|
+
summary = await loader.load(self.id)
|
|
370
|
+
return SpanCostSummary(
|
|
371
|
+
prompt=CostBreakdown(
|
|
372
|
+
tokens=summary.prompt.tokens,
|
|
373
|
+
cost=summary.prompt.cost,
|
|
374
|
+
),
|
|
375
|
+
completion=CostBreakdown(
|
|
376
|
+
tokens=summary.completion.tokens,
|
|
377
|
+
cost=summary.completion.cost,
|
|
378
|
+
),
|
|
379
|
+
total=CostBreakdown(
|
|
380
|
+
tokens=summary.total.tokens,
|
|
381
|
+
cost=summary.total.cost,
|
|
382
|
+
),
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
@strawberry.field
|
|
386
|
+
async def cost_detail_summary_entries(
|
|
387
|
+
self,
|
|
388
|
+
info: Info[Context, None],
|
|
389
|
+
) -> list[SpanCostDetailSummaryEntry]:
|
|
390
|
+
loader = info.context.data_loaders.span_cost_detail_summary_entries_by_trace
|
|
391
|
+
entries = await loader.load(self.id)
|
|
392
|
+
return [
|
|
393
|
+
SpanCostDetailSummaryEntry(
|
|
394
|
+
token_type=entry.token_type,
|
|
395
|
+
is_prompt=entry.is_prompt,
|
|
396
|
+
value=CostBreakdown(tokens=entry.value.tokens, cost=entry.value.cost),
|
|
397
|
+
)
|
|
398
|
+
for entry in entries
|
|
399
|
+
]
|
|
228
400
|
|
|
229
401
|
|
|
230
402
|
INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
from
|
|
1
|
+
from math import isfinite
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
2
3
|
|
|
3
4
|
import strawberry
|
|
4
|
-
from strawberry import
|
|
5
|
-
from strawberry.relay import GlobalID, Node, NodeID
|
|
5
|
+
from strawberry.relay import Node, NodeID
|
|
6
6
|
from strawberry.scalars import JSON
|
|
7
7
|
from strawberry.types import Info
|
|
8
8
|
|
|
@@ -11,58 +11,157 @@ from phoenix.server.api.context import Context
|
|
|
11
11
|
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
12
12
|
|
|
13
13
|
from .AnnotationSource import AnnotationSource
|
|
14
|
-
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .Trace import Trace
|
|
17
|
+
from .User import User
|
|
15
18
|
|
|
16
19
|
|
|
17
20
|
@strawberry.type
|
|
18
21
|
class TraceAnnotation(Node):
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
22
|
+
id: NodeID[int]
|
|
23
|
+
db_record: strawberry.Private[Optional[models.TraceAnnotation]] = None
|
|
24
|
+
|
|
25
|
+
def __post_init__(self) -> None:
|
|
26
|
+
if self.db_record and self.id != self.db_record.id:
|
|
27
|
+
raise ValueError("TraceAnnotation ID mismatch")
|
|
28
|
+
|
|
29
|
+
@strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
|
|
30
|
+
async def name(
|
|
31
|
+
self,
|
|
32
|
+
info: Info[Context, None],
|
|
33
|
+
) -> str:
|
|
34
|
+
if self.db_record:
|
|
35
|
+
val = self.db_record.name
|
|
36
|
+
else:
|
|
37
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
38
|
+
(self.id, models.TraceAnnotation.name),
|
|
39
|
+
)
|
|
40
|
+
return val
|
|
41
|
+
|
|
42
|
+
@strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
|
|
43
|
+
async def annotator_kind(
|
|
44
|
+
self,
|
|
45
|
+
info: Info[Context, None],
|
|
46
|
+
) -> AnnotatorKind:
|
|
47
|
+
if self.db_record:
|
|
48
|
+
val = self.db_record.annotator_kind
|
|
49
|
+
else:
|
|
50
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
51
|
+
(self.id, models.TraceAnnotation.annotator_kind),
|
|
52
|
+
)
|
|
53
|
+
return AnnotatorKind(val)
|
|
54
|
+
|
|
55
|
+
@strawberry.field(
|
|
56
|
+
description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
|
|
57
|
+
) # type: ignore
|
|
58
|
+
async def label(
|
|
59
|
+
self,
|
|
60
|
+
info: Info[Context, None],
|
|
61
|
+
) -> Optional[str]:
|
|
62
|
+
if self.db_record:
|
|
63
|
+
val = self.db_record.label
|
|
64
|
+
else:
|
|
65
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
66
|
+
(self.id, models.TraceAnnotation.label),
|
|
67
|
+
)
|
|
68
|
+
return val
|
|
69
|
+
|
|
70
|
+
@strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
|
|
71
|
+
async def score(
|
|
72
|
+
self,
|
|
73
|
+
info: Info[Context, None],
|
|
74
|
+
) -> Optional[float]:
|
|
75
|
+
if self.db_record:
|
|
76
|
+
val = self.db_record.score
|
|
77
|
+
else:
|
|
78
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
79
|
+
(self.id, models.TraceAnnotation.score),
|
|
80
|
+
)
|
|
81
|
+
return val if val is not None and isfinite(val) else None
|
|
82
|
+
|
|
83
|
+
@strawberry.field(
|
|
84
|
+
description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
|
|
85
|
+
) # type: ignore
|
|
86
|
+
async def explanation(
|
|
87
|
+
self,
|
|
88
|
+
info: Info[Context, None],
|
|
89
|
+
) -> Optional[str]:
|
|
90
|
+
if self.db_record:
|
|
91
|
+
val = self.db_record.explanation
|
|
92
|
+
else:
|
|
93
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
94
|
+
(self.id, models.TraceAnnotation.explanation),
|
|
95
|
+
)
|
|
96
|
+
return val
|
|
97
|
+
|
|
98
|
+
@strawberry.field(description="Metadata about the annotation.") # type: ignore
|
|
99
|
+
async def metadata(
|
|
100
|
+
self,
|
|
101
|
+
info: Info[Context, None],
|
|
102
|
+
) -> JSON:
|
|
103
|
+
if self.db_record:
|
|
104
|
+
val = self.db_record.metadata_
|
|
105
|
+
else:
|
|
106
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
107
|
+
(self.id, models.TraceAnnotation.metadata_),
|
|
108
|
+
)
|
|
109
|
+
return val
|
|
110
|
+
|
|
111
|
+
@strawberry.field(description="The identifier of the annotation.") # type: ignore
|
|
112
|
+
async def identifier(
|
|
113
|
+
self,
|
|
114
|
+
info: Info[Context, None],
|
|
115
|
+
) -> str:
|
|
116
|
+
if self.db_record:
|
|
117
|
+
val = self.db_record.identifier
|
|
118
|
+
else:
|
|
119
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
120
|
+
(self.id, models.TraceAnnotation.identifier),
|
|
121
|
+
)
|
|
122
|
+
return val
|
|
123
|
+
|
|
124
|
+
@strawberry.field(description="The source of the annotation.") # type: ignore
|
|
125
|
+
async def source(
|
|
126
|
+
self,
|
|
127
|
+
info: Info[Context, None],
|
|
128
|
+
) -> AnnotationSource:
|
|
129
|
+
if self.db_record:
|
|
130
|
+
val = self.db_record.source
|
|
131
|
+
else:
|
|
132
|
+
val = await info.context.data_loaders.trace_annotation_fields.load(
|
|
133
|
+
(self.id, models.TraceAnnotation.source),
|
|
134
|
+
)
|
|
135
|
+
return AnnotationSource(val)
|
|
136
|
+
|
|
137
|
+
@strawberry.field(description="The trace associated with the annotation.") # type: ignore
|
|
138
|
+
async def trace(
|
|
139
|
+
self,
|
|
140
|
+
info: Info[Context, None],
|
|
141
|
+
) -> Annotated["Trace", strawberry.lazy(".Trace")]:
|
|
142
|
+
if self.db_record:
|
|
143
|
+
trace_rowid = self.db_record.trace_rowid
|
|
144
|
+
else:
|
|
145
|
+
trace_rowid = await info.context.data_loaders.trace_annotation_fields.load(
|
|
146
|
+
(self.id, models.TraceAnnotation.trace_rowid),
|
|
147
|
+
)
|
|
148
|
+
from .Trace import Trace
|
|
149
|
+
|
|
150
|
+
return Trace(id=trace_rowid)
|
|
151
|
+
|
|
152
|
+
@strawberry.field(description="The user that produced the annotation.") # type: ignore
|
|
38
153
|
async def user(
|
|
39
154
|
self,
|
|
40
155
|
info: Info[Context, None],
|
|
41
|
-
) -> Optional[User]:
|
|
42
|
-
if self.
|
|
156
|
+
) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
157
|
+
if self.db_record:
|
|
158
|
+
user_id = self.db_record.user_id
|
|
159
|
+
else:
|
|
160
|
+
user_id = await info.context.data_loaders.trace_annotation_fields.load(
|
|
161
|
+
(self.id, models.TraceAnnotation.user_id),
|
|
162
|
+
)
|
|
163
|
+
if user_id is None:
|
|
43
164
|
return None
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
return to_gql_user(user)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def to_gql_trace_annotation(
|
|
51
|
-
annotation: models.TraceAnnotation,
|
|
52
|
-
) -> TraceAnnotation:
|
|
53
|
-
"""
|
|
54
|
-
Converts an ORM trace annotation to a GraphQL TraceAnnotation.
|
|
55
|
-
"""
|
|
56
|
-
return TraceAnnotation(
|
|
57
|
-
id_attr=annotation.id,
|
|
58
|
-
user_id=annotation.user_id,
|
|
59
|
-
trace_rowid=annotation.trace_rowid,
|
|
60
|
-
name=annotation.name,
|
|
61
|
-
annotator_kind=AnnotatorKind(annotation.annotator_kind),
|
|
62
|
-
label=annotation.label,
|
|
63
|
-
score=annotation.score,
|
|
64
|
-
explanation=annotation.explanation,
|
|
65
|
-
metadata=annotation.metadata_,
|
|
66
|
-
identifier=annotation.identifier,
|
|
67
|
-
source=AnnotationSource(annotation.source),
|
|
68
|
-
)
|
|
165
|
+
from .User import User
|
|
166
|
+
|
|
167
|
+
return User(id=user_id)
|