arize-phoenix 11.32.1__py3-none-any.whl → 11.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.33.0.dist-info}/METADATA +1 -1
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.33.0.dist-info}/RECORD +48 -41
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/context.py +20 -0
- phoenix/server/api/dataloaders/__init__.py +20 -0
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +59 -0
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +39 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/helpers/playground_clients.py +4 -0
- phoenix/server/api/mutations/prompt_label_mutations.py +67 -58
- phoenix/server/api/queries.py +52 -37
- phoenix/server/api/routers/v1/experiment_runs.py +1 -1
- phoenix/server/api/routers/v1/experiments.py +1 -1
- phoenix/server/api/types/DatasetExample.py +49 -1
- phoenix/server/api/types/Experiment.py +12 -2
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +146 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +12 -19
- phoenix/server/api/types/Prompt.py +11 -0
- phoenix/server/api/types/PromptLabel.py +2 -19
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/app.py +22 -2
- phoenix/server/cost_tracking/model_cost_manifest.json +1 -1
- phoenix/server/static/.vite/manifest.json +43 -43
- phoenix/server/static/assets/components-YTHUASXI.js +5778 -0
- phoenix/server/static/assets/{index-D1FDMBMV.js → index-CugQp26L.js} +12 -21
- phoenix/server/static/assets/pages-4Qu8GNlt.js +9159 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-CRRxHwSp.js +903 -0
- phoenix/server/static/assets/{vendor-arizeai-DsYDNOqt.js → vendor-arizeai-CUN6lRd9.js} +4 -4
- phoenix/server/static/assets/vendor-codemirror-sJXwoqrE.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-BTHn5Y2R.js → vendor-recharts-BT_PeGhc.js} +2 -2
- phoenix/server/static/assets/{vendor-shiki-BAcocHFl.js → vendor-shiki-1F3op0QC.js} +1 -1
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +13 -0
- phoenix/trace/projects.py +1 -2
- phoenix/version.py +1 -1
- phoenix/server/static/assets/components-Cs9c4Nxp.js +0 -5698
- phoenix/server/static/assets/pages-Cbj9SjBx.js +0 -8928
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-RdRDaQiR.js +0 -905
- phoenix/server/static/assets/vendor-codemirror-BzJDUbEx.js +0 -25
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.33.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.33.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.33.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.33.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from sqlalchemy import func, select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
ExperimentID: TypeAlias = int
|
|
9
|
+
RepetitionCount: TypeAlias = int
|
|
10
|
+
Key: TypeAlias = ExperimentID
|
|
11
|
+
Result: TypeAlias = RepetitionCount
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ExperimentRepetitionCountsDataLoader(DataLoader[Key, Result]):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
db: DbSessionFactory,
|
|
18
|
+
) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
experiment_ids = keys
|
|
24
|
+
repetition_counts_query = (
|
|
25
|
+
select(
|
|
26
|
+
models.ExperimentRun.experiment_id,
|
|
27
|
+
func.max(models.ExperimentRun.repetition_number).label("repetition_count"),
|
|
28
|
+
)
|
|
29
|
+
.group_by(models.ExperimentRun.experiment_id)
|
|
30
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
31
|
+
)
|
|
32
|
+
async with self._db() as session:
|
|
33
|
+
repetition_counts = {
|
|
34
|
+
experiment_id: repetition_count
|
|
35
|
+
for experiment_id, repetition_count in await session.execute(
|
|
36
|
+
repetition_counts_query
|
|
37
|
+
)
|
|
38
|
+
}
|
|
39
|
+
return [repetition_counts.get(experiment_id, 0) for experiment_id in keys]
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
ExperimentId: TypeAlias = int
|
|
12
|
+
DatasetExampleId: TypeAlias = int
|
|
13
|
+
Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
|
|
14
|
+
Result: TypeAlias = SpanCostSummary
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpanCostSummaryByExperimentRepeatedRunGroupDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
stmt = (
|
|
24
|
+
select(
|
|
25
|
+
models.ExperimentRun.experiment_id,
|
|
26
|
+
models.ExperimentRun.dataset_example_id,
|
|
27
|
+
func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
|
|
28
|
+
func.sum(models.SpanCost.completion_cost).label("completion_cost"),
|
|
29
|
+
func.sum(models.SpanCost.total_cost).label("total_cost"),
|
|
30
|
+
func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
|
|
31
|
+
func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
|
|
32
|
+
func.sum(models.SpanCost.total_tokens).label("total_tokens"),
|
|
33
|
+
)
|
|
34
|
+
.select_from(models.ExperimentRun)
|
|
35
|
+
.join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
36
|
+
.join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
|
|
37
|
+
.where(
|
|
38
|
+
tuple_(
|
|
39
|
+
models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
|
|
40
|
+
).in_(set(keys))
|
|
41
|
+
)
|
|
42
|
+
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
|
|
46
|
+
async with self._db() as session:
|
|
47
|
+
data = await session.stream(stmt)
|
|
48
|
+
async for (
|
|
49
|
+
experiment_id,
|
|
50
|
+
dataset_example_id,
|
|
51
|
+
prompt_cost,
|
|
52
|
+
completion_cost,
|
|
53
|
+
total_cost,
|
|
54
|
+
prompt_tokens,
|
|
55
|
+
completion_tokens,
|
|
56
|
+
total_tokens,
|
|
57
|
+
) in data:
|
|
58
|
+
summary = SpanCostSummary(
|
|
59
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
60
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
61
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
62
|
+
)
|
|
63
|
+
results[(experiment_id, dataset_example_id)] = summary
|
|
64
|
+
return [results.get(key, SpanCostSummary()) for key in keys]
|
|
@@ -1669,7 +1669,11 @@ class AnthropicReasoningStreamingClient(AnthropicStreamingClient):
|
|
|
1669
1669
|
provider_key=GenerativeProviderKey.GOOGLE,
|
|
1670
1670
|
model_names=[
|
|
1671
1671
|
PROVIDER_DEFAULT,
|
|
1672
|
+
"gemini-2.5-flash",
|
|
1673
|
+
"gemini-2.5-flash-lite",
|
|
1674
|
+
"gemini-2.5-pro",
|
|
1672
1675
|
"gemini-2.5-pro-preview-03-25",
|
|
1676
|
+
"gemini-2.0-flash",
|
|
1673
1677
|
"gemini-2.0-flash-lite",
|
|
1674
1678
|
"gemini-2.0-flash-001",
|
|
1675
1679
|
"gemini-2.0-flash-thinking-exp-01-21",
|
|
@@ -10,12 +10,10 @@ from strawberry.relay import GlobalID
|
|
|
10
10
|
from strawberry.types import Info
|
|
11
11
|
|
|
12
12
|
from phoenix.db import models
|
|
13
|
-
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
14
13
|
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
15
14
|
from phoenix.server.api.context import Context
|
|
16
15
|
from phoenix.server.api.exceptions import Conflict, NotFound
|
|
17
16
|
from phoenix.server.api.queries import Query
|
|
18
|
-
from phoenix.server.api.types.Identifier import Identifier
|
|
19
17
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
18
|
from phoenix.server.api.types.Prompt import Prompt
|
|
21
19
|
from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
|
|
@@ -23,37 +21,49 @@ from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_labe
|
|
|
23
21
|
|
|
24
22
|
@strawberry.input
|
|
25
23
|
class CreatePromptLabelInput:
|
|
26
|
-
name:
|
|
24
|
+
name: str
|
|
27
25
|
description: Optional[str] = None
|
|
26
|
+
color: str
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
@strawberry.input
|
|
31
30
|
class PatchPromptLabelInput:
|
|
32
31
|
prompt_label_id: GlobalID
|
|
33
|
-
name: Optional[
|
|
32
|
+
name: Optional[str] = None
|
|
34
33
|
description: Optional[str] = None
|
|
35
34
|
|
|
36
35
|
|
|
37
36
|
@strawberry.input
|
|
38
|
-
class
|
|
39
|
-
|
|
37
|
+
class DeletePromptLabelsInput:
|
|
38
|
+
prompt_label_ids: list[GlobalID]
|
|
40
39
|
|
|
41
40
|
|
|
42
41
|
@strawberry.input
|
|
43
|
-
class
|
|
42
|
+
class SetPromptLabelsInput:
|
|
44
43
|
prompt_id: GlobalID
|
|
45
|
-
|
|
44
|
+
prompt_label_ids: list[GlobalID]
|
|
46
45
|
|
|
47
46
|
|
|
48
47
|
@strawberry.input
|
|
49
|
-
class
|
|
48
|
+
class UnsetPromptLabelsInput:
|
|
50
49
|
prompt_id: GlobalID
|
|
51
|
-
|
|
50
|
+
prompt_label_ids: list[GlobalID]
|
|
52
51
|
|
|
53
52
|
|
|
54
53
|
@strawberry.type
|
|
55
54
|
class PromptLabelMutationPayload:
|
|
56
|
-
|
|
55
|
+
prompt_labels: list["PromptLabel"]
|
|
56
|
+
query: "Query"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@strawberry.type
|
|
60
|
+
class PromptLabelDeleteMutationPayload:
|
|
61
|
+
deleted_prompt_label_ids: list["GlobalID"]
|
|
62
|
+
query: "Query"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@strawberry.type
|
|
66
|
+
class PromptLabelAssociationMutationPayload:
|
|
57
67
|
query: "Query"
|
|
58
68
|
|
|
59
69
|
|
|
@@ -64,17 +74,18 @@ class PromptLabelMutationMixin:
|
|
|
64
74
|
self, info: Info[Context, None], input: CreatePromptLabelInput
|
|
65
75
|
) -> PromptLabelMutationPayload:
|
|
66
76
|
async with info.context.db() as session:
|
|
67
|
-
|
|
68
|
-
|
|
77
|
+
label_orm = models.PromptLabel(
|
|
78
|
+
name=input.name, description=input.description, color=input.color
|
|
79
|
+
)
|
|
69
80
|
session.add(label_orm)
|
|
70
81
|
|
|
71
82
|
try:
|
|
72
83
|
await session.commit()
|
|
73
84
|
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
74
|
-
raise Conflict(f"A prompt label named '{name}' already exists.")
|
|
85
|
+
raise Conflict(f"A prompt label named '{input.name}' already exists.")
|
|
75
86
|
|
|
76
87
|
return PromptLabelMutationPayload(
|
|
77
|
-
|
|
88
|
+
prompt_labels=[to_gql_prompt_label(label_orm)],
|
|
78
89
|
query=Query(),
|
|
79
90
|
)
|
|
80
91
|
|
|
@@ -82,7 +93,6 @@ class PromptLabelMutationMixin:
|
|
|
82
93
|
async def patch_prompt_label(
|
|
83
94
|
self, info: Info[Context, None], input: PatchPromptLabelInput
|
|
84
95
|
) -> PromptLabelMutationPayload:
|
|
85
|
-
validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None
|
|
86
96
|
async with info.context.db() as session:
|
|
87
97
|
label_id = from_global_id_with_expected_type(
|
|
88
98
|
input.prompt_label_id, PromptLabel.__name__
|
|
@@ -92,8 +102,8 @@ class PromptLabelMutationMixin:
|
|
|
92
102
|
if not label_orm:
|
|
93
103
|
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
94
104
|
|
|
95
|
-
if
|
|
96
|
-
label_orm.name =
|
|
105
|
+
if input.name is not None:
|
|
106
|
+
label_orm.name = input.name
|
|
97
107
|
if input.description is not None:
|
|
98
108
|
label_orm.description = input.description
|
|
99
109
|
|
|
@@ -103,46 +113,48 @@ class PromptLabelMutationMixin:
|
|
|
103
113
|
raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
|
|
104
114
|
|
|
105
115
|
return PromptLabelMutationPayload(
|
|
106
|
-
|
|
116
|
+
prompt_labels=[to_gql_prompt_label(label_orm)],
|
|
107
117
|
query=Query(),
|
|
108
118
|
)
|
|
109
119
|
|
|
110
120
|
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
111
|
-
async def
|
|
112
|
-
self, info: Info[Context, None], input:
|
|
113
|
-
) ->
|
|
121
|
+
async def delete_prompt_labels(
|
|
122
|
+
self, info: Info[Context, None], input: DeletePromptLabelsInput
|
|
123
|
+
) -> PromptLabelDeleteMutationPayload:
|
|
114
124
|
"""
|
|
115
125
|
Deletes a PromptLabel (and any crosswalk references).
|
|
116
126
|
"""
|
|
117
127
|
async with info.context.db() as session:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
if result.rowcount == 0:
|
|
125
|
-
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
128
|
+
label_ids = [
|
|
129
|
+
from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
|
|
130
|
+
for prompt_label_id in input.prompt_label_ids
|
|
131
|
+
]
|
|
132
|
+
stmt = delete(models.PromptLabel).where(models.PromptLabel.id.in_(label_ids))
|
|
133
|
+
await session.execute(stmt)
|
|
126
134
|
|
|
127
135
|
await session.commit()
|
|
128
136
|
|
|
129
|
-
return
|
|
130
|
-
|
|
137
|
+
return PromptLabelDeleteMutationPayload(
|
|
138
|
+
deleted_prompt_label_ids=input.prompt_label_ids,
|
|
131
139
|
query=Query(),
|
|
132
140
|
)
|
|
133
141
|
|
|
134
142
|
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
135
|
-
async def
|
|
136
|
-
self, info: Info[Context, None], input:
|
|
137
|
-
) ->
|
|
143
|
+
async def set_prompt_labels(
|
|
144
|
+
self, info: Info[Context, None], input: SetPromptLabelsInput
|
|
145
|
+
) -> PromptLabelAssociationMutationPayload:
|
|
138
146
|
async with info.context.db() as session:
|
|
139
147
|
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
148
|
+
label_ids = [
|
|
149
|
+
from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
|
|
150
|
+
for prompt_label_id in input.prompt_label_ids
|
|
151
|
+
]
|
|
143
152
|
|
|
144
|
-
|
|
145
|
-
|
|
153
|
+
crosswalk_items = [
|
|
154
|
+
models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
|
|
155
|
+
for label_id in label_ids
|
|
156
|
+
]
|
|
157
|
+
session.add_all(crosswalk_items)
|
|
146
158
|
|
|
147
159
|
try:
|
|
148
160
|
await session.commit()
|
|
@@ -152,41 +164,38 @@ class PromptLabelMutationMixin:
|
|
|
152
164
|
# - Foreign key violation => prompt_id or label_id doesn't exist
|
|
153
165
|
raise Conflict("Failed to associate PromptLabel with Prompt.") from e
|
|
154
166
|
|
|
155
|
-
|
|
156
|
-
if not label_orm:
|
|
157
|
-
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
158
|
-
|
|
159
|
-
return PromptLabelMutationPayload(
|
|
160
|
-
prompt_label=to_gql_prompt_label(label_orm),
|
|
167
|
+
return PromptLabelAssociationMutationPayload(
|
|
161
168
|
query=Query(),
|
|
162
169
|
)
|
|
163
170
|
|
|
164
171
|
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
165
|
-
async def
|
|
166
|
-
self, info: Info[Context, None], input:
|
|
167
|
-
) ->
|
|
172
|
+
async def unset_prompt_labels(
|
|
173
|
+
self, info: Info[Context, None], input: UnsetPromptLabelsInput
|
|
174
|
+
) -> PromptLabelAssociationMutationPayload:
|
|
168
175
|
"""
|
|
169
176
|
Unsets a PromptLabel from a Prompt by removing the row in the crosswalk.
|
|
170
177
|
"""
|
|
171
178
|
async with info.context.db() as session:
|
|
172
179
|
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
180
|
+
label_ids = [
|
|
181
|
+
from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
|
|
182
|
+
for prompt_label_id in input.prompt_label_ids
|
|
183
|
+
]
|
|
176
184
|
|
|
177
185
|
stmt = delete(models.PromptPromptLabel).where(
|
|
178
186
|
(models.PromptPromptLabel.prompt_id == prompt_id)
|
|
179
|
-
& (models.PromptPromptLabel.prompt_label_id
|
|
187
|
+
& (models.PromptPromptLabel.prompt_label_id.in_(label_ids))
|
|
180
188
|
)
|
|
181
189
|
result = await session.execute(stmt)
|
|
182
190
|
|
|
183
|
-
if result.rowcount
|
|
184
|
-
|
|
191
|
+
if result.rowcount != len(label_ids):
|
|
192
|
+
label_ids_str = ", ".join(str(i) for i in label_ids)
|
|
193
|
+
raise NotFound(
|
|
194
|
+
f"No association between prompt={prompt_id} and labels={label_ids_str}."
|
|
195
|
+
)
|
|
185
196
|
|
|
186
197
|
await session.commit()
|
|
187
198
|
|
|
188
|
-
|
|
189
|
-
return PromptLabelMutationPayload(
|
|
190
|
-
prompt_label=to_gql_prompt_label(label_orm) if label_orm else None,
|
|
199
|
+
return PromptLabelAssociationMutationPayload(
|
|
191
200
|
query=Query(),
|
|
192
201
|
)
|
phoenix/server/api/queries.py
CHANGED
|
@@ -56,15 +56,25 @@ from phoenix.server.api.types.EmbeddingDimension import (
|
|
|
56
56
|
to_gql_embedding_dimension,
|
|
57
57
|
)
|
|
58
58
|
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
59
|
-
from phoenix.server.api.types.Experiment import Experiment
|
|
60
|
-
from phoenix.server.api.types.ExperimentComparison import
|
|
59
|
+
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
60
|
+
from phoenix.server.api.types.ExperimentComparison import (
|
|
61
|
+
ExperimentComparison,
|
|
62
|
+
)
|
|
63
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
64
|
+
ExperimentRepeatedRunGroup,
|
|
65
|
+
parse_experiment_repeated_run_group_node_id,
|
|
66
|
+
)
|
|
61
67
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
62
68
|
from phoenix.server.api.types.Functionality import Functionality
|
|
63
69
|
from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
|
|
64
70
|
from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
|
|
65
71
|
from phoenix.server.api.types.InferenceModel import InferenceModel
|
|
66
72
|
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
67
|
-
from phoenix.server.api.types.node import
|
|
73
|
+
from phoenix.server.api.types.node import (
|
|
74
|
+
from_global_id,
|
|
75
|
+
from_global_id_with_expected_type,
|
|
76
|
+
is_global_id,
|
|
77
|
+
)
|
|
68
78
|
from phoenix.server.api.types.pagination import (
|
|
69
79
|
ConnectionArgs,
|
|
70
80
|
Cursor,
|
|
@@ -513,11 +523,12 @@ class Query:
|
|
|
513
523
|
|
|
514
524
|
cursors_and_nodes = []
|
|
515
525
|
for example in examples:
|
|
516
|
-
|
|
526
|
+
repeated_run_groups = []
|
|
517
527
|
for experiment_id in experiment_rowids:
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
528
|
+
repeated_run_groups.append(
|
|
529
|
+
ExperimentRepeatedRunGroup(
|
|
530
|
+
experiment_rowid=experiment_id,
|
|
531
|
+
dataset_example_rowid=example.id,
|
|
521
532
|
runs=[
|
|
522
533
|
to_gql_experiment_run(run)
|
|
523
534
|
for run in sorted(
|
|
@@ -533,7 +544,7 @@ class Query:
|
|
|
533
544
|
created_at=example.created_at,
|
|
534
545
|
version_id=base_experiment.dataset_version_id,
|
|
535
546
|
),
|
|
536
|
-
|
|
547
|
+
repeated_run_groups=repeated_run_groups,
|
|
537
548
|
)
|
|
538
549
|
cursors_and_nodes.append((Cursor(rowid=example.id), experiment_comparison))
|
|
539
550
|
|
|
@@ -863,8 +874,37 @@ class Query:
|
|
|
863
874
|
return InferenceModel()
|
|
864
875
|
|
|
865
876
|
@strawberry.field
|
|
866
|
-
async def node(self, id:
|
|
867
|
-
|
|
877
|
+
async def node(self, id: strawberry.ID, info: Info[Context, None]) -> Node:
|
|
878
|
+
if not is_global_id(id):
|
|
879
|
+
try:
|
|
880
|
+
experiment_rowid, dataset_example_rowid = (
|
|
881
|
+
parse_experiment_repeated_run_group_node_id(id)
|
|
882
|
+
)
|
|
883
|
+
except Exception:
|
|
884
|
+
raise NotFound(f"Unknown node: {id}")
|
|
885
|
+
|
|
886
|
+
async with info.context.db() as session:
|
|
887
|
+
runs = (
|
|
888
|
+
await session.scalars(
|
|
889
|
+
select(models.ExperimentRun)
|
|
890
|
+
.where(models.ExperimentRun.experiment_id == experiment_rowid)
|
|
891
|
+
.where(models.ExperimentRun.dataset_example_id == dataset_example_rowid)
|
|
892
|
+
.order_by(models.ExperimentRun.repetition_number.asc())
|
|
893
|
+
.options(
|
|
894
|
+
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
895
|
+
)
|
|
896
|
+
)
|
|
897
|
+
).all()
|
|
898
|
+
if not runs:
|
|
899
|
+
raise NotFound(f"Unknown experiment or dataset example: {id}")
|
|
900
|
+
return ExperimentRepeatedRunGroup(
|
|
901
|
+
experiment_rowid=experiment_rowid,
|
|
902
|
+
dataset_example_rowid=dataset_example_rowid,
|
|
903
|
+
runs=[to_gql_experiment_run(run) for run in runs],
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
global_id = GlobalID.from_id(id)
|
|
907
|
+
type_name, node_id = from_global_id(global_id)
|
|
868
908
|
if type_name == "Dimension":
|
|
869
909
|
dimension = info.context.model.scalar_dimensions[node_id]
|
|
870
910
|
return to_gql_dimension(node_id, dimension)
|
|
@@ -909,26 +949,9 @@ class Query:
|
|
|
909
949
|
return to_gql_dataset(dataset)
|
|
910
950
|
elif type_name == DatasetExample.__name__:
|
|
911
951
|
example_id = node_id
|
|
912
|
-
latest_revision_id = (
|
|
913
|
-
select(func.max(models.DatasetExampleRevision.id))
|
|
914
|
-
.where(models.DatasetExampleRevision.dataset_example_id == example_id)
|
|
915
|
-
.scalar_subquery()
|
|
916
|
-
)
|
|
917
952
|
async with info.context.db() as session:
|
|
918
953
|
example = await session.scalar(
|
|
919
|
-
select(models.DatasetExample)
|
|
920
|
-
.join(
|
|
921
|
-
models.DatasetExampleRevision,
|
|
922
|
-
onclause=models.DatasetExampleRevision.dataset_example_id
|
|
923
|
-
== models.DatasetExample.id,
|
|
924
|
-
)
|
|
925
|
-
.where(
|
|
926
|
-
and_(
|
|
927
|
-
models.DatasetExample.id == example_id,
|
|
928
|
-
models.DatasetExampleRevision.id == latest_revision_id,
|
|
929
|
-
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
930
|
-
)
|
|
931
|
-
)
|
|
954
|
+
select(models.DatasetExample).where(models.DatasetExample.id == example_id)
|
|
932
955
|
)
|
|
933
956
|
if not example:
|
|
934
957
|
raise NotFound(f"Unknown dataset example: {id}")
|
|
@@ -943,15 +966,7 @@ class Query:
|
|
|
943
966
|
)
|
|
944
967
|
if not experiment:
|
|
945
968
|
raise NotFound(f"Unknown experiment: {id}")
|
|
946
|
-
return
|
|
947
|
-
id_attr=experiment.id,
|
|
948
|
-
name=experiment.name,
|
|
949
|
-
project_name=experiment.project_name,
|
|
950
|
-
description=experiment.description,
|
|
951
|
-
created_at=experiment.created_at,
|
|
952
|
-
updated_at=experiment.updated_at,
|
|
953
|
-
metadata=experiment.metadata_,
|
|
954
|
-
)
|
|
969
|
+
return to_gql_experiment(experiment)
|
|
955
970
|
elif type_name == ExperimentRun.__name__:
|
|
956
971
|
async with info.context.db() as session:
|
|
957
972
|
if not (
|
|
@@ -27,7 +27,7 @@ class ExperimentRun(V1RoutesBaseModel):
|
|
|
27
27
|
description="The ID of the dataset example used in the experiment run"
|
|
28
28
|
)
|
|
29
29
|
output: Any = Field(description="The output of the experiment task")
|
|
30
|
-
repetition_number: int = Field(description="The repetition number of the experiment run")
|
|
30
|
+
repetition_number: int = Field(description="The repetition number of the experiment run", gt=0)
|
|
31
31
|
start_time: datetime = Field(description="The start time of the experiment run")
|
|
32
32
|
end_time: datetime = Field(description="The end time of the experiment run")
|
|
33
33
|
trace_id: Optional[str] = Field(
|
|
@@ -46,7 +46,7 @@ class Experiment(V1RoutesBaseModel):
|
|
|
46
46
|
dataset_version_id: str = Field(
|
|
47
47
|
description="The ID of the dataset version associated with the experiment"
|
|
48
48
|
)
|
|
49
|
-
repetitions: int = Field(description="Number of times the experiment is repeated")
|
|
49
|
+
repetitions: int = Field(description="Number of times the experiment is repeated", gt=0)
|
|
50
50
|
metadata: dict[str, Any] = Field(description="Metadata of the experiment")
|
|
51
51
|
project_name: Optional[str] = Field(
|
|
52
52
|
description="The name of the project associated with the experiment"
|
|
@@ -10,8 +10,12 @@ from strawberry.types import Info
|
|
|
10
10
|
|
|
11
11
|
from phoenix.db import models
|
|
12
12
|
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
13
14
|
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
14
15
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
16
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
17
|
+
ExperimentRepeatedRunGroup,
|
|
18
|
+
)
|
|
15
19
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
16
20
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
17
21
|
from phoenix.server.api.types.pagination import (
|
|
@@ -65,6 +69,7 @@ class DatasetExample(Node):
|
|
|
65
69
|
last: Optional[int] = UNSET,
|
|
66
70
|
after: Optional[CursorString] = UNSET,
|
|
67
71
|
before: Optional[CursorString] = UNSET,
|
|
72
|
+
experiment_ids: Optional[list[GlobalID]] = UNSET,
|
|
68
73
|
) -> Connection[ExperimentRun]:
|
|
69
74
|
args = ConnectionArgs(
|
|
70
75
|
first=first,
|
|
@@ -78,8 +83,51 @@ class DatasetExample(Node):
|
|
|
78
83
|
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
79
84
|
.join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
|
|
80
85
|
.where(models.ExperimentRun.dataset_example_id == example_id)
|
|
81
|
-
.order_by(
|
|
86
|
+
.order_by(
|
|
87
|
+
models.ExperimentRun.experiment_id.asc(),
|
|
88
|
+
models.ExperimentRun.repetition_number.asc(),
|
|
89
|
+
)
|
|
82
90
|
)
|
|
91
|
+
if experiment_ids:
|
|
92
|
+
experiment_db_ids = [
|
|
93
|
+
from_global_id_with_expected_type(
|
|
94
|
+
global_id=experiment_id,
|
|
95
|
+
expected_type_name=models.Experiment.__name__,
|
|
96
|
+
)
|
|
97
|
+
for experiment_id in experiment_ids or []
|
|
98
|
+
]
|
|
99
|
+
query = query.where(models.ExperimentRun.experiment_id.in_(experiment_db_ids))
|
|
83
100
|
async with info.context.db() as session:
|
|
84
101
|
runs = (await session.scalars(query)).all()
|
|
85
102
|
return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
|
|
103
|
+
|
|
104
|
+
@strawberry.field
|
|
105
|
+
async def experiment_repeated_run_groups(
|
|
106
|
+
self,
|
|
107
|
+
info: Info[Context, None],
|
|
108
|
+
experiment_ids: list[GlobalID],
|
|
109
|
+
) -> list[ExperimentRepeatedRunGroup]:
|
|
110
|
+
example_rowid = self.id_attr
|
|
111
|
+
experiment_rowids = []
|
|
112
|
+
for experiment_id in experiment_ids:
|
|
113
|
+
try:
|
|
114
|
+
experiment_rowid = from_global_id_with_expected_type(
|
|
115
|
+
global_id=experiment_id,
|
|
116
|
+
expected_type_name=models.Experiment.__name__,
|
|
117
|
+
)
|
|
118
|
+
except Exception:
|
|
119
|
+
raise BadRequest(f"Invalid experiment ID: {experiment_id}")
|
|
120
|
+
experiment_rowids.append(experiment_rowid)
|
|
121
|
+
repeated_run_groups = (
|
|
122
|
+
await info.context.data_loaders.experiment_repeated_run_groups.load_many(
|
|
123
|
+
[(experiment_rowid, example_rowid) for experiment_rowid in experiment_rowids]
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
return [
|
|
127
|
+
ExperimentRepeatedRunGroup(
|
|
128
|
+
experiment_rowid=group.experiment_rowid,
|
|
129
|
+
dataset_example_rowid=group.dataset_example_rowid,
|
|
130
|
+
runs=[to_gql_experiment_run(run) for run in group.runs],
|
|
131
|
+
)
|
|
132
|
+
for group in repeated_run_groups
|
|
133
|
+
]
|
|
@@ -5,13 +5,14 @@ import strawberry
|
|
|
5
5
|
from sqlalchemy import func, select
|
|
6
6
|
from sqlalchemy.orm import joinedload
|
|
7
7
|
from strawberry import UNSET, Private
|
|
8
|
-
from strawberry.relay import Connection, Node, NodeID
|
|
8
|
+
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
9
|
from strawberry.scalars import JSON
|
|
10
10
|
from strawberry.types import Info
|
|
11
11
|
|
|
12
12
|
from phoenix.db import models
|
|
13
13
|
from phoenix.server.api.context import Context
|
|
14
14
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
15
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
15
16
|
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
16
17
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
17
18
|
from phoenix.server.api.types.pagination import (
|
|
@@ -32,6 +33,7 @@ class Experiment(Node):
|
|
|
32
33
|
name: str
|
|
33
34
|
project_name: Optional[str]
|
|
34
35
|
description: Optional[str]
|
|
36
|
+
dataset_version_id: GlobalID
|
|
35
37
|
metadata: JSON
|
|
36
38
|
created_at: datetime
|
|
37
39
|
updated_at: datetime
|
|
@@ -71,7 +73,10 @@ class Experiment(Node):
|
|
|
71
73
|
await session.scalars(
|
|
72
74
|
select(models.ExperimentRun)
|
|
73
75
|
.where(models.ExperimentRun.experiment_id == experiment_id)
|
|
74
|
-
.order_by(
|
|
76
|
+
.order_by(
|
|
77
|
+
models.ExperimentRun.dataset_example_id.asc(),
|
|
78
|
+
models.ExperimentRun.repetition_number.asc(),
|
|
79
|
+
)
|
|
75
80
|
.options(
|
|
76
81
|
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
77
82
|
)
|
|
@@ -187,6 +192,10 @@ class Experiment(Node):
|
|
|
187
192
|
async for token_type, is_prompt, cost, tokens in data
|
|
188
193
|
]
|
|
189
194
|
|
|
195
|
+
@strawberry.field
|
|
196
|
+
async def repetition_count(self, info: Info[Context, None]) -> int:
|
|
197
|
+
return await info.context.data_loaders.experiment_repetition_counts.load(self.id_attr)
|
|
198
|
+
|
|
190
199
|
|
|
191
200
|
def to_gql_experiment(
|
|
192
201
|
experiment: models.Experiment,
|
|
@@ -201,6 +210,7 @@ def to_gql_experiment(
|
|
|
201
210
|
name=experiment.name,
|
|
202
211
|
project_name=experiment.project_name,
|
|
203
212
|
description=experiment.description,
|
|
213
|
+
dataset_version_id=GlobalID(DatasetVersion.__name__, str(experiment.dataset_version_id)),
|
|
204
214
|
metadata=experiment.metadata_,
|
|
205
215
|
created_at=experiment.created_at,
|
|
206
216
|
updated_at=experiment.updated_at,
|
|
@@ -1,18 +1,12 @@
|
|
|
1
1
|
import strawberry
|
|
2
|
-
from strawberry.relay import
|
|
2
|
+
from strawberry.relay import Node, NodeID
|
|
3
3
|
|
|
4
4
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
5
|
-
from phoenix.server.api.types.
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
@strawberry.type
|
|
9
|
-
class RunComparisonItem:
|
|
10
|
-
experiment_id: GlobalID
|
|
11
|
-
runs: list[ExperimentRun]
|
|
5
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import ExperimentRepeatedRunGroup
|
|
12
6
|
|
|
13
7
|
|
|
14
8
|
@strawberry.type
|
|
15
9
|
class ExperimentComparison(Node):
|
|
16
10
|
id_attr: NodeID[int]
|
|
17
11
|
example: DatasetExample
|
|
18
|
-
|
|
12
|
+
repeated_run_groups: list[ExperimentRepeatedRunGroup]
|