arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import func, select, tuple_
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
ExperimentID: TypeAlias = int
|
|
12
|
+
DatasetExampleID: TypeAlias = int
|
|
13
|
+
AnnotationName: TypeAlias = str
|
|
14
|
+
MeanAnnotationScore: TypeAlias = float
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class AnnotationSummary:
|
|
19
|
+
annotation_name: AnnotationName
|
|
20
|
+
mean_score: Optional[MeanAnnotationScore]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
|
|
24
|
+
Result: TypeAlias = list[AnnotationSummary]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ExperimentRepeatedRunGroupAnnotationSummariesDataLoader(DataLoader[Key, Result]):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
db: DbSessionFactory,
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__(load_fn=self._load_fn)
|
|
33
|
+
self._db = db
|
|
34
|
+
|
|
35
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
36
|
+
annotation_summaries_query = (
|
|
37
|
+
select(
|
|
38
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
39
|
+
models.ExperimentRun.dataset_example_id.label("dataset_example_id"),
|
|
40
|
+
models.ExperimentRunAnnotation.name.label("annotation_name"),
|
|
41
|
+
func.avg(models.ExperimentRunAnnotation.score).label("mean_score"),
|
|
42
|
+
)
|
|
43
|
+
.select_from(models.ExperimentRunAnnotation)
|
|
44
|
+
.join(
|
|
45
|
+
models.ExperimentRun,
|
|
46
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
47
|
+
)
|
|
48
|
+
.where(
|
|
49
|
+
tuple_(
|
|
50
|
+
models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
|
|
51
|
+
).in_(set(keys))
|
|
52
|
+
)
|
|
53
|
+
.group_by(
|
|
54
|
+
models.ExperimentRun.experiment_id,
|
|
55
|
+
models.ExperimentRun.dataset_example_id,
|
|
56
|
+
models.ExperimentRunAnnotation.name,
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
async with self._db() as session:
|
|
60
|
+
annotation_summaries = (await session.execute(annotation_summaries_query)).all()
|
|
61
|
+
annotation_summaries_by_key: dict[Key, list[AnnotationSummary]] = {}
|
|
62
|
+
for summary in annotation_summaries:
|
|
63
|
+
key = (summary.experiment_id, summary.dataset_example_id)
|
|
64
|
+
gql_summary = AnnotationSummary(
|
|
65
|
+
annotation_name=summary.annotation_name,
|
|
66
|
+
mean_score=summary.mean_score,
|
|
67
|
+
)
|
|
68
|
+
if key not in annotation_summaries_by_key:
|
|
69
|
+
annotation_summaries_by_key[key] = []
|
|
70
|
+
annotation_summaries_by_key[key].append(gql_summary)
|
|
71
|
+
return [
|
|
72
|
+
sorted(
|
|
73
|
+
annotation_summaries_by_key.get(key, []),
|
|
74
|
+
key=lambda summary: summary.annotation_name,
|
|
75
|
+
)
|
|
76
|
+
for key in keys
|
|
77
|
+
]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ExperimentID: TypeAlias = int
|
|
11
|
+
DatasetExampleID: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ExperimentRepeatedRunGroup:
|
|
17
|
+
experiment_rowid: int
|
|
18
|
+
dataset_example_rowid: int
|
|
19
|
+
runs: list[models.ExperimentRun]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
Result: TypeAlias = ExperimentRepeatedRunGroup
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ExperimentRepeatedRunGroupsDataLoader(DataLoader[Key, Result]):
|
|
26
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
27
|
+
super().__init__(load_fn=self._load_fn)
|
|
28
|
+
self._db = db
|
|
29
|
+
|
|
30
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
31
|
+
repeated_run_groups_query = (
|
|
32
|
+
select(models.ExperimentRun)
|
|
33
|
+
.where(
|
|
34
|
+
tuple_(
|
|
35
|
+
models.ExperimentRun.experiment_id,
|
|
36
|
+
models.ExperimentRun.dataset_example_id,
|
|
37
|
+
).in_(set(keys))
|
|
38
|
+
)
|
|
39
|
+
.order_by(models.ExperimentRun.repetition_number)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
async with self._db() as session:
|
|
43
|
+
runs_by_key: dict[Key, list[models.ExperimentRun]] = {}
|
|
44
|
+
for run in (await session.scalars(repeated_run_groups_query)).all():
|
|
45
|
+
key = (run.experiment_id, run.dataset_example_id)
|
|
46
|
+
if key not in runs_by_key:
|
|
47
|
+
runs_by_key[key] = []
|
|
48
|
+
runs_by_key[key].append(run)
|
|
49
|
+
|
|
50
|
+
return [
|
|
51
|
+
ExperimentRepeatedRunGroup(
|
|
52
|
+
experiment_rowid=experiment_id,
|
|
53
|
+
dataset_example_rowid=dataset_example_id,
|
|
54
|
+
runs=runs_by_key.get((experiment_id, dataset_example_id), []),
|
|
55
|
+
)
|
|
56
|
+
for (experiment_id, dataset_example_id) in keys
|
|
57
|
+
]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ExperimentId: TypeAlias = int
|
|
11
|
+
DatasetExampleId: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
|
|
13
|
+
Result: TypeAlias = list[models.ExperimentRun]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ExperimentRunsByExperimentAndExampleDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
18
|
+
super().__init__(load_fn=self._load_fn)
|
|
19
|
+
self._db = db
|
|
20
|
+
|
|
21
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
22
|
+
runs_by_key: defaultdict[Key, Result] = defaultdict(list)
|
|
23
|
+
|
|
24
|
+
async with self._db() as session:
|
|
25
|
+
stmt = (
|
|
26
|
+
select(models.ExperimentRun)
|
|
27
|
+
.where(
|
|
28
|
+
tuple_(
|
|
29
|
+
models.ExperimentRun.experiment_id,
|
|
30
|
+
models.ExperimentRun.dataset_example_id,
|
|
31
|
+
).in_(keys)
|
|
32
|
+
)
|
|
33
|
+
.order_by(
|
|
34
|
+
models.ExperimentRun.experiment_id,
|
|
35
|
+
models.ExperimentRun.dataset_example_id,
|
|
36
|
+
models.ExperimentRun.repetition_number,
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
result = await session.stream_scalars(stmt)
|
|
40
|
+
async for run in result:
|
|
41
|
+
key = (run.experiment_id, run.dataset_example_id)
|
|
42
|
+
runs_by_key[key].append(run)
|
|
43
|
+
|
|
44
|
+
return [runs_by_key[key] for key in keys]
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import func, select
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
GenerativeModelID: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = GenerativeModelID
|
|
13
|
+
Result: TypeAlias = Optional[datetime]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LastUsedTimesByGenerativeModelIdDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
18
|
+
super().__init__(load_fn=self._load_fn)
|
|
19
|
+
self._db = db
|
|
20
|
+
|
|
21
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
last_used_times_by_model_id: dict[Key, Result] = {
|
|
24
|
+
model_id: last_used_time
|
|
25
|
+
async for model_id, last_used_time in await session.stream(
|
|
26
|
+
select(
|
|
27
|
+
models.SpanCost.model_id,
|
|
28
|
+
func.max(models.SpanCost.span_start_time).label("last_used_time"),
|
|
29
|
+
)
|
|
30
|
+
.select_from(models.SpanCost)
|
|
31
|
+
.where(models.SpanCost.model_id.in_(keys))
|
|
32
|
+
.group_by(models.SpanCost.model_id)
|
|
33
|
+
)
|
|
34
|
+
}
|
|
35
|
+
return [last_used_times_by_model_id.get(model_id) for model_id in keys]
|
|
@@ -25,6 +25,7 @@ from phoenix.db import models
|
|
|
25
25
|
from phoenix.db.helpers import SupportedSQLDialect
|
|
26
26
|
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
27
27
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
28
|
+
from phoenix.server.session_filters import get_filtered_session_rowids_subquery
|
|
28
29
|
from phoenix.server.types import DbSessionFactory
|
|
29
30
|
from phoenix.trace.dsl import SpanFilter
|
|
30
31
|
|
|
@@ -32,13 +33,16 @@ Kind: TypeAlias = Literal["span", "trace"]
|
|
|
32
33
|
ProjectRowId: TypeAlias = int
|
|
33
34
|
TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
|
|
34
35
|
FilterCondition: TypeAlias = Optional[str]
|
|
36
|
+
SessionFilterCondition: TypeAlias = Optional[str]
|
|
35
37
|
Probability: TypeAlias = float
|
|
36
38
|
QuantileValue: TypeAlias = float
|
|
37
39
|
|
|
38
|
-
Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition]
|
|
40
|
+
Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition, SessionFilterCondition]
|
|
39
41
|
Param: TypeAlias = tuple[ProjectRowId, Probability]
|
|
40
42
|
|
|
41
|
-
Key: TypeAlias = tuple[
|
|
43
|
+
Key: TypeAlias = tuple[
|
|
44
|
+
Kind, ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition, Probability
|
|
45
|
+
]
|
|
42
46
|
Result: TypeAlias = Optional[QuantileValue]
|
|
43
47
|
ResultPosition: TypeAlias = int
|
|
44
48
|
DEFAULT_VALUE: Result = None
|
|
@@ -47,15 +51,18 @@ FloatCol: TypeAlias = SQLColumnExpression[Float[float]]
|
|
|
47
51
|
|
|
48
52
|
|
|
49
53
|
def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
|
|
50
|
-
kind, project_rowid, time_range, filter_condition, probability = key
|
|
54
|
+
kind, project_rowid, time_range, filter_condition, session_filter_condition, probability = key
|
|
51
55
|
interval = (
|
|
52
56
|
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
53
57
|
)
|
|
54
|
-
return (kind, interval, filter_condition), (
|
|
58
|
+
return (kind, interval, filter_condition, session_filter_condition), (
|
|
59
|
+
project_rowid,
|
|
60
|
+
probability,
|
|
61
|
+
)
|
|
55
62
|
|
|
56
63
|
|
|
57
64
|
_Section: TypeAlias = ProjectRowId
|
|
58
|
-
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind, Probability]
|
|
65
|
+
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition, Kind, Probability]
|
|
59
66
|
|
|
60
67
|
|
|
61
68
|
class LatencyMsQuantileCache(
|
|
@@ -71,8 +78,17 @@ class LatencyMsQuantileCache(
|
|
|
71
78
|
)
|
|
72
79
|
|
|
73
80
|
def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
|
|
74
|
-
(
|
|
75
|
-
|
|
81
|
+
(
|
|
82
|
+
(kind, interval, filter_condition, session_filter_condition),
|
|
83
|
+
(project_rowid, probability),
|
|
84
|
+
) = _cache_key_fn(key)
|
|
85
|
+
return project_rowid, (
|
|
86
|
+
interval,
|
|
87
|
+
filter_condition,
|
|
88
|
+
session_filter_condition,
|
|
89
|
+
kind,
|
|
90
|
+
probability,
|
|
91
|
+
)
|
|
76
92
|
|
|
77
93
|
|
|
78
94
|
class LatencyMsQuantileDataLoader(DataLoader[Key, Result]):
|
|
@@ -113,11 +129,18 @@ async def _get_results(
|
|
|
113
129
|
segment: Segment,
|
|
114
130
|
params: Mapping[Param, list[ResultPosition]],
|
|
115
131
|
) -> AsyncIterator[tuple[ResultPosition, QuantileValue]]:
|
|
116
|
-
kind, (start_time, end_time), filter_condition = segment
|
|
132
|
+
kind, (start_time, end_time), filter_condition, session_filter_condition = segment
|
|
117
133
|
stmt = select(models.Trace.project_rowid)
|
|
118
134
|
if kind == "trace":
|
|
119
135
|
latency_column = cast(FloatCol, models.Trace.latency_ms)
|
|
120
136
|
time_column = models.Trace.start_time
|
|
137
|
+
if filter_condition:
|
|
138
|
+
sf = SpanFilter(filter_condition)
|
|
139
|
+
stmt = stmt.where(
|
|
140
|
+
models.Trace.id.in_(
|
|
141
|
+
sf(select(models.Span.trace_rowid).distinct()).scalar_subquery()
|
|
142
|
+
)
|
|
143
|
+
)
|
|
121
144
|
elif kind == "span":
|
|
122
145
|
latency_column = cast(FloatCol, models.Span.latency_ms)
|
|
123
146
|
time_column = models.Span.start_time
|
|
@@ -127,6 +150,15 @@ async def _get_results(
|
|
|
127
150
|
stmt = sf(stmt)
|
|
128
151
|
else:
|
|
129
152
|
assert_never(kind)
|
|
153
|
+
if session_filter_condition:
|
|
154
|
+
project_rowids = [project_rowid for project_rowid, _ in params]
|
|
155
|
+
filtered_session_rowids = get_filtered_session_rowids_subquery(
|
|
156
|
+
session_filter_condition=session_filter_condition,
|
|
157
|
+
project_rowids=project_rowids,
|
|
158
|
+
start_time=start_time,
|
|
159
|
+
end_time=end_time,
|
|
160
|
+
)
|
|
161
|
+
stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
|
|
130
162
|
if start_time:
|
|
131
163
|
stmt = stmt.where(start_time <= time_column)
|
|
132
164
|
if end_time:
|
|
@@ -3,13 +3,14 @@ from datetime import datetime
|
|
|
3
3
|
from typing import Any, Literal, Optional
|
|
4
4
|
|
|
5
5
|
from cachetools import LFUCache, TTLCache
|
|
6
|
-
from sqlalchemy import Select, func, select
|
|
6
|
+
from sqlalchemy import Select, distinct, func, select
|
|
7
7
|
from strawberry.dataloader import AbstractCache, DataLoader
|
|
8
8
|
from typing_extensions import TypeAlias, assert_never
|
|
9
9
|
|
|
10
10
|
from phoenix.db import models
|
|
11
11
|
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
12
12
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
13
|
+
from phoenix.server.session_filters import get_filtered_session_rowids_subquery
|
|
13
14
|
from phoenix.server.types import DbSessionFactory
|
|
14
15
|
from phoenix.trace.dsl import SpanFilter
|
|
15
16
|
|
|
@@ -17,27 +18,35 @@ Kind: TypeAlias = Literal["span", "trace"]
|
|
|
17
18
|
ProjectRowId: TypeAlias = int
|
|
18
19
|
TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
|
|
19
20
|
FilterCondition: TypeAlias = Optional[str]
|
|
21
|
+
SessionFilterCondition: TypeAlias = Optional[str]
|
|
20
22
|
SpanCount: TypeAlias = int
|
|
21
23
|
|
|
22
|
-
Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition]
|
|
24
|
+
Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition, SessionFilterCondition]
|
|
23
25
|
Param: TypeAlias = ProjectRowId
|
|
24
26
|
|
|
25
|
-
Key: TypeAlias = tuple[
|
|
27
|
+
Key: TypeAlias = tuple[
|
|
28
|
+
Kind, ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition
|
|
29
|
+
]
|
|
26
30
|
Result: TypeAlias = SpanCount
|
|
27
31
|
ResultPosition: TypeAlias = int
|
|
28
32
|
DEFAULT_VALUE: Result = 0
|
|
29
33
|
|
|
30
34
|
|
|
31
35
|
def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
|
|
32
|
-
kind, project_rowid, time_range, filter_condition = key
|
|
36
|
+
kind, project_rowid, time_range, filter_condition, session_filter_condition = key
|
|
33
37
|
interval = (
|
|
34
38
|
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
35
39
|
)
|
|
36
|
-
return (
|
|
40
|
+
return (
|
|
41
|
+
kind,
|
|
42
|
+
interval,
|
|
43
|
+
filter_condition,
|
|
44
|
+
session_filter_condition,
|
|
45
|
+
), project_rowid
|
|
37
46
|
|
|
38
47
|
|
|
39
48
|
_Section: TypeAlias = ProjectRowId
|
|
40
|
-
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind]
|
|
49
|
+
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition, Kind]
|
|
41
50
|
|
|
42
51
|
|
|
43
52
|
class RecordCountCache(
|
|
@@ -53,8 +62,10 @@ class RecordCountCache(
|
|
|
53
62
|
)
|
|
54
63
|
|
|
55
64
|
def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
|
|
56
|
-
(kind, interval, filter_condition), project_rowid = _cache_key_fn(
|
|
57
|
-
|
|
65
|
+
(kind, interval, filter_condition, session_filter_condition), project_rowid = _cache_key_fn(
|
|
66
|
+
key
|
|
67
|
+
)
|
|
68
|
+
return project_rowid, (interval, filter_condition, session_filter_condition, kind)
|
|
58
69
|
|
|
59
70
|
|
|
60
71
|
class RecordCountDataLoader(DataLoader[Key, Result]):
|
|
@@ -93,7 +104,7 @@ def _get_stmt(
|
|
|
93
104
|
segment: Segment,
|
|
94
105
|
*project_rowids: Param,
|
|
95
106
|
) -> Select[Any]:
|
|
96
|
-
kind, (start_time, end_time), filter_condition = segment
|
|
107
|
+
kind, (start_time, end_time), filter_condition, session_filter_condition = segment
|
|
97
108
|
pid = models.Trace.project_rowid
|
|
98
109
|
stmt = select(pid)
|
|
99
110
|
if kind == "span":
|
|
@@ -102,12 +113,28 @@ def _get_stmt(
|
|
|
102
113
|
if filter_condition:
|
|
103
114
|
sf = SpanFilter(filter_condition)
|
|
104
115
|
stmt = sf(stmt)
|
|
116
|
+
stmt = stmt.add_columns(func.count().label("count"))
|
|
105
117
|
elif kind == "trace":
|
|
106
118
|
time_column = models.Trace.start_time
|
|
119
|
+
if filter_condition:
|
|
120
|
+
stmt = stmt.join(models.Span, models.Trace.id == models.Span.trace_rowid)
|
|
121
|
+
stmt = stmt.add_columns(func.count(distinct(models.Trace.id)).label("count"))
|
|
122
|
+
sf = SpanFilter(filter_condition)
|
|
123
|
+
stmt = sf(stmt)
|
|
124
|
+
else:
|
|
125
|
+
stmt = stmt.add_columns(func.count().label("count"))
|
|
107
126
|
else:
|
|
108
127
|
assert_never(kind)
|
|
109
|
-
stmt = stmt.add_columns(func.count().label("count"))
|
|
110
128
|
stmt = stmt.where(pid.in_(project_rowids))
|
|
129
|
+
|
|
130
|
+
if session_filter_condition:
|
|
131
|
+
filtered_session_rowids = get_filtered_session_rowids_subquery(
|
|
132
|
+
session_filter_condition=session_filter_condition,
|
|
133
|
+
project_rowids=project_rowids,
|
|
134
|
+
start_time=start_time,
|
|
135
|
+
end_time=end_time,
|
|
136
|
+
)
|
|
137
|
+
stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
|
|
111
138
|
stmt = stmt.group_by(pid)
|
|
112
139
|
if start_time:
|
|
113
140
|
stmt = stmt.where(start_time <= time_column)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db.models import ProjectSessionAnnotation
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ProjectSessionId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = ProjectSessionId
|
|
12
|
+
Result: TypeAlias = list[ProjectSessionAnnotation]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SessionAnnotationsBySessionDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
async for annotation in await session.stream_scalars(
|
|
24
|
+
select(ProjectSessionAnnotation).where(
|
|
25
|
+
ProjectSessionAnnotation.project_session_id.in_(keys)
|
|
26
|
+
)
|
|
27
|
+
):
|
|
28
|
+
annotations_by_id[annotation.project_session_id].append(annotation)
|
|
29
|
+
return [annotations_by_id[key] for key in keys]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
SpanRowId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = SpanRowId
|
|
12
|
+
Result: TypeAlias = Optional[models.SpanCost]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SpanCostBySpanDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
stmt = select(models.SpanCost).where(models.SpanCost.span_rowid.in_(keys))
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
result = {sc.span_rowid: sc async for sc in await session.stream_scalars(stmt)}
|
|
24
|
+
return list(map(result.get, keys))
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
|
+
from sqlalchemy.sql.functions import coalesce
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.dataloaders.types import (
|
|
10
|
+
CostBreakdown,
|
|
11
|
+
SpanCostDetailSummaryEntry,
|
|
12
|
+
)
|
|
13
|
+
from phoenix.server.types import DbSessionFactory
|
|
14
|
+
|
|
15
|
+
GenerativeModelId: TypeAlias = int
|
|
16
|
+
Key: TypeAlias = GenerativeModelId
|
|
17
|
+
Result: TypeAlias = list[SpanCostDetailSummaryEntry]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SpanCostDetailSummaryEntriesByGenerativeModelDataLoader(DataLoader[Key, Result]):
|
|
21
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
22
|
+
super().__init__(load_fn=self._load_fn)
|
|
23
|
+
self._db = db
|
|
24
|
+
|
|
25
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
26
|
+
pk = models.SpanCost.model_id
|
|
27
|
+
stmt = (
|
|
28
|
+
select(
|
|
29
|
+
pk,
|
|
30
|
+
models.SpanCostDetail.token_type,
|
|
31
|
+
models.SpanCostDetail.is_prompt,
|
|
32
|
+
coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
|
|
33
|
+
coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
|
|
34
|
+
)
|
|
35
|
+
.select_from(models.SpanCostDetail)
|
|
36
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
37
|
+
.where(pk.in_(keys))
|
|
38
|
+
.group_by(pk, models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
39
|
+
)
|
|
40
|
+
results: defaultdict[Key, Result] = defaultdict(list)
|
|
41
|
+
async with self._db() as session:
|
|
42
|
+
data = await session.stream(stmt)
|
|
43
|
+
async for (
|
|
44
|
+
id_,
|
|
45
|
+
token_type,
|
|
46
|
+
is_prompt,
|
|
47
|
+
cost,
|
|
48
|
+
tokens,
|
|
49
|
+
) in data:
|
|
50
|
+
entry = SpanCostDetailSummaryEntry(
|
|
51
|
+
token_type=token_type,
|
|
52
|
+
is_prompt=is_prompt,
|
|
53
|
+
value=CostBreakdown(tokens=tokens, cost=cost),
|
|
54
|
+
)
|
|
55
|
+
results[id_].append(entry)
|
|
56
|
+
return list(map(list, map(results.__getitem__, keys)))
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
|
+
from sqlalchemy.sql.functions import coalesce
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.dataloaders.types import (
|
|
10
|
+
CostBreakdown,
|
|
11
|
+
SpanCostDetailSummaryEntry,
|
|
12
|
+
)
|
|
13
|
+
from phoenix.server.types import DbSessionFactory
|
|
14
|
+
|
|
15
|
+
ProjectSessionRowId: TypeAlias = int
|
|
16
|
+
Key: TypeAlias = ProjectSessionRowId
|
|
17
|
+
Result: TypeAlias = list[SpanCostDetailSummaryEntry]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SpanCostDetailSummaryEntriesByProjectSessionDataLoader(DataLoader[Key, Result]):
|
|
21
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
22
|
+
super().__init__(load_fn=self._load_fn)
|
|
23
|
+
self._db = db
|
|
24
|
+
|
|
25
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
26
|
+
pk = models.Trace.project_session_rowid
|
|
27
|
+
stmt = (
|
|
28
|
+
select(
|
|
29
|
+
pk,
|
|
30
|
+
models.SpanCostDetail.token_type,
|
|
31
|
+
models.SpanCostDetail.is_prompt,
|
|
32
|
+
coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
|
|
33
|
+
coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
|
|
34
|
+
)
|
|
35
|
+
.select_from(models.SpanCostDetail)
|
|
36
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
37
|
+
.join(models.Trace, models.SpanCost.trace_rowid == models.Trace.id)
|
|
38
|
+
.where(pk.in_(keys))
|
|
39
|
+
.group_by(pk, models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
40
|
+
)
|
|
41
|
+
results: defaultdict[Key, Result] = defaultdict(list)
|
|
42
|
+
async with self._db() as session:
|
|
43
|
+
data = await session.stream(stmt)
|
|
44
|
+
async for (
|
|
45
|
+
id_,
|
|
46
|
+
token_type,
|
|
47
|
+
is_prompt,
|
|
48
|
+
cost,
|
|
49
|
+
tokens,
|
|
50
|
+
) in data:
|
|
51
|
+
entry = SpanCostDetailSummaryEntry(
|
|
52
|
+
token_type=token_type,
|
|
53
|
+
is_prompt=is_prompt,
|
|
54
|
+
value=CostBreakdown(tokens=tokens, cost=cost),
|
|
55
|
+
)
|
|
56
|
+
results[id_].append(entry)
|
|
57
|
+
return list(map(list, map(results.__getitem__, keys)))
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from sqlalchemy.orm import contains_eager
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.dataloaders.types import (
|
|
10
|
+
CostBreakdown,
|
|
11
|
+
SpanCostDetailSummaryEntry,
|
|
12
|
+
)
|
|
13
|
+
from phoenix.server.types import DbSessionFactory
|
|
14
|
+
|
|
15
|
+
SpanRowID: TypeAlias = int
|
|
16
|
+
Key: TypeAlias = SpanRowID
|
|
17
|
+
Result: TypeAlias = list[SpanCostDetailSummaryEntry]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SpanCostDetailSummaryEntriesBySpanDataLoader(DataLoader[Key, Result]):
|
|
21
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
22
|
+
super().__init__(load_fn=self._load_fn)
|
|
23
|
+
self._db = db
|
|
24
|
+
|
|
25
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
26
|
+
results: defaultdict[Key, Result] = defaultdict(list)
|
|
27
|
+
async with self._db() as session:
|
|
28
|
+
async for span_cost_detail in await session.stream_scalars(
|
|
29
|
+
select(models.SpanCostDetail)
|
|
30
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
31
|
+
.where(models.SpanCost.span_rowid.in_(keys))
|
|
32
|
+
.options(contains_eager(models.SpanCostDetail.span_cost))
|
|
33
|
+
):
|
|
34
|
+
entry = SpanCostDetailSummaryEntry(
|
|
35
|
+
token_type=span_cost_detail.token_type,
|
|
36
|
+
is_prompt=span_cost_detail.is_prompt,
|
|
37
|
+
value=CostBreakdown(
|
|
38
|
+
tokens=span_cost_detail.tokens,
|
|
39
|
+
cost=span_cost_detail.cost,
|
|
40
|
+
),
|
|
41
|
+
)
|
|
42
|
+
results[span_cost_detail.span_cost.span_rowid].append(entry)
|
|
43
|
+
return list(map(list, map(results.__getitem__, keys)))
|