arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import strawberry
|
|
4
4
|
from strawberry import UNSET
|
|
5
5
|
from strawberry.relay import GlobalID
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
6
7
|
|
|
7
8
|
from phoenix.db import models
|
|
8
9
|
from phoenix.server.api.exceptions import BadRequest
|
|
@@ -11,7 +12,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@strawberry.input
|
|
14
|
-
class
|
|
15
|
+
class AnnotationFilterCondition:
|
|
15
16
|
names: Optional[list[str]] = UNSET
|
|
16
17
|
sources: Optional[list[AnnotationSource]] = UNSET
|
|
17
18
|
user_ids: Optional[list[Optional[GlobalID]]] = UNSET
|
|
@@ -26,42 +27,49 @@ class SpanAnnotationFilterCondition:
|
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
@strawberry.input
|
|
29
|
-
class
|
|
30
|
-
include: Optional[
|
|
31
|
-
exclude: Optional[
|
|
30
|
+
class AnnotationFilter:
|
|
31
|
+
include: Optional[AnnotationFilterCondition] = UNSET
|
|
32
|
+
exclude: Optional[AnnotationFilterCondition] = UNSET
|
|
32
33
|
|
|
33
34
|
def __post_init__(self) -> None:
|
|
34
35
|
if self.include is UNSET and self.exclude is UNSET:
|
|
35
36
|
raise BadRequest("include and exclude cannot both be unset")
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
|
|
39
|
+
_Annotation: TypeAlias = Union[
|
|
40
|
+
models.SpanAnnotation,
|
|
41
|
+
models.TraceAnnotation,
|
|
42
|
+
models.ProjectSessionAnnotation,
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def satisfies_filter(annotation: _Annotation, filter: AnnotationFilter) -> bool:
|
|
39
47
|
"""
|
|
40
|
-
Returns true if the
|
|
48
|
+
Returns true if the annotation satisfies the filter and false otherwise.
|
|
41
49
|
"""
|
|
42
|
-
|
|
50
|
+
annotation_source = AnnotationSource(annotation.source)
|
|
43
51
|
if include := filter.include:
|
|
44
|
-
if include.names and
|
|
52
|
+
if include.names and annotation.name not in include.names:
|
|
45
53
|
return False
|
|
46
|
-
if include.sources and
|
|
54
|
+
if include.sources and annotation_source not in include.sources:
|
|
47
55
|
return False
|
|
48
56
|
if include.user_ids:
|
|
49
57
|
user_rowids = [
|
|
50
58
|
from_global_id_with_expected_type(user_id, "User") if user_id is not None else None
|
|
51
59
|
for user_id in include.user_ids
|
|
52
60
|
]
|
|
53
|
-
if
|
|
61
|
+
if annotation.user_id not in user_rowids:
|
|
54
62
|
return False
|
|
55
63
|
if exclude := filter.exclude:
|
|
56
|
-
if exclude.names and
|
|
64
|
+
if exclude.names and annotation.name in exclude.names:
|
|
57
65
|
return False
|
|
58
|
-
if exclude.sources and
|
|
66
|
+
if exclude.sources and annotation_source in exclude.sources:
|
|
59
67
|
return False
|
|
60
68
|
if exclude.user_ids:
|
|
61
69
|
user_rowids = [
|
|
62
70
|
from_global_id_with_expected_type(user_id, "User") if user_id is not None else None
|
|
63
71
|
for user_id in exclude.user_ids
|
|
64
72
|
]
|
|
65
|
-
if
|
|
73
|
+
if annotation.user_id in user_rowids:
|
|
66
74
|
return False
|
|
67
75
|
return True
|
|
@@ -8,6 +8,7 @@ from strawberry.scalars import JSON
|
|
|
8
8
|
from phoenix.server.api.helpers.prompts.models import (
|
|
9
9
|
PromptTemplateFormat,
|
|
10
10
|
)
|
|
11
|
+
from phoenix.server.api.input_types.GenerativeCredentialInput import GenerativeCredentialInput
|
|
11
12
|
from phoenix.server.api.types.Identifier import Identifier
|
|
12
13
|
|
|
13
14
|
from .ChatCompletionMessageInput import ChatCompletionMessageInput
|
|
@@ -22,9 +23,10 @@ class ChatCompletionInput:
|
|
|
22
23
|
model: GenerativeModelInput
|
|
23
24
|
invocation_parameters: list[InvocationParameterInput] = strawberry.field(default_factory=list)
|
|
24
25
|
tools: Optional[list[JSON]] = UNSET
|
|
25
|
-
|
|
26
|
+
credentials: Optional[list[GenerativeCredentialInput]] = UNSET
|
|
26
27
|
template: Optional[PromptTemplateOptions] = UNSET
|
|
27
28
|
prompt_name: Optional[Identifier] = None
|
|
29
|
+
repetitions: int
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
@strawberry.input
|
|
@@ -33,10 +35,12 @@ class ChatCompletionOverDatasetInput:
|
|
|
33
35
|
model: GenerativeModelInput
|
|
34
36
|
invocation_parameters: list[InvocationParameterInput] = strawberry.field(default_factory=list)
|
|
35
37
|
tools: Optional[list[JSON]] = UNSET
|
|
36
|
-
|
|
38
|
+
credentials: Optional[list[GenerativeCredentialInput]] = UNSET
|
|
37
39
|
template_format: PromptTemplateFormat = PromptTemplateFormat.MUSTACHE
|
|
40
|
+
repetitions: int
|
|
38
41
|
dataset_id: GlobalID
|
|
39
42
|
dataset_version_id: Optional[GlobalID] = None
|
|
43
|
+
split_ids: Optional[list[GlobalID]] = None
|
|
40
44
|
experiment_name: Optional[str] = None
|
|
41
45
|
experiment_description: Optional[str] = None
|
|
42
46
|
experiment_metadata: Optional[JSON] = strawberry.field(default_factory=dict)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry import UNSET
|
|
6
|
+
|
|
7
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@strawberry.input
|
|
11
|
+
class CreateProjectInput:
|
|
12
|
+
name: str
|
|
13
|
+
description: Optional[str] = UNSET
|
|
14
|
+
gradient_start_color: Optional[str] = UNSET
|
|
15
|
+
gradient_end_color: Optional[str] = UNSET
|
|
16
|
+
|
|
17
|
+
def __post_init__(self) -> None:
|
|
18
|
+
if not self.name.strip():
|
|
19
|
+
raise BadRequest("Name cannot be empty")
|
|
20
|
+
if self.gradient_start_color and not re.match(
|
|
21
|
+
r"^#([0-9a-fA-F]{6})$", self.gradient_start_color
|
|
22
|
+
):
|
|
23
|
+
raise BadRequest("Gradient start color must be a valid hex color")
|
|
24
|
+
if self.gradient_end_color and not re.match(
|
|
25
|
+
r"^#([0-9a-fA-F]{6})$", self.gradient_end_color
|
|
26
|
+
):
|
|
27
|
+
raise BadRequest("Gradient end color must be a valid hex color")
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry.relay import GlobalID
|
|
5
|
+
from strawberry.scalars import JSON
|
|
6
|
+
|
|
7
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
8
|
+
from phoenix.server.api.types.AnnotationSource import AnnotationSource
|
|
9
|
+
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@strawberry.input
|
|
13
|
+
class CreateProjectSessionAnnotationInput:
|
|
14
|
+
project_session_id: GlobalID
|
|
15
|
+
name: str
|
|
16
|
+
annotator_kind: AnnotatorKind = AnnotatorKind.HUMAN
|
|
17
|
+
label: Optional[str] = None
|
|
18
|
+
score: Optional[float] = None
|
|
19
|
+
explanation: Optional[str] = None
|
|
20
|
+
metadata: JSON = strawberry.field(default_factory=dict)
|
|
21
|
+
source: AnnotationSource = AnnotationSource.APP
|
|
22
|
+
identifier: Optional[str] = strawberry.UNSET
|
|
23
|
+
|
|
24
|
+
def __post_init__(self) -> None:
|
|
25
|
+
self.name = self.name.strip()
|
|
26
|
+
if isinstance(self.label, str):
|
|
27
|
+
self.label = self.label.strip()
|
|
28
|
+
if not self.label:
|
|
29
|
+
self.label = None
|
|
30
|
+
if isinstance(self.explanation, str):
|
|
31
|
+
self.explanation = self.explanation.strip()
|
|
32
|
+
if not self.explanation:
|
|
33
|
+
self.explanation = None
|
|
34
|
+
if isinstance(self.identifier, str):
|
|
35
|
+
self.identifier = self.identifier.strip()
|
|
36
|
+
if self.score is None and not self.label and not self.explanation:
|
|
37
|
+
raise BadRequest("At least one of score, label, or explanation must be not null/empty.")
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry import UNSET
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@strawberry.enum
|
|
9
|
+
class DatasetFilterColumn(Enum):
|
|
10
|
+
name = "name"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@strawberry.input(description="A filter for datasets")
|
|
14
|
+
class DatasetFilter:
|
|
15
|
+
col: Optional[DatasetFilterColumn] = None
|
|
16
|
+
value: Optional[str] = None
|
|
17
|
+
filter_labels: Optional[list[str]] = UNSET
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
from enum import Enum, auto
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
import strawberry
|
|
6
|
+
from sqlalchemy import ColumnElement, Select, and_, func, literal, or_, select, tuple_
|
|
7
|
+
from sqlalchemy.sql.selectable import NamedFromClause
|
|
8
|
+
from strawberry import Maybe
|
|
9
|
+
from typing_extensions import assert_never
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.types.pagination import (
|
|
13
|
+
Cursor,
|
|
14
|
+
CursorSortColumn,
|
|
15
|
+
CursorSortColumnDataType,
|
|
16
|
+
CursorSortColumnValue,
|
|
17
|
+
)
|
|
18
|
+
from phoenix.server.api.types.SortDir import SortDir
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@strawberry.enum
|
|
22
|
+
class ExperimentRunMetric(Enum):
|
|
23
|
+
latencyMs = auto()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@strawberry.input(one_of=True)
|
|
27
|
+
class ExperimentRunColumn:
|
|
28
|
+
metric: Maybe[ExperimentRunMetric]
|
|
29
|
+
annotation_name: Maybe[str]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@strawberry.input(description="The sort key and direction for experiment run connections")
|
|
33
|
+
class ExperimentRunSort:
|
|
34
|
+
col: ExperimentRunColumn
|
|
35
|
+
dir: SortDir
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_experiment_run_cursor(
|
|
39
|
+
run: models.ExperimentRun, annotation_score: Optional[float], sort: Optional[ExperimentRunSort]
|
|
40
|
+
) -> Cursor:
|
|
41
|
+
sort_column: Optional[CursorSortColumn] = None
|
|
42
|
+
if sort:
|
|
43
|
+
if sort.col.metric:
|
|
44
|
+
metric = sort.col.metric.value
|
|
45
|
+
assert metric is not None
|
|
46
|
+
if metric is ExperimentRunMetric.latencyMs:
|
|
47
|
+
sort_column = CursorSortColumn(
|
|
48
|
+
type=CursorSortColumnDataType.FLOAT,
|
|
49
|
+
value=run.latency_ms,
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
assert_never(metric)
|
|
53
|
+
elif sort.col.annotation_name:
|
|
54
|
+
data_type = (
|
|
55
|
+
CursorSortColumnDataType.FLOAT
|
|
56
|
+
if annotation_score is not None
|
|
57
|
+
else CursorSortColumnDataType.NULL
|
|
58
|
+
)
|
|
59
|
+
sort_column = CursorSortColumn(
|
|
60
|
+
type=data_type,
|
|
61
|
+
value=annotation_score,
|
|
62
|
+
)
|
|
63
|
+
return Cursor(rowid=run.id, sort_column=sort_column)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def add_order_by_and_page_start_to_query(
|
|
67
|
+
query: Select[Any],
|
|
68
|
+
sort: Optional[ExperimentRunSort],
|
|
69
|
+
experiment_rowid: int,
|
|
70
|
+
after_experiment_run_rowid: Optional[int],
|
|
71
|
+
after_sort_column_value: Optional[CursorSortColumnValue] = None,
|
|
72
|
+
) -> Select[Any]:
|
|
73
|
+
mean_annotation_scores: Optional[NamedFromClause] = None
|
|
74
|
+
if sort and sort.col.annotation_name:
|
|
75
|
+
annotation_name = sort.col.annotation_name.value
|
|
76
|
+
assert annotation_name is not None
|
|
77
|
+
mean_annotation_scores = _get_mean_annotation_scores_subquery(annotation_name)
|
|
78
|
+
order_by_columns = _get_order_by_columns(
|
|
79
|
+
sort=sort, experiment_rowid=experiment_rowid, mean_annotation_scores=mean_annotation_scores
|
|
80
|
+
)
|
|
81
|
+
query = query.order_by(*order_by_columns)
|
|
82
|
+
if after_experiment_run_rowid is not None:
|
|
83
|
+
query = _add_after_expression(
|
|
84
|
+
query=query,
|
|
85
|
+
sort=sort,
|
|
86
|
+
experiment_run_rowid=after_experiment_run_rowid,
|
|
87
|
+
after_sort_column_value=after_sort_column_value,
|
|
88
|
+
mean_annotation_scores=mean_annotation_scores,
|
|
89
|
+
)
|
|
90
|
+
query = _add_joins_and_selects_to_query(
|
|
91
|
+
query=query,
|
|
92
|
+
sort=sort,
|
|
93
|
+
mean_annotation_scores=mean_annotation_scores,
|
|
94
|
+
)
|
|
95
|
+
return query
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _get_order_by_columns(
|
|
99
|
+
sort: Optional[ExperimentRunSort],
|
|
100
|
+
experiment_rowid: int,
|
|
101
|
+
mean_annotation_scores: Optional[NamedFromClause],
|
|
102
|
+
) -> tuple[ColumnElement[Any], ...]:
|
|
103
|
+
if not sort:
|
|
104
|
+
# Ideally, this would sort the runs by (example_id, repetition_number),
|
|
105
|
+
# but this would require making the cursor more complex or adding an additional query
|
|
106
|
+
# to handle the after cursor.
|
|
107
|
+
return (models.ExperimentRun.id.asc(),)
|
|
108
|
+
sort_direction = sort.dir
|
|
109
|
+
if sort.col.metric:
|
|
110
|
+
metric = sort.col.metric.value
|
|
111
|
+
assert metric is not None
|
|
112
|
+
if metric is ExperimentRunMetric.latencyMs:
|
|
113
|
+
if sort_direction is SortDir.asc:
|
|
114
|
+
return (models.ExperimentRun.latency_ms.asc(), models.ExperimentRun.id.asc())
|
|
115
|
+
else:
|
|
116
|
+
return (models.ExperimentRun.latency_ms.desc(), models.ExperimentRun.id.desc())
|
|
117
|
+
else:
|
|
118
|
+
assert_never(metric)
|
|
119
|
+
elif sort.col.annotation_name:
|
|
120
|
+
annotation_name = sort.col.annotation_name.value
|
|
121
|
+
assert annotation_name is not None
|
|
122
|
+
assert mean_annotation_scores is not None
|
|
123
|
+
if sort_direction is SortDir.asc:
|
|
124
|
+
return (
|
|
125
|
+
mean_annotation_scores.c.score.asc().nulls_last(),
|
|
126
|
+
models.ExperimentRun.id.asc(),
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
return (
|
|
130
|
+
mean_annotation_scores.c.score.desc().nulls_last(),
|
|
131
|
+
models.ExperimentRun.id.desc(),
|
|
132
|
+
)
|
|
133
|
+
raise NotImplementedError
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _add_after_expression(
|
|
137
|
+
query: Select[Any],
|
|
138
|
+
sort: Optional[ExperimentRunSort],
|
|
139
|
+
experiment_run_rowid: int,
|
|
140
|
+
after_sort_column_value: Optional[CursorSortColumnValue],
|
|
141
|
+
mean_annotation_scores: Optional[NamedFromClause],
|
|
142
|
+
) -> Select[Any]:
|
|
143
|
+
if not sort:
|
|
144
|
+
# Ideally, this would return the runs sorted by (example_id, repetition_number),
|
|
145
|
+
# but this would require making the cursor more complex or adding an additional query.
|
|
146
|
+
return query.where(models.ExperimentRun.id > literal(experiment_run_rowid))
|
|
147
|
+
sort_direction = sort.dir
|
|
148
|
+
compare_fn = operator.gt if sort_direction is SortDir.asc else operator.lt
|
|
149
|
+
if sort.col.metric:
|
|
150
|
+
metric = sort.col.metric.value
|
|
151
|
+
assert metric is not None
|
|
152
|
+
if metric is ExperimentRunMetric.latencyMs:
|
|
153
|
+
assert after_sort_column_value is not None
|
|
154
|
+
return query.where(
|
|
155
|
+
compare_fn(
|
|
156
|
+
tuple_(models.ExperimentRun.latency_ms, models.ExperimentRun.id),
|
|
157
|
+
tuple_(
|
|
158
|
+
literal(after_sort_column_value),
|
|
159
|
+
literal(experiment_run_rowid),
|
|
160
|
+
),
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
assert_never(metric)
|
|
165
|
+
elif sort.col.annotation_name:
|
|
166
|
+
annotation_name = sort.col.annotation_name.value
|
|
167
|
+
assert annotation_name is not None
|
|
168
|
+
assert mean_annotation_scores is not None
|
|
169
|
+
if after_sort_column_value is None:
|
|
170
|
+
return query.where(
|
|
171
|
+
and_(
|
|
172
|
+
compare_fn(models.ExperimentRun.id, literal(experiment_run_rowid)),
|
|
173
|
+
mean_annotation_scores.c.score.is_(None),
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
return query.where(
|
|
178
|
+
or_(
|
|
179
|
+
compare_fn(
|
|
180
|
+
tuple_(mean_annotation_scores.c.score, models.ExperimentRun.id),
|
|
181
|
+
tuple_(
|
|
182
|
+
literal(after_sort_column_value),
|
|
183
|
+
literal(experiment_run_rowid),
|
|
184
|
+
),
|
|
185
|
+
),
|
|
186
|
+
mean_annotation_scores.c.score.is_(None),
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
raise NotImplementedError
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _get_mean_annotation_scores_subquery(annotation_name: str) -> NamedFromClause:
|
|
193
|
+
return (
|
|
194
|
+
select(
|
|
195
|
+
func.avg(models.ExperimentRunAnnotation.score).label("score"),
|
|
196
|
+
models.ExperimentRunAnnotation.experiment_run_id.label("experiment_run_id"),
|
|
197
|
+
)
|
|
198
|
+
.select_from(models.ExperimentRunAnnotation)
|
|
199
|
+
.join(
|
|
200
|
+
models.ExperimentRun,
|
|
201
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
202
|
+
)
|
|
203
|
+
.where(models.ExperimentRunAnnotation.name == annotation_name)
|
|
204
|
+
.group_by(models.ExperimentRunAnnotation.experiment_run_id)
|
|
205
|
+
.subquery()
|
|
206
|
+
.alias("mean_annotation_scores")
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _add_joins_and_selects_to_query(
|
|
211
|
+
query: Select[tuple[models.ExperimentRun]],
|
|
212
|
+
sort: Optional[ExperimentRunSort],
|
|
213
|
+
mean_annotation_scores: Optional[NamedFromClause],
|
|
214
|
+
) -> Select[tuple[models.ExperimentRun]]:
|
|
215
|
+
if not sort:
|
|
216
|
+
return query
|
|
217
|
+
if sort.col.metric:
|
|
218
|
+
metric = sort.col.metric.value
|
|
219
|
+
assert metric is not None
|
|
220
|
+
if metric is ExperimentRunMetric.latencyMs:
|
|
221
|
+
return query
|
|
222
|
+
else:
|
|
223
|
+
assert_never(metric)
|
|
224
|
+
elif sort.col.annotation_name:
|
|
225
|
+
annotation_name = sort.col.annotation_name.value
|
|
226
|
+
assert annotation_name is not None
|
|
227
|
+
assert mean_annotation_scores is not None
|
|
228
|
+
query = query.join(
|
|
229
|
+
mean_annotation_scores,
|
|
230
|
+
mean_annotation_scores.c.experiment_run_id == models.ExperimentRun.id,
|
|
231
|
+
isouter=True,
|
|
232
|
+
)
|
|
233
|
+
query = query.add_columns(
|
|
234
|
+
mean_annotation_scores.c.score.label("score")
|
|
235
|
+
) # the score must be in the select so that the value can be included in the cursor
|
|
236
|
+
return query
|
|
237
|
+
raise NotImplementedError
|
|
@@ -2,6 +2,7 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import strawberry
|
|
4
4
|
from strawberry import UNSET
|
|
5
|
+
from strawberry.scalars import JSON
|
|
5
6
|
|
|
6
7
|
from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
7
8
|
|
|
@@ -17,3 +18,7 @@ class GenerativeModelInput:
|
|
|
17
18
|
""" The endpoint to use for the model. Only required for Azure OpenAI models. """
|
|
18
19
|
api_version: Optional[str] = UNSET
|
|
19
20
|
""" The API version to use for the model. """
|
|
21
|
+
region: Optional[str] = UNSET
|
|
22
|
+
""" The region to use for the model. """
|
|
23
|
+
custom_headers: Optional[JSON] = UNSET
|
|
24
|
+
""" Custom headers to use for the model. """
|
|
@@ -1,8 +1,16 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from enum import Enum, auto
|
|
3
|
+
from typing import Any, Optional
|
|
2
4
|
|
|
3
5
|
import strawberry
|
|
6
|
+
from sqlalchemy import and_, desc, func, nulls_last, select
|
|
7
|
+
from sqlalchemy.orm import InstrumentedAttribute
|
|
8
|
+
from sqlalchemy.sql.expression import Select
|
|
9
|
+
from strawberry import UNSET
|
|
4
10
|
from typing_extensions import assert_never
|
|
5
11
|
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.db.helpers import truncate_name
|
|
6
14
|
from phoenix.server.api.types.pagination import CursorSortColumnDataType
|
|
7
15
|
from phoenix.server.api.types.SortDir import SortDir
|
|
8
16
|
|
|
@@ -13,6 +21,30 @@ class ProjectSessionColumn(Enum):
|
|
|
13
21
|
endTime = auto()
|
|
14
22
|
tokenCountTotal = auto()
|
|
15
23
|
numTraces = auto()
|
|
24
|
+
costTotal = auto()
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def column_name(self) -> str:
|
|
28
|
+
return truncate_name(f"{self.name}_project_session_sort_column")
|
|
29
|
+
|
|
30
|
+
def as_orm_expression(self, joined_table: Optional[Any] = None) -> Any:
|
|
31
|
+
expr: Any
|
|
32
|
+
if self is ProjectSessionColumn.startTime:
|
|
33
|
+
expr = models.ProjectSession.start_time
|
|
34
|
+
elif self is ProjectSessionColumn.endTime:
|
|
35
|
+
expr = models.ProjectSession.end_time
|
|
36
|
+
elif self is ProjectSessionColumn.tokenCountTotal:
|
|
37
|
+
assert joined_table is not None
|
|
38
|
+
expr = joined_table.c.key
|
|
39
|
+
elif self is ProjectSessionColumn.numTraces:
|
|
40
|
+
assert joined_table is not None
|
|
41
|
+
expr = joined_table.c.key
|
|
42
|
+
elif self is ProjectSessionColumn.costTotal:
|
|
43
|
+
assert joined_table is not None
|
|
44
|
+
expr = joined_table.c.key
|
|
45
|
+
else:
|
|
46
|
+
assert_never(self)
|
|
47
|
+
return expr.label(self.column_name)
|
|
16
48
|
|
|
17
49
|
@property
|
|
18
50
|
def data_type(self) -> CursorSortColumnDataType:
|
|
@@ -20,10 +52,138 @@ class ProjectSessionColumn(Enum):
|
|
|
20
52
|
return CursorSortColumnDataType.INT
|
|
21
53
|
if self is ProjectSessionColumn.startTime or self is ProjectSessionColumn.endTime:
|
|
22
54
|
return CursorSortColumnDataType.DATETIME
|
|
55
|
+
if self is ProjectSessionColumn.costTotal:
|
|
56
|
+
return CursorSortColumnDataType.FLOAT
|
|
23
57
|
assert_never(self)
|
|
24
58
|
|
|
59
|
+
def join_tables(self, stmt: Select[Any]) -> tuple[Select[Any], Any]:
|
|
60
|
+
"""
|
|
61
|
+
If needed, joins tables required for the sort column.
|
|
62
|
+
"""
|
|
63
|
+
if self is ProjectSessionColumn.tokenCountTotal:
|
|
64
|
+
sort_subq = (
|
|
65
|
+
select(
|
|
66
|
+
models.Trace.project_session_rowid.label("id"),
|
|
67
|
+
func.sum(models.Span.cumulative_llm_token_count_total).label("key"),
|
|
68
|
+
)
|
|
69
|
+
.join_from(models.Trace, models.Span)
|
|
70
|
+
.where(models.Span.parent_id.is_(None))
|
|
71
|
+
.group_by(models.Trace.project_session_rowid)
|
|
72
|
+
).subquery()
|
|
73
|
+
stmt = stmt.join(sort_subq, models.ProjectSession.id == sort_subq.c.id)
|
|
74
|
+
return stmt, sort_subq
|
|
75
|
+
if self is ProjectSessionColumn.numTraces:
|
|
76
|
+
sort_subq = (
|
|
77
|
+
select(
|
|
78
|
+
models.Trace.project_session_rowid.label("id"),
|
|
79
|
+
func.count(models.Trace.id).label("key"),
|
|
80
|
+
).group_by(models.Trace.project_session_rowid)
|
|
81
|
+
).subquery()
|
|
82
|
+
stmt = stmt.join(sort_subq, models.ProjectSession.id == sort_subq.c.id)
|
|
83
|
+
return stmt, sort_subq
|
|
84
|
+
if self is ProjectSessionColumn.costTotal:
|
|
85
|
+
sort_subq = (
|
|
86
|
+
select(
|
|
87
|
+
models.Trace.project_session_rowid.label("id"),
|
|
88
|
+
func.sum(models.SpanCost.total_cost).label("key"),
|
|
89
|
+
)
|
|
90
|
+
.join_from(
|
|
91
|
+
models.Trace,
|
|
92
|
+
models.SpanCost,
|
|
93
|
+
models.Trace.id == models.SpanCost.trace_rowid,
|
|
94
|
+
)
|
|
95
|
+
.group_by(models.Trace.project_session_rowid)
|
|
96
|
+
).subquery()
|
|
97
|
+
stmt = stmt.join(sort_subq, models.ProjectSession.id == sort_subq.c.id)
|
|
98
|
+
return stmt, sort_subq
|
|
99
|
+
return stmt, None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@strawberry.enum
|
|
103
|
+
class ProjectSessionAnnoAttr(Enum):
|
|
104
|
+
score = "score"
|
|
105
|
+
label = "label"
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def column_name(self) -> str:
|
|
109
|
+
return f"{self.value}_anno_sort_column"
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def orm_expression(self) -> Any:
|
|
113
|
+
expr: InstrumentedAttribute[Any]
|
|
114
|
+
if self is ProjectSessionAnnoAttr.score:
|
|
115
|
+
expr = models.ProjectSessionAnnotation.score
|
|
116
|
+
elif self is ProjectSessionAnnoAttr.label:
|
|
117
|
+
expr = models.ProjectSessionAnnotation.label
|
|
118
|
+
else:
|
|
119
|
+
assert_never(self)
|
|
120
|
+
return expr.label(self.column_name)
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def data_type(self) -> CursorSortColumnDataType:
|
|
124
|
+
if self is ProjectSessionAnnoAttr.label:
|
|
125
|
+
return CursorSortColumnDataType.STRING
|
|
126
|
+
if self is ProjectSessionAnnoAttr.score:
|
|
127
|
+
return CursorSortColumnDataType.FLOAT
|
|
128
|
+
assert_never(self)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@strawberry.input
|
|
132
|
+
class ProjectSessionAnnoResultKey:
|
|
133
|
+
name: str
|
|
134
|
+
attr: ProjectSessionAnnoAttr
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass(frozen=True)
|
|
138
|
+
class ProjectSessionSortConfig:
|
|
139
|
+
stmt: Select[Any]
|
|
140
|
+
orm_expression: Any
|
|
141
|
+
dir: SortDir
|
|
142
|
+
column_name: str
|
|
143
|
+
column_data_type: CursorSortColumnDataType
|
|
144
|
+
|
|
25
145
|
|
|
26
146
|
@strawberry.input(description="The sort key and direction for ProjectSession connections.")
|
|
27
147
|
class ProjectSessionSort:
|
|
28
|
-
col: ProjectSessionColumn
|
|
148
|
+
col: Optional[ProjectSessionColumn] = UNSET
|
|
149
|
+
anno_result_key: Optional[ProjectSessionAnnoResultKey] = UNSET
|
|
29
150
|
dir: SortDir
|
|
151
|
+
|
|
152
|
+
def update_orm_expr(self, stmt: Select[Any]) -> ProjectSessionSortConfig:
|
|
153
|
+
if (col := self.col) and not self.anno_result_key:
|
|
154
|
+
stmt, joined_table = col.join_tables(stmt)
|
|
155
|
+
expr = col.as_orm_expression(joined_table)
|
|
156
|
+
stmt = stmt.add_columns(expr)
|
|
157
|
+
if self.dir == SortDir.desc:
|
|
158
|
+
expr = desc(expr)
|
|
159
|
+
return ProjectSessionSortConfig(
|
|
160
|
+
stmt=stmt.order_by(nulls_last(expr)),
|
|
161
|
+
orm_expression=col.as_orm_expression(joined_table),
|
|
162
|
+
dir=self.dir,
|
|
163
|
+
column_name=col.column_name,
|
|
164
|
+
column_data_type=col.data_type,
|
|
165
|
+
)
|
|
166
|
+
if (anno_result_key := self.anno_result_key) and not col:
|
|
167
|
+
anno_name = anno_result_key.name
|
|
168
|
+
anno_attr = anno_result_key.attr
|
|
169
|
+
expr = anno_result_key.attr.orm_expression
|
|
170
|
+
stmt = stmt.add_columns(expr)
|
|
171
|
+
if self.dir == SortDir.desc:
|
|
172
|
+
expr = desc(expr)
|
|
173
|
+
stmt = stmt.join(
|
|
174
|
+
models.ProjectSessionAnnotation,
|
|
175
|
+
onclause=and_(
|
|
176
|
+
models.ProjectSessionAnnotation.project_session_id == models.ProjectSession.id,
|
|
177
|
+
models.ProjectSessionAnnotation.name == anno_name,
|
|
178
|
+
),
|
|
179
|
+
).order_by(nulls_last(expr))
|
|
180
|
+
return ProjectSessionSortConfig(
|
|
181
|
+
stmt=stmt,
|
|
182
|
+
orm_expression=anno_result_key.attr.orm_expression,
|
|
183
|
+
dir=self.dir,
|
|
184
|
+
column_name=anno_attr.column_name,
|
|
185
|
+
column_data_type=anno_attr.data_type,
|
|
186
|
+
)
|
|
187
|
+
raise ValueError(
|
|
188
|
+
"Exactly one of `col` or `annoResultKey` must be specified on `ProjectSessionSort`."
|
|
189
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@strawberry.enum
|
|
7
|
+
class PromptFilterColumn(Enum):
|
|
8
|
+
name = "name"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@strawberry.input(description="The filter key and value for prompt connections")
|
|
12
|
+
class PromptFilter:
|
|
13
|
+
col: PromptFilterColumn
|
|
14
|
+
value: str
|