arize-phoenix 11.32.1__py3-none-any.whl → 11.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/METADATA +1 -1
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/RECORD +57 -50
- phoenix/config.py +44 -0
- phoenix/db/bulk_inserter.py +111 -116
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/context.py +20 -0
- phoenix/server/api/dataloaders/__init__.py +20 -0
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +59 -0
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +39 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/helpers/playground_clients.py +4 -0
- phoenix/server/api/mutations/prompt_label_mutations.py +67 -58
- phoenix/server/api/queries.py +52 -37
- phoenix/server/api/routers/v1/documents.py +1 -1
- phoenix/server/api/routers/v1/evaluations.py +4 -4
- phoenix/server/api/routers/v1/experiment_runs.py +1 -1
- phoenix/server/api/routers/v1/experiments.py +1 -1
- phoenix/server/api/routers/v1/spans.py +2 -2
- phoenix/server/api/routers/v1/traces.py +18 -3
- phoenix/server/api/types/DatasetExample.py +49 -1
- phoenix/server/api/types/Experiment.py +12 -2
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +146 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +12 -19
- phoenix/server/api/types/Prompt.py +11 -0
- phoenix/server/api/types/PromptLabel.py +2 -19
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/app.py +78 -20
- phoenix/server/cost_tracking/model_cost_manifest.json +1 -1
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/prometheus.py +30 -6
- phoenix/server/static/.vite/manifest.json +43 -43
- phoenix/server/static/assets/components-CdQiQTvs.js +5778 -0
- phoenix/server/static/assets/{index-D1FDMBMV.js → index-B1VuXYRI.js} +12 -21
- phoenix/server/static/assets/pages-CnfZ3RhB.js +9163 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-Cfrr9FCF.js +903 -0
- phoenix/server/static/assets/{vendor-arizeai-DsYDNOqt.js → vendor-arizeai-Dz0kN-lQ.js} +4 -4
- phoenix/server/static/assets/vendor-codemirror-ClqtONZQ.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-BTHn5Y2R.js → vendor-recharts-D6kvOpmb.js} +2 -2
- phoenix/server/static/assets/{vendor-shiki-BAcocHFl.js → vendor-shiki-xSOiKxt0.js} +1 -1
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +13 -0
- phoenix/trace/projects.py +1 -2
- phoenix/version.py +1 -1
- phoenix/server/static/assets/components-Cs9c4Nxp.js +0 -5698
- phoenix/server/static/assets/pages-Cbj9SjBx.js +0 -8928
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-RdRDaQiR.js +0 -905
- phoenix/server/static/assets/vendor-codemirror-BzJDUbEx.js +0 -25
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import func, select, tuple_
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
ExperimentID: TypeAlias = int
|
|
12
|
+
DatasetExampleID: TypeAlias = int
|
|
13
|
+
AnnotationName: TypeAlias = str
|
|
14
|
+
MeanAnnotationScore: TypeAlias = float
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class AnnotationSummary:
|
|
19
|
+
annotation_name: AnnotationName
|
|
20
|
+
mean_score: Optional[MeanAnnotationScore]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
|
|
24
|
+
Result: TypeAlias = list[AnnotationSummary]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ExperimentRepeatedRunGroupAnnotationSummariesDataLoader(DataLoader[Key, Result]):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
db: DbSessionFactory,
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__(load_fn=self._load_fn)
|
|
33
|
+
self._db = db
|
|
34
|
+
|
|
35
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
36
|
+
annotation_summaries_query = (
|
|
37
|
+
select(
|
|
38
|
+
models.ExperimentRun.experiment_id.label("experiment_id"),
|
|
39
|
+
models.ExperimentRun.dataset_example_id.label("dataset_example_id"),
|
|
40
|
+
models.ExperimentRunAnnotation.name.label("annotation_name"),
|
|
41
|
+
func.avg(models.ExperimentRunAnnotation.score).label("mean_score"),
|
|
42
|
+
)
|
|
43
|
+
.select_from(models.ExperimentRunAnnotation)
|
|
44
|
+
.join(
|
|
45
|
+
models.ExperimentRun,
|
|
46
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
47
|
+
)
|
|
48
|
+
.where(
|
|
49
|
+
tuple_(
|
|
50
|
+
models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
|
|
51
|
+
).in_(set(keys))
|
|
52
|
+
)
|
|
53
|
+
.group_by(
|
|
54
|
+
models.ExperimentRun.experiment_id,
|
|
55
|
+
models.ExperimentRun.dataset_example_id,
|
|
56
|
+
models.ExperimentRunAnnotation.name,
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
async with self._db() as session:
|
|
60
|
+
annotation_summaries = (await session.execute(annotation_summaries_query)).all()
|
|
61
|
+
annotation_summaries_by_key: dict[Key, list[AnnotationSummary]] = {}
|
|
62
|
+
for summary in annotation_summaries:
|
|
63
|
+
key = (summary.experiment_id, summary.dataset_example_id)
|
|
64
|
+
gql_summary = AnnotationSummary(
|
|
65
|
+
annotation_name=summary.annotation_name,
|
|
66
|
+
mean_score=summary.mean_score,
|
|
67
|
+
)
|
|
68
|
+
if key not in annotation_summaries_by_key:
|
|
69
|
+
annotation_summaries_by_key[key] = []
|
|
70
|
+
annotation_summaries_by_key[key].append(gql_summary)
|
|
71
|
+
return [
|
|
72
|
+
sorted(
|
|
73
|
+
annotation_summaries_by_key.get(key, []),
|
|
74
|
+
key=lambda summary: summary.annotation_name,
|
|
75
|
+
)
|
|
76
|
+
for key in keys
|
|
77
|
+
]
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select, tuple_
|
|
4
|
+
from sqlalchemy.orm import joinedload
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
ExperimentID: TypeAlias = int
|
|
12
|
+
DatasetExampleID: TypeAlias = int
|
|
13
|
+
Key: TypeAlias = tuple[ExperimentID, DatasetExampleID]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ExperimentRepeatedRunGroup:
|
|
18
|
+
experiment_rowid: int
|
|
19
|
+
dataset_example_rowid: int
|
|
20
|
+
runs: list[models.ExperimentRun]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
Result: TypeAlias = ExperimentRepeatedRunGroup
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ExperimentRepeatedRunGroupsDataLoader(DataLoader[Key, Result]):
|
|
27
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
28
|
+
super().__init__(load_fn=self._load_fn)
|
|
29
|
+
self._db = db
|
|
30
|
+
|
|
31
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
32
|
+
repeated_run_groups_query = (
|
|
33
|
+
select(models.ExperimentRun)
|
|
34
|
+
.where(
|
|
35
|
+
tuple_(
|
|
36
|
+
models.ExperimentRun.experiment_id,
|
|
37
|
+
models.ExperimentRun.dataset_example_id,
|
|
38
|
+
).in_(set(keys))
|
|
39
|
+
)
|
|
40
|
+
.order_by(models.ExperimentRun.repetition_number)
|
|
41
|
+
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
async with self._db() as session:
|
|
45
|
+
runs_by_key: dict[Key, list[models.ExperimentRun]] = {}
|
|
46
|
+
for run in (await session.scalars(repeated_run_groups_query)).all():
|
|
47
|
+
key = (run.experiment_id, run.dataset_example_id)
|
|
48
|
+
if key not in runs_by_key:
|
|
49
|
+
runs_by_key[key] = []
|
|
50
|
+
runs_by_key[key].append(run)
|
|
51
|
+
|
|
52
|
+
return [
|
|
53
|
+
ExperimentRepeatedRunGroup(
|
|
54
|
+
experiment_rowid=experiment_id,
|
|
55
|
+
dataset_example_rowid=dataset_example_id,
|
|
56
|
+
runs=runs_by_key.get((experiment_id, dataset_example_id), []),
|
|
57
|
+
)
|
|
58
|
+
for (experiment_id, dataset_example_id) in keys
|
|
59
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from sqlalchemy import func, select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
ExperimentID: TypeAlias = int
|
|
9
|
+
RepetitionCount: TypeAlias = int
|
|
10
|
+
Key: TypeAlias = ExperimentID
|
|
11
|
+
Result: TypeAlias = RepetitionCount
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ExperimentRepetitionCountsDataLoader(DataLoader[Key, Result]):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
db: DbSessionFactory,
|
|
18
|
+
) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
experiment_ids = keys
|
|
24
|
+
repetition_counts_query = (
|
|
25
|
+
select(
|
|
26
|
+
models.ExperimentRun.experiment_id,
|
|
27
|
+
func.max(models.ExperimentRun.repetition_number).label("repetition_count"),
|
|
28
|
+
)
|
|
29
|
+
.group_by(models.ExperimentRun.experiment_id)
|
|
30
|
+
.where(models.ExperimentRun.experiment_id.in_(experiment_ids))
|
|
31
|
+
)
|
|
32
|
+
async with self._db() as session:
|
|
33
|
+
repetition_counts = {
|
|
34
|
+
experiment_id: repetition_count
|
|
35
|
+
for experiment_id, repetition_count in await session.execute(
|
|
36
|
+
repetition_counts_query
|
|
37
|
+
)
|
|
38
|
+
}
|
|
39
|
+
return [repetition_counts.get(experiment_id, 0) for experiment_id in keys]
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import func, select, tuple_
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
ExperimentId: TypeAlias = int
|
|
12
|
+
DatasetExampleId: TypeAlias = int
|
|
13
|
+
Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
|
|
14
|
+
Result: TypeAlias = SpanCostSummary
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpanCostSummaryByExperimentRepeatedRunGroupDataLoader(DataLoader[Key, Result]):
|
|
18
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
19
|
+
super().__init__(load_fn=self._load_fn)
|
|
20
|
+
self._db = db
|
|
21
|
+
|
|
22
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
23
|
+
stmt = (
|
|
24
|
+
select(
|
|
25
|
+
models.ExperimentRun.experiment_id,
|
|
26
|
+
models.ExperimentRun.dataset_example_id,
|
|
27
|
+
func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
|
|
28
|
+
func.sum(models.SpanCost.completion_cost).label("completion_cost"),
|
|
29
|
+
func.sum(models.SpanCost.total_cost).label("total_cost"),
|
|
30
|
+
func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
|
|
31
|
+
func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
|
|
32
|
+
func.sum(models.SpanCost.total_tokens).label("total_tokens"),
|
|
33
|
+
)
|
|
34
|
+
.select_from(models.ExperimentRun)
|
|
35
|
+
.join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
36
|
+
.join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
|
|
37
|
+
.where(
|
|
38
|
+
tuple_(
|
|
39
|
+
models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
|
|
40
|
+
).in_(set(keys))
|
|
41
|
+
)
|
|
42
|
+
.group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
|
|
46
|
+
async with self._db() as session:
|
|
47
|
+
data = await session.stream(stmt)
|
|
48
|
+
async for (
|
|
49
|
+
experiment_id,
|
|
50
|
+
dataset_example_id,
|
|
51
|
+
prompt_cost,
|
|
52
|
+
completion_cost,
|
|
53
|
+
total_cost,
|
|
54
|
+
prompt_tokens,
|
|
55
|
+
completion_tokens,
|
|
56
|
+
total_tokens,
|
|
57
|
+
) in data:
|
|
58
|
+
summary = SpanCostSummary(
|
|
59
|
+
prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
|
|
60
|
+
completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
|
|
61
|
+
total=CostBreakdown(tokens=total_tokens, cost=total_cost),
|
|
62
|
+
)
|
|
63
|
+
results[(experiment_id, dataset_example_id)] = summary
|
|
64
|
+
return [results.get(key, SpanCostSummary()) for key in keys]
|
|
@@ -1669,7 +1669,11 @@ class AnthropicReasoningStreamingClient(AnthropicStreamingClient):
|
|
|
1669
1669
|
provider_key=GenerativeProviderKey.GOOGLE,
|
|
1670
1670
|
model_names=[
|
|
1671
1671
|
PROVIDER_DEFAULT,
|
|
1672
|
+
"gemini-2.5-flash",
|
|
1673
|
+
"gemini-2.5-flash-lite",
|
|
1674
|
+
"gemini-2.5-pro",
|
|
1672
1675
|
"gemini-2.5-pro-preview-03-25",
|
|
1676
|
+
"gemini-2.0-flash",
|
|
1673
1677
|
"gemini-2.0-flash-lite",
|
|
1674
1678
|
"gemini-2.0-flash-001",
|
|
1675
1679
|
"gemini-2.0-flash-thinking-exp-01-21",
|
|
@@ -10,12 +10,10 @@ from strawberry.relay import GlobalID
|
|
|
10
10
|
from strawberry.types import Info
|
|
11
11
|
|
|
12
12
|
from phoenix.db import models
|
|
13
|
-
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
14
13
|
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
15
14
|
from phoenix.server.api.context import Context
|
|
16
15
|
from phoenix.server.api.exceptions import Conflict, NotFound
|
|
17
16
|
from phoenix.server.api.queries import Query
|
|
18
|
-
from phoenix.server.api.types.Identifier import Identifier
|
|
19
17
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
18
|
from phoenix.server.api.types.Prompt import Prompt
|
|
21
19
|
from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
|
|
@@ -23,37 +21,49 @@ from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_labe
|
|
|
23
21
|
|
|
24
22
|
@strawberry.input
|
|
25
23
|
class CreatePromptLabelInput:
|
|
26
|
-
name:
|
|
24
|
+
name: str
|
|
27
25
|
description: Optional[str] = None
|
|
26
|
+
color: str
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
@strawberry.input
|
|
31
30
|
class PatchPromptLabelInput:
|
|
32
31
|
prompt_label_id: GlobalID
|
|
33
|
-
name: Optional[
|
|
32
|
+
name: Optional[str] = None
|
|
34
33
|
description: Optional[str] = None
|
|
35
34
|
|
|
36
35
|
|
|
37
36
|
@strawberry.input
|
|
38
|
-
class
|
|
39
|
-
|
|
37
|
+
class DeletePromptLabelsInput:
|
|
38
|
+
prompt_label_ids: list[GlobalID]
|
|
40
39
|
|
|
41
40
|
|
|
42
41
|
@strawberry.input
|
|
43
|
-
class
|
|
42
|
+
class SetPromptLabelsInput:
|
|
44
43
|
prompt_id: GlobalID
|
|
45
|
-
|
|
44
|
+
prompt_label_ids: list[GlobalID]
|
|
46
45
|
|
|
47
46
|
|
|
48
47
|
@strawberry.input
|
|
49
|
-
class
|
|
48
|
+
class UnsetPromptLabelsInput:
|
|
50
49
|
prompt_id: GlobalID
|
|
51
|
-
|
|
50
|
+
prompt_label_ids: list[GlobalID]
|
|
52
51
|
|
|
53
52
|
|
|
54
53
|
@strawberry.type
|
|
55
54
|
class PromptLabelMutationPayload:
|
|
56
|
-
|
|
55
|
+
prompt_labels: list["PromptLabel"]
|
|
56
|
+
query: "Query"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@strawberry.type
|
|
60
|
+
class PromptLabelDeleteMutationPayload:
|
|
61
|
+
deleted_prompt_label_ids: list["GlobalID"]
|
|
62
|
+
query: "Query"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@strawberry.type
|
|
66
|
+
class PromptLabelAssociationMutationPayload:
|
|
57
67
|
query: "Query"
|
|
58
68
|
|
|
59
69
|
|
|
@@ -64,17 +74,18 @@ class PromptLabelMutationMixin:
|
|
|
64
74
|
self, info: Info[Context, None], input: CreatePromptLabelInput
|
|
65
75
|
) -> PromptLabelMutationPayload:
|
|
66
76
|
async with info.context.db() as session:
|
|
67
|
-
|
|
68
|
-
|
|
77
|
+
label_orm = models.PromptLabel(
|
|
78
|
+
name=input.name, description=input.description, color=input.color
|
|
79
|
+
)
|
|
69
80
|
session.add(label_orm)
|
|
70
81
|
|
|
71
82
|
try:
|
|
72
83
|
await session.commit()
|
|
73
84
|
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
74
|
-
raise Conflict(f"A prompt label named '{name}' already exists.")
|
|
85
|
+
raise Conflict(f"A prompt label named '{input.name}' already exists.")
|
|
75
86
|
|
|
76
87
|
return PromptLabelMutationPayload(
|
|
77
|
-
|
|
88
|
+
prompt_labels=[to_gql_prompt_label(label_orm)],
|
|
78
89
|
query=Query(),
|
|
79
90
|
)
|
|
80
91
|
|
|
@@ -82,7 +93,6 @@ class PromptLabelMutationMixin:
|
|
|
82
93
|
async def patch_prompt_label(
|
|
83
94
|
self, info: Info[Context, None], input: PatchPromptLabelInput
|
|
84
95
|
) -> PromptLabelMutationPayload:
|
|
85
|
-
validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None
|
|
86
96
|
async with info.context.db() as session:
|
|
87
97
|
label_id = from_global_id_with_expected_type(
|
|
88
98
|
input.prompt_label_id, PromptLabel.__name__
|
|
@@ -92,8 +102,8 @@ class PromptLabelMutationMixin:
|
|
|
92
102
|
if not label_orm:
|
|
93
103
|
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
94
104
|
|
|
95
|
-
if
|
|
96
|
-
label_orm.name =
|
|
105
|
+
if input.name is not None:
|
|
106
|
+
label_orm.name = input.name
|
|
97
107
|
if input.description is not None:
|
|
98
108
|
label_orm.description = input.description
|
|
99
109
|
|
|
@@ -103,46 +113,48 @@ class PromptLabelMutationMixin:
|
|
|
103
113
|
raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
|
|
104
114
|
|
|
105
115
|
return PromptLabelMutationPayload(
|
|
106
|
-
|
|
116
|
+
prompt_labels=[to_gql_prompt_label(label_orm)],
|
|
107
117
|
query=Query(),
|
|
108
118
|
)
|
|
109
119
|
|
|
110
120
|
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
111
|
-
async def
|
|
112
|
-
self, info: Info[Context, None], input:
|
|
113
|
-
) ->
|
|
121
|
+
async def delete_prompt_labels(
|
|
122
|
+
self, info: Info[Context, None], input: DeletePromptLabelsInput
|
|
123
|
+
) -> PromptLabelDeleteMutationPayload:
|
|
114
124
|
"""
|
|
115
125
|
Deletes a PromptLabel (and any crosswalk references).
|
|
116
126
|
"""
|
|
117
127
|
async with info.context.db() as session:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
if result.rowcount == 0:
|
|
125
|
-
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
128
|
+
label_ids = [
|
|
129
|
+
from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
|
|
130
|
+
for prompt_label_id in input.prompt_label_ids
|
|
131
|
+
]
|
|
132
|
+
stmt = delete(models.PromptLabel).where(models.PromptLabel.id.in_(label_ids))
|
|
133
|
+
await session.execute(stmt)
|
|
126
134
|
|
|
127
135
|
await session.commit()
|
|
128
136
|
|
|
129
|
-
return
|
|
130
|
-
|
|
137
|
+
return PromptLabelDeleteMutationPayload(
|
|
138
|
+
deleted_prompt_label_ids=input.prompt_label_ids,
|
|
131
139
|
query=Query(),
|
|
132
140
|
)
|
|
133
141
|
|
|
134
142
|
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
135
|
-
async def
|
|
136
|
-
self, info: Info[Context, None], input:
|
|
137
|
-
) ->
|
|
143
|
+
async def set_prompt_labels(
|
|
144
|
+
self, info: Info[Context, None], input: SetPromptLabelsInput
|
|
145
|
+
) -> PromptLabelAssociationMutationPayload:
|
|
138
146
|
async with info.context.db() as session:
|
|
139
147
|
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
148
|
+
label_ids = [
|
|
149
|
+
from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
|
|
150
|
+
for prompt_label_id in input.prompt_label_ids
|
|
151
|
+
]
|
|
143
152
|
|
|
144
|
-
|
|
145
|
-
|
|
153
|
+
crosswalk_items = [
|
|
154
|
+
models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
|
|
155
|
+
for label_id in label_ids
|
|
156
|
+
]
|
|
157
|
+
session.add_all(crosswalk_items)
|
|
146
158
|
|
|
147
159
|
try:
|
|
148
160
|
await session.commit()
|
|
@@ -152,41 +164,38 @@ class PromptLabelMutationMixin:
|
|
|
152
164
|
# - Foreign key violation => prompt_id or label_id doesn't exist
|
|
153
165
|
raise Conflict("Failed to associate PromptLabel with Prompt.") from e
|
|
154
166
|
|
|
155
|
-
|
|
156
|
-
if not label_orm:
|
|
157
|
-
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
158
|
-
|
|
159
|
-
return PromptLabelMutationPayload(
|
|
160
|
-
prompt_label=to_gql_prompt_label(label_orm),
|
|
167
|
+
return PromptLabelAssociationMutationPayload(
|
|
161
168
|
query=Query(),
|
|
162
169
|
)
|
|
163
170
|
|
|
164
171
|
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
165
|
-
async def
|
|
166
|
-
self, info: Info[Context, None], input:
|
|
167
|
-
) ->
|
|
172
|
+
async def unset_prompt_labels(
|
|
173
|
+
self, info: Info[Context, None], input: UnsetPromptLabelsInput
|
|
174
|
+
) -> PromptLabelAssociationMutationPayload:
|
|
168
175
|
"""
|
|
169
176
|
Unsets a PromptLabel from a Prompt by removing the row in the crosswalk.
|
|
170
177
|
"""
|
|
171
178
|
async with info.context.db() as session:
|
|
172
179
|
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
180
|
+
label_ids = [
|
|
181
|
+
from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
|
|
182
|
+
for prompt_label_id in input.prompt_label_ids
|
|
183
|
+
]
|
|
176
184
|
|
|
177
185
|
stmt = delete(models.PromptPromptLabel).where(
|
|
178
186
|
(models.PromptPromptLabel.prompt_id == prompt_id)
|
|
179
|
-
& (models.PromptPromptLabel.prompt_label_id
|
|
187
|
+
& (models.PromptPromptLabel.prompt_label_id.in_(label_ids))
|
|
180
188
|
)
|
|
181
189
|
result = await session.execute(stmt)
|
|
182
190
|
|
|
183
|
-
if result.rowcount
|
|
184
|
-
|
|
191
|
+
if result.rowcount != len(label_ids):
|
|
192
|
+
label_ids_str = ", ".join(str(i) for i in label_ids)
|
|
193
|
+
raise NotFound(
|
|
194
|
+
f"No association between prompt={prompt_id} and labels={label_ids_str}."
|
|
195
|
+
)
|
|
185
196
|
|
|
186
197
|
await session.commit()
|
|
187
198
|
|
|
188
|
-
|
|
189
|
-
return PromptLabelMutationPayload(
|
|
190
|
-
prompt_label=to_gql_prompt_label(label_orm) if label_orm else None,
|
|
199
|
+
return PromptLabelAssociationMutationPayload(
|
|
191
200
|
query=Query(),
|
|
192
201
|
)
|
phoenix/server/api/queries.py
CHANGED
|
@@ -56,15 +56,25 @@ from phoenix.server.api.types.EmbeddingDimension import (
|
|
|
56
56
|
to_gql_embedding_dimension,
|
|
57
57
|
)
|
|
58
58
|
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
59
|
-
from phoenix.server.api.types.Experiment import Experiment
|
|
60
|
-
from phoenix.server.api.types.ExperimentComparison import
|
|
59
|
+
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
60
|
+
from phoenix.server.api.types.ExperimentComparison import (
|
|
61
|
+
ExperimentComparison,
|
|
62
|
+
)
|
|
63
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
64
|
+
ExperimentRepeatedRunGroup,
|
|
65
|
+
parse_experiment_repeated_run_group_node_id,
|
|
66
|
+
)
|
|
61
67
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
62
68
|
from phoenix.server.api.types.Functionality import Functionality
|
|
63
69
|
from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
|
|
64
70
|
from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
|
|
65
71
|
from phoenix.server.api.types.InferenceModel import InferenceModel
|
|
66
72
|
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
67
|
-
from phoenix.server.api.types.node import
|
|
73
|
+
from phoenix.server.api.types.node import (
|
|
74
|
+
from_global_id,
|
|
75
|
+
from_global_id_with_expected_type,
|
|
76
|
+
is_global_id,
|
|
77
|
+
)
|
|
68
78
|
from phoenix.server.api.types.pagination import (
|
|
69
79
|
ConnectionArgs,
|
|
70
80
|
Cursor,
|
|
@@ -513,11 +523,12 @@ class Query:
|
|
|
513
523
|
|
|
514
524
|
cursors_and_nodes = []
|
|
515
525
|
for example in examples:
|
|
516
|
-
|
|
526
|
+
repeated_run_groups = []
|
|
517
527
|
for experiment_id in experiment_rowids:
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
528
|
+
repeated_run_groups.append(
|
|
529
|
+
ExperimentRepeatedRunGroup(
|
|
530
|
+
experiment_rowid=experiment_id,
|
|
531
|
+
dataset_example_rowid=example.id,
|
|
521
532
|
runs=[
|
|
522
533
|
to_gql_experiment_run(run)
|
|
523
534
|
for run in sorted(
|
|
@@ -533,7 +544,7 @@ class Query:
|
|
|
533
544
|
created_at=example.created_at,
|
|
534
545
|
version_id=base_experiment.dataset_version_id,
|
|
535
546
|
),
|
|
536
|
-
|
|
547
|
+
repeated_run_groups=repeated_run_groups,
|
|
537
548
|
)
|
|
538
549
|
cursors_and_nodes.append((Cursor(rowid=example.id), experiment_comparison))
|
|
539
550
|
|
|
@@ -863,8 +874,37 @@ class Query:
|
|
|
863
874
|
return InferenceModel()
|
|
864
875
|
|
|
865
876
|
@strawberry.field
|
|
866
|
-
async def node(self, id:
|
|
867
|
-
|
|
877
|
+
async def node(self, id: strawberry.ID, info: Info[Context, None]) -> Node:
|
|
878
|
+
if not is_global_id(id):
|
|
879
|
+
try:
|
|
880
|
+
experiment_rowid, dataset_example_rowid = (
|
|
881
|
+
parse_experiment_repeated_run_group_node_id(id)
|
|
882
|
+
)
|
|
883
|
+
except Exception:
|
|
884
|
+
raise NotFound(f"Unknown node: {id}")
|
|
885
|
+
|
|
886
|
+
async with info.context.db() as session:
|
|
887
|
+
runs = (
|
|
888
|
+
await session.scalars(
|
|
889
|
+
select(models.ExperimentRun)
|
|
890
|
+
.where(models.ExperimentRun.experiment_id == experiment_rowid)
|
|
891
|
+
.where(models.ExperimentRun.dataset_example_id == dataset_example_rowid)
|
|
892
|
+
.order_by(models.ExperimentRun.repetition_number.asc())
|
|
893
|
+
.options(
|
|
894
|
+
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
895
|
+
)
|
|
896
|
+
)
|
|
897
|
+
).all()
|
|
898
|
+
if not runs:
|
|
899
|
+
raise NotFound(f"Unknown experiment or dataset example: {id}")
|
|
900
|
+
return ExperimentRepeatedRunGroup(
|
|
901
|
+
experiment_rowid=experiment_rowid,
|
|
902
|
+
dataset_example_rowid=dataset_example_rowid,
|
|
903
|
+
runs=[to_gql_experiment_run(run) for run in runs],
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
global_id = GlobalID.from_id(id)
|
|
907
|
+
type_name, node_id = from_global_id(global_id)
|
|
868
908
|
if type_name == "Dimension":
|
|
869
909
|
dimension = info.context.model.scalar_dimensions[node_id]
|
|
870
910
|
return to_gql_dimension(node_id, dimension)
|
|
@@ -909,26 +949,9 @@ class Query:
|
|
|
909
949
|
return to_gql_dataset(dataset)
|
|
910
950
|
elif type_name == DatasetExample.__name__:
|
|
911
951
|
example_id = node_id
|
|
912
|
-
latest_revision_id = (
|
|
913
|
-
select(func.max(models.DatasetExampleRevision.id))
|
|
914
|
-
.where(models.DatasetExampleRevision.dataset_example_id == example_id)
|
|
915
|
-
.scalar_subquery()
|
|
916
|
-
)
|
|
917
952
|
async with info.context.db() as session:
|
|
918
953
|
example = await session.scalar(
|
|
919
|
-
select(models.DatasetExample)
|
|
920
|
-
.join(
|
|
921
|
-
models.DatasetExampleRevision,
|
|
922
|
-
onclause=models.DatasetExampleRevision.dataset_example_id
|
|
923
|
-
== models.DatasetExample.id,
|
|
924
|
-
)
|
|
925
|
-
.where(
|
|
926
|
-
and_(
|
|
927
|
-
models.DatasetExample.id == example_id,
|
|
928
|
-
models.DatasetExampleRevision.id == latest_revision_id,
|
|
929
|
-
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
930
|
-
)
|
|
931
|
-
)
|
|
954
|
+
select(models.DatasetExample).where(models.DatasetExample.id == example_id)
|
|
932
955
|
)
|
|
933
956
|
if not example:
|
|
934
957
|
raise NotFound(f"Unknown dataset example: {id}")
|
|
@@ -943,15 +966,7 @@ class Query:
|
|
|
943
966
|
)
|
|
944
967
|
if not experiment:
|
|
945
968
|
raise NotFound(f"Unknown experiment: {id}")
|
|
946
|
-
return
|
|
947
|
-
id_attr=experiment.id,
|
|
948
|
-
name=experiment.name,
|
|
949
|
-
project_name=experiment.project_name,
|
|
950
|
-
description=experiment.description,
|
|
951
|
-
created_at=experiment.created_at,
|
|
952
|
-
updated_at=experiment.updated_at,
|
|
953
|
-
metadata=experiment.metadata_,
|
|
954
|
-
)
|
|
969
|
+
return to_gql_experiment(experiment)
|
|
955
970
|
elif type_name == ExperimentRun.__name__:
|
|
956
971
|
async with info.context.db() as session:
|
|
957
972
|
if not (
|
|
@@ -82,7 +82,7 @@ async def annotate_span_documents(
|
|
|
82
82
|
annotation.as_precursor(user_id=user_id) for annotation in span_document_annotations
|
|
83
83
|
]
|
|
84
84
|
if not sync:
|
|
85
|
-
await request.state.
|
|
85
|
+
await request.state.enqueue_annotations(*precursors)
|
|
86
86
|
return AnnotateSpanDocumentsResponseBody(data=[])
|
|
87
87
|
|
|
88
88
|
span_ids = {p.span_id for p in precursors}
|