arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
|
+
from sqlalchemy.sql.functions import coalesce
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.dataloaders.types import (
|
|
10
|
+
CostBreakdown,
|
|
11
|
+
SpanCostDetailSummaryEntry,
|
|
12
|
+
)
|
|
13
|
+
from phoenix.server.types import DbSessionFactory
|
|
14
|
+
|
|
15
|
+
TraceRowId: TypeAlias = int
|
|
16
|
+
Key: TypeAlias = TraceRowId
|
|
17
|
+
Result: TypeAlias = list[SpanCostDetailSummaryEntry]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SpanCostDetailSummaryEntriesByTraceDataLoader(DataLoader[Key, Result]):
|
|
21
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
22
|
+
super().__init__(load_fn=self._load_fn)
|
|
23
|
+
self._db = db
|
|
24
|
+
|
|
25
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
26
|
+
pk = models.SpanCost.trace_rowid
|
|
27
|
+
stmt = (
|
|
28
|
+
select(
|
|
29
|
+
pk,
|
|
30
|
+
models.SpanCostDetail.token_type,
|
|
31
|
+
models.SpanCostDetail.is_prompt,
|
|
32
|
+
coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
|
|
33
|
+
coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
|
|
34
|
+
)
|
|
35
|
+
.select_from(models.SpanCostDetail)
|
|
36
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
37
|
+
.where(pk.in_(keys))
|
|
38
|
+
.group_by(pk, models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
39
|
+
)
|
|
40
|
+
results: defaultdict[Key, Result] = defaultdict(list)
|
|
41
|
+
async with self._db() as session:
|
|
42
|
+
data = await session.stream(stmt)
|
|
43
|
+
async for (
|
|
44
|
+
id_,
|
|
45
|
+
token_type,
|
|
46
|
+
is_prompt,
|
|
47
|
+
cost,
|
|
48
|
+
tokens,
|
|
49
|
+
) in data:
|
|
50
|
+
entry = SpanCostDetailSummaryEntry(
|
|
51
|
+
token_type=token_type,
|
|
52
|
+
is_prompt=is_prompt,
|
|
53
|
+
value=CostBreakdown(tokens=tokens, cost=cost),
|
|
54
|
+
)
|
|
55
|
+
results[id_].append(entry)
|
|
56
|
+
return list(map(list, map(results.__getitem__, keys)))
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
SpanCostId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = SpanCostId
|
|
12
|
+
Result: TypeAlias = list[models.SpanCostDetail]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SpanCostDetailsBySpanCostDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
entity = models.SpanCostDetail
|
|
22
|
+
stmt = select(entity).where(entity.span_cost_id.in_(keys))
|
|
23
|
+
result: defaultdict[Key, Result] = defaultdict(list)
|
|
24
|
+
async with self._db() as session:
|
|
25
|
+
async for obj in await session.stream_scalars(stmt):
|
|
26
|
+
result[obj.span_cost_id].append(obj)
|
|
27
|
+
return list(map(result.__getitem__, keys))
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
ExperimentId: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = ExperimentId
|
|
13
|
+
Result: TypeAlias = SpanCostSummary
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SpanCostSummaryByExperimentDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
18
|
+
super().__init__(load_fn=self._load_fn)
|
|
19
|
+
self._db = db
|
|
20
|
+
|
|
21
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
22
|
+
stmt = (
|
|
23
|
+
select(
|
|
24
|
+
models.ExperimentRun.experiment_id,
|
|
25
|
+
func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
|
|
26
|
+
func.sum(models.SpanCost.completion_cost).label("completion_cost"),
|
|
27
|
+
func.sum(models.SpanCost.total_cost).label("total_cost"),
|
|
28
|
+
func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
|
|
29
|
+
func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
|
|
30
|
+
func.sum(models.SpanCost.total_tokens).label("total_tokens"),
|
|
31
|
+
)
|
|
32
|
+
.select_from(models.ExperimentRun)
|
|
33
|
+
.join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
34
|
+
.join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
|
|
35
|
+
.where(models.ExperimentRun.experiment_id.in_(keys))
|
|
36
|
+
.group_by(models.ExperimentRun.experiment_id)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
|
|
40
|
+
async with self._db() as session:
|
|
41
|
+
data = await session.stream(stmt)
|
|
42
|
+
async for (
|
|
43
|
+
experiment_id,
|
|
44
|
+
prompt_cost,
|
|
45
|
+
completion_cost,
|
|
46
|
+
total_cost,
|
|
47
|
+
prompt_tokens,
|
|
48
|
+
completion_tokens,
|
|
49
|
+
total_tokens,
|
|
50
|
+
) in data:
|
|
51
|
+
summary = SpanCostSummary(
|
|
52
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
53
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
54
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
55
|
+
)
|
|
56
|
+
results[experiment_id] = summary
|
|
57
|
+
return list(map(results.__getitem__, keys))
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
ExperimentId: TypeAlias = int
|
|
12
|
+
DatasetExampleId: TypeAlias = int
|
|
13
|
+
Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
|
|
14
|
+
Result: TypeAlias = SpanCostSummary
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpanCostSummaryByExperimentRepeatedRunGroupDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
stmt = (
|
|
24
|
+
select(
|
|
25
|
+
models.ExperimentRun.experiment_id,
|
|
26
|
+
models.ExperimentRun.dataset_example_id,
|
|
27
|
+
func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
|
|
28
|
+
func.sum(models.SpanCost.completion_cost).label("completion_cost"),
|
|
29
|
+
func.sum(models.SpanCost.total_cost).label("total_cost"),
|
|
30
|
+
func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
|
|
31
|
+
func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
|
|
32
|
+
func.sum(models.SpanCost.total_tokens).label("total_tokens"),
|
|
33
|
+
)
|
|
34
|
+
.select_from(models.ExperimentRun)
|
|
35
|
+
.join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
36
|
+
.join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
|
|
37
|
+
.where(
|
|
38
|
+
tuple_(
|
|
39
|
+
models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
|
|
40
|
+
).in_(set(keys))
|
|
41
|
+
)
|
|
42
|
+
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
|
|
46
|
+
async with self._db() as session:
|
|
47
|
+
data = await session.stream(stmt)
|
|
48
|
+
async for (
|
|
49
|
+
experiment_id,
|
|
50
|
+
dataset_example_id,
|
|
51
|
+
prompt_cost,
|
|
52
|
+
completion_cost,
|
|
53
|
+
total_cost,
|
|
54
|
+
prompt_tokens,
|
|
55
|
+
completion_tokens,
|
|
56
|
+
total_tokens,
|
|
57
|
+
) in data:
|
|
58
|
+
summary = SpanCostSummary(
|
|
59
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
60
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
61
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
62
|
+
)
|
|
63
|
+
results[(experiment_id, dataset_example_id)] = summary
|
|
64
|
+
return [results.get(key, SpanCostSummary()) for key in keys]
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
|
+
from sqlalchemy.sql.functions import coalesce
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
10
|
+
from phoenix.server.types import DbSessionFactory
|
|
11
|
+
|
|
12
|
+
ExperimentRunId: TypeAlias = int
|
|
13
|
+
Key: TypeAlias = ExperimentRunId
|
|
14
|
+
Result: TypeAlias = SpanCostSummary
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpanCostSummaryByExperimentRunDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
stmt = (
|
|
24
|
+
select(
|
|
25
|
+
models.ExperimentRun.id,
|
|
26
|
+
coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
|
|
27
|
+
coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
|
|
28
|
+
coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
|
|
29
|
+
coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
|
|
30
|
+
coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
|
|
31
|
+
coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
|
|
32
|
+
)
|
|
33
|
+
.select_from(models.ExperimentRun)
|
|
34
|
+
.join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
35
|
+
.join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
|
|
36
|
+
.where(models.ExperimentRun.id.in_(keys))
|
|
37
|
+
.group_by(models.ExperimentRun.id)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
|
|
41
|
+
async with self._db() as session:
|
|
42
|
+
data = await session.stream(stmt)
|
|
43
|
+
async for (
|
|
44
|
+
run_id,
|
|
45
|
+
prompt_cost,
|
|
46
|
+
completion_cost,
|
|
47
|
+
total_cost,
|
|
48
|
+
prompt_tokens,
|
|
49
|
+
completion_tokens,
|
|
50
|
+
total_tokens,
|
|
51
|
+
) in data:
|
|
52
|
+
summary = SpanCostSummary(
|
|
53
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
54
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
55
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
56
|
+
)
|
|
57
|
+
results[run_id] = summary
|
|
58
|
+
return list(map(results.__getitem__, keys))
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
|
+
from sqlalchemy.sql.functions import coalesce
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
10
|
+
from phoenix.server.types import DbSessionFactory
|
|
11
|
+
|
|
12
|
+
GenerativeModelId: TypeAlias = int
|
|
13
|
+
Key: TypeAlias = GenerativeModelId
|
|
14
|
+
Result: TypeAlias = SpanCostSummary
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpanCostSummaryByGenerativeModelDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
pk = models.SpanCost.model_id
|
|
24
|
+
stmt = (
|
|
25
|
+
select(
|
|
26
|
+
pk,
|
|
27
|
+
coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
|
|
28
|
+
coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
|
|
29
|
+
coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
|
|
30
|
+
coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
|
|
31
|
+
coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
|
|
32
|
+
coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
|
|
33
|
+
)
|
|
34
|
+
.where(pk.in_(keys))
|
|
35
|
+
.group_by(pk)
|
|
36
|
+
)
|
|
37
|
+
results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
|
|
38
|
+
async with self._db() as session:
|
|
39
|
+
data = await session.stream(stmt)
|
|
40
|
+
async for (
|
|
41
|
+
id_,
|
|
42
|
+
prompt_cost,
|
|
43
|
+
completion_cost,
|
|
44
|
+
total_cost,
|
|
45
|
+
prompt_tokens,
|
|
46
|
+
completion_tokens,
|
|
47
|
+
total_tokens,
|
|
48
|
+
) in data:
|
|
49
|
+
summary = SpanCostSummary(
|
|
50
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
51
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
52
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
53
|
+
)
|
|
54
|
+
results[id_] = summary
|
|
55
|
+
return list(map(results.__getitem__, keys))
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from cachetools import LFUCache, TTLCache
|
|
6
|
+
from sqlalchemy import Select, func, select
|
|
7
|
+
from sqlalchemy.sql.functions import coalesce
|
|
8
|
+
from strawberry.dataloader import AbstractCache, DataLoader
|
|
9
|
+
from typing_extensions import TypeAlias
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
13
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
14
|
+
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
15
|
+
from phoenix.server.session_filters import get_filtered_session_rowids_subquery
|
|
16
|
+
from phoenix.server.types import DbSessionFactory
|
|
17
|
+
from phoenix.trace.dsl import SpanFilter
|
|
18
|
+
|
|
19
|
+
ProjectRowId: TypeAlias = int
|
|
20
|
+
TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
|
|
21
|
+
FilterCondition: TypeAlias = Optional[str]
|
|
22
|
+
SessionFilterCondition: TypeAlias = Optional[str]
|
|
23
|
+
|
|
24
|
+
Segment: TypeAlias = tuple[
|
|
25
|
+
TimeInterval,
|
|
26
|
+
FilterCondition,
|
|
27
|
+
SessionFilterCondition,
|
|
28
|
+
]
|
|
29
|
+
Param: TypeAlias = ProjectRowId
|
|
30
|
+
|
|
31
|
+
Key: TypeAlias = tuple[ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition]
|
|
32
|
+
Result: TypeAlias = SpanCostSummary
|
|
33
|
+
ResultPosition: TypeAlias = int
|
|
34
|
+
DEFAULT_VALUE: Result = SpanCostSummary()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
|
|
38
|
+
project_rowid, time_range, filter_condition, session_filter_condition = key
|
|
39
|
+
interval = (
|
|
40
|
+
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
41
|
+
)
|
|
42
|
+
return (interval, filter_condition, session_filter_condition), project_rowid
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
_Section: TypeAlias = ProjectRowId
|
|
46
|
+
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SpanCostSummaryCache(
|
|
50
|
+
TwoTierCache[Key, Result, _Section, _SubKey],
|
|
51
|
+
):
|
|
52
|
+
def __init__(self) -> None:
|
|
53
|
+
super().__init__(
|
|
54
|
+
# TTL=3600 (1-hour) because time intervals are always moving forward, but
|
|
55
|
+
# interval endpoints are rounded down to the hour by the UI, so anything
|
|
56
|
+
# older than an hour most likely won't be a cache-hit anyway.
|
|
57
|
+
main_cache=TTLCache(maxsize=64, ttl=3600),
|
|
58
|
+
sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 3),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
|
|
62
|
+
(interval, filter_condition, session_filter_condition), project_rowid = _cache_key_fn(key)
|
|
63
|
+
return project_rowid, (interval, filter_condition, session_filter_condition)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SpanCostSummaryByProjectDataLoader(DataLoader[Key, Result]):
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
db: DbSessionFactory,
|
|
70
|
+
cache_map: Optional[AbstractCache[Key, Result]] = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
super().__init__(
|
|
73
|
+
load_fn=self._load_fn,
|
|
74
|
+
cache_key_fn=_cache_key_fn,
|
|
75
|
+
cache_map=cache_map,
|
|
76
|
+
)
|
|
77
|
+
self._db = db
|
|
78
|
+
|
|
79
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
80
|
+
results: list[Result] = [DEFAULT_VALUE] * len(keys)
|
|
81
|
+
arguments: defaultdict[
|
|
82
|
+
Segment,
|
|
83
|
+
defaultdict[Param, list[ResultPosition]],
|
|
84
|
+
] = defaultdict(lambda: defaultdict(list))
|
|
85
|
+
for position, key in enumerate(keys):
|
|
86
|
+
segment, param = _cache_key_fn(key)
|
|
87
|
+
arguments[segment][param].append(position)
|
|
88
|
+
async with self._db() as session:
|
|
89
|
+
for segment, params in arguments.items():
|
|
90
|
+
stmt = _get_stmt(segment, *params.keys())
|
|
91
|
+
data = await session.stream(stmt)
|
|
92
|
+
async for (
|
|
93
|
+
id_,
|
|
94
|
+
prompt_cost,
|
|
95
|
+
completion_cost,
|
|
96
|
+
total_cost,
|
|
97
|
+
prompt_tokens,
|
|
98
|
+
completion_tokens,
|
|
99
|
+
total_tokens,
|
|
100
|
+
) in data:
|
|
101
|
+
summary = SpanCostSummary(
|
|
102
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
103
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
104
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
105
|
+
)
|
|
106
|
+
for position in params.get(id_, []):
|
|
107
|
+
results[position] = summary
|
|
108
|
+
return results
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _get_stmt(
|
|
112
|
+
segment: Segment,
|
|
113
|
+
*params: Param,
|
|
114
|
+
) -> Select[Any]:
|
|
115
|
+
project_rowids = params
|
|
116
|
+
(start_time, end_time), filter_condition, session_filter_condition = segment
|
|
117
|
+
|
|
118
|
+
stmt: Select[Any] = (
|
|
119
|
+
select(
|
|
120
|
+
models.Trace.project_rowid,
|
|
121
|
+
coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
|
|
122
|
+
coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
|
|
123
|
+
coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
|
|
124
|
+
coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
|
|
125
|
+
coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
|
|
126
|
+
coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
|
|
127
|
+
)
|
|
128
|
+
.select_from(models.Trace)
|
|
129
|
+
.join(models.SpanCost, models.Trace.id == models.SpanCost.trace_rowid)
|
|
130
|
+
.where(models.Trace.project_rowid.in_(project_rowids))
|
|
131
|
+
.group_by(models.Trace.project_rowid)
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if start_time:
|
|
135
|
+
stmt = stmt.where(start_time <= models.Trace.start_time)
|
|
136
|
+
if end_time:
|
|
137
|
+
stmt = stmt.where(models.Trace.start_time < end_time)
|
|
138
|
+
|
|
139
|
+
if filter_condition:
|
|
140
|
+
sf = SpanFilter(filter_condition)
|
|
141
|
+
stmt = sf(stmt.join_from(models.SpanCost, models.Span))
|
|
142
|
+
|
|
143
|
+
if session_filter_condition:
|
|
144
|
+
filtered_session_rowids = get_filtered_session_rowids_subquery(
|
|
145
|
+
session_filter_condition=session_filter_condition,
|
|
146
|
+
project_rowids=project_rowids,
|
|
147
|
+
start_time=start_time,
|
|
148
|
+
end_time=end_time,
|
|
149
|
+
)
|
|
150
|
+
stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
|
|
151
|
+
|
|
152
|
+
return stmt
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
|
+
from sqlalchemy.sql.functions import coalesce
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
10
|
+
from phoenix.server.types import DbSessionFactory
|
|
11
|
+
|
|
12
|
+
ProjectSessionRowId: TypeAlias = int
|
|
13
|
+
Key: TypeAlias = ProjectSessionRowId
|
|
14
|
+
Result: TypeAlias = SpanCostSummary
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpanCostSummaryByProjectSessionDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
pk = models.Trace.project_session_rowid
|
|
24
|
+
stmt = (
|
|
25
|
+
select(
|
|
26
|
+
pk,
|
|
27
|
+
coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
|
|
28
|
+
coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
|
|
29
|
+
coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
|
|
30
|
+
coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
|
|
31
|
+
coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
|
|
32
|
+
coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
|
|
33
|
+
)
|
|
34
|
+
.join_from(models.SpanCost, models.Trace)
|
|
35
|
+
.where(pk.in_(keys))
|
|
36
|
+
.group_by(pk)
|
|
37
|
+
)
|
|
38
|
+
results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
|
|
39
|
+
async with self._db() as session:
|
|
40
|
+
data = await session.stream(stmt)
|
|
41
|
+
async for (
|
|
42
|
+
id_,
|
|
43
|
+
prompt_cost,
|
|
44
|
+
completion_cost,
|
|
45
|
+
total_cost,
|
|
46
|
+
prompt_tokens,
|
|
47
|
+
completion_tokens,
|
|
48
|
+
total_tokens,
|
|
49
|
+
) in data:
|
|
50
|
+
summary = SpanCostSummary(
|
|
51
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
52
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
53
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
54
|
+
)
|
|
55
|
+
results[id_] = summary
|
|
56
|
+
return list(map(results.__getitem__, keys))
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
|
+
from sqlalchemy.sql.functions import coalesce
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
10
|
+
from phoenix.server.types import DbSessionFactory
|
|
11
|
+
|
|
12
|
+
TraceRowId: TypeAlias = int
|
|
13
|
+
Key: TypeAlias = TraceRowId
|
|
14
|
+
Result: TypeAlias = SpanCostSummary
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpanCostSummaryByTraceDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
pk = models.SpanCost.trace_rowid
|
|
24
|
+
stmt = (
|
|
25
|
+
select(
|
|
26
|
+
pk,
|
|
27
|
+
coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
|
|
28
|
+
coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
|
|
29
|
+
coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
|
|
30
|
+
coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
|
|
31
|
+
coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
|
|
32
|
+
coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
|
|
33
|
+
)
|
|
34
|
+
.where(pk.in_(keys))
|
|
35
|
+
.group_by(pk)
|
|
36
|
+
)
|
|
37
|
+
results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
|
|
38
|
+
async with self._db() as session:
|
|
39
|
+
data = await session.stream(stmt)
|
|
40
|
+
async for (
|
|
41
|
+
id_,
|
|
42
|
+
prompt_cost,
|
|
43
|
+
completion_cost,
|
|
44
|
+
total_cost,
|
|
45
|
+
prompt_tokens,
|
|
46
|
+
completion_tokens,
|
|
47
|
+
total_tokens,
|
|
48
|
+
) in data:
|
|
49
|
+
summary = SpanCostSummary(
|
|
50
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
51
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
52
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
53
|
+
)
|
|
54
|
+
results[id_] = summary
|
|
55
|
+
return list(map(results.__getitem__, keys))
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
SpanID: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = SpanID
|
|
12
|
+
Result: TypeAlias = Optional[models.SpanCost]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SpanCostsDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
span_ids = list(set(keys))
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
costs = {
|
|
24
|
+
span_cost.span_rowid: span_cost
|
|
25
|
+
async for span_cost in await session.stream_scalars(
|
|
26
|
+
select(models.SpanCost).where(models.SpanCost.span_rowid.in_(span_ids))
|
|
27
|
+
)
|
|
28
|
+
}
|
|
29
|
+
return [costs.get(span_id) for span_id in keys]
|
|
@@ -18,7 +18,7 @@ _AttrStrIdentifier: TypeAlias = str
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class TableFieldsDataLoader(DataLoader[Key, Result]):
|
|
21
|
-
def __init__(self, db: DbSessionFactory, table: type[models.
|
|
21
|
+
def __init__(self, db: DbSessionFactory, table: type[models.HasId]) -> None:
|
|
22
22
|
super().__init__(load_fn=self._load_fn)
|
|
23
23
|
self._db = db
|
|
24
24
|
self._table = table
|
|
@@ -37,7 +37,7 @@ class TableFieldsDataLoader(DataLoader[Key, Result]):
|
|
|
37
37
|
|
|
38
38
|
def _get_stmt(
|
|
39
39
|
keys: Iterable[tuple[RowId, QueryableAttribute[Any]]],
|
|
40
|
-
table: type[models.
|
|
40
|
+
table: type[models.HasId],
|
|
41
41
|
) -> tuple[
|
|
42
42
|
Select[Any],
|
|
43
43
|
dict[_ResultColumnPosition, _AttrStrIdentifier],
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ModelId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = ModelId
|
|
12
|
+
Result: TypeAlias = list[models.TokenPrice]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TokenPricesByModelDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
model_ids = keys
|
|
22
|
+
token_prices: defaultdict[Key, Result] = defaultdict(list)
|
|
23
|
+
|
|
24
|
+
async with self._db() as session:
|
|
25
|
+
async for token_price in await session.stream_scalars(
|
|
26
|
+
select(models.TokenPrice).where(models.TokenPrice.model_id.in_(model_ids))
|
|
27
|
+
):
|
|
28
|
+
token_prices[token_price.model_id].append(token_price)
|
|
29
|
+
|
|
30
|
+
return [token_prices[model_id] for model_id in keys]
|