arize-phoenix 11.38.0__py3-none-any.whl → 12.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +3 -3
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +83 -58
- phoenix/config.py +1 -11
- phoenix/db/bulk_inserter.py +8 -0
- phoenix/db/facilitator.py +1 -1
- phoenix/db/helpers.py +202 -33
- phoenix/db/insertion/dataset.py +7 -0
- phoenix/db/insertion/document_annotation.py +1 -1
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +1 -1
- phoenix/db/insertion/trace_annotation.py +1 -1
- phoenix/db/insertion/types.py +29 -3
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +306 -46
- phoenix/server/api/context.py +15 -2
- phoenix/server/api/dataloaders/__init__.py +8 -2
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/helpers/playground_clients.py +66 -35
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/chat_mutations.py +8 -3
- phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
- phoenix/server/api/mutations/dataset_mutations.py +5 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +423 -0
- phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
- phoenix/server/api/queries.py +53 -0
- phoenix/server/api/routers/auth.py +5 -5
- phoenix/server/api/routers/oauth2.py +5 -23
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/annotations.py +320 -0
- phoenix/server/api/routers/v1/datasets.py +5 -0
- phoenix/server/api/routers/v1/experiments.py +10 -3
- phoenix/server/api/routers/v1/sessions.py +111 -0
- phoenix/server/api/routers/v1/traces.py +1 -2
- phoenix/server/api/routers/v1/users.py +7 -0
- phoenix/server/api/subscriptions.py +5 -2
- phoenix/server/api/types/Dataset.py +8 -0
- phoenix/server/api/types/DatasetExample.py +18 -0
- phoenix/server/api/types/DatasetLabel.py +23 -0
- phoenix/server/api/types/DatasetSplit.py +32 -0
- phoenix/server/api/types/Experiment.py +0 -4
- phoenix/server/api/types/Project.py +16 -0
- phoenix/server/api/types/ProjectSession.py +88 -3
- phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
- phoenix/server/api/types/Prompt.py +18 -1
- phoenix/server/api/types/Span.py +5 -5
- phoenix/server/api/types/Trace.py +61 -0
- phoenix/server/app.py +13 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
- phoenix/server/dml_event.py +13 -0
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-BQPHTBfv.js → components-BG6v0EM8.js} +705 -385
- phoenix/server/static/assets/{index-BL5BMgJU.js → index-CSVcULw1.js} +13 -13
- phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DgaM7kpM.js} +1356 -1155
- phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-BqTEkGQU.js} +183 -183
- phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
- phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts-CKsi4IjN.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-DN26BkKE.js} +1 -1
- phoenix/server/utils.py +74 -0
- phoenix/session/session.py +25 -5
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -26,6 +26,7 @@ from typing_extensions import TypeAlias, assert_never
|
|
|
26
26
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
27
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
28
|
from phoenix.db import models
|
|
29
|
+
from phoenix.db.helpers import insert_experiment_with_examples_snapshot
|
|
29
30
|
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
30
31
|
from phoenix.server.api.context import Context
|
|
31
32
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
@@ -43,6 +44,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
43
44
|
get_db_trace,
|
|
44
45
|
streaming_llm_span,
|
|
45
46
|
)
|
|
47
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
46
48
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
47
49
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
48
50
|
ChatCompletionInput,
|
|
@@ -302,6 +304,7 @@ class Subscription:
|
|
|
302
304
|
description="Traces from prompt playground",
|
|
303
305
|
)
|
|
304
306
|
)
|
|
307
|
+
user_id = get_user(info)
|
|
305
308
|
experiment = models.Experiment(
|
|
306
309
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
307
310
|
dataset_version_id=resolved_version_id,
|
|
@@ -311,9 +314,9 @@ class Subscription:
|
|
|
311
314
|
repetitions=input.repetitions,
|
|
312
315
|
metadata_=input.experiment_metadata or dict(),
|
|
313
316
|
project_name=project_name,
|
|
317
|
+
user_id=user_id,
|
|
314
318
|
)
|
|
315
|
-
session
|
|
316
|
-
await session.flush()
|
|
319
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
317
320
|
yield ChatCompletionSubscriptionExperiment(
|
|
318
321
|
experiment=to_gql_experiment(experiment)
|
|
319
322
|
) # eagerly yields experiment so it can be linked by consumers of the subscription
|
|
@@ -18,6 +18,7 @@ from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
|
18
18
|
from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
|
|
19
19
|
DatasetExperimentAnnotationSummary,
|
|
20
20
|
)
|
|
21
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
21
22
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
22
23
|
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
23
24
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
@@ -303,6 +304,13 @@ class Dataset(Node):
|
|
|
303
304
|
async for scores_tuple in await session.stream(query)
|
|
304
305
|
]
|
|
305
306
|
|
|
307
|
+
@strawberry.field
|
|
308
|
+
async def labels(self, info: Info[Context, None]) -> list[DatasetLabel]:
|
|
309
|
+
return [
|
|
310
|
+
to_gql_dataset_label(label)
|
|
311
|
+
for label in await info.context.data_loaders.dataset_labels.load(self.id_attr)
|
|
312
|
+
]
|
|
313
|
+
|
|
306
314
|
@strawberry.field
|
|
307
315
|
def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
308
316
|
return info.context.last_updated_at.get(self._table, self.id_attr)
|
|
@@ -12,6 +12,7 @@ from phoenix.db import models
|
|
|
12
12
|
from phoenix.server.api.context import Context
|
|
13
13
|
from phoenix.server.api.exceptions import BadRequest
|
|
14
14
|
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
15
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
15
16
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
16
17
|
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
17
18
|
ExperimentRepeatedRunGroup,
|
|
@@ -131,3 +132,20 @@ class DatasetExample(Node):
|
|
|
131
132
|
)
|
|
132
133
|
for group in repeated_run_groups
|
|
133
134
|
]
|
|
135
|
+
|
|
136
|
+
@strawberry.field
|
|
137
|
+
async def dataset_splits(
|
|
138
|
+
self,
|
|
139
|
+
info: Info[Context, None],
|
|
140
|
+
) -> list[DatasetSplit]:
|
|
141
|
+
return [
|
|
142
|
+
to_gql_dataset_split(split)
|
|
143
|
+
for split in await info.context.data_loaders.dataset_example_splits.load(self.id_attr)
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def to_gql_dataset_example(example: models.DatasetExample) -> DatasetExample:
|
|
148
|
+
return DatasetExample(
|
|
149
|
+
id_attr=example.id,
|
|
150
|
+
created_at=example.created_at,
|
|
151
|
+
)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry.relay import Node, NodeID
|
|
5
|
+
|
|
6
|
+
from phoenix.db import models
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@strawberry.type
|
|
10
|
+
class DatasetLabel(Node):
|
|
11
|
+
id_attr: NodeID[int]
|
|
12
|
+
name: str
|
|
13
|
+
description: Optional[str]
|
|
14
|
+
color: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def to_gql_dataset_label(dataset_label: models.DatasetLabel) -> DatasetLabel:
|
|
18
|
+
return DatasetLabel(
|
|
19
|
+
id_attr=dataset_label.id,
|
|
20
|
+
name=dataset_label.name,
|
|
21
|
+
description=dataset_label.description,
|
|
22
|
+
color=dataset_label.color,
|
|
23
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import ClassVar, Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry.relay import Node, NodeID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@strawberry.type
|
|
12
|
+
class DatasetSplit(Node):
|
|
13
|
+
_table: ClassVar[type[models.Base]] = models.DatasetSplit
|
|
14
|
+
id_attr: NodeID[int]
|
|
15
|
+
name: str
|
|
16
|
+
description: Optional[str]
|
|
17
|
+
metadata: JSON
|
|
18
|
+
color: str
|
|
19
|
+
created_at: datetime
|
|
20
|
+
updated_at: datetime
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def to_gql_dataset_split(dataset_split: models.DatasetSplit) -> DatasetSplit:
|
|
24
|
+
return DatasetSplit(
|
|
25
|
+
id_attr=dataset_split.id,
|
|
26
|
+
name=dataset_split.name,
|
|
27
|
+
description=dataset_split.description,
|
|
28
|
+
color=dataset_split.color or "#ffffff",
|
|
29
|
+
metadata=dataset_split.metadata_,
|
|
30
|
+
created_at=dataset_split.created_at,
|
|
31
|
+
updated_at=dataset_split.updated_at,
|
|
32
|
+
)
|
|
@@ -193,10 +193,6 @@ class Experiment(Node):
|
|
|
193
193
|
async for token_type, is_prompt, cost, tokens in data
|
|
194
194
|
]
|
|
195
195
|
|
|
196
|
-
@strawberry.field
|
|
197
|
-
async def repetition_count(self, info: Info[Context, None]) -> int:
|
|
198
|
-
return await info.context.data_loaders.experiment_repetition_counts.load(self.id_attr)
|
|
199
|
-
|
|
200
196
|
|
|
201
197
|
def to_gql_experiment(
|
|
202
198
|
experiment: models.Experiment,
|
|
@@ -588,6 +588,22 @@ class Project(Node):
|
|
|
588
588
|
async with info.context.db() as session:
|
|
589
589
|
return list(await session.scalars(stmt))
|
|
590
590
|
|
|
591
|
+
@strawberry.field(
|
|
592
|
+
description="Names of all available annotations for sessions. "
|
|
593
|
+
"(The list contains no duplicates.)"
|
|
594
|
+
) # type: ignore
|
|
595
|
+
async def session_annotation_names(
|
|
596
|
+
self,
|
|
597
|
+
info: Info[Context, None],
|
|
598
|
+
) -> list[str]:
|
|
599
|
+
stmt = (
|
|
600
|
+
select(distinct(models.ProjectSessionAnnotation.name))
|
|
601
|
+
.join(models.ProjectSession)
|
|
602
|
+
.where(models.ProjectSession.project_id == self.project_rowid)
|
|
603
|
+
)
|
|
604
|
+
async with info.context.db() as session:
|
|
605
|
+
return list(await session.scalars(stmt))
|
|
606
|
+
|
|
591
607
|
@strawberry.field(
|
|
592
608
|
description="Names of available document evaluations.",
|
|
593
609
|
) # type: ignore
|
|
@@ -1,14 +1,19 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from dataclasses import asdict, dataclass
|
|
1
3
|
from datetime import datetime
|
|
2
4
|
from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Type
|
|
3
5
|
|
|
6
|
+
import pandas as pd
|
|
4
7
|
import strawberry
|
|
5
8
|
from openinference.semconv.trace import SpanAttributes
|
|
6
9
|
from sqlalchemy import select
|
|
7
10
|
from strawberry import UNSET, Info, Private, lazy
|
|
8
|
-
from strawberry.relay import Connection,
|
|
11
|
+
from strawberry.relay import Connection, Node, NodeID
|
|
9
12
|
|
|
10
13
|
from phoenix.db import models
|
|
11
14
|
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
|
|
16
|
+
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
|
|
12
17
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
13
18
|
from phoenix.server.api.types.MimeType import MimeType
|
|
14
19
|
from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
|
|
@@ -18,6 +23,8 @@ from phoenix.server.api.types.SpanIOValue import SpanIOValue
|
|
|
18
23
|
from phoenix.server.api.types.TokenUsage import TokenUsage
|
|
19
24
|
|
|
20
25
|
if TYPE_CHECKING:
|
|
26
|
+
from phoenix.server.api.types.Project import Project
|
|
27
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import ProjectSessionAnnotation
|
|
21
28
|
from phoenix.server.api.types.Trace import Trace
|
|
22
29
|
|
|
23
30
|
|
|
@@ -31,10 +38,13 @@ class ProjectSession(Node):
|
|
|
31
38
|
end_time: datetime
|
|
32
39
|
|
|
33
40
|
@strawberry.field
|
|
34
|
-
async def
|
|
41
|
+
async def project(
|
|
42
|
+
self,
|
|
43
|
+
info: Info[Context, None],
|
|
44
|
+
) -> Annotated["Project", lazy(".Project")]:
|
|
35
45
|
from phoenix.server.api.types.Project import Project
|
|
36
46
|
|
|
37
|
-
return
|
|
47
|
+
return Project(project_rowid=self.project_rowid)
|
|
38
48
|
|
|
39
49
|
@strawberry.field
|
|
40
50
|
async def num_traces(
|
|
@@ -165,6 +175,81 @@ class ProjectSession(Node):
|
|
|
165
175
|
for entry in summary
|
|
166
176
|
]
|
|
167
177
|
|
|
178
|
+
@strawberry.field
|
|
179
|
+
async def session_annotations(
|
|
180
|
+
self,
|
|
181
|
+
info: Info[Context, None],
|
|
182
|
+
) -> list[Annotated["ProjectSessionAnnotation", lazy(".ProjectSessionAnnotation")]]:
|
|
183
|
+
"""Get all annotations for this session."""
|
|
184
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import (
|
|
185
|
+
to_gql_project_session_annotation,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.id_attr)
|
|
189
|
+
async with info.context.db() as session:
|
|
190
|
+
annotations = await session.stream_scalars(stmt)
|
|
191
|
+
return [
|
|
192
|
+
to_gql_project_session_annotation(annotation) async for annotation in annotations
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
@strawberry.field(
|
|
196
|
+
description="Summarizes each annotation (by name) associated with the session"
|
|
197
|
+
) # type: ignore
|
|
198
|
+
async def session_annotation_summaries(
|
|
199
|
+
self,
|
|
200
|
+
info: Info[Context, None],
|
|
201
|
+
filter: Optional[AnnotationFilter] = None,
|
|
202
|
+
) -> list[AnnotationSummary]:
|
|
203
|
+
"""
|
|
204
|
+
Retrieves and summarizes annotations associated with this span.
|
|
205
|
+
|
|
206
|
+
This method aggregates annotation data by name and label, calculating metrics
|
|
207
|
+
such as count of occurrences and sum of scores. The results are organized
|
|
208
|
+
into a structured format that can be easily converted to a DataFrame.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
info: GraphQL context information
|
|
212
|
+
filter: Optional filter to apply to annotations before processing
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
A list of AnnotationSummary objects, each containing:
|
|
216
|
+
- name: The name of the annotation
|
|
217
|
+
- data: A list of dictionaries with label statistics
|
|
218
|
+
"""
|
|
219
|
+
# Load all annotations for this span from the data loader
|
|
220
|
+
annotations = await info.context.data_loaders.session_annotations_by_session.load(
|
|
221
|
+
self.id_attr
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Apply filter if provided to narrow down the annotations
|
|
225
|
+
if filter:
|
|
226
|
+
annotations = [
|
|
227
|
+
annotation for annotation in annotations if satisfies_filter(annotation, filter)
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
@dataclass
|
|
231
|
+
class Metrics:
|
|
232
|
+
record_count: int = 0
|
|
233
|
+
label_count: int = 0
|
|
234
|
+
score_sum: float = 0
|
|
235
|
+
score_count: int = 0
|
|
236
|
+
|
|
237
|
+
summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
|
|
238
|
+
lambda: defaultdict(Metrics)
|
|
239
|
+
)
|
|
240
|
+
for annotation in annotations:
|
|
241
|
+
metrics = summaries[annotation.name][annotation.label]
|
|
242
|
+
metrics.record_count += 1
|
|
243
|
+
metrics.label_count += int(annotation.label is not None)
|
|
244
|
+
metrics.score_sum += annotation.score or 0
|
|
245
|
+
metrics.score_count += int(annotation.score is not None)
|
|
246
|
+
|
|
247
|
+
result: list[AnnotationSummary] = []
|
|
248
|
+
for name, label_metrics in summaries.items():
|
|
249
|
+
rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
|
|
250
|
+
result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
|
|
251
|
+
return result
|
|
252
|
+
|
|
168
253
|
|
|
169
254
|
def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
|
|
170
255
|
return ProjectSession(
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry import Private
|
|
5
|
+
from strawberry.relay import GlobalID, Node, NodeID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
from strawberry.types import Info
|
|
8
|
+
|
|
9
|
+
from phoenix.db import models
|
|
10
|
+
from phoenix.server.api.context import Context
|
|
11
|
+
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
12
|
+
|
|
13
|
+
from .AnnotationSource import AnnotationSource
|
|
14
|
+
from .User import User, to_gql_user
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@strawberry.type
|
|
18
|
+
class ProjectSessionAnnotation(Node):
|
|
19
|
+
id_attr: NodeID[int]
|
|
20
|
+
user_id: Private[Optional[int]]
|
|
21
|
+
name: str
|
|
22
|
+
annotator_kind: AnnotatorKind
|
|
23
|
+
label: Optional[str]
|
|
24
|
+
score: Optional[float]
|
|
25
|
+
explanation: Optional[str]
|
|
26
|
+
metadata: JSON
|
|
27
|
+
_project_session_id: Private[Optional[int]]
|
|
28
|
+
identifier: str
|
|
29
|
+
source: AnnotationSource
|
|
30
|
+
|
|
31
|
+
@strawberry.field
|
|
32
|
+
async def project_session_id(self) -> GlobalID:
|
|
33
|
+
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
34
|
+
|
|
35
|
+
return GlobalID(type_name=ProjectSession.__name__, node_id=str(self._project_session_id))
|
|
36
|
+
|
|
37
|
+
@strawberry.field
|
|
38
|
+
async def user(
|
|
39
|
+
self,
|
|
40
|
+
info: Info[Context, None],
|
|
41
|
+
) -> Optional[User]:
|
|
42
|
+
if self.user_id is None:
|
|
43
|
+
return None
|
|
44
|
+
user = await info.context.data_loaders.users.load(self.user_id)
|
|
45
|
+
if user is None:
|
|
46
|
+
return None
|
|
47
|
+
return to_gql_user(user)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def to_gql_project_session_annotation(
|
|
51
|
+
annotation: models.ProjectSessionAnnotation,
|
|
52
|
+
) -> ProjectSessionAnnotation:
|
|
53
|
+
"""
|
|
54
|
+
Converts an ORM projectSession annotation to a GraphQL ProjectSessionAnnotation.
|
|
55
|
+
"""
|
|
56
|
+
return ProjectSessionAnnotation(
|
|
57
|
+
id_attr=annotation.id,
|
|
58
|
+
user_id=annotation.user_id,
|
|
59
|
+
_project_session_id=annotation.project_session_id,
|
|
60
|
+
name=annotation.name,
|
|
61
|
+
annotator_kind=AnnotatorKind(annotation.annotator_kind),
|
|
62
|
+
label=annotation.label,
|
|
63
|
+
score=annotation.score,
|
|
64
|
+
explanation=annotation.explanation,
|
|
65
|
+
metadata=JSON(annotation.metadata_),
|
|
66
|
+
identifier=annotation.identifier,
|
|
67
|
+
source=AnnotationSource(annotation.source),
|
|
68
|
+
)
|
|
@@ -9,6 +9,7 @@ from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
|
9
9
|
from strawberry.types import Info
|
|
10
10
|
|
|
11
11
|
from phoenix.db import models
|
|
12
|
+
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
12
13
|
from phoenix.server.api.context import Context
|
|
13
14
|
from phoenix.server.api.exceptions import NotFound
|
|
14
15
|
from phoenix.server.api.types.Identifier import Identifier
|
|
@@ -37,7 +38,10 @@ class Prompt(Node):
|
|
|
37
38
|
|
|
38
39
|
@strawberry.field
|
|
39
40
|
async def version(
|
|
40
|
-
self,
|
|
41
|
+
self,
|
|
42
|
+
info: Info[Context, None],
|
|
43
|
+
version_id: Optional[GlobalID] = None,
|
|
44
|
+
tag_name: Optional[Identifier] = None,
|
|
41
45
|
) -> PromptVersion:
|
|
42
46
|
async with info.context.db() as session:
|
|
43
47
|
if version_id:
|
|
@@ -50,6 +54,19 @@ class Prompt(Node):
|
|
|
50
54
|
)
|
|
51
55
|
if not version:
|
|
52
56
|
raise NotFound(f"Prompt version not found: {version_id}")
|
|
57
|
+
elif tag_name:
|
|
58
|
+
try:
|
|
59
|
+
name = IdentifierModel(tag_name)
|
|
60
|
+
except ValueError:
|
|
61
|
+
raise NotFound(f"Prompt version tag not found: {tag_name}")
|
|
62
|
+
version = await session.scalar(
|
|
63
|
+
select(models.PromptVersion)
|
|
64
|
+
.where(models.PromptVersion.prompt_id == self.id_attr)
|
|
65
|
+
.join_from(models.PromptVersion, models.PromptVersionTag)
|
|
66
|
+
.where(models.PromptVersionTag.name == name)
|
|
67
|
+
)
|
|
68
|
+
if not version:
|
|
69
|
+
raise NotFound(f"This prompt has no associated versions by tag {tag_name}")
|
|
53
70
|
else:
|
|
54
71
|
stmt = (
|
|
55
72
|
select(models.PromptVersion)
|
phoenix/server/api/types/Span.py
CHANGED
|
@@ -23,11 +23,11 @@ from phoenix.server.api.helpers.dataset_helpers import (
|
|
|
23
23
|
get_dataset_example_input,
|
|
24
24
|
get_dataset_example_output,
|
|
25
25
|
)
|
|
26
|
-
from phoenix.server.api.input_types.
|
|
27
|
-
|
|
28
|
-
SpanAnnotationFilter,
|
|
26
|
+
from phoenix.server.api.input_types.AnnotationFilter import (
|
|
27
|
+
AnnotationFilter,
|
|
29
28
|
satisfies_filter,
|
|
30
29
|
)
|
|
30
|
+
from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
|
|
31
31
|
from phoenix.server.api.input_types.SpanAnnotationSort import (
|
|
32
32
|
SpanAnnotationColumn,
|
|
33
33
|
SpanAnnotationSort,
|
|
@@ -547,7 +547,7 @@ class Span(Node):
|
|
|
547
547
|
self,
|
|
548
548
|
info: Info[Context, None],
|
|
549
549
|
sort: Optional[SpanAnnotationSort] = UNSET,
|
|
550
|
-
filter: Optional[
|
|
550
|
+
filter: Optional[AnnotationFilter] = None,
|
|
551
551
|
) -> list[SpanAnnotation]:
|
|
552
552
|
span_id = self.span_rowid
|
|
553
553
|
annotations = await info.context.data_loaders.span_annotations.load(span_id)
|
|
@@ -580,7 +580,7 @@ class Span(Node):
|
|
|
580
580
|
async def span_annotation_summaries(
|
|
581
581
|
self,
|
|
582
582
|
info: Info[Context, None],
|
|
583
|
-
filter: Optional[
|
|
583
|
+
filter: Optional[AnnotationFilter] = None,
|
|
584
584
|
) -> list[AnnotationSummary]:
|
|
585
585
|
"""
|
|
586
586
|
Retrieves and summarizes annotations associated with this span.
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from dataclasses import asdict, dataclass
|
|
3
5
|
from datetime import datetime
|
|
4
6
|
from typing import TYPE_CHECKING, Annotated, Optional, Union
|
|
5
7
|
|
|
8
|
+
import pandas as pd
|
|
6
9
|
import strawberry
|
|
7
10
|
from openinference.semconv.trace import SpanAttributes
|
|
8
11
|
from sqlalchemy import desc, select
|
|
@@ -13,7 +16,9 @@ from typing_extensions import TypeAlias
|
|
|
13
16
|
|
|
14
17
|
from phoenix.db import models
|
|
15
18
|
from phoenix.server.api.context import Context
|
|
19
|
+
from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
|
|
16
20
|
from phoenix.server.api.input_types.TraceAnnotationSort import TraceAnnotationSort
|
|
21
|
+
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
|
|
17
22
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
18
23
|
from phoenix.server.api.types.pagination import (
|
|
19
24
|
ConnectionArgs,
|
|
@@ -229,6 +234,62 @@ class Trace(Node):
|
|
|
229
234
|
annotations = await session.scalars(stmt)
|
|
230
235
|
return [to_gql_trace_annotation(annotation) for annotation in annotations]
|
|
231
236
|
|
|
237
|
+
@strawberry.field(description="Summarizes each annotation (by name) associated with the trace") # type: ignore
|
|
238
|
+
async def trace_annotation_summaries(
|
|
239
|
+
self,
|
|
240
|
+
info: Info[Context, None],
|
|
241
|
+
filter: Optional[AnnotationFilter] = None,
|
|
242
|
+
) -> list[AnnotationSummary]:
|
|
243
|
+
"""
|
|
244
|
+
Retrieves and summarizes annotations associated with this span.
|
|
245
|
+
|
|
246
|
+
This method aggregates annotation data by name and label, calculating metrics
|
|
247
|
+
such as count of occurrences and sum of scores. The results are organized
|
|
248
|
+
into a structured format that can be easily converted to a DataFrame.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
info: GraphQL context information
|
|
252
|
+
filter: Optional filter to apply to annotations before processing
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
A list of AnnotationSummary objects, each containing:
|
|
256
|
+
- name: The name of the annotation
|
|
257
|
+
- data: A list of dictionaries with label statistics
|
|
258
|
+
"""
|
|
259
|
+
# Load all annotations for this span from the data loader
|
|
260
|
+
annotations = await info.context.data_loaders.trace_annotations_by_trace.load(
|
|
261
|
+
self.trace_rowid
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Apply filter if provided to narrow down the annotations
|
|
265
|
+
if filter:
|
|
266
|
+
annotations = [
|
|
267
|
+
annotation for annotation in annotations if satisfies_filter(annotation, filter)
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
@dataclass
|
|
271
|
+
class Metrics:
|
|
272
|
+
record_count: int = 0
|
|
273
|
+
label_count: int = 0
|
|
274
|
+
score_sum: float = 0
|
|
275
|
+
score_count: int = 0
|
|
276
|
+
|
|
277
|
+
summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
|
|
278
|
+
lambda: defaultdict(Metrics)
|
|
279
|
+
)
|
|
280
|
+
for annotation in annotations:
|
|
281
|
+
metrics = summaries[annotation.name][annotation.label]
|
|
282
|
+
metrics.record_count += 1
|
|
283
|
+
metrics.label_count += int(annotation.label is not None)
|
|
284
|
+
metrics.score_sum += annotation.score or 0
|
|
285
|
+
metrics.score_count += int(annotation.score is not None)
|
|
286
|
+
|
|
287
|
+
result: list[AnnotationSummary] = []
|
|
288
|
+
for name, label_metrics in summaries.items():
|
|
289
|
+
rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
|
|
290
|
+
result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
|
|
291
|
+
return result
|
|
292
|
+
|
|
232
293
|
@strawberry.field
|
|
233
294
|
async def cost_summary(
|
|
234
295
|
self,
|
phoenix/server/app.py
CHANGED
|
@@ -67,7 +67,6 @@ from phoenix.config import (
|
|
|
67
67
|
get_env_gql_extension_paths,
|
|
68
68
|
get_env_grpc_interceptor_paths,
|
|
69
69
|
get_env_host,
|
|
70
|
-
get_env_host_root_path,
|
|
71
70
|
get_env_max_spans_queue_size,
|
|
72
71
|
get_env_port,
|
|
73
72
|
get_env_support_email,
|
|
@@ -92,6 +91,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
92
91
|
DatasetExampleRevisionsDataLoader,
|
|
93
92
|
DatasetExamplesAndVersionsByExperimentRunDataLoader,
|
|
94
93
|
DatasetExampleSpansDataLoader,
|
|
94
|
+
DatasetExampleSplitsDataLoader,
|
|
95
95
|
DocumentEvaluationsDataLoader,
|
|
96
96
|
DocumentEvaluationSummaryDataLoader,
|
|
97
97
|
DocumentRetrievalMetricsDataLoader,
|
|
@@ -99,7 +99,6 @@ from phoenix.server.api.dataloaders import (
|
|
|
99
99
|
ExperimentErrorRatesDataLoader,
|
|
100
100
|
ExperimentRepeatedRunGroupAnnotationSummariesDataLoader,
|
|
101
101
|
ExperimentRepeatedRunGroupsDataLoader,
|
|
102
|
-
ExperimentRepetitionCountsDataLoader,
|
|
103
102
|
ExperimentRunAnnotations,
|
|
104
103
|
ExperimentRunCountsDataLoader,
|
|
105
104
|
ExperimentSequenceNumberDataLoader,
|
|
@@ -112,6 +111,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
112
111
|
ProjectIdsByTraceRetentionPolicyIdDataLoader,
|
|
113
112
|
PromptVersionSequenceNumberDataLoader,
|
|
114
113
|
RecordCountDataLoader,
|
|
114
|
+
SessionAnnotationsBySessionDataLoader,
|
|
115
115
|
SessionIODataLoader,
|
|
116
116
|
SessionNumTracesDataLoader,
|
|
117
117
|
SessionNumTracesWithErrorDataLoader,
|
|
@@ -137,12 +137,14 @@ from phoenix.server.api.dataloaders import (
|
|
|
137
137
|
SpanProjectsDataLoader,
|
|
138
138
|
TableFieldsDataLoader,
|
|
139
139
|
TokenCountDataLoader,
|
|
140
|
+
TraceAnnotationsByTraceDataLoader,
|
|
140
141
|
TraceByTraceIdsDataLoader,
|
|
141
142
|
TraceRetentionPolicyIdByProjectIdDataLoader,
|
|
142
143
|
TraceRootSpansDataLoader,
|
|
143
144
|
UserRolesDataLoader,
|
|
144
145
|
UsersDataLoader,
|
|
145
146
|
)
|
|
147
|
+
from phoenix.server.api.dataloaders.dataset_labels import DatasetLabelsDataLoader
|
|
146
148
|
from phoenix.server.api.routers import (
|
|
147
149
|
auth_router,
|
|
148
150
|
create_embeddings_router,
|
|
@@ -173,6 +175,7 @@ from phoenix.server.types import (
|
|
|
173
175
|
LastUpdatedAt,
|
|
174
176
|
TokenStore,
|
|
175
177
|
)
|
|
178
|
+
from phoenix.server.utils import get_root_path, prepend_root_path
|
|
176
179
|
from phoenix.settings import Settings
|
|
177
180
|
from phoenix.trace.fixtures import (
|
|
178
181
|
TracesFixture,
|
|
@@ -281,9 +284,6 @@ class Static(StaticFiles):
|
|
|
281
284
|
return {}
|
|
282
285
|
raise e
|
|
283
286
|
|
|
284
|
-
def _sanitize_basename(self, basename: str) -> str:
|
|
285
|
-
return basename[:-1] if basename.endswith("/") else basename
|
|
286
|
-
|
|
287
287
|
async def get_response(self, path: str, scope: Scope) -> Response:
|
|
288
288
|
# Redirect to the oauth2 login page if basic auth is disabled and auto_login is enabled
|
|
289
289
|
# TODO: this needs to be refactored to be cleaner
|
|
@@ -292,14 +292,10 @@ class Static(StaticFiles):
|
|
|
292
292
|
and self._app_config.basic_auth_disabled
|
|
293
293
|
and self._app_config.auto_login_idp_name
|
|
294
294
|
):
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
str(
|
|
298
|
-
Path(get_env_host_root_path())
|
|
299
|
-
/ f"oauth2/{self._app_config.auto_login_idp_name}/login"
|
|
300
|
-
)
|
|
295
|
+
redirect_path = prepend_root_path(
|
|
296
|
+
scope, f"oauth2/{self._app_config.auto_login_idp_name}/login"
|
|
301
297
|
)
|
|
302
|
-
url =
|
|
298
|
+
url = URL(redirect_path).include_query_params(**Request(scope).query_params)
|
|
303
299
|
return RedirectResponse(url=url)
|
|
304
300
|
try:
|
|
305
301
|
response = await super().get_response(path, scope)
|
|
@@ -316,7 +312,7 @@ class Static(StaticFiles):
|
|
|
316
312
|
"min_dist": self._app_config.min_dist,
|
|
317
313
|
"n_neighbors": self._app_config.n_neighbors,
|
|
318
314
|
"n_samples": self._app_config.n_samples,
|
|
319
|
-
"basename":
|
|
315
|
+
"basename": get_root_path(scope),
|
|
320
316
|
"platform_version": phoenix_version,
|
|
321
317
|
"request": request,
|
|
322
318
|
"is_development": self._app_config.is_development,
|
|
@@ -715,6 +711,8 @@ def create_graphql_router(
|
|
|
715
711
|
dataset_examples_and_versions_by_experiment_run=DatasetExamplesAndVersionsByExperimentRunDataLoader(
|
|
716
712
|
db
|
|
717
713
|
),
|
|
714
|
+
dataset_example_splits=DatasetExampleSplitsDataLoader(db),
|
|
715
|
+
dataset_labels=DatasetLabelsDataLoader(db),
|
|
718
716
|
document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
|
|
719
717
|
db,
|
|
720
718
|
cache_map=(
|
|
@@ -737,7 +735,6 @@ def create_graphql_router(
|
|
|
737
735
|
db
|
|
738
736
|
),
|
|
739
737
|
experiment_repeated_run_groups=ExperimentRepeatedRunGroupsDataLoader(db),
|
|
740
|
-
experiment_repetition_counts=ExperimentRepetitionCountsDataLoader(db),
|
|
741
738
|
experiment_run_annotations=ExperimentRunAnnotations(db),
|
|
742
739
|
experiment_run_counts=ExperimentRunCountsDataLoader(db),
|
|
743
740
|
experiment_sequence_number=ExperimentSequenceNumberDataLoader(db),
|
|
@@ -769,6 +766,7 @@ def create_graphql_router(
|
|
|
769
766
|
db,
|
|
770
767
|
cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
|
|
771
768
|
),
|
|
769
|
+
session_annotations_by_session=SessionAnnotationsBySessionDataLoader(db),
|
|
772
770
|
session_first_inputs=SessionIODataLoader(db, "first_input"),
|
|
773
771
|
session_last_outputs=SessionIODataLoader(db, "last_output"),
|
|
774
772
|
session_num_traces=SessionNumTracesDataLoader(db),
|
|
@@ -815,6 +813,7 @@ def create_graphql_router(
|
|
|
815
813
|
db,
|
|
816
814
|
cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None,
|
|
817
815
|
),
|
|
816
|
+
trace_annotations_by_trace=TraceAnnotationsByTraceDataLoader(db),
|
|
818
817
|
trace_by_trace_ids=TraceByTraceIdsDataLoader(db),
|
|
819
818
|
trace_fields=TableFieldsDataLoader(db, models.Trace),
|
|
820
819
|
trace_retention_policy_id_by_project_id=TraceRetentionPolicyIdByProjectIdDataLoader(
|