arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,14 +1,19 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from dataclasses import asdict, dataclass
|
|
1
3
|
from datetime import datetime
|
|
2
|
-
from typing import TYPE_CHECKING, Annotated,
|
|
4
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
5
|
|
|
6
|
+
import pandas as pd
|
|
4
7
|
import strawberry
|
|
5
8
|
from openinference.semconv.trace import SpanAttributes
|
|
6
9
|
from sqlalchemy import select
|
|
7
|
-
from strawberry import UNSET, Info,
|
|
8
|
-
from strawberry.relay import Connection,
|
|
10
|
+
from strawberry import UNSET, Info, lazy
|
|
11
|
+
from strawberry.relay import Connection, Node, NodeID
|
|
9
12
|
|
|
10
13
|
from phoenix.db import models
|
|
11
14
|
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
|
|
16
|
+
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
|
|
12
17
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
13
18
|
from phoenix.server.api.types.MimeType import MimeType
|
|
14
19
|
from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
|
|
@@ -18,44 +23,94 @@ from phoenix.server.api.types.SpanIOValue import SpanIOValue
|
|
|
18
23
|
from phoenix.server.api.types.TokenUsage import TokenUsage
|
|
19
24
|
|
|
20
25
|
if TYPE_CHECKING:
|
|
26
|
+
from phoenix.server.api.types.Project import Project
|
|
27
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import ProjectSessionAnnotation
|
|
21
28
|
from phoenix.server.api.types.Trace import Trace
|
|
22
29
|
|
|
23
30
|
|
|
24
31
|
@strawberry.type
|
|
25
32
|
class ProjectSession(Node):
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
33
|
+
id: NodeID[int]
|
|
34
|
+
db_record: strawberry.Private[Optional[models.ProjectSession]] = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self) -> None:
|
|
37
|
+
if self.db_record and self.id != self.db_record.id:
|
|
38
|
+
raise ValueError("ProjectSession ID mismatch")
|
|
32
39
|
|
|
33
40
|
@strawberry.field
|
|
34
|
-
async def
|
|
41
|
+
async def session_id(
|
|
42
|
+
self,
|
|
43
|
+
info: Info[Context, None],
|
|
44
|
+
) -> str:
|
|
45
|
+
if self.db_record:
|
|
46
|
+
val = self.db_record.session_id
|
|
47
|
+
else:
|
|
48
|
+
val = await info.context.data_loaders.project_session_fields.load(
|
|
49
|
+
(self.id, models.ProjectSession.session_id),
|
|
50
|
+
)
|
|
51
|
+
return val
|
|
52
|
+
|
|
53
|
+
@strawberry.field
|
|
54
|
+
async def start_time(
|
|
55
|
+
self,
|
|
56
|
+
info: Info[Context, None],
|
|
57
|
+
) -> datetime:
|
|
58
|
+
if self.db_record:
|
|
59
|
+
val = self.db_record.start_time
|
|
60
|
+
else:
|
|
61
|
+
val = await info.context.data_loaders.project_session_fields.load(
|
|
62
|
+
(self.id, models.ProjectSession.start_time),
|
|
63
|
+
)
|
|
64
|
+
return val
|
|
65
|
+
|
|
66
|
+
@strawberry.field
|
|
67
|
+
async def end_time(
|
|
68
|
+
self,
|
|
69
|
+
info: Info[Context, None],
|
|
70
|
+
) -> datetime:
|
|
71
|
+
if self.db_record:
|
|
72
|
+
val = self.db_record.end_time
|
|
73
|
+
else:
|
|
74
|
+
val = await info.context.data_loaders.project_session_fields.load(
|
|
75
|
+
(self.id, models.ProjectSession.end_time),
|
|
76
|
+
)
|
|
77
|
+
return val
|
|
78
|
+
|
|
79
|
+
@strawberry.field
|
|
80
|
+
async def project(
|
|
81
|
+
self,
|
|
82
|
+
info: Info[Context, None],
|
|
83
|
+
) -> Annotated["Project", lazy(".Project")]:
|
|
35
84
|
from phoenix.server.api.types.Project import Project
|
|
36
85
|
|
|
37
|
-
|
|
86
|
+
if self.db_record:
|
|
87
|
+
project_rowid = self.db_record.project_id
|
|
88
|
+
else:
|
|
89
|
+
project_rowid = await info.context.data_loaders.project_session_fields.load(
|
|
90
|
+
(self.id, models.ProjectSession.project_id),
|
|
91
|
+
)
|
|
92
|
+
return Project(id=project_rowid)
|
|
38
93
|
|
|
39
94
|
@strawberry.field
|
|
40
95
|
async def num_traces(
|
|
41
96
|
self,
|
|
42
97
|
info: Info[Context, None],
|
|
43
98
|
) -> int:
|
|
44
|
-
return await info.context.data_loaders.session_num_traces.load(self.
|
|
99
|
+
return await info.context.data_loaders.session_num_traces.load(self.id)
|
|
45
100
|
|
|
46
101
|
@strawberry.field
|
|
47
102
|
async def num_traces_with_error(
|
|
48
103
|
self,
|
|
49
104
|
info: Info[Context, None],
|
|
50
105
|
) -> int:
|
|
51
|
-
return await info.context.data_loaders.session_num_traces_with_error.load(self.
|
|
106
|
+
return await info.context.data_loaders.session_num_traces_with_error.load(self.id)
|
|
52
107
|
|
|
53
108
|
@strawberry.field
|
|
54
109
|
async def first_input(
|
|
55
110
|
self,
|
|
56
111
|
info: Info[Context, None],
|
|
57
112
|
) -> Optional[SpanIOValue]:
|
|
58
|
-
record = await info.context.data_loaders.session_first_inputs.load(self.
|
|
113
|
+
record = await info.context.data_loaders.session_first_inputs.load(self.id)
|
|
59
114
|
if record is None:
|
|
60
115
|
return None
|
|
61
116
|
return SpanIOValue(
|
|
@@ -68,7 +123,7 @@ class ProjectSession(Node):
|
|
|
68
123
|
self,
|
|
69
124
|
info: Info[Context, None],
|
|
70
125
|
) -> Optional[SpanIOValue]:
|
|
71
|
-
record = await info.context.data_loaders.session_last_outputs.load(self.
|
|
126
|
+
record = await info.context.data_loaders.session_last_outputs.load(self.id)
|
|
72
127
|
if record is None:
|
|
73
128
|
return None
|
|
74
129
|
return SpanIOValue(
|
|
@@ -81,7 +136,7 @@ class ProjectSession(Node):
|
|
|
81
136
|
self,
|
|
82
137
|
info: Info[Context, None],
|
|
83
138
|
) -> TokenUsage:
|
|
84
|
-
usage = await info.context.data_loaders.session_token_usages.load(self.
|
|
139
|
+
usage = await info.context.data_loaders.session_token_usages.load(self.id)
|
|
85
140
|
return TokenUsage(
|
|
86
141
|
prompt=usage.prompt,
|
|
87
142
|
completion=usage.completion,
|
|
@@ -106,12 +161,12 @@ class ProjectSession(Node):
|
|
|
106
161
|
)
|
|
107
162
|
stmt = (
|
|
108
163
|
select(models.Trace)
|
|
109
|
-
.filter_by(project_session_rowid=self.
|
|
164
|
+
.filter_by(project_session_rowid=self.id)
|
|
110
165
|
.order_by(models.Trace.start_time)
|
|
111
166
|
)
|
|
112
167
|
async with info.context.db() as session:
|
|
113
168
|
traces = await session.stream_scalars(stmt)
|
|
114
|
-
data = [Trace(
|
|
169
|
+
data = [Trace(id=trace.id, db_record=trace) async for trace in traces]
|
|
115
170
|
return connection_from_list(data=data, args=args)
|
|
116
171
|
|
|
117
172
|
@strawberry.field
|
|
@@ -121,7 +176,7 @@ class ProjectSession(Node):
|
|
|
121
176
|
probability: float,
|
|
122
177
|
) -> Optional[float]:
|
|
123
178
|
return await info.context.data_loaders.session_trace_latency_ms_quantile.load(
|
|
124
|
-
(self.
|
|
179
|
+
(self.id, probability)
|
|
125
180
|
)
|
|
126
181
|
|
|
127
182
|
@strawberry.field
|
|
@@ -130,7 +185,7 @@ class ProjectSession(Node):
|
|
|
130
185
|
info: Info[Context, None],
|
|
131
186
|
) -> SpanCostSummary:
|
|
132
187
|
loader = info.context.data_loaders.span_cost_summary_by_project_session
|
|
133
|
-
summary = await loader.load(self.
|
|
188
|
+
summary = await loader.load(self.id)
|
|
134
189
|
return SpanCostSummary(
|
|
135
190
|
prompt=CostBreakdown(
|
|
136
191
|
tokens=summary.prompt.tokens,
|
|
@@ -152,7 +207,7 @@ class ProjectSession(Node):
|
|
|
152
207
|
info: Info[Context, None],
|
|
153
208
|
) -> list[SpanCostDetailSummaryEntry]:
|
|
154
209
|
loader = info.context.data_loaders.span_cost_detail_summary_entries_by_project_session
|
|
155
|
-
summary = await loader.load(self.
|
|
210
|
+
summary = await loader.load(self.id)
|
|
156
211
|
return [
|
|
157
212
|
SpanCostDetailSummaryEntry(
|
|
158
213
|
token_type=entry.token_type,
|
|
@@ -165,15 +220,77 @@ class ProjectSession(Node):
|
|
|
165
220
|
for entry in summary
|
|
166
221
|
]
|
|
167
222
|
|
|
223
|
+
@strawberry.field
|
|
224
|
+
async def session_annotations(
|
|
225
|
+
self,
|
|
226
|
+
info: Info[Context, None],
|
|
227
|
+
) -> list[Annotated["ProjectSessionAnnotation", lazy(".ProjectSessionAnnotation")]]:
|
|
228
|
+
"""Get all annotations for this session."""
|
|
229
|
+
from .ProjectSessionAnnotation import ProjectSessionAnnotation
|
|
230
|
+
|
|
231
|
+
stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.id)
|
|
232
|
+
async with info.context.db() as session:
|
|
233
|
+
annotations = await session.stream_scalars(stmt)
|
|
234
|
+
return [
|
|
235
|
+
ProjectSessionAnnotation(id=annotation.id, db_record=annotation)
|
|
236
|
+
async for annotation in annotations
|
|
237
|
+
]
|
|
238
|
+
|
|
239
|
+
@strawberry.field(
|
|
240
|
+
description="Summarizes each annotation (by name) associated with the session"
|
|
241
|
+
) # type: ignore
|
|
242
|
+
async def session_annotation_summaries(
|
|
243
|
+
self,
|
|
244
|
+
info: Info[Context, None],
|
|
245
|
+
filter: Optional[AnnotationFilter] = None,
|
|
246
|
+
) -> list[AnnotationSummary]:
|
|
247
|
+
"""
|
|
248
|
+
Retrieves and summarizes annotations associated with this span.
|
|
249
|
+
|
|
250
|
+
This method aggregates annotation data by name and label, calculating metrics
|
|
251
|
+
such as count of occurrences and sum of scores. The results are organized
|
|
252
|
+
into a structured format that can be easily converted to a DataFrame.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
info: GraphQL context information
|
|
256
|
+
filter: Optional filter to apply to annotations before processing
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
A list of AnnotationSummary objects, each containing:
|
|
260
|
+
- name: The name of the annotation
|
|
261
|
+
- data: A list of dictionaries with label statistics
|
|
262
|
+
"""
|
|
263
|
+
# Load all annotations for this span from the data loader
|
|
264
|
+
annotations = await info.context.data_loaders.session_annotations_by_session.load(self.id)
|
|
265
|
+
|
|
266
|
+
# Apply filter if provided to narrow down the annotations
|
|
267
|
+
if filter:
|
|
268
|
+
annotations = [
|
|
269
|
+
annotation for annotation in annotations if satisfies_filter(annotation, filter)
|
|
270
|
+
]
|
|
271
|
+
|
|
272
|
+
@dataclass
|
|
273
|
+
class Metrics:
|
|
274
|
+
record_count: int = 0
|
|
275
|
+
label_count: int = 0
|
|
276
|
+
score_sum: float = 0
|
|
277
|
+
score_count: int = 0
|
|
278
|
+
|
|
279
|
+
summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
|
|
280
|
+
lambda: defaultdict(Metrics)
|
|
281
|
+
)
|
|
282
|
+
for annotation in annotations:
|
|
283
|
+
metrics = summaries[annotation.name][annotation.label]
|
|
284
|
+
metrics.record_count += 1
|
|
285
|
+
metrics.label_count += int(annotation.label is not None)
|
|
286
|
+
metrics.score_sum += annotation.score or 0
|
|
287
|
+
metrics.score_count += int(annotation.score is not None)
|
|
168
288
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
project_rowid=project_session.project_id,
|
|
175
|
-
end_time=project_session.end_time,
|
|
176
|
-
)
|
|
289
|
+
result: list[AnnotationSummary] = []
|
|
290
|
+
for name, label_metrics in summaries.items():
|
|
291
|
+
rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
|
|
292
|
+
result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
|
|
293
|
+
return result
|
|
177
294
|
|
|
178
295
|
|
|
179
296
|
INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
from math import isfinite
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry.relay import GlobalID, Node, NodeID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
from strawberry.types import Info
|
|
8
|
+
|
|
9
|
+
from phoenix.db import models
|
|
10
|
+
from phoenix.server.api.context import Context
|
|
11
|
+
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
12
|
+
|
|
13
|
+
from .Annotation import Annotation
|
|
14
|
+
from .AnnotationSource import AnnotationSource
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .ProjectSession import ProjectSession
|
|
18
|
+
from .User import User
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@strawberry.type
|
|
22
|
+
class ProjectSessionAnnotation(Node, Annotation):
|
|
23
|
+
id: NodeID[int]
|
|
24
|
+
db_record: strawberry.Private[Optional[models.ProjectSessionAnnotation]] = None
|
|
25
|
+
|
|
26
|
+
def __post_init__(self) -> None:
|
|
27
|
+
if self.db_record and self.id != self.db_record.id:
|
|
28
|
+
raise ValueError("ProjectSessionAnnotation ID mismatch")
|
|
29
|
+
|
|
30
|
+
@strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
|
|
31
|
+
async def name(
|
|
32
|
+
self,
|
|
33
|
+
info: Info[Context, None],
|
|
34
|
+
) -> str:
|
|
35
|
+
if self.db_record:
|
|
36
|
+
val = self.db_record.name
|
|
37
|
+
else:
|
|
38
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
39
|
+
(self.id, models.ProjectSessionAnnotation.name),
|
|
40
|
+
)
|
|
41
|
+
return val
|
|
42
|
+
|
|
43
|
+
@strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
|
|
44
|
+
async def annotator_kind(
|
|
45
|
+
self,
|
|
46
|
+
info: Info[Context, None],
|
|
47
|
+
) -> AnnotatorKind:
|
|
48
|
+
if self.db_record:
|
|
49
|
+
val = self.db_record.annotator_kind
|
|
50
|
+
else:
|
|
51
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
52
|
+
(self.id, models.ProjectSessionAnnotation.annotator_kind),
|
|
53
|
+
)
|
|
54
|
+
return AnnotatorKind(val)
|
|
55
|
+
|
|
56
|
+
@strawberry.field(
|
|
57
|
+
description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
|
|
58
|
+
) # type: ignore
|
|
59
|
+
async def label(
|
|
60
|
+
self,
|
|
61
|
+
info: Info[Context, None],
|
|
62
|
+
) -> Optional[str]:
|
|
63
|
+
if self.db_record:
|
|
64
|
+
val = self.db_record.label
|
|
65
|
+
else:
|
|
66
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
67
|
+
(self.id, models.ProjectSessionAnnotation.label),
|
|
68
|
+
)
|
|
69
|
+
return val
|
|
70
|
+
|
|
71
|
+
@strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
|
|
72
|
+
async def score(
|
|
73
|
+
self,
|
|
74
|
+
info: Info[Context, None],
|
|
75
|
+
) -> Optional[float]:
|
|
76
|
+
if self.db_record:
|
|
77
|
+
val = self.db_record.score
|
|
78
|
+
else:
|
|
79
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
80
|
+
(self.id, models.ProjectSessionAnnotation.score),
|
|
81
|
+
)
|
|
82
|
+
return val if val is not None and isfinite(val) else None
|
|
83
|
+
|
|
84
|
+
@strawberry.field(
|
|
85
|
+
description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
|
|
86
|
+
) # type: ignore
|
|
87
|
+
async def explanation(
|
|
88
|
+
self,
|
|
89
|
+
info: Info[Context, None],
|
|
90
|
+
) -> Optional[str]:
|
|
91
|
+
if self.db_record:
|
|
92
|
+
val = self.db_record.explanation
|
|
93
|
+
else:
|
|
94
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
95
|
+
(self.id, models.ProjectSessionAnnotation.explanation),
|
|
96
|
+
)
|
|
97
|
+
return val
|
|
98
|
+
|
|
99
|
+
@strawberry.field(description="Metadata about the annotation.") # type: ignore
|
|
100
|
+
async def metadata(
|
|
101
|
+
self,
|
|
102
|
+
info: Info[Context, None],
|
|
103
|
+
) -> JSON:
|
|
104
|
+
if self.db_record:
|
|
105
|
+
val = self.db_record.metadata_
|
|
106
|
+
else:
|
|
107
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
108
|
+
(self.id, models.ProjectSessionAnnotation.metadata_),
|
|
109
|
+
)
|
|
110
|
+
return val
|
|
111
|
+
|
|
112
|
+
@strawberry.field(description="The identifier of the annotation.") # type: ignore
|
|
113
|
+
async def identifier(
|
|
114
|
+
self,
|
|
115
|
+
info: Info[Context, None],
|
|
116
|
+
) -> str:
|
|
117
|
+
if self.db_record:
|
|
118
|
+
val = self.db_record.identifier
|
|
119
|
+
else:
|
|
120
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
121
|
+
(self.id, models.ProjectSessionAnnotation.identifier),
|
|
122
|
+
)
|
|
123
|
+
return val
|
|
124
|
+
|
|
125
|
+
@strawberry.field(description="The source of the annotation.") # type: ignore
|
|
126
|
+
async def source(
|
|
127
|
+
self,
|
|
128
|
+
info: Info[Context, None],
|
|
129
|
+
) -> AnnotationSource:
|
|
130
|
+
if self.db_record:
|
|
131
|
+
val = self.db_record.source
|
|
132
|
+
else:
|
|
133
|
+
val = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
134
|
+
(self.id, models.ProjectSessionAnnotation.source),
|
|
135
|
+
)
|
|
136
|
+
return AnnotationSource(val)
|
|
137
|
+
|
|
138
|
+
@strawberry.field(description="The project session associated with the annotation.") # type: ignore
|
|
139
|
+
async def project_session_id(
|
|
140
|
+
self,
|
|
141
|
+
info: Info[Context, None],
|
|
142
|
+
) -> GlobalID:
|
|
143
|
+
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
144
|
+
|
|
145
|
+
if self.db_record:
|
|
146
|
+
project_session_id = self.db_record.project_session_id
|
|
147
|
+
else:
|
|
148
|
+
project_session_id = (
|
|
149
|
+
await info.context.data_loaders.project_session_annotation_fields.load(
|
|
150
|
+
(self.id, models.ProjectSessionAnnotation.project_session_id),
|
|
151
|
+
)
|
|
152
|
+
)
|
|
153
|
+
return GlobalID(type_name=ProjectSession.__name__, node_id=str(project_session_id))
|
|
154
|
+
|
|
155
|
+
@strawberry.field(description="The project session associated with the annotation.") # type: ignore
|
|
156
|
+
async def project_session(
|
|
157
|
+
self,
|
|
158
|
+
info: Info[Context, None],
|
|
159
|
+
) -> Annotated["ProjectSession", strawberry.lazy(".ProjectSession")]:
|
|
160
|
+
if self.db_record:
|
|
161
|
+
project_session_id = self.db_record.project_session_id
|
|
162
|
+
else:
|
|
163
|
+
project_session_id = (
|
|
164
|
+
await info.context.data_loaders.project_session_annotation_fields.load(
|
|
165
|
+
(self.id, models.ProjectSessionAnnotation.project_session_id),
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
from .ProjectSession import ProjectSession
|
|
169
|
+
|
|
170
|
+
return ProjectSession(id=project_session_id)
|
|
171
|
+
|
|
172
|
+
@strawberry.field(description="The user that produced the annotation.") # type: ignore
|
|
173
|
+
async def user(
|
|
174
|
+
self,
|
|
175
|
+
info: Info[Context, None],
|
|
176
|
+
) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
177
|
+
if self.db_record:
|
|
178
|
+
user_id = self.db_record.user_id
|
|
179
|
+
else:
|
|
180
|
+
user_id = await info.context.data_loaders.project_session_annotation_fields.load(
|
|
181
|
+
(self.id, models.ProjectSessionAnnotation.user_id),
|
|
182
|
+
)
|
|
183
|
+
if user_id is None:
|
|
184
|
+
return None
|
|
185
|
+
from .User import User
|
|
186
|
+
|
|
187
|
+
return User(id=user_id)
|
|
@@ -106,5 +106,5 @@ class ProjectTraceRetentionPolicy(Node):
|
|
|
106
106
|
project_rowids = await info.context.data_loaders.projects_by_trace_retention_policy_id.load(
|
|
107
107
|
self.id
|
|
108
108
|
)
|
|
109
|
-
data = [Project(
|
|
109
|
+
data = [Project(id=project_rowid) for project_rowid in project_rowids]
|
|
110
110
|
return connection_from_list(data=data, args=args)
|
|
@@ -6,9 +6,11 @@ import strawberry
|
|
|
6
6
|
from sqlalchemy import func, select
|
|
7
7
|
from strawberry import UNSET
|
|
8
8
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
|
+
from strawberry.scalars import JSON
|
|
9
10
|
from strawberry.types import Info
|
|
10
11
|
|
|
11
12
|
from phoenix.db import models
|
|
13
|
+
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
12
14
|
from phoenix.server.api.context import Context
|
|
13
15
|
from phoenix.server.api.exceptions import NotFound
|
|
14
16
|
from phoenix.server.api.types.Identifier import Identifier
|
|
@@ -19,24 +21,96 @@ from phoenix.server.api.types.pagination import (
|
|
|
19
21
|
connection_from_list,
|
|
20
22
|
)
|
|
21
23
|
|
|
24
|
+
from .PromptLabel import PromptLabel
|
|
22
25
|
from .PromptVersion import (
|
|
23
26
|
PromptVersion,
|
|
24
27
|
to_gql_prompt_version,
|
|
25
28
|
)
|
|
26
|
-
from .PromptVersionTag import PromptVersionTag
|
|
29
|
+
from .PromptVersionTag import PromptVersionTag
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
@strawberry.type
|
|
30
33
|
class Prompt(Node):
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
id: NodeID[int]
|
|
35
|
+
db_record: strawberry.Private[Optional[models.Prompt]] = None
|
|
36
|
+
|
|
37
|
+
def __post_init__(self) -> None:
|
|
38
|
+
if self.db_record and self.id != self.db_record.id:
|
|
39
|
+
raise ValueError("Prompt ID mismatch")
|
|
40
|
+
|
|
41
|
+
@strawberry.field
|
|
42
|
+
async def source_prompt_id(
|
|
43
|
+
self,
|
|
44
|
+
info: Info[Context, None],
|
|
45
|
+
) -> Optional[GlobalID]:
|
|
46
|
+
if self.db_record:
|
|
47
|
+
source_id = self.db_record.source_prompt_id
|
|
48
|
+
else:
|
|
49
|
+
source_id = await info.context.data_loaders.prompt_fields.load(
|
|
50
|
+
(self.id, models.Prompt.source_prompt_id),
|
|
51
|
+
)
|
|
52
|
+
if not source_id:
|
|
53
|
+
return None
|
|
54
|
+
return GlobalID(Prompt.__name__, str(source_id))
|
|
55
|
+
|
|
56
|
+
@strawberry.field
|
|
57
|
+
async def name(
|
|
58
|
+
self,
|
|
59
|
+
info: Info[Context, None],
|
|
60
|
+
) -> Identifier:
|
|
61
|
+
if self.db_record:
|
|
62
|
+
val = self.db_record.name
|
|
63
|
+
else:
|
|
64
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
65
|
+
(self.id, models.Prompt.name),
|
|
66
|
+
)
|
|
67
|
+
return Identifier(val.root)
|
|
68
|
+
|
|
69
|
+
@strawberry.field
|
|
70
|
+
async def description(
|
|
71
|
+
self,
|
|
72
|
+
info: Info[Context, None],
|
|
73
|
+
) -> Optional[str]:
|
|
74
|
+
if self.db_record:
|
|
75
|
+
val = self.db_record.description
|
|
76
|
+
else:
|
|
77
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
78
|
+
(self.id, models.Prompt.description),
|
|
79
|
+
)
|
|
80
|
+
return val
|
|
81
|
+
|
|
82
|
+
@strawberry.field
|
|
83
|
+
async def metadata(
|
|
84
|
+
self,
|
|
85
|
+
info: Info[Context, None],
|
|
86
|
+
) -> JSON:
|
|
87
|
+
if self.db_record:
|
|
88
|
+
val = self.db_record.metadata_
|
|
89
|
+
else:
|
|
90
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
91
|
+
(self.id, models.Prompt.metadata_),
|
|
92
|
+
)
|
|
93
|
+
return val
|
|
94
|
+
|
|
95
|
+
@strawberry.field
|
|
96
|
+
async def created_at(
|
|
97
|
+
self,
|
|
98
|
+
info: Info[Context, None],
|
|
99
|
+
) -> datetime:
|
|
100
|
+
if self.db_record:
|
|
101
|
+
val = self.db_record.created_at
|
|
102
|
+
else:
|
|
103
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
104
|
+
(self.id, models.Prompt.created_at),
|
|
105
|
+
)
|
|
106
|
+
return val
|
|
36
107
|
|
|
37
108
|
@strawberry.field
|
|
38
109
|
async def version(
|
|
39
|
-
self,
|
|
110
|
+
self,
|
|
111
|
+
info: Info[Context, None],
|
|
112
|
+
version_id: Optional[GlobalID] = None,
|
|
113
|
+
tag_name: Optional[Identifier] = None,
|
|
40
114
|
) -> PromptVersion:
|
|
41
115
|
async with info.context.db() as session:
|
|
42
116
|
if version_id:
|
|
@@ -44,15 +118,28 @@ class Prompt(Node):
|
|
|
44
118
|
version = await session.scalar(
|
|
45
119
|
select(models.PromptVersion).where(
|
|
46
120
|
models.PromptVersion.id == v_id,
|
|
47
|
-
models.PromptVersion.prompt_id == self.
|
|
121
|
+
models.PromptVersion.prompt_id == self.id,
|
|
48
122
|
)
|
|
49
123
|
)
|
|
50
124
|
if not version:
|
|
51
125
|
raise NotFound(f"Prompt version not found: {version_id}")
|
|
126
|
+
elif tag_name:
|
|
127
|
+
try:
|
|
128
|
+
name = IdentifierModel(tag_name)
|
|
129
|
+
except ValueError:
|
|
130
|
+
raise NotFound(f"Prompt version tag not found: {tag_name}")
|
|
131
|
+
version = await session.scalar(
|
|
132
|
+
select(models.PromptVersion)
|
|
133
|
+
.where(models.PromptVersion.prompt_id == self.id)
|
|
134
|
+
.join_from(models.PromptVersion, models.PromptVersionTag)
|
|
135
|
+
.where(models.PromptVersionTag.name == name)
|
|
136
|
+
)
|
|
137
|
+
if not version:
|
|
138
|
+
raise NotFound(f"This prompt has no associated versions by tag {tag_name}")
|
|
52
139
|
else:
|
|
53
140
|
stmt = (
|
|
54
141
|
select(models.PromptVersion)
|
|
55
|
-
.where(models.PromptVersion.prompt_id == self.
|
|
142
|
+
.where(models.PromptVersion.prompt_id == self.id)
|
|
56
143
|
.order_by(models.PromptVersion.id.desc())
|
|
57
144
|
.limit(1)
|
|
58
145
|
)
|
|
@@ -65,10 +152,11 @@ class Prompt(Node):
|
|
|
65
152
|
async def version_tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
|
|
66
153
|
async with info.context.db() as session:
|
|
67
154
|
stmt = select(models.PromptVersionTag).where(
|
|
68
|
-
models.PromptVersionTag.prompt_id == self.
|
|
155
|
+
models.PromptVersionTag.prompt_id == self.id
|
|
69
156
|
)
|
|
70
157
|
return [
|
|
71
|
-
|
|
158
|
+
PromptVersionTag(id=tag.id, db_record=tag)
|
|
159
|
+
async for tag in await session.stream_scalars(stmt)
|
|
72
160
|
]
|
|
73
161
|
|
|
74
162
|
@strawberry.field
|
|
@@ -89,7 +177,7 @@ class Prompt(Node):
|
|
|
89
177
|
row_number = func.row_number().over(order_by=models.PromptVersion.id).label("row_number")
|
|
90
178
|
stmt = (
|
|
91
179
|
select(models.PromptVersion, row_number)
|
|
92
|
-
.where(models.PromptVersion.prompt_id == self.
|
|
180
|
+
.where(models.PromptVersion.prompt_id == self.id)
|
|
93
181
|
.order_by(models.PromptVersion.id.desc())
|
|
94
182
|
)
|
|
95
183
|
async with info.context.db() as session:
|
|
@@ -101,34 +189,26 @@ class Prompt(Node):
|
|
|
101
189
|
|
|
102
190
|
@strawberry.field
|
|
103
191
|
async def source_prompt(self, info: Info[Context, None]) -> Optional["Prompt"]:
|
|
104
|
-
if
|
|
192
|
+
if self.db_record:
|
|
193
|
+
id_ = self.db_record.source_prompt_id
|
|
194
|
+
else:
|
|
195
|
+
id_ = await info.context.data_loaders.prompt_fields.load(
|
|
196
|
+
(self.id, models.Prompt.source_prompt_id),
|
|
197
|
+
)
|
|
198
|
+
if not id_:
|
|
105
199
|
return None
|
|
200
|
+
async with info.context.db() as session:
|
|
201
|
+
source_prompt = await session.get(models.Prompt, id_)
|
|
202
|
+
if not source_prompt:
|
|
203
|
+
raise NotFound(f"Source prompt not found: {id_}")
|
|
204
|
+
return Prompt(id=source_prompt.id, db_record=source_prompt)
|
|
106
205
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
)
|
|
110
|
-
|
|
206
|
+
@strawberry.field
|
|
207
|
+
async def labels(self, info: Info[Context, None]) -> list["PromptLabel"]:
|
|
111
208
|
async with info.context.db() as session:
|
|
112
|
-
|
|
113
|
-
select(models.
|
|
209
|
+
labels = await session.scalars(
|
|
210
|
+
select(models.PromptLabel)
|
|
211
|
+
.join(models.PromptPromptLabel)
|
|
212
|
+
.where(models.PromptPromptLabel.prompt_id == self.id)
|
|
114
213
|
)
|
|
115
|
-
|
|
116
|
-
raise NotFound(f"Source prompt not found: {self.source_prompt_id}")
|
|
117
|
-
return to_gql_prompt_from_orm(source_prompt)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def to_gql_prompt_from_orm(orm_model: "models.Prompt") -> Prompt:
|
|
121
|
-
if not orm_model.source_prompt_id:
|
|
122
|
-
source_prompt_gid = None
|
|
123
|
-
else:
|
|
124
|
-
source_prompt_gid = GlobalID(
|
|
125
|
-
Prompt.__name__,
|
|
126
|
-
str(orm_model.source_prompt_id),
|
|
127
|
-
)
|
|
128
|
-
return Prompt(
|
|
129
|
-
id_attr=orm_model.id,
|
|
130
|
-
source_prompt_id=source_prompt_gid,
|
|
131
|
-
name=Identifier(orm_model.name.root),
|
|
132
|
-
description=orm_model.description,
|
|
133
|
-
created_at=orm_model.created_at,
|
|
134
|
-
)
|
|
214
|
+
return [PromptLabel(id=label.id, db_record=label) for label in labels]
|