arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,6 +13,7 @@ from phoenix.db import models
|
|
|
13
13
|
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
14
14
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
15
15
|
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
|
|
16
|
+
from phoenix.server.session_filters import get_filtered_session_rowids_subquery
|
|
16
17
|
from phoenix.server.types import DbSessionFactory
|
|
17
18
|
from phoenix.trace.dsl import SpanFilter
|
|
18
19
|
|
|
@@ -20,27 +21,41 @@ Kind: TypeAlias = Literal["span", "trace"]
|
|
|
20
21
|
ProjectRowId: TypeAlias = int
|
|
21
22
|
TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
|
|
22
23
|
FilterCondition: TypeAlias = Optional[str]
|
|
24
|
+
SessionFilterCondition: TypeAlias = Optional[str]
|
|
23
25
|
AnnotationName: TypeAlias = str
|
|
24
26
|
|
|
25
|
-
Segment: TypeAlias = tuple[
|
|
27
|
+
Segment: TypeAlias = tuple[
|
|
28
|
+
Kind,
|
|
29
|
+
ProjectRowId,
|
|
30
|
+
TimeInterval,
|
|
31
|
+
FilterCondition,
|
|
32
|
+
SessionFilterCondition,
|
|
33
|
+
]
|
|
26
34
|
Param: TypeAlias = AnnotationName
|
|
27
35
|
|
|
28
|
-
Key: TypeAlias = tuple[
|
|
36
|
+
Key: TypeAlias = tuple[
|
|
37
|
+
Kind,
|
|
38
|
+
ProjectRowId,
|
|
39
|
+
Optional[TimeRange],
|
|
40
|
+
FilterCondition,
|
|
41
|
+
SessionFilterCondition,
|
|
42
|
+
AnnotationName,
|
|
43
|
+
]
|
|
29
44
|
Result: TypeAlias = Optional[AnnotationSummary]
|
|
30
45
|
ResultPosition: TypeAlias = int
|
|
31
46
|
DEFAULT_VALUE: Result = None
|
|
32
47
|
|
|
33
48
|
|
|
34
49
|
def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
|
|
35
|
-
kind, project_rowid, time_range, filter_condition, eval_name = key
|
|
50
|
+
kind, project_rowid, time_range, filter_condition, session_filter_condition, eval_name = key
|
|
36
51
|
interval = (
|
|
37
52
|
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
|
|
38
53
|
)
|
|
39
|
-
return (kind, project_rowid, interval, filter_condition), eval_name
|
|
54
|
+
return (kind, project_rowid, interval, filter_condition, session_filter_condition), eval_name
|
|
40
55
|
|
|
41
56
|
|
|
42
57
|
_Section: TypeAlias = tuple[ProjectRowId, AnnotationName, Kind]
|
|
43
|
-
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition]
|
|
58
|
+
_SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition]
|
|
44
59
|
|
|
45
60
|
|
|
46
61
|
class AnnotationSummaryCache(
|
|
@@ -61,8 +76,21 @@ class AnnotationSummaryCache(
|
|
|
61
76
|
del self._cache[section]
|
|
62
77
|
|
|
63
78
|
def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
|
|
64
|
-
(
|
|
65
|
-
|
|
79
|
+
(
|
|
80
|
+
(
|
|
81
|
+
kind,
|
|
82
|
+
project_rowid,
|
|
83
|
+
interval,
|
|
84
|
+
filter_condition,
|
|
85
|
+
session_filter_condition,
|
|
86
|
+
),
|
|
87
|
+
annotation_name,
|
|
88
|
+
) = _cache_key_fn(key)
|
|
89
|
+
return (project_rowid, annotation_name, kind), (
|
|
90
|
+
interval,
|
|
91
|
+
filter_condition,
|
|
92
|
+
session_filter_condition,
|
|
93
|
+
)
|
|
66
94
|
|
|
67
95
|
|
|
68
96
|
class AnnotationSummaryDataLoader(DataLoader[Key, Result]):
|
|
@@ -102,7 +130,9 @@ def _get_stmt(
|
|
|
102
130
|
segment: Segment,
|
|
103
131
|
*annotation_names: Param,
|
|
104
132
|
) -> Select[Any]:
|
|
105
|
-
kind, project_rowid, (start_time, end_time), filter_condition =
|
|
133
|
+
kind, project_rowid, (start_time, end_time), filter_condition, session_filter_condition = (
|
|
134
|
+
segment
|
|
135
|
+
)
|
|
106
136
|
|
|
107
137
|
annotation_model: Union[Type[models.SpanAnnotation], Type[models.TraceAnnotation]]
|
|
108
138
|
entity_model: Union[Type[models.Span], Type[models.Trace]]
|
|
@@ -144,6 +174,19 @@ def _get_stmt(
|
|
|
144
174
|
entity_count_query = entity_count_query.where(
|
|
145
175
|
cast(Type[models.Trace], entity_model).project_rowid == project_rowid
|
|
146
176
|
)
|
|
177
|
+
else:
|
|
178
|
+
assert_never(kind)
|
|
179
|
+
|
|
180
|
+
if session_filter_condition:
|
|
181
|
+
filtered_session_rowids = get_filtered_session_rowids_subquery(
|
|
182
|
+
session_filter_condition=session_filter_condition,
|
|
183
|
+
project_rowids=[project_rowid],
|
|
184
|
+
start_time=start_time,
|
|
185
|
+
end_time=end_time,
|
|
186
|
+
)
|
|
187
|
+
entity_count_query = entity_count_query.where(
|
|
188
|
+
models.Trace.project_session_rowid.in_(filtered_session_rowids)
|
|
189
|
+
)
|
|
147
190
|
|
|
148
191
|
entity_count_query = entity_count_query.where(
|
|
149
192
|
or_(score_column.is_not(None), label_column.is_not(None))
|
|
@@ -186,6 +229,15 @@ def _get_stmt(
|
|
|
186
229
|
else:
|
|
187
230
|
assert_never(kind)
|
|
188
231
|
|
|
232
|
+
if session_filter_condition:
|
|
233
|
+
filtered_session_rowids = get_filtered_session_rowids_subquery(
|
|
234
|
+
session_filter_condition=session_filter_condition,
|
|
235
|
+
project_rowids=[project_rowid],
|
|
236
|
+
start_time=start_time,
|
|
237
|
+
end_time=end_time,
|
|
238
|
+
)
|
|
239
|
+
base_stmt = base_stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
|
|
240
|
+
|
|
189
241
|
base_stmt = base_stmt.where(or_(score_column.is_not(None), label_column.is_not(None)))
|
|
190
242
|
base_stmt = base_stmt.where(name_column.in_(annotation_names))
|
|
191
243
|
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
ExperimentID: TypeAlias = int
|
|
11
|
+
DatasetExampleID: TypeAlias = int
|
|
12
|
+
RunLatency: TypeAlias = float
|
|
13
|
+
Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
|
|
14
|
+
Result: TypeAlias = Optional[RunLatency]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AverageExperimentRepeatedRunGroupLatencyDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
db: DbSessionFactory,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__(load_fn=self._load_fn)
|
|
23
|
+
self._db = db
|
|
24
|
+
|
|
25
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
26
|
+
average_latency_query = (
|
|
27
|
+
select(
|
|
28
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
29
|
+
models.ExperimentRun.dataset_example_id.label("example_id"),
|
|
30
|
+
func.avg(models.ExperimentRun.latency_ms).label("average_repetition_latency_ms"),
|
|
31
|
+
)
|
|
32
|
+
.select_from(models.ExperimentRun)
|
|
33
|
+
.where(
|
|
34
|
+
tuple_(
|
|
35
|
+
models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
|
|
36
|
+
).in_(set(keys))
|
|
37
|
+
)
|
|
38
|
+
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
|
|
39
|
+
)
|
|
40
|
+
async with self._db() as session:
|
|
41
|
+
average_run_latencies_ms = {
|
|
42
|
+
(experiment_id, example_id): average_run_latency_ms
|
|
43
|
+
async for experiment_id, example_id, average_run_latency_ms in await session.stream(
|
|
44
|
+
average_latency_query
|
|
45
|
+
)
|
|
46
|
+
}
|
|
47
|
+
return [
|
|
48
|
+
average_run_latencies_ms.get((experiment_id, example_id))
|
|
49
|
+
for experiment_id, example_id in keys
|
|
50
|
+
]
|
|
@@ -23,32 +23,25 @@ class AverageExperimentRunLatencyDataLoader(DataLoader[Key, Result]):
|
|
|
23
23
|
|
|
24
24
|
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
25
25
|
experiment_ids = keys
|
|
26
|
-
|
|
27
|
-
select(models.Experiment.id)
|
|
28
|
-
.where(models.Experiment.id.in_(set(experiment_ids)))
|
|
29
|
-
.subquery()
|
|
30
|
-
)
|
|
31
|
-
query = (
|
|
26
|
+
average_repetition_latency_ms = (
|
|
32
27
|
select(
|
|
33
|
-
|
|
34
|
-
func.avg(
|
|
35
|
-
func.extract("epoch", models.ExperimentRun.end_time)
|
|
36
|
-
- func.extract("epoch", models.ExperimentRun.start_time)
|
|
37
|
-
),
|
|
28
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
29
|
+
func.avg(models.ExperimentRun.latency_ms).label("average_repetition_latency_ms"),
|
|
38
30
|
)
|
|
39
|
-
.
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
)
|
|
44
|
-
.group_by(resolved_experiment_ids.c.id)
|
|
31
|
+
.select_from(models.ExperimentRun)
|
|
32
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
33
|
+
.group_by(models.ExperimentRun.dataset_example_id, models.ExperimentRun.experiment_id)
|
|
34
|
+
.subquery()
|
|
45
35
|
)
|
|
36
|
+
query = select(
|
|
37
|
+
average_repetition_latency_ms.c.experiment_id,
|
|
38
|
+
func.avg(average_repetition_latency_ms.c.average_repetition_latency_ms).label(
|
|
39
|
+
"average_run_latency_ms"
|
|
40
|
+
),
|
|
41
|
+
).group_by(average_repetition_latency_ms.c.experiment_id)
|
|
46
42
|
async with self._db() as session:
|
|
47
|
-
|
|
48
|
-
experiment_id:
|
|
49
|
-
async for experiment_id,
|
|
43
|
+
average_run_latencies_ms = {
|
|
44
|
+
experiment_id: average_run_latency_ms
|
|
45
|
+
async for experiment_id, average_run_latency_ms in await session.stream(query)
|
|
50
46
|
}
|
|
51
|
-
return [
|
|
52
|
-
avg_latencies.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
|
|
53
|
-
for experiment_id in keys
|
|
54
|
-
]
|
|
47
|
+
return [average_run_latencies_ms.get(experiment_id) for experiment_id in keys]
|
|
@@ -7,7 +7,7 @@ single-tier system we would need to check all the keys to see if they are in the
|
|
|
7
7
|
subset that we want to invalidate.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from abc import
|
|
10
|
+
from abc import abstractmethod
|
|
11
11
|
from asyncio import Future
|
|
12
12
|
from collections.abc import Callable
|
|
13
13
|
from typing import Any, Generic, Optional, TypeVar
|
|
@@ -25,7 +25,6 @@ _SubKey = TypeVar("_SubKey")
|
|
|
25
25
|
class TwoTierCache(
|
|
26
26
|
AbstractCache[_Key, _Result],
|
|
27
27
|
Generic[_Key, _Result, _Section, _SubKey],
|
|
28
|
-
ABC,
|
|
29
28
|
):
|
|
30
29
|
def __init__(
|
|
31
30
|
self,
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
DatasetID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = DatasetID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetSplit]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatasetDatasetSplitsDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(
|
|
16
|
+
load_fn=self._load_fn,
|
|
17
|
+
)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
dataset_ids = keys
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
splits: dict[DatasetID, dict[int, models.DatasetSplit]] = {
|
|
24
|
+
dataset_id: {} for dataset_id in dataset_ids
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
async for dataset_id, split in await session.stream(
|
|
28
|
+
select(models.DatasetExample.dataset_id, models.DatasetSplit)
|
|
29
|
+
.select_from(models.DatasetSplit)
|
|
30
|
+
.join(
|
|
31
|
+
models.DatasetSplitDatasetExample,
|
|
32
|
+
onclause=(
|
|
33
|
+
models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id
|
|
34
|
+
),
|
|
35
|
+
)
|
|
36
|
+
.join(
|
|
37
|
+
models.DatasetExample,
|
|
38
|
+
onclause=(
|
|
39
|
+
models.DatasetSplitDatasetExample.dataset_example_id
|
|
40
|
+
== models.DatasetExample.id
|
|
41
|
+
),
|
|
42
|
+
)
|
|
43
|
+
.where(models.DatasetExample.dataset_id.in_(dataset_ids))
|
|
44
|
+
):
|
|
45
|
+
# Use dict to deduplicate splits by split.id
|
|
46
|
+
if dataset_id in splits:
|
|
47
|
+
splits[dataset_id][split.id] = split
|
|
48
|
+
|
|
49
|
+
return [
|
|
50
|
+
sorted(splits.get(dataset_id, {}).values(), key=lambda x: x.name)
|
|
51
|
+
for dataset_id in keys
|
|
52
|
+
]
|
|
@@ -91,7 +91,6 @@ class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
|
|
|
91
91
|
onclause=revision_ids.c.version_id == models.DatasetVersion.id,
|
|
92
92
|
isouter=True, # keep rows where the version id is null
|
|
93
93
|
)
|
|
94
|
-
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
95
94
|
)
|
|
96
95
|
async with self._db() as session:
|
|
97
96
|
results = {
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
ExampleID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = ExampleID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetSplit]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatasetExampleSplitsDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(
|
|
16
|
+
load_fn=self._load_fn,
|
|
17
|
+
)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
example_ids = keys
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
splits: dict[ExampleID, list[models.DatasetSplit]] = {}
|
|
24
|
+
|
|
25
|
+
async for example_id, split in await session.stream(
|
|
26
|
+
select(models.DatasetSplitDatasetExample.dataset_example_id, models.DatasetSplit)
|
|
27
|
+
.select_from(models.DatasetSplit)
|
|
28
|
+
.join(
|
|
29
|
+
models.DatasetSplitDatasetExample,
|
|
30
|
+
onclause=(
|
|
31
|
+
models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
.where(models.DatasetSplitDatasetExample.dataset_example_id.in_(example_ids))
|
|
35
|
+
):
|
|
36
|
+
if example_id not in splits:
|
|
37
|
+
splits[example_id] = []
|
|
38
|
+
splits[example_id].append(split)
|
|
39
|
+
|
|
40
|
+
return [sorted(splits.get(example_id, []), key=lambda x: x.name) for example_id in keys]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
ExperimentRunID: TypeAlias = int
|
|
9
|
+
DatasetExampleID: TypeAlias = int
|
|
10
|
+
DatasetVersionID: TypeAlias = int
|
|
11
|
+
Key: TypeAlias = ExperimentRunID
|
|
12
|
+
Result: TypeAlias = tuple[models.DatasetExample, DatasetVersionID]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatasetExamplesAndVersionsByExperimentRunDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
experiment_run_ids = set(keys)
|
|
22
|
+
examples_and_versions_query = (
|
|
23
|
+
select(
|
|
24
|
+
models.ExperimentRun.id.label("experiment_run_id"),
|
|
25
|
+
models.DatasetExample,
|
|
26
|
+
models.Experiment.dataset_version_id.label("dataset_version_id"),
|
|
27
|
+
)
|
|
28
|
+
.select_from(models.ExperimentRun)
|
|
29
|
+
.join(
|
|
30
|
+
models.DatasetExample,
|
|
31
|
+
models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
|
|
32
|
+
)
|
|
33
|
+
.join(
|
|
34
|
+
models.Experiment,
|
|
35
|
+
models.Experiment.id == models.ExperimentRun.experiment_id,
|
|
36
|
+
)
|
|
37
|
+
.where(models.ExperimentRun.id.in_(experiment_run_ids))
|
|
38
|
+
)
|
|
39
|
+
async with self._db() as session:
|
|
40
|
+
examples_and_versions = {
|
|
41
|
+
experiment_run_id: (example, version_id)
|
|
42
|
+
for experiment_run_id, example, version_id in (
|
|
43
|
+
await session.execute(examples_and_versions_query)
|
|
44
|
+
).all()
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
return [examples_and_versions[key] for key in keys]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
DatasetID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = DatasetID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetLabel]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatasetLabelsDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(load_fn=self._load_fn)
|
|
16
|
+
self._db = db
|
|
17
|
+
|
|
18
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
19
|
+
dataset_ids = keys
|
|
20
|
+
async with self._db() as session:
|
|
21
|
+
labels: dict[Key, Result] = {}
|
|
22
|
+
for dataset_id, label in await session.execute(
|
|
23
|
+
select(models.DatasetsDatasetLabel.dataset_id, models.DatasetLabel)
|
|
24
|
+
.select_from(models.DatasetLabel)
|
|
25
|
+
.join(
|
|
26
|
+
models.DatasetsDatasetLabel,
|
|
27
|
+
models.DatasetLabel.id == models.DatasetsDatasetLabel.dataset_label_id,
|
|
28
|
+
)
|
|
29
|
+
.where(models.DatasetsDatasetLabel.dataset_id.in_(dataset_ids))
|
|
30
|
+
):
|
|
31
|
+
if dataset_id not in labels:
|
|
32
|
+
labels[dataset_id] = []
|
|
33
|
+
labels[dataset_id].append(label)
|
|
34
|
+
return [
|
|
35
|
+
sorted(labels.get(dataset_id, []), key=lambda label: label.name) for dataset_id in keys
|
|
36
|
+
]
|
|
@@ -10,7 +10,7 @@ from strawberry.dataloader import AbstractCache, DataLoader
|
|
|
10
10
|
from typing_extensions import TypeAlias
|
|
11
11
|
|
|
12
12
|
from phoenix.db import models
|
|
13
|
-
from phoenix.db.helpers import SupportedSQLDialect
|
|
13
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
14
14
|
from phoenix.metrics.retrieval_metrics import RetrievalMetrics
|
|
15
15
|
from phoenix.server.api.dataloaders.cache import TwoTierCache
|
|
16
16
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
@@ -122,7 +122,7 @@ def _get_stmt(
|
|
|
122
122
|
select(
|
|
123
123
|
mda.name,
|
|
124
124
|
models.Span.id,
|
|
125
|
-
|
|
125
|
+
models.Span.num_documents.label("num_docs"),
|
|
126
126
|
mda.score,
|
|
127
127
|
mda.document_position,
|
|
128
128
|
)
|
|
@@ -5,11 +5,10 @@ from strawberry.dataloader import DataLoader
|
|
|
5
5
|
from typing_extensions import TypeAlias
|
|
6
6
|
|
|
7
7
|
from phoenix.db import models
|
|
8
|
-
from phoenix.server.api.types.Evaluation import DocumentEvaluation
|
|
9
8
|
from phoenix.server.types import DbSessionFactory
|
|
10
9
|
|
|
11
10
|
Key: TypeAlias = int
|
|
12
|
-
Result: TypeAlias = list[
|
|
11
|
+
Result: TypeAlias = list[models.DocumentAnnotation]
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
class DocumentEvaluationsDataLoader(DataLoader[Key, Result]):
|
|
@@ -18,14 +17,12 @@ class DocumentEvaluationsDataLoader(DataLoader[Key, Result]):
|
|
|
18
17
|
self._db = db
|
|
19
18
|
|
|
20
19
|
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
-
|
|
20
|
+
document_annotations_by_id: defaultdict[Key, Result] = defaultdict(list)
|
|
22
21
|
mda = models.DocumentAnnotation
|
|
23
22
|
async with self._db() as session:
|
|
24
|
-
data = await session.stream_scalars(
|
|
25
|
-
select(mda).where(mda.span_rowid.in_(keys)).where(mda.annotator_kind == "LLM")
|
|
26
|
-
)
|
|
23
|
+
data = await session.stream_scalars(select(mda).where(mda.span_rowid.in_(keys)))
|
|
27
24
|
async for document_evaluation in data:
|
|
28
|
-
|
|
29
|
-
|
|
25
|
+
document_annotations_by_id[document_evaluation.span_rowid].append(
|
|
26
|
+
document_evaluation
|
|
30
27
|
)
|
|
31
|
-
return [
|
|
28
|
+
return [document_annotations_by_id[key] for key in keys]
|
|
@@ -2,7 +2,7 @@ from collections import defaultdict
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
|
-
from sqlalchemy import func, select
|
|
5
|
+
from sqlalchemy import and_, func, select
|
|
6
6
|
from strawberry.dataloader import AbstractCache, DataLoader
|
|
7
7
|
from typing_extensions import TypeAlias
|
|
8
8
|
|
|
@@ -37,43 +37,97 @@ class ExperimentAnnotationSummaryDataLoader(DataLoader[Key, Result]):
|
|
|
37
37
|
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
38
38
|
experiment_ids = keys
|
|
39
39
|
summaries: defaultdict[ExperimentID, Result] = defaultdict(list)
|
|
40
|
+
repetition_mean_scores_by_example_subquery = (
|
|
41
|
+
select(
|
|
42
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
43
|
+
models.ExperimentRunAnnotation.name.label("annotation_name"),
|
|
44
|
+
func.avg(models.ExperimentRunAnnotation.score).label("mean_repetition_score"),
|
|
45
|
+
)
|
|
46
|
+
.select_from(models.ExperimentRunAnnotation)
|
|
47
|
+
.join(
|
|
48
|
+
models.ExperimentRun,
|
|
49
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
50
|
+
)
|
|
51
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
52
|
+
.group_by(
|
|
53
|
+
models.ExperimentRun.experiment_id,
|
|
54
|
+
models.ExperimentRun.dataset_example_id,
|
|
55
|
+
models.ExperimentRunAnnotation.name,
|
|
56
|
+
)
|
|
57
|
+
.subquery()
|
|
58
|
+
.alias("repetition_mean_scores_by_example")
|
|
59
|
+
)
|
|
60
|
+
repetition_mean_scores_subquery = (
|
|
61
|
+
select(
|
|
62
|
+
repetition_mean_scores_by_example_subquery.c.experiment_id.label("experiment_id"),
|
|
63
|
+
repetition_mean_scores_by_example_subquery.c.annotation_name.label(
|
|
64
|
+
"annotation_name"
|
|
65
|
+
),
|
|
66
|
+
func.avg(repetition_mean_scores_by_example_subquery.c.mean_repetition_score).label(
|
|
67
|
+
"mean_score"
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
.select_from(repetition_mean_scores_by_example_subquery)
|
|
71
|
+
.group_by(
|
|
72
|
+
repetition_mean_scores_by_example_subquery.c.experiment_id,
|
|
73
|
+
repetition_mean_scores_by_example_subquery.c.annotation_name,
|
|
74
|
+
)
|
|
75
|
+
.subquery()
|
|
76
|
+
.alias("repetition_mean_scores")
|
|
77
|
+
)
|
|
78
|
+
repetitions_subquery = (
|
|
79
|
+
select(
|
|
80
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
81
|
+
models.ExperimentRunAnnotation.name.label("annotation_name"),
|
|
82
|
+
func.min(models.ExperimentRunAnnotation.score).label("min_score"),
|
|
83
|
+
func.max(models.ExperimentRunAnnotation.score).label("max_score"),
|
|
84
|
+
func.count().label("count"),
|
|
85
|
+
func.count(models.ExperimentRunAnnotation.error).label("error_count"),
|
|
86
|
+
)
|
|
87
|
+
.select_from(models.ExperimentRunAnnotation)
|
|
88
|
+
.join(
|
|
89
|
+
models.ExperimentRun,
|
|
90
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
91
|
+
)
|
|
92
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
93
|
+
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
|
|
94
|
+
.subquery()
|
|
95
|
+
)
|
|
96
|
+
run_scores_query = (
|
|
97
|
+
select(
|
|
98
|
+
repetition_mean_scores_subquery.c.experiment_id.label("experiment_id"),
|
|
99
|
+
repetition_mean_scores_subquery.c.annotation_name.label("annotation_name"),
|
|
100
|
+
repetition_mean_scores_subquery.c.mean_score.label("mean_score"),
|
|
101
|
+
repetitions_subquery.c.min_score.label("min_score"),
|
|
102
|
+
repetitions_subquery.c.max_score.label("max_score"),
|
|
103
|
+
repetitions_subquery.c.count.label("count_"),
|
|
104
|
+
repetitions_subquery.c.error_count.label("error_count"),
|
|
105
|
+
)
|
|
106
|
+
.select_from(repetition_mean_scores_subquery)
|
|
107
|
+
.join(
|
|
108
|
+
repetitions_subquery,
|
|
109
|
+
and_(
|
|
110
|
+
repetitions_subquery.c.experiment_id
|
|
111
|
+
== repetition_mean_scores_subquery.c.experiment_id,
|
|
112
|
+
repetitions_subquery.c.annotation_name
|
|
113
|
+
== repetition_mean_scores_subquery.c.annotation_name,
|
|
114
|
+
),
|
|
115
|
+
)
|
|
116
|
+
.order_by(repetition_mean_scores_subquery.c.annotation_name)
|
|
117
|
+
)
|
|
40
118
|
async with self._db() as session:
|
|
41
|
-
async for (
|
|
42
|
-
experiment_id
|
|
43
|
-
annotation_name,
|
|
44
|
-
min_score,
|
|
45
|
-
max_score,
|
|
46
|
-
mean_score,
|
|
47
|
-
count,
|
|
48
|
-
error_count,
|
|
49
|
-
) in await session.stream(
|
|
50
|
-
select(
|
|
51
|
-
models.ExperimentRun.experiment_id,
|
|
52
|
-
models.ExperimentRunAnnotation.name,
|
|
53
|
-
func.min(models.ExperimentRunAnnotation.score),
|
|
54
|
-
func.max(models.ExperimentRunAnnotation.score),
|
|
55
|
-
func.avg(models.ExperimentRunAnnotation.score),
|
|
56
|
-
func.count(),
|
|
57
|
-
func.count(models.ExperimentRunAnnotation.error),
|
|
58
|
-
)
|
|
59
|
-
.join(
|
|
60
|
-
models.ExperimentRun,
|
|
61
|
-
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
62
|
-
)
|
|
63
|
-
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
64
|
-
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
|
|
65
|
-
):
|
|
66
|
-
summaries[experiment_id].append(
|
|
119
|
+
async for scores_tuple in await session.stream(run_scores_query):
|
|
120
|
+
summaries[scores_tuple.experiment_id].append(
|
|
67
121
|
ExperimentAnnotationSummary(
|
|
68
|
-
annotation_name=annotation_name,
|
|
69
|
-
min_score=min_score,
|
|
70
|
-
max_score=max_score,
|
|
71
|
-
mean_score=mean_score,
|
|
72
|
-
count=
|
|
73
|
-
error_count=error_count,
|
|
122
|
+
annotation_name=scores_tuple.annotation_name,
|
|
123
|
+
min_score=scores_tuple.min_score,
|
|
124
|
+
max_score=scores_tuple.max_score,
|
|
125
|
+
mean_score=scores_tuple.mean_score,
|
|
126
|
+
count=scores_tuple.count_,
|
|
127
|
+
error_count=scores_tuple.error_count,
|
|
74
128
|
)
|
|
75
129
|
)
|
|
76
130
|
return [
|
|
77
131
|
sorted(summaries[experiment_id], key=lambda summary: summary.annotation_name)
|
|
78
|
-
for experiment_id in
|
|
132
|
+
for experiment_id in experiment_ids
|
|
79
133
|
]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from sqlalchemy import select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
ExperimentID: TypeAlias = int
|
|
9
|
+
Key: TypeAlias = ExperimentID
|
|
10
|
+
Result: TypeAlias = list[models.DatasetSplit]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ExperimentDatasetSplitsDataLoader(DataLoader[Key, Result]):
|
|
14
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
15
|
+
super().__init__(
|
|
16
|
+
load_fn=self._load_fn,
|
|
17
|
+
)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
21
|
+
experiment_ids = keys
|
|
22
|
+
async with self._db() as session:
|
|
23
|
+
splits: dict[ExperimentID, list[models.DatasetSplit]] = {}
|
|
24
|
+
|
|
25
|
+
async for experiment_id, split in await session.stream(
|
|
26
|
+
select(models.ExperimentDatasetSplit.experiment_id, models.DatasetSplit)
|
|
27
|
+
.select_from(models.DatasetSplit)
|
|
28
|
+
.join(
|
|
29
|
+
models.ExperimentDatasetSplit,
|
|
30
|
+
onclause=(
|
|
31
|
+
models.DatasetSplit.id == models.ExperimentDatasetSplit.dataset_split_id
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
.where(models.ExperimentDatasetSplit.experiment_id.in_(experiment_ids))
|
|
35
|
+
):
|
|
36
|
+
if experiment_id not in splits:
|
|
37
|
+
splits[experiment_id] = []
|
|
38
|
+
splits[experiment_id].append(split)
|
|
39
|
+
|
|
40
|
+
return [
|
|
41
|
+
sorted(splits.get(experiment_id, []), key=lambda x: x.name)
|
|
42
|
+
for experiment_id in keys
|
|
43
|
+
]
|