arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
-
from sqlalchemy import
|
|
3
|
+
from sqlalchemy import func, select
|
|
4
4
|
from strawberry.dataloader import DataLoader
|
|
5
5
|
from typing_extensions import TypeAlias
|
|
6
6
|
|
|
@@ -23,36 +23,29 @@ class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
|
|
|
23
23
|
|
|
24
24
|
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
25
25
|
experiment_ids = keys
|
|
26
|
-
|
|
27
|
-
select(models.Experiment.id)
|
|
28
|
-
.where(models.Experiment.id.in_(set(experiment_ids)))
|
|
29
|
-
.subquery()
|
|
30
|
-
)
|
|
31
|
-
query = (
|
|
26
|
+
average_repetition_error_rates_subquery = (
|
|
32
27
|
select(
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
(
|
|
36
|
-
|
|
37
|
-
func.count(models.ExperimentRun.error)
|
|
38
|
-
/ func.count(models.ExperimentRun.id),
|
|
39
|
-
),
|
|
40
|
-
else_=None,
|
|
41
|
-
),
|
|
28
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
29
|
+
(
|
|
30
|
+
func.count(models.ExperimentRun.error) / func.count(models.ExperimentRun.id)
|
|
31
|
+
).label("average_repetition_error_rate"),
|
|
42
32
|
)
|
|
43
|
-
.
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
)
|
|
48
|
-
.group_by(resolved_experiment_ids.c.id)
|
|
33
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
34
|
+
.group_by(models.ExperimentRun.dataset_example_id, models.ExperimentRun.experiment_id)
|
|
35
|
+
.subquery()
|
|
36
|
+
.alias("average_repetition_error_rates")
|
|
49
37
|
)
|
|
38
|
+
average_run_error_rates_query = select(
|
|
39
|
+
average_repetition_error_rates_subquery.c.experiment_id,
|
|
40
|
+
func.avg(average_repetition_error_rates_subquery.c.average_repetition_error_rate).label(
|
|
41
|
+
"average_run_error_rates"
|
|
42
|
+
),
|
|
43
|
+
).group_by(average_repetition_error_rates_subquery.c.experiment_id)
|
|
50
44
|
async with self._db() as session:
|
|
51
|
-
|
|
45
|
+
average_run_error_rates = {
|
|
52
46
|
experiment_id: error_rate
|
|
53
|
-
async for experiment_id, error_rate in await session.stream(
|
|
47
|
+
async for experiment_id, error_rate in await session.stream(
|
|
48
|
+
average_run_error_rates_query
|
|
49
|
+
)
|
|
54
50
|
}
|
|
55
|
-
return [
|
|
56
|
-
error_rates.get(experiment_id, ValueError(f"Unknown experiment ID: {experiment_id}"))
|
|
57
|
-
for experiment_id in keys
|
|
58
|
-
]
|
|
51
|
+
return [average_run_error_rates.get(experiment_id) for experiment_id in experiment_ids]
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import func, select, tuple_
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
ExperimentID: TypeAlias = int
|
|
12
|
+
DatasetExampleID: TypeAlias = int
|
|
13
|
+
AnnotationName: TypeAlias = str
|
|
14
|
+
MeanAnnotationScore: TypeAlias = float
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class AnnotationSummary:
|
|
19
|
+
annotation_name: AnnotationName
|
|
20
|
+
mean_score: Optional[MeanAnnotationScore]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
|
|
24
|
+
Result: TypeAlias = list[AnnotationSummary]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ExperimentRepeatedRunGroupAnnotationSummariesDataLoader(DataLoader[Key, Result]):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
db: DbSessionFactory,
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__(load_fn=self._load_fn)
|
|
33
|
+
self._db = db
|
|
34
|
+
|
|
35
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
36
|
+
annotation_summaries_query = (
|
|
37
|
+
select(
|
|
38
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
39
|
+
models.ExperimentRun.dataset_example_id.label("dataset_example_id"),
|
|
40
|
+
models.ExperimentRunAnnotation.name.label("annotation_name"),
|
|
41
|
+
func.avg(models.ExperimentRunAnnotation.score).label("mean_score"),
|
|
42
|
+
)
|
|
43
|
+
.select_from(models.ExperimentRunAnnotation)
|
|
44
|
+
.join(
|
|
45
|
+
models.ExperimentRun,
|
|
46
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
47
|
+
)
|
|
48
|
+
.where(
|
|
49
|
+
tuple_(
|
|
50
|
+
models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
|
|
51
|
+
).in_(set(keys))
|
|
52
|
+
)
|
|
53
|
+
.group_by(
|
|
54
|
+
models.ExperimentRun.experiment_id,
|
|
55
|
+
models.ExperimentRun.dataset_example_id,
|
|
56
|
+
models.ExperimentRunAnnotation.name,
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
async with self._db() as session:
|
|
60
|
+
annotation_summaries = (await session.execute(annotation_summaries_query)).all()
|
|
61
|
+
annotation_summaries_by_key: dict[Key, list[AnnotationSummary]] = {}
|
|
62
|
+
for summary in annotation_summaries:
|
|
63
|
+
key = (summary.experiment_id, summary.dataset_example_id)
|
|
64
|
+
gql_summary = AnnotationSummary(
|
|
65
|
+
annotation_name=summary.annotation_name,
|
|
66
|
+
mean_score=summary.mean_score,
|
|
67
|
+
)
|
|
68
|
+
if key not in annotation_summaries_by_key:
|
|
69
|
+
annotation_summaries_by_key[key] = []
|
|
70
|
+
annotation_summaries_by_key[key].append(gql_summary)
|
|
71
|
+
return [
|
|
72
|
+
sorted(
|
|
73
|
+
annotation_summaries_by_key.get(key, []),
|
|
74
|
+
key=lambda summary: summary.annotation_name,
|
|
75
|
+
)
|
|
76
|
+
for key in keys
|
|
77
|
+
]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ExperimentID: TypeAlias = int
|
|
11
|
+
DatasetExampleID: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ExperimentRepeatedRunGroup:
|
|
17
|
+
experiment_rowid: int
|
|
18
|
+
dataset_example_rowid: int
|
|
19
|
+
runs: list[models.ExperimentRun]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
Result: TypeAlias = ExperimentRepeatedRunGroup
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ExperimentRepeatedRunGroupsDataLoader(DataLoader[Key, Result]):
|
|
26
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
27
|
+
super().__init__(load_fn=self._load_fn)
|
|
28
|
+
self._db = db
|
|
29
|
+
|
|
30
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
31
|
+
repeated_run_groups_query = (
|
|
32
|
+
select(models.ExperimentRun)
|
|
33
|
+
.where(
|
|
34
|
+
tuple_(
|
|
35
|
+
models.ExperimentRun.experiment_id,
|
|
36
|
+
models.ExperimentRun.dataset_example_id,
|
|
37
|
+
).in_(set(keys))
|
|
38
|
+
)
|
|
39
|
+
.order_by(models.ExperimentRun.repetition_number)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
async with self._db() as session:
|
|
43
|
+
runs_by_key: dict[Key, list[models.ExperimentRun]] = {}
|
|
44
|
+
for run in (await session.scalars(repeated_run_groups_query)).all():
|
|
45
|
+
key = (run.experiment_id, run.dataset_example_id)
|
|
46
|
+
if key not in runs_by_key:
|
|
47
|
+
runs_by_key[key] = []
|
|
48
|
+
runs_by_key[key].append(run)
|
|
49
|
+
|
|
50
|
+
return [
|
|
51
|
+
ExperimentRepeatedRunGroup(
|
|
52
|
+
experiment_rowid=experiment_id,
|
|
53
|
+
dataset_example_rowid=dataset_example_id,
|
|
54
|
+
runs=runs_by_key.get((experiment_id, dataset_example_id), []),
|
|
55
|
+
)
|
|
56
|
+
for (experiment_id, dataset_example_id) in keys
|
|
57
|
+
]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ExperimentId: TypeAlias = int
|
|
11
|
+
DatasetExampleId: TypeAlias = int
|
|
12
|
+
Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
|
|
13
|
+
Result: TypeAlias = list[models.ExperimentRun]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ExperimentRunsByExperimentAndExampleDataLoader(DataLoader[Key, Result]):
|
|
17
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
18
|
+
super().__init__(load_fn=self._load_fn)
|
|
19
|
+
self._db = db
|
|
20
|
+
|
|
21
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
22
|
+
runs_by_key: defaultdict[Key, Result] = defaultdict(list)
|
|
23
|
+
|
|
24
|
+
async with self._db() as session:
|
|
25
|
+
stmt = (
|
|
26
|
+
select(models.ExperimentRun)
|
|
27
|
+
.where(
|
|
28
|
+
tuple_(
|
|
29
|
+
models.ExperimentRun.experiment_id,
|
|
30
|
+
models.ExperimentRun.dataset_example_id,
|
|
31
|
+
).in_(keys)
|
|
32
|
+
)
|
|
33
|
+
.order_by(
|
|
34
|
+
models.ExperimentRun.experiment_id,
|
|
35
|
+
models.ExperimentRun.dataset_example_id,
|
|
36
|
+
models.ExperimentRun.repetition_number,
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
result = await session.stream_scalars(stmt)
|
|
40
|
+
async for run in result:
|
|
41
|
+
key = (run.experiment_id, run.dataset_example_id)
|
|
42
|
+
runs_by_key[key].append(run)
|
|
43
|
+
|
|
44
|
+
return [runs_by_key[key] for key in keys]
|
|
@@ -25,6 +25,7 @@ from phoenix.db import models
|
|
|
25
25
|
from phoenix.db.helpers import SupportedSQLDialect
|
|
26
26
|
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
27
27
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
28
|
+
from phoenix.server.session_filters import get_filtered_session_rowids_subquery
|
|
28
29
|
from phoenix.server.types import DbSessionFactory
|
|
29
30
|
from phoenix.trace.dsl import SpanFilter
|
|
30
31
|
|
|
@@ -32,13 +33,16 @@ Kind: TypeAlias = Literal["span", "trace"]
|
|
|
32
33
|
ProjectRowId: TypeAlias = int
|
|
33
34
|
TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
|
|
34
35
|
FilterCondition: TypeAlias = Optional[str]
|
|
36
|
+
SessionFilterCondition: TypeAlias = Optional[str]
|
|
35
37
|
Probability: TypeAlias = float
|
|
36
38
|
QuantileValue: TypeAlias = float
|
|
37
39
|
|
|
38
|
-
Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition]
|
|
40
|
+
Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition, SessionFilterCondition]
|
|
39
41
|
Param: TypeAlias = tuple[ProjectRowId, Probability]
|
|
40
42
|
|
|
41
|
-
Key: TypeAlias = tuple[
|
|
43
|
+
Key: TypeAlias = tuple[
|
|
44
|
+
Kind, ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition, Probability
|
|
45
|
+
]
|
|
42
46
|
Result: TypeAlias = Optional[QuantileValue]
|
|
43
47
|
ResultPosition: TypeAlias = int
|
|
44
48
|
DEFAULT_VALUE: Result = None
|
|
@@ -47,15 +51,18 @@ FloatCol: TypeAlias = SQLColumnExpression[Float[float]]
|
|
|
47
51
|
|
|
48
52
|
|
|
49
53
|
def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
|
|
50
|
-
kind, project_rowid, time_range, filter_condition, probability = key
|
|
54
|
+
kind, project_rowid, time_range, filter_condition, session_filter_condition, probability = key
|
|
51
55
|
interval = (
|
|
52
56
|
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
53
57
|
)
|
|
54
|
-
return (kind, interval, filter_condition), (
|
|
58
|
+
return (kind, interval, filter_condition, session_filter_condition), (
|
|
59
|
+
project_rowid,
|
|
60
|
+
probability,
|
|
61
|
+
)
|
|
55
62
|
|
|
56
63
|
|
|
57
64
|
_Section: TypeAlias = ProjectRowId
|
|
58
|
-
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind, Probability]
|
|
65
|
+
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition, Kind, Probability]
|
|
59
66
|
|
|
60
67
|
|
|
61
68
|
class LatencyMsQuantileCache(
|
|
@@ -71,8 +78,17 @@ class LatencyMsQuantileCache(
|
|
|
71
78
|
)
|
|
72
79
|
|
|
73
80
|
def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
|
|
74
|
-
(
|
|
75
|
-
|
|
81
|
+
(
|
|
82
|
+
(kind, interval, filter_condition, session_filter_condition),
|
|
83
|
+
(project_rowid, probability),
|
|
84
|
+
) = _cache_key_fn(key)
|
|
85
|
+
return project_rowid, (
|
|
86
|
+
interval,
|
|
87
|
+
filter_condition,
|
|
88
|
+
session_filter_condition,
|
|
89
|
+
kind,
|
|
90
|
+
probability,
|
|
91
|
+
)
|
|
76
92
|
|
|
77
93
|
|
|
78
94
|
class LatencyMsQuantileDataLoader(DataLoader[Key, Result]):
|
|
@@ -113,11 +129,18 @@ async def _get_results(
|
|
|
113
129
|
segment: Segment,
|
|
114
130
|
params: Mapping[Param, list[ResultPosition]],
|
|
115
131
|
) -> AsyncIterator[tuple[ResultPosition, QuantileValue]]:
|
|
116
|
-
kind, (start_time, end_time), filter_condition = segment
|
|
132
|
+
kind, (start_time, end_time), filter_condition, session_filter_condition = segment
|
|
117
133
|
stmt = select(models.Trace.project_rowid)
|
|
118
134
|
if kind == "trace":
|
|
119
135
|
latency_column = cast(FloatCol, models.Trace.latency_ms)
|
|
120
136
|
time_column = models.Trace.start_time
|
|
137
|
+
if filter_condition:
|
|
138
|
+
sf = SpanFilter(filter_condition)
|
|
139
|
+
stmt = stmt.where(
|
|
140
|
+
models.Trace.id.in_(
|
|
141
|
+
sf(select(models.Span.trace_rowid).distinct()).scalar_subquery()
|
|
142
|
+
)
|
|
143
|
+
)
|
|
121
144
|
elif kind == "span":
|
|
122
145
|
latency_column = cast(FloatCol, models.Span.latency_ms)
|
|
123
146
|
time_column = models.Span.start_time
|
|
@@ -127,6 +150,15 @@ async def _get_results(
|
|
|
127
150
|
stmt = sf(stmt)
|
|
128
151
|
else:
|
|
129
152
|
assert_never(kind)
|
|
153
|
+
if session_filter_condition:
|
|
154
|
+
project_rowids = [project_rowid for project_rowid, _ in params]
|
|
155
|
+
filtered_session_rowids = get_filtered_session_rowids_subquery(
|
|
156
|
+
session_filter_condition=session_filter_condition,
|
|
157
|
+
project_rowids=project_rowids,
|
|
158
|
+
start_time=start_time,
|
|
159
|
+
end_time=end_time,
|
|
160
|
+
)
|
|
161
|
+
stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
|
|
130
162
|
if start_time:
|
|
131
163
|
stmt = stmt.where(start_time <= time_column)
|
|
132
164
|
if end_time:
|
|
@@ -3,13 +3,14 @@ from datetime import datetime
|
|
|
3
3
|
from typing import Any, Literal, Optional
|
|
4
4
|
|
|
5
5
|
from cachetools import LFUCache, TTLCache
|
|
6
|
-
from sqlalchemy import Select, func, select
|
|
6
|
+
from sqlalchemy import Select, distinct, func, select
|
|
7
7
|
from strawberry.dataloader import AbstractCache, DataLoader
|
|
8
8
|
from typing_extensions import TypeAlias, assert_never
|
|
9
9
|
|
|
10
10
|
from phoenix.db import models
|
|
11
11
|
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
12
12
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
13
|
+
from phoenix.server.session_filters import get_filtered_session_rowids_subquery
|
|
13
14
|
from phoenix.server.types import DbSessionFactory
|
|
14
15
|
from phoenix.trace.dsl import SpanFilter
|
|
15
16
|
|
|
@@ -17,27 +18,35 @@ Kind: TypeAlias = Literal["span", "trace"]
|
|
|
17
18
|
ProjectRowId: TypeAlias = int
|
|
18
19
|
TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
|
|
19
20
|
FilterCondition: TypeAlias = Optional[str]
|
|
21
|
+
SessionFilterCondition: TypeAlias = Optional[str]
|
|
20
22
|
SpanCount: TypeAlias = int
|
|
21
23
|
|
|
22
|
-
Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition]
|
|
24
|
+
Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition, SessionFilterCondition]
|
|
23
25
|
Param: TypeAlias = ProjectRowId
|
|
24
26
|
|
|
25
|
-
Key: TypeAlias = tuple[
|
|
27
|
+
Key: TypeAlias = tuple[
|
|
28
|
+
Kind, ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition
|
|
29
|
+
]
|
|
26
30
|
Result: TypeAlias = SpanCount
|
|
27
31
|
ResultPosition: TypeAlias = int
|
|
28
32
|
DEFAULT_VALUE: Result = 0
|
|
29
33
|
|
|
30
34
|
|
|
31
35
|
def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
|
|
32
|
-
kind, project_rowid, time_range, filter_condition = key
|
|
36
|
+
kind, project_rowid, time_range, filter_condition, session_filter_condition = key
|
|
33
37
|
interval = (
|
|
34
38
|
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
35
39
|
)
|
|
36
|
-
return (
|
|
40
|
+
return (
|
|
41
|
+
kind,
|
|
42
|
+
interval,
|
|
43
|
+
filter_condition,
|
|
44
|
+
session_filter_condition,
|
|
45
|
+
), project_rowid
|
|
37
46
|
|
|
38
47
|
|
|
39
48
|
_Section: TypeAlias = ProjectRowId
|
|
40
|
-
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind]
|
|
49
|
+
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition, Kind]
|
|
41
50
|
|
|
42
51
|
|
|
43
52
|
class RecordCountCache(
|
|
@@ -53,8 +62,10 @@ class RecordCountCache(
|
|
|
53
62
|
)
|
|
54
63
|
|
|
55
64
|
def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
|
|
56
|
-
(kind, interval, filter_condition), project_rowid = _cache_key_fn(
|
|
57
|
-
|
|
65
|
+
(kind, interval, filter_condition, session_filter_condition), project_rowid = _cache_key_fn(
|
|
66
|
+
key
|
|
67
|
+
)
|
|
68
|
+
return project_rowid, (interval, filter_condition, session_filter_condition, kind)
|
|
58
69
|
|
|
59
70
|
|
|
60
71
|
class RecordCountDataLoader(DataLoader[Key, Result]):
|
|
@@ -93,7 +104,7 @@ def _get_stmt(
|
|
|
93
104
|
segment: Segment,
|
|
94
105
|
*project_rowids: Param,
|
|
95
106
|
) -> Select[Any]:
|
|
96
|
-
kind, (start_time, end_time), filter_condition = segment
|
|
107
|
+
kind, (start_time, end_time), filter_condition, session_filter_condition = segment
|
|
97
108
|
pid = models.Trace.project_rowid
|
|
98
109
|
stmt = select(pid)
|
|
99
110
|
if kind == "span":
|
|
@@ -102,12 +113,28 @@ def _get_stmt(
|
|
|
102
113
|
if filter_condition:
|
|
103
114
|
sf = SpanFilter(filter_condition)
|
|
104
115
|
stmt = sf(stmt)
|
|
116
|
+
stmt = stmt.add_columns(func.count().label("count"))
|
|
105
117
|
elif kind == "trace":
|
|
106
118
|
time_column = models.Trace.start_time
|
|
119
|
+
if filter_condition:
|
|
120
|
+
stmt = stmt.join(models.Span, models.Trace.id == models.Span.trace_rowid)
|
|
121
|
+
stmt = stmt.add_columns(func.count(distinct(models.Trace.id)).label("count"))
|
|
122
|
+
sf = SpanFilter(filter_condition)
|
|
123
|
+
stmt = sf(stmt)
|
|
124
|
+
else:
|
|
125
|
+
stmt = stmt.add_columns(func.count().label("count"))
|
|
107
126
|
else:
|
|
108
127
|
assert_never(kind)
|
|
109
|
-
stmt = stmt.add_columns(func.count().label("count"))
|
|
110
128
|
stmt = stmt.where(pid.in_(project_rowids))
|
|
129
|
+
|
|
130
|
+
if session_filter_condition:
|
|
131
|
+
filtered_session_rowids = get_filtered_session_rowids_subquery(
|
|
132
|
+
session_filter_condition=session_filter_condition,
|
|
133
|
+
project_rowids=project_rowids,
|
|
134
|
+
start_time=start_time,
|
|
135
|
+
end_time=end_time,
|
|
136
|
+
)
|
|
137
|
+
stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
|
|
111
138
|
stmt = stmt.group_by(pid)
|
|
112
139
|
if start_time:
|
|
113
140
|
stmt = stmt.where(start_time <= time_column)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db.models import ProjectSessionAnnotation
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ProjectSessionId: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = ProjectSessionId
|
|
12
|
+
Result: TypeAlias = list[ProjectSessionAnnotation]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SessionAnnotationsBySessionDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
async for annotation in await session.stream_scalars(
|
|
24
|
+
select(ProjectSessionAnnotation).where(
|
|
25
|
+
ProjectSessionAnnotation.project_session_id.in_(keys)
|
|
26
|
+
)
|
|
27
|
+
):
|
|
28
|
+
annotations_by_id[annotation.project_session_id].append(annotation)
|
|
29
|
+
return [annotations_by_id[key] for key in keys]
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
ExperimentId: TypeAlias = int
|
|
12
|
+
DatasetExampleId: TypeAlias = int
|
|
13
|
+
Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
|
|
14
|
+
Result: TypeAlias = SpanCostSummary
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpanCostSummaryByExperimentRepeatedRunGroupDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
stmt = (
|
|
24
|
+
select(
|
|
25
|
+
models.ExperimentRun.experiment_id,
|
|
26
|
+
models.ExperimentRun.dataset_example_id,
|
|
27
|
+
func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
|
|
28
|
+
func.sum(models.SpanCost.completion_cost).label("completion_cost"),
|
|
29
|
+
func.sum(models.SpanCost.total_cost).label("total_cost"),
|
|
30
|
+
func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
|
|
31
|
+
func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
|
|
32
|
+
func.sum(models.SpanCost.total_tokens).label("total_tokens"),
|
|
33
|
+
)
|
|
34
|
+
.select_from(models.ExperimentRun)
|
|
35
|
+
.join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
36
|
+
.join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
|
|
37
|
+
.where(
|
|
38
|
+
tuple_(
|
|
39
|
+
models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
|
|
40
|
+
).in_(set(keys))
|
|
41
|
+
)
|
|
42
|
+
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
|
|
46
|
+
async with self._db() as session:
|
|
47
|
+
data = await session.stream(stmt)
|
|
48
|
+
async for (
|
|
49
|
+
experiment_id,
|
|
50
|
+
dataset_example_id,
|
|
51
|
+
prompt_cost,
|
|
52
|
+
completion_cost,
|
|
53
|
+
total_cost,
|
|
54
|
+
prompt_tokens,
|
|
55
|
+
completion_tokens,
|
|
56
|
+
total_tokens,
|
|
57
|
+
) in data:
|
|
58
|
+
summary = SpanCostSummary(
|
|
59
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
60
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
61
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
62
|
+
)
|
|
63
|
+
results[(experiment_id, dataset_example_id)] = summary
|
|
64
|
+
return [results.get(key, SpanCostSummary()) for key in keys]
|
|
@@ -12,32 +12,38 @@ from phoenix.db import models
|
|
|
12
12
|
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
13
13
|
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
14
14
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
15
|
+
from phoenix.server.session_filters import get_filtered_session_rowids_subquery
|
|
15
16
|
from phoenix.server.types import DbSessionFactory
|
|
16
17
|
from phoenix.trace.dsl import SpanFilter
|
|
17
18
|
|
|
18
19
|
ProjectRowId: TypeAlias = int
|
|
19
20
|
TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
|
|
20
21
|
FilterCondition: TypeAlias = Optional[str]
|
|
22
|
+
SessionFilterCondition: TypeAlias = Optional[str]
|
|
21
23
|
|
|
22
|
-
Segment: TypeAlias = tuple[
|
|
24
|
+
Segment: TypeAlias = tuple[
|
|
25
|
+
TimeInterval,
|
|
26
|
+
FilterCondition,
|
|
27
|
+
SessionFilterCondition,
|
|
28
|
+
]
|
|
23
29
|
Param: TypeAlias = ProjectRowId
|
|
24
30
|
|
|
25
|
-
Key: TypeAlias = tuple[ProjectRowId, Optional[TimeRange], FilterCondition]
|
|
31
|
+
Key: TypeAlias = tuple[ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition]
|
|
26
32
|
Result: TypeAlias = SpanCostSummary
|
|
27
33
|
ResultPosition: TypeAlias = int
|
|
28
34
|
DEFAULT_VALUE: Result = SpanCostSummary()
|
|
29
35
|
|
|
30
36
|
|
|
31
37
|
def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
|
|
32
|
-
project_rowid, time_range, filter_condition = key
|
|
38
|
+
project_rowid, time_range, filter_condition, session_filter_condition = key
|
|
33
39
|
interval = (
|
|
34
40
|
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
35
41
|
)
|
|
36
|
-
return (interval, filter_condition), project_rowid
|
|
42
|
+
return (interval, filter_condition, session_filter_condition), project_rowid
|
|
37
43
|
|
|
38
44
|
|
|
39
45
|
_Section: TypeAlias = ProjectRowId
|
|
40
|
-
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition]
|
|
46
|
+
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition]
|
|
41
47
|
|
|
42
48
|
|
|
43
49
|
class SpanCostSummaryCache(
|
|
@@ -53,8 +59,8 @@ class SpanCostSummaryCache(
|
|
|
53
59
|
)
|
|
54
60
|
|
|
55
61
|
def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
|
|
56
|
-
(interval, filter_condition), project_rowid = _cache_key_fn(key)
|
|
57
|
-
return project_rowid, (interval, filter_condition)
|
|
62
|
+
(interval, filter_condition, session_filter_condition), project_rowid = _cache_key_fn(key)
|
|
63
|
+
return project_rowid, (interval, filter_condition, session_filter_condition)
|
|
58
64
|
|
|
59
65
|
|
|
60
66
|
class SpanCostSummaryByProjectDataLoader(DataLoader[Key, Result]):
|
|
@@ -106,12 +112,12 @@ def _get_stmt(
|
|
|
106
112
|
segment: Segment,
|
|
107
113
|
*params: Param,
|
|
108
114
|
) -> Select[Any]:
|
|
109
|
-
|
|
110
|
-
|
|
115
|
+
project_rowids = params
|
|
116
|
+
(start_time, end_time), filter_condition, session_filter_condition = segment
|
|
111
117
|
|
|
112
118
|
stmt: Select[Any] = (
|
|
113
119
|
select(
|
|
114
|
-
|
|
120
|
+
models.Trace.project_rowid,
|
|
115
121
|
coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
|
|
116
122
|
coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
|
|
117
123
|
coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
|
|
@@ -119,8 +125,10 @@ def _get_stmt(
|
|
|
119
125
|
coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
|
|
120
126
|
coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
|
|
121
127
|
)
|
|
122
|
-
.
|
|
123
|
-
.
|
|
128
|
+
.select_from(models.Trace)
|
|
129
|
+
.join(models.SpanCost, models.Trace.id == models.SpanCost.trace_rowid)
|
|
130
|
+
.where(models.Trace.project_rowid.in_(project_rowids))
|
|
131
|
+
.group_by(models.Trace.project_rowid)
|
|
124
132
|
)
|
|
125
133
|
|
|
126
134
|
if start_time:
|
|
@@ -132,7 +140,13 @@ def _get_stmt(
|
|
|
132
140
|
sf = SpanFilter(filter_condition)
|
|
133
141
|
stmt = sf(stmt.join_from(models.SpanCost, models.Span))
|
|
134
142
|
|
|
135
|
-
|
|
136
|
-
|
|
143
|
+
if session_filter_condition:
|
|
144
|
+
filtered_session_rowids = get_filtered_session_rowids_subquery(
|
|
145
|
+
session_filter_condition=session_filter_condition,
|
|
146
|
+
project_rowids=project_rowids,
|
|
147
|
+
start_time=start_time,
|
|
148
|
+
end_time=end_time,
|
|
149
|
+
)
|
|
150
|
+
stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
|
|
137
151
|
|
|
138
152
|
return stmt
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from sqlalchemy import select
|
|
4
|
-
from sqlalchemy.orm import joinedload, load_only
|
|
5
4
|
from strawberry.dataloader import DataLoader
|
|
6
5
|
from typing_extensions import TypeAlias
|
|
7
6
|
|
|
@@ -22,14 +21,9 @@ class SpanCostsDataLoader(DataLoader[Key, Result]):
|
|
|
22
21
|
span_ids = list(set(keys))
|
|
23
22
|
async with self._db() as session:
|
|
24
23
|
costs = {
|
|
25
|
-
|
|
26
|
-
async for
|
|
27
|
-
select(models.
|
|
28
|
-
.where(models.Span.id.in_(span_ids))
|
|
29
|
-
.options(
|
|
30
|
-
load_only(models.Span.id),
|
|
31
|
-
joinedload(models.Span.span_cost),
|
|
32
|
-
)
|
|
24
|
+
span_cost.span_rowid: span_cost
|
|
25
|
+
async for span_cost in await session.stream_scalars(
|
|
26
|
+
select(models.SpanCost).where(models.SpanCost.span_rowid.in_(span_ids))
|
|
33
27
|
)
|
|
34
28
|
}
|
|
35
29
|
return [costs.get(span_id) for span_id in keys]
|