arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,9 +6,11 @@ import strawberry
|
|
|
6
6
|
from sqlalchemy import func, select
|
|
7
7
|
from strawberry import UNSET
|
|
8
8
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
|
+
from strawberry.scalars import JSON
|
|
9
10
|
from strawberry.types import Info
|
|
10
11
|
|
|
11
12
|
from phoenix.db import models
|
|
13
|
+
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
12
14
|
from phoenix.server.api.context import Context
|
|
13
15
|
from phoenix.server.api.exceptions import NotFound
|
|
14
16
|
from phoenix.server.api.types.Identifier import Identifier
|
|
@@ -19,24 +21,96 @@ from phoenix.server.api.types.pagination import (
|
|
|
19
21
|
connection_from_list,
|
|
20
22
|
)
|
|
21
23
|
|
|
24
|
+
from .PromptLabel import PromptLabel
|
|
22
25
|
from .PromptVersion import (
|
|
23
26
|
PromptVersion,
|
|
24
27
|
to_gql_prompt_version,
|
|
25
28
|
)
|
|
26
|
-
from .PromptVersionTag import PromptVersionTag
|
|
29
|
+
from .PromptVersionTag import PromptVersionTag
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
@strawberry.type
|
|
30
33
|
class Prompt(Node):
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
id: NodeID[int]
|
|
35
|
+
db_record: strawberry.Private[Optional[models.Prompt]] = None
|
|
36
|
+
|
|
37
|
+
def __post_init__(self) -> None:
|
|
38
|
+
if self.db_record and self.id != self.db_record.id:
|
|
39
|
+
raise ValueError("Prompt ID mismatch")
|
|
40
|
+
|
|
41
|
+
@strawberry.field
|
|
42
|
+
async def source_prompt_id(
|
|
43
|
+
self,
|
|
44
|
+
info: Info[Context, None],
|
|
45
|
+
) -> Optional[GlobalID]:
|
|
46
|
+
if self.db_record:
|
|
47
|
+
source_id = self.db_record.source_prompt_id
|
|
48
|
+
else:
|
|
49
|
+
source_id = await info.context.data_loaders.prompt_fields.load(
|
|
50
|
+
(self.id, models.Prompt.source_prompt_id),
|
|
51
|
+
)
|
|
52
|
+
if not source_id:
|
|
53
|
+
return None
|
|
54
|
+
return GlobalID(Prompt.__name__, str(source_id))
|
|
55
|
+
|
|
56
|
+
@strawberry.field
|
|
57
|
+
async def name(
|
|
58
|
+
self,
|
|
59
|
+
info: Info[Context, None],
|
|
60
|
+
) -> Identifier:
|
|
61
|
+
if self.db_record:
|
|
62
|
+
val = self.db_record.name
|
|
63
|
+
else:
|
|
64
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
65
|
+
(self.id, models.Prompt.name),
|
|
66
|
+
)
|
|
67
|
+
return Identifier(val.root)
|
|
68
|
+
|
|
69
|
+
@strawberry.field
|
|
70
|
+
async def description(
|
|
71
|
+
self,
|
|
72
|
+
info: Info[Context, None],
|
|
73
|
+
) -> Optional[str]:
|
|
74
|
+
if self.db_record:
|
|
75
|
+
val = self.db_record.description
|
|
76
|
+
else:
|
|
77
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
78
|
+
(self.id, models.Prompt.description),
|
|
79
|
+
)
|
|
80
|
+
return val
|
|
81
|
+
|
|
82
|
+
@strawberry.field
|
|
83
|
+
async def metadata(
|
|
84
|
+
self,
|
|
85
|
+
info: Info[Context, None],
|
|
86
|
+
) -> JSON:
|
|
87
|
+
if self.db_record:
|
|
88
|
+
val = self.db_record.metadata_
|
|
89
|
+
else:
|
|
90
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
91
|
+
(self.id, models.Prompt.metadata_),
|
|
92
|
+
)
|
|
93
|
+
return val
|
|
94
|
+
|
|
95
|
+
@strawberry.field
|
|
96
|
+
async def created_at(
|
|
97
|
+
self,
|
|
98
|
+
info: Info[Context, None],
|
|
99
|
+
) -> datetime:
|
|
100
|
+
if self.db_record:
|
|
101
|
+
val = self.db_record.created_at
|
|
102
|
+
else:
|
|
103
|
+
val = await info.context.data_loaders.prompt_fields.load(
|
|
104
|
+
(self.id, models.Prompt.created_at),
|
|
105
|
+
)
|
|
106
|
+
return val
|
|
36
107
|
|
|
37
108
|
@strawberry.field
|
|
38
109
|
async def version(
|
|
39
|
-
self,
|
|
110
|
+
self,
|
|
111
|
+
info: Info[Context, None],
|
|
112
|
+
version_id: Optional[GlobalID] = None,
|
|
113
|
+
tag_name: Optional[Identifier] = None,
|
|
40
114
|
) -> PromptVersion:
|
|
41
115
|
async with info.context.db() as session:
|
|
42
116
|
if version_id:
|
|
@@ -44,15 +118,28 @@ class Prompt(Node):
|
|
|
44
118
|
version = await session.scalar(
|
|
45
119
|
select(models.PromptVersion).where(
|
|
46
120
|
models.PromptVersion.id == v_id,
|
|
47
|
-
models.PromptVersion.prompt_id == self.
|
|
121
|
+
models.PromptVersion.prompt_id == self.id,
|
|
48
122
|
)
|
|
49
123
|
)
|
|
50
124
|
if not version:
|
|
51
125
|
raise NotFound(f"Prompt version not found: {version_id}")
|
|
126
|
+
elif tag_name:
|
|
127
|
+
try:
|
|
128
|
+
name = IdentifierModel(tag_name)
|
|
129
|
+
except ValueError:
|
|
130
|
+
raise NotFound(f"Prompt version tag not found: {tag_name}")
|
|
131
|
+
version = await session.scalar(
|
|
132
|
+
select(models.PromptVersion)
|
|
133
|
+
.where(models.PromptVersion.prompt_id == self.id)
|
|
134
|
+
.join_from(models.PromptVersion, models.PromptVersionTag)
|
|
135
|
+
.where(models.PromptVersionTag.name == name)
|
|
136
|
+
)
|
|
137
|
+
if not version:
|
|
138
|
+
raise NotFound(f"This prompt has no associated versions by tag {tag_name}")
|
|
52
139
|
else:
|
|
53
140
|
stmt = (
|
|
54
141
|
select(models.PromptVersion)
|
|
55
|
-
.where(models.PromptVersion.prompt_id == self.
|
|
142
|
+
.where(models.PromptVersion.prompt_id == self.id)
|
|
56
143
|
.order_by(models.PromptVersion.id.desc())
|
|
57
144
|
.limit(1)
|
|
58
145
|
)
|
|
@@ -65,10 +152,11 @@ class Prompt(Node):
|
|
|
65
152
|
async def version_tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
|
|
66
153
|
async with info.context.db() as session:
|
|
67
154
|
stmt = select(models.PromptVersionTag).where(
|
|
68
|
-
models.PromptVersionTag.prompt_id == self.
|
|
155
|
+
models.PromptVersionTag.prompt_id == self.id
|
|
69
156
|
)
|
|
70
157
|
return [
|
|
71
|
-
|
|
158
|
+
PromptVersionTag(id=tag.id, db_record=tag)
|
|
159
|
+
async for tag in await session.stream_scalars(stmt)
|
|
72
160
|
]
|
|
73
161
|
|
|
74
162
|
@strawberry.field
|
|
@@ -89,7 +177,7 @@ class Prompt(Node):
|
|
|
89
177
|
row_number = func.row_number().over(order_by=models.PromptVersion.id).label("row_number")
|
|
90
178
|
stmt = (
|
|
91
179
|
select(models.PromptVersion, row_number)
|
|
92
|
-
.where(models.PromptVersion.prompt_id == self.
|
|
180
|
+
.where(models.PromptVersion.prompt_id == self.id)
|
|
93
181
|
.order_by(models.PromptVersion.id.desc())
|
|
94
182
|
)
|
|
95
183
|
async with info.context.db() as session:
|
|
@@ -101,34 +189,26 @@ class Prompt(Node):
|
|
|
101
189
|
|
|
102
190
|
@strawberry.field
|
|
103
191
|
async def source_prompt(self, info: Info[Context, None]) -> Optional["Prompt"]:
|
|
104
|
-
if
|
|
192
|
+
if self.db_record:
|
|
193
|
+
id_ = self.db_record.source_prompt_id
|
|
194
|
+
else:
|
|
195
|
+
id_ = await info.context.data_loaders.prompt_fields.load(
|
|
196
|
+
(self.id, models.Prompt.source_prompt_id),
|
|
197
|
+
)
|
|
198
|
+
if not id_:
|
|
105
199
|
return None
|
|
200
|
+
async with info.context.db() as session:
|
|
201
|
+
source_prompt = await session.get(models.Prompt, id_)
|
|
202
|
+
if not source_prompt:
|
|
203
|
+
raise NotFound(f"Source prompt not found: {id_}")
|
|
204
|
+
return Prompt(id=source_prompt.id, db_record=source_prompt)
|
|
106
205
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
)
|
|
110
|
-
|
|
206
|
+
@strawberry.field
|
|
207
|
+
async def labels(self, info: Info[Context, None]) -> list["PromptLabel"]:
|
|
111
208
|
async with info.context.db() as session:
|
|
112
|
-
|
|
113
|
-
select(models.
|
|
209
|
+
labels = await session.scalars(
|
|
210
|
+
select(models.PromptLabel)
|
|
211
|
+
.join(models.PromptPromptLabel)
|
|
212
|
+
.where(models.PromptPromptLabel.prompt_id == self.id)
|
|
114
213
|
)
|
|
115
|
-
|
|
116
|
-
raise NotFound(f"Source prompt not found: {self.source_prompt_id}")
|
|
117
|
-
return to_gql_prompt_from_orm(source_prompt)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def to_gql_prompt_from_orm(orm_model: "models.Prompt") -> Prompt:
|
|
121
|
-
if not orm_model.source_prompt_id:
|
|
122
|
-
source_prompt_gid = None
|
|
123
|
-
else:
|
|
124
|
-
source_prompt_gid = GlobalID(
|
|
125
|
-
Prompt.__name__,
|
|
126
|
-
str(orm_model.source_prompt_id),
|
|
127
|
-
)
|
|
128
|
-
return Prompt(
|
|
129
|
-
id_attr=orm_model.id,
|
|
130
|
-
source_prompt_id=source_prompt_gid,
|
|
131
|
-
name=Identifier(orm_model.name.root),
|
|
132
|
-
description=orm_model.description,
|
|
133
|
-
created_at=orm_model.created_at,
|
|
134
|
-
)
|
|
214
|
+
return [PromptLabel(id=label.id, db_record=label) for label in labels]
|
|
@@ -1,41 +1,58 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import strawberry
|
|
4
|
-
from sqlalchemy import select
|
|
5
4
|
from strawberry.relay import Node, NodeID
|
|
6
5
|
from strawberry.types import Info
|
|
7
6
|
|
|
8
7
|
from phoenix.db import models
|
|
9
8
|
from phoenix.server.api.context import Context
|
|
10
9
|
from phoenix.server.api.types.Identifier import Identifier
|
|
11
|
-
from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
@strawberry.type
|
|
15
13
|
class PromptLabel(Node):
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
14
|
+
id: NodeID[int]
|
|
15
|
+
db_record: strawberry.Private[Optional[models.PromptLabel]] = None
|
|
16
|
+
|
|
17
|
+
def __post_init__(self) -> None:
|
|
18
|
+
if self.db_record and self.id != self.db_record.id:
|
|
19
|
+
raise ValueError("PromptLabel ID mismatch")
|
|
20
|
+
|
|
21
|
+
@strawberry.field
|
|
22
|
+
async def name(
|
|
23
|
+
self,
|
|
24
|
+
info: Info[Context, None],
|
|
25
|
+
) -> Identifier:
|
|
26
|
+
if self.db_record:
|
|
27
|
+
val = self.db_record.name
|
|
28
|
+
else:
|
|
29
|
+
val = await info.context.data_loaders.prompt_label_fields.load(
|
|
30
|
+
(self.id, models.PromptLabel.name),
|
|
31
|
+
)
|
|
32
|
+
return Identifier(val)
|
|
33
|
+
|
|
34
|
+
@strawberry.field
|
|
35
|
+
async def description(
|
|
36
|
+
self,
|
|
37
|
+
info: Info[Context, None],
|
|
38
|
+
) -> Optional[str]:
|
|
39
|
+
if self.db_record:
|
|
40
|
+
val = self.db_record.description
|
|
41
|
+
else:
|
|
42
|
+
val = await info.context.data_loaders.prompt_label_fields.load(
|
|
43
|
+
(self.id, models.PromptLabel.description),
|
|
44
|
+
)
|
|
45
|
+
return val
|
|
19
46
|
|
|
20
47
|
@strawberry.field
|
|
21
|
-
async def
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
48
|
+
async def color(
|
|
49
|
+
self,
|
|
50
|
+
info: Info[Context, None],
|
|
51
|
+
) -> str:
|
|
52
|
+
if self.db_record:
|
|
53
|
+
val = self.db_record.color
|
|
54
|
+
else:
|
|
55
|
+
val = await info.context.data_loaders.prompt_label_fields.load(
|
|
56
|
+
(self.id, models.PromptLabel.color),
|
|
29
57
|
)
|
|
30
|
-
|
|
31
|
-
to_gql_prompt_from_orm(prompt_orm)
|
|
32
|
-
async for prompt_orm in await session.stream_scalars(statement)
|
|
33
|
-
]
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def to_gql_prompt_label(label_orm: models.PromptLabel) -> PromptLabel:
|
|
37
|
-
return PromptLabel(
|
|
38
|
-
id_attr=label_orm.id,
|
|
39
|
-
name=Identifier(label_orm.name),
|
|
40
|
-
description=label_orm.description,
|
|
41
|
-
)
|
|
58
|
+
return val
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import select
|
|
@@ -18,7 +18,7 @@ from phoenix.server.api.helpers.prompts.models import (
|
|
|
18
18
|
denormalize_tools,
|
|
19
19
|
get_raw_invocation_parameters,
|
|
20
20
|
)
|
|
21
|
-
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
21
|
+
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
22
22
|
from phoenix.server.api.types.PromptVersionTemplate import (
|
|
23
23
|
PromptTemplate,
|
|
24
24
|
to_gql_template_from_orm,
|
|
@@ -26,7 +26,9 @@ from phoenix.server.api.types.PromptVersionTemplate import (
|
|
|
26
26
|
|
|
27
27
|
from .ResponseFormat import ResponseFormat
|
|
28
28
|
from .ToolDefinition import ToolDefinition
|
|
29
|
-
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from .User import User
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
@strawberry.type
|
|
@@ -53,16 +55,17 @@ class PromptVersion(Node):
|
|
|
53
55
|
models.PromptVersionTag.prompt_version_id == self.id_attr
|
|
54
56
|
)
|
|
55
57
|
return [
|
|
56
|
-
|
|
58
|
+
PromptVersionTag(id=tag.id, db_record=tag)
|
|
59
|
+
async for tag in await session.stream_scalars(stmt)
|
|
57
60
|
]
|
|
58
61
|
|
|
59
62
|
@strawberry.field
|
|
60
|
-
async def user(self
|
|
63
|
+
async def user(self) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
61
64
|
if self.user_id is None:
|
|
62
65
|
return None
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
return
|
|
66
|
+
from .User import User
|
|
67
|
+
|
|
68
|
+
return User(id=self.user_id)
|
|
66
69
|
|
|
67
70
|
@strawberry.field
|
|
68
71
|
async def previous_version(self, info: Info[Context, None]) -> Optional["PromptVersion"]:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
2
2
|
|
|
3
3
|
import strawberry
|
|
4
4
|
from strawberry import Info
|
|
@@ -7,34 +7,74 @@ from strawberry.relay import GlobalID, Node, NodeID
|
|
|
7
7
|
from phoenix.db import models
|
|
8
8
|
from phoenix.server.api.context import Context
|
|
9
9
|
from phoenix.server.api.types.Identifier import Identifier
|
|
10
|
-
from phoenix.server.api.types.User import User
|
|
10
|
+
from phoenix.server.api.types.User import User
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from .User import User
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
@strawberry.type
|
|
14
17
|
class PromptVersionTag(Node):
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
id: NodeID[int]
|
|
19
|
+
db_record: strawberry.Private[Optional[models.PromptVersionTag]] = None
|
|
20
|
+
|
|
21
|
+
def __post_init__(self) -> None:
|
|
22
|
+
if self.db_record and self.id != self.db_record.id:
|
|
23
|
+
raise ValueError("PromptVersionTag ID mismatch")
|
|
24
|
+
|
|
25
|
+
@strawberry.field
|
|
26
|
+
async def prompt_version_id(
|
|
27
|
+
self,
|
|
28
|
+
info: Info[Context, None],
|
|
29
|
+
) -> GlobalID:
|
|
30
|
+
from phoenix.server.api.types.PromptVersion import PromptVersion
|
|
31
|
+
|
|
32
|
+
if self.db_record:
|
|
33
|
+
version_id = self.db_record.prompt_version_id
|
|
34
|
+
else:
|
|
35
|
+
version_id = await info.context.data_loaders.prompt_version_tag_fields.load(
|
|
36
|
+
(self.id, models.PromptVersionTag.prompt_version_id),
|
|
37
|
+
)
|
|
38
|
+
return GlobalID(PromptVersion.__name__, str(version_id))
|
|
20
39
|
|
|
21
40
|
@strawberry.field
|
|
22
|
-
async def
|
|
23
|
-
|
|
41
|
+
async def name(
|
|
42
|
+
self,
|
|
43
|
+
info: Info[Context, None],
|
|
44
|
+
) -> Identifier:
|
|
45
|
+
if self.db_record:
|
|
46
|
+
val = self.db_record.name
|
|
47
|
+
else:
|
|
48
|
+
val = await info.context.data_loaders.prompt_version_tag_fields.load(
|
|
49
|
+
(self.id, models.PromptVersionTag.name),
|
|
50
|
+
)
|
|
51
|
+
return Identifier(val.root)
|
|
52
|
+
|
|
53
|
+
@strawberry.field
|
|
54
|
+
async def description(
|
|
55
|
+
self,
|
|
56
|
+
info: Info[Context, None],
|
|
57
|
+
) -> Optional[str]:
|
|
58
|
+
if self.db_record:
|
|
59
|
+
val = self.db_record.description
|
|
60
|
+
else:
|
|
61
|
+
val = await info.context.data_loaders.prompt_version_tag_fields.load(
|
|
62
|
+
(self.id, models.PromptVersionTag.description),
|
|
63
|
+
)
|
|
64
|
+
return val
|
|
65
|
+
|
|
66
|
+
@strawberry.field
|
|
67
|
+
async def user(
|
|
68
|
+
self, info: Info[Context, None]
|
|
69
|
+
) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
|
|
70
|
+
if self.db_record:
|
|
71
|
+
user_id = self.db_record.user_id
|
|
72
|
+
else:
|
|
73
|
+
user_id = await info.context.data_loaders.prompt_version_tag_fields.load(
|
|
74
|
+
(self.id, models.PromptVersionTag.user_id),
|
|
75
|
+
)
|
|
76
|
+
if user_id is None:
|
|
24
77
|
return None
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
return
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def to_gql_prompt_version_tag(prompt_version_tag: models.PromptVersionTag) -> PromptVersionTag:
|
|
31
|
-
from phoenix.server.api.types.PromptVersion import PromptVersion
|
|
32
|
-
|
|
33
|
-
version_gid = GlobalID(PromptVersion.__name__, str(prompt_version_tag.prompt_version_id))
|
|
34
|
-
return PromptVersionTag(
|
|
35
|
-
id_attr=prompt_version_tag.id,
|
|
36
|
-
prompt_version_id=version_gid,
|
|
37
|
-
name=Identifier(prompt_version_tag.name.root),
|
|
38
|
-
description=prompt_version_tag.description,
|
|
39
|
-
user_id=prompt_version_tag.user_id,
|
|
40
|
-
)
|
|
78
|
+
from .User import User
|
|
79
|
+
|
|
80
|
+
return User(id=user_id)
|