arize-phoenix 8.2.2__py3-none-any.whl → 8.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/METADATA +4 -3
- {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/RECORD +33 -32
- phoenix/config.py +32 -5
- phoenix/db/models.py +1 -25
- phoenix/server/api/context.py +6 -2
- phoenix/server/api/dataloaders/__init__.py +4 -2
- phoenix/server/api/dataloaders/num_child_spans.py +35 -0
- phoenix/server/api/dataloaders/{span_fields.py → table_fields.py} +21 -19
- phoenix/server/api/helpers/playground_clients.py +4 -0
- phoenix/server/api/helpers/prompts/models.py +1 -0
- phoenix/server/api/queries.py +80 -20
- phoenix/server/api/types/Experiment.py +2 -4
- phoenix/server/api/types/ExperimentRun.py +2 -2
- phoenix/server/api/types/ExperimentRunAnnotation.py +2 -2
- phoenix/server/api/types/Project.py +67 -38
- phoenix/server/api/types/ProjectSession.py +2 -2
- phoenix/server/api/types/Span.py +35 -2
- phoenix/server/api/types/Trace.py +98 -30
- phoenix/server/app.py +6 -2
- phoenix/server/static/.vite/manifest.json +40 -40
- phoenix/server/static/assets/{components-MeFAEc1z.js → components-BgFPI6sn.js} +166 -167
- phoenix/server/static/assets/{index-BSRuZ-_J.js → index-CIkk8uHr.js} +9 -9
- phoenix/server/static/assets/{pages-NrL4hb9q.js → pages-CmDiPH1A.js} +660 -623
- phoenix/server/static/assets/{vendor-Cqfydjep.js → vendor-CRRq3WgM.js} +116 -116
- phoenix/server/static/assets/{vendor-arizeai-WnerlUPN.js → vendor-arizeai-Dq64X_0o.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-D-ZZKLFq.js → vendor-codemirror-C1oevlym.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-KY97ZPfK.js → vendor-recharts-CPj01S89.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-D5K9GnFn.js → vendor-shiki-aY7rz1pm.js} +1 -1
- phoenix/version.py +1 -1
- {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -8,35 +8,36 @@ from typing_extensions import TypeAlias
|
|
|
8
8
|
from phoenix.db import models
|
|
9
9
|
from phoenix.server.types import DbSessionFactory
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
RowId: TypeAlias = int
|
|
12
12
|
|
|
13
|
-
Key: TypeAlias = tuple[
|
|
13
|
+
Key: TypeAlias = tuple[RowId, QueryableAttribute[Any]]
|
|
14
14
|
Result: TypeAlias = Any
|
|
15
15
|
|
|
16
|
-
|
|
17
16
|
_ResultColumnPosition: TypeAlias = int
|
|
18
17
|
_AttrStrIdentifier: TypeAlias = str
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
class
|
|
22
|
-
def __init__(self, db: DbSessionFactory) -> None:
|
|
20
|
+
class TableFieldsDataLoader(DataLoader[Key, Result]):
|
|
21
|
+
def __init__(self, db: DbSessionFactory, table: type[models.Base]) -> None:
|
|
23
22
|
super().__init__(load_fn=self._load_fn)
|
|
24
23
|
self._db = db
|
|
24
|
+
self._table = table
|
|
25
25
|
|
|
26
26
|
async def _load_fn(self, keys: Iterable[Key]) -> list[Union[Result, ValueError]]:
|
|
27
|
-
result: dict[tuple[
|
|
28
|
-
stmt, attr_strs = _get_stmt(keys)
|
|
27
|
+
result: dict[tuple[RowId, _AttrStrIdentifier], Result] = {}
|
|
28
|
+
stmt, attr_strs = _get_stmt(keys, self._table)
|
|
29
29
|
async with self._db() as session:
|
|
30
30
|
data = await session.stream(stmt)
|
|
31
31
|
async for row in data:
|
|
32
|
-
|
|
32
|
+
rowid: RowId = row[0] # models.Span's primary key
|
|
33
33
|
for i, value in enumerate(row[1:]):
|
|
34
|
-
result[
|
|
35
|
-
return [result.get((
|
|
34
|
+
result[rowid, attr_strs[i]] = value
|
|
35
|
+
return [result.get((rowid, str(attr))) for rowid, attr in keys]
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def _get_stmt(
|
|
39
|
-
keys: Iterable[
|
|
39
|
+
keys: Iterable[tuple[RowId, QueryableAttribute[Any]]],
|
|
40
|
+
table: type[models.Base],
|
|
40
41
|
) -> tuple[
|
|
41
42
|
Select[Any],
|
|
42
43
|
dict[_ResultColumnPosition, _AttrStrIdentifier],
|
|
@@ -52,7 +53,8 @@ def _get_stmt(
|
|
|
52
53
|
|
|
53
54
|
Args:
|
|
54
55
|
keys (list[Key]): A list of tuples, where each tuple contains an integer ID, i.e. the
|
|
55
|
-
primary key of
|
|
56
|
+
primary key of table, and a QueryableAttribute.
|
|
57
|
+
table (models.Base): The table to query.
|
|
56
58
|
|
|
57
59
|
Returns:
|
|
58
60
|
tuple: A tuple containing:
|
|
@@ -61,16 +63,16 @@ def _get_stmt(
|
|
|
61
63
|
at the second column (because the first column is the span's primary key)--in the
|
|
62
64
|
result to the attribute's string identifier.
|
|
63
65
|
"""
|
|
64
|
-
|
|
66
|
+
rowids: set[RowId] = set()
|
|
65
67
|
attrs: dict[_AttrStrIdentifier, QueryableAttribute[Any]] = {}
|
|
66
68
|
joins = set()
|
|
67
|
-
for
|
|
68
|
-
|
|
69
|
+
for rowid, attr in keys:
|
|
70
|
+
rowids.add(rowid)
|
|
69
71
|
attrs[str(attr)] = attr
|
|
70
|
-
if (entity := attr.parent.entity) is not
|
|
72
|
+
if (entity := attr.parent.entity) is not table:
|
|
71
73
|
joins.add(entity)
|
|
72
|
-
stmt = select(
|
|
73
|
-
for
|
|
74
|
-
stmt = stmt.join(
|
|
74
|
+
stmt = select(table.id).where(table.id.in_(rowids))
|
|
75
|
+
for other_table in joins:
|
|
76
|
+
stmt = stmt.join(other_table)
|
|
75
77
|
identifiers, columns = zip(*attrs.items())
|
|
76
78
|
return stmt.add_columns(*columns), dict(enumerate(identifiers))
|
|
@@ -806,6 +806,10 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
806
806
|
raise NotImplementedError
|
|
807
807
|
elif isinstance(event, anthropic_streaming._types.CitationEvent):
|
|
808
808
|
raise NotImplementedError
|
|
809
|
+
elif isinstance(event, anthropic_streaming._types.ThinkingEvent):
|
|
810
|
+
raise NotImplementedError
|
|
811
|
+
elif isinstance(event, anthropic_streaming._types.SignatureEvent):
|
|
812
|
+
raise NotImplementedError
|
|
809
813
|
else:
|
|
810
814
|
assert_never(event)
|
|
811
815
|
|
|
@@ -348,6 +348,7 @@ class AnthropicToolDefinition(PromptModel):
|
|
|
348
348
|
class PromptOpenAIInvocationParametersContent(PromptModel):
|
|
349
349
|
temperature: float = UNDEFINED
|
|
350
350
|
max_tokens: int = UNDEFINED
|
|
351
|
+
max_completion_tokens: int = UNDEFINED
|
|
351
352
|
frequency_penalty: float = UNDEFINED
|
|
352
353
|
presence_penalty: float = UNDEFINED
|
|
353
354
|
top_p: float = UNDEFINED
|
phoenix/server/api/queries.py
CHANGED
|
@@ -1,19 +1,21 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import Optional, Union
|
|
3
|
+
from typing import Iterable, Iterator, Optional, Union, cast
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import numpy.typing as npt
|
|
7
7
|
import strawberry
|
|
8
|
-
from sqlalchemy import and_, distinct, func, select
|
|
8
|
+
from sqlalchemy import and_, distinct, func, select, text
|
|
9
9
|
from sqlalchemy.orm import joinedload
|
|
10
10
|
from starlette.authentication import UnauthenticatedUser
|
|
11
11
|
from strawberry import ID, UNSET
|
|
12
12
|
from strawberry.relay import Connection, GlobalID, Node
|
|
13
13
|
from strawberry.types import Info
|
|
14
|
-
from typing_extensions import Annotated, TypeAlias
|
|
14
|
+
from typing_extensions import Annotated, TypeAlias, assert_never
|
|
15
15
|
|
|
16
|
+
from phoenix.config import ENV_PHOENIX_SQL_DATABASE_SCHEMA, getenv
|
|
16
17
|
from phoenix.db import enums, models
|
|
18
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
17
19
|
from phoenix.db.models import DatasetExample as OrmExample
|
|
18
20
|
from phoenix.db.models import DatasetExampleRevision as OrmRevision
|
|
19
21
|
from phoenix.db.models import DatasetVersion as OrmVersion
|
|
@@ -66,7 +68,7 @@ from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_
|
|
|
66
68
|
from phoenix.server.api.types.SortDir import SortDir
|
|
67
69
|
from phoenix.server.api.types.Span import Span
|
|
68
70
|
from phoenix.server.api.types.SystemApiKey import SystemApiKey
|
|
69
|
-
from phoenix.server.api.types.Trace import
|
|
71
|
+
from phoenix.server.api.types.Trace import Trace
|
|
70
72
|
from phoenix.server.api.types.User import User, to_gql_user
|
|
71
73
|
from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
|
|
72
74
|
from phoenix.server.api.types.UserRole import UserRole
|
|
@@ -81,6 +83,12 @@ class ModelsInput:
|
|
|
81
83
|
model_name: Optional[str] = None
|
|
82
84
|
|
|
83
85
|
|
|
86
|
+
@strawberry.type
|
|
87
|
+
class DbTableStats:
|
|
88
|
+
table_name: str
|
|
89
|
+
num_bytes: int
|
|
90
|
+
|
|
91
|
+
|
|
84
92
|
@strawberry.type
|
|
85
93
|
class Query:
|
|
86
94
|
@strawberry.field
|
|
@@ -236,10 +244,8 @@ class Query:
|
|
|
236
244
|
projects = await session.stream_scalars(stmt)
|
|
237
245
|
data = [
|
|
238
246
|
Project(
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
gradient_start_color=project.gradient_start_color,
|
|
242
|
-
gradient_end_color=project.gradient_end_color,
|
|
247
|
+
project_rowid=project.id,
|
|
248
|
+
db_project=project,
|
|
243
249
|
)
|
|
244
250
|
async for project in projects
|
|
245
251
|
]
|
|
@@ -448,21 +454,14 @@ class Query:
|
|
|
448
454
|
embedding_dimension = info.context.model.embedding_dimensions[node_id]
|
|
449
455
|
return to_gql_embedding_dimension(node_id, embedding_dimension)
|
|
450
456
|
elif type_name == "Project":
|
|
451
|
-
project_stmt = select(
|
|
452
|
-
models.Project.id,
|
|
453
|
-
models.Project.name,
|
|
454
|
-
models.Project.gradient_start_color,
|
|
455
|
-
models.Project.gradient_end_color,
|
|
456
|
-
).where(models.Project.id == node_id)
|
|
457
|
+
project_stmt = select(models.Project).filter_by(id=node_id)
|
|
457
458
|
async with info.context.db() as session:
|
|
458
|
-
project =
|
|
459
|
+
project = await session.scalar(project_stmt)
|
|
459
460
|
if project is None:
|
|
460
461
|
raise NotFound(f"Unknown project: {id}")
|
|
461
462
|
return Project(
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
gradient_start_color=project.gradient_start_color,
|
|
465
|
-
gradient_end_color=project.gradient_end_color,
|
|
463
|
+
project_rowid=project.id,
|
|
464
|
+
db_project=project,
|
|
466
465
|
)
|
|
467
466
|
elif type_name == "Trace":
|
|
468
467
|
trace_stmt = select(models.Trace).filter_by(id=node_id)
|
|
@@ -470,7 +469,7 @@ class Query:
|
|
|
470
469
|
trace = await session.scalar(trace_stmt)
|
|
471
470
|
if trace is None:
|
|
472
471
|
raise NotFound(f"Unknown trace: {id}")
|
|
473
|
-
return
|
|
472
|
+
return Trace(trace_rowid=trace.id, db_trace=trace)
|
|
474
473
|
elif type_name == Span.__name__:
|
|
475
474
|
span_stmt = (
|
|
476
475
|
select(models.Span)
|
|
@@ -789,3 +788,64 @@ class Query:
|
|
|
789
788
|
return to_gql_clusters(
|
|
790
789
|
clustered_events=clustered_events,
|
|
791
790
|
)
|
|
791
|
+
|
|
792
|
+
@strawberry.field
|
|
793
|
+
async def db_table_stats(
|
|
794
|
+
self,
|
|
795
|
+
info: Info[Context, None],
|
|
796
|
+
) -> list[DbTableStats]:
|
|
797
|
+
if info.context.db.dialect is SupportedSQLDialect.SQLITE:
|
|
798
|
+
stmt = text("SELECT name, sum(pgsize) FROM dbstat group by name;")
|
|
799
|
+
async with info.context.db() as session:
|
|
800
|
+
stats = cast(Iterable[tuple[str, int]], await session.execute(stmt))
|
|
801
|
+
stats = _consolidate_sqlite_db_table_stats(stats)
|
|
802
|
+
elif info.context.db.dialect is SupportedSQLDialect.POSTGRESQL:
|
|
803
|
+
stmt = text(f"""\
|
|
804
|
+
SELECT c.relname, pg_total_relation_size(c.oid)
|
|
805
|
+
FROM pg_class as c
|
|
806
|
+
INNER JOIN pg_namespace as n ON n.oid = c.relnamespace
|
|
807
|
+
WHERE c.relkind = 'r'
|
|
808
|
+
AND n.nspname = '{getenv(ENV_PHOENIX_SQL_DATABASE_SCHEMA) or "public"}';
|
|
809
|
+
""")
|
|
810
|
+
async with info.context.db() as session:
|
|
811
|
+
stats = cast(Iterable[tuple[str, int]], await session.execute(stmt))
|
|
812
|
+
else:
|
|
813
|
+
assert_never(info.context.db.dialect)
|
|
814
|
+
return [
|
|
815
|
+
DbTableStats(table_name=table_name, num_bytes=num_bytes)
|
|
816
|
+
for table_name, num_bytes in stats
|
|
817
|
+
]
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
def _consolidate_sqlite_db_table_stats(
|
|
821
|
+
stats: Iterable[tuple[str, int]],
|
|
822
|
+
) -> Iterator[tuple[str, int]]:
|
|
823
|
+
"""
|
|
824
|
+
Consolidate SQLite database stats by combining indexes with their respective tables.
|
|
825
|
+
"""
|
|
826
|
+
aggregate: dict[str, int] = {}
|
|
827
|
+
for name, num_bytes in stats:
|
|
828
|
+
# Skip internal SQLite tables and indexes.
|
|
829
|
+
if name.startswith("ix_") or name.startswith("sqlite_"):
|
|
830
|
+
continue
|
|
831
|
+
aggregate[name] = num_bytes
|
|
832
|
+
for name, num_bytes in stats:
|
|
833
|
+
# Combine indexes with their respective tables.
|
|
834
|
+
for flag in ["sqlite_autoindex_", "ix_"]:
|
|
835
|
+
if not name.startswith(flag):
|
|
836
|
+
continue
|
|
837
|
+
if parent := _longest_matching_prefix(name[len(flag) :], aggregate.keys()):
|
|
838
|
+
aggregate[parent] += num_bytes
|
|
839
|
+
break
|
|
840
|
+
yield from aggregate.items()
|
|
841
|
+
|
|
842
|
+
|
|
843
|
+
def _longest_matching_prefix(s: str, prefixes: Iterable[str]) -> str:
|
|
844
|
+
"""
|
|
845
|
+
Return the longest prefix of s that matches any of the given prefixes.
|
|
846
|
+
"""
|
|
847
|
+
longest = ""
|
|
848
|
+
for prefix in prefixes:
|
|
849
|
+
if s.startswith(prefix) and len(prefix) > len(longest):
|
|
850
|
+
longest = prefix
|
|
851
|
+
return longest
|
|
@@ -122,10 +122,8 @@ class Experiment(Node):
|
|
|
122
122
|
return None
|
|
123
123
|
|
|
124
124
|
return Project(
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
gradient_start_color=db_project.gradient_start_color,
|
|
128
|
-
gradient_end_color=db_project.gradient_end_color,
|
|
125
|
+
project_rowid=db_project.id,
|
|
126
|
+
db_project=db_project,
|
|
129
127
|
)
|
|
130
128
|
|
|
131
129
|
@strawberry.field
|
|
@@ -20,7 +20,7 @@ from phoenix.server.api.types.pagination import (
|
|
|
20
20
|
CursorString,
|
|
21
21
|
connection_from_list,
|
|
22
22
|
)
|
|
23
|
-
from phoenix.server.api.types.Trace import Trace
|
|
23
|
+
from phoenix.server.api.types.Trace import Trace
|
|
24
24
|
|
|
25
25
|
if TYPE_CHECKING:
|
|
26
26
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
@@ -64,7 +64,7 @@ class ExperimentRun(Node):
|
|
|
64
64
|
dataloader = info.context.data_loaders.trace_by_trace_ids
|
|
65
65
|
if (trace := await dataloader.load(self.trace_id)) is None:
|
|
66
66
|
return None
|
|
67
|
-
return
|
|
67
|
+
return Trace(trace_rowid=trace.id, db_trace=trace)
|
|
68
68
|
|
|
69
69
|
@strawberry.field
|
|
70
70
|
async def example(
|
|
@@ -8,7 +8,7 @@ from strawberry.scalars import JSON
|
|
|
8
8
|
|
|
9
9
|
from phoenix.db import models
|
|
10
10
|
from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
|
|
11
|
-
from phoenix.server.api.types.Trace import Trace
|
|
11
|
+
from phoenix.server.api.types.Trace import Trace
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
@strawberry.type
|
|
@@ -32,7 +32,7 @@ class ExperimentRunAnnotation(Node):
|
|
|
32
32
|
dataloader = info.context.data_loaders.trace_by_trace_ids
|
|
33
33
|
if (trace := await dataloader.load(self.trace_id)) is None:
|
|
34
34
|
return None
|
|
35
|
-
return
|
|
35
|
+
return Trace(trace_rowid=trace.id, db_trace=trace)
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def to_gql_experiment_run_annotation(
|
|
@@ -8,7 +8,7 @@ from openinference.semconv.trace import SpanAttributes
|
|
|
8
8
|
from sqlalchemy import desc, distinct, func, or_, select
|
|
9
9
|
from sqlalchemy.sql.elements import ColumnElement
|
|
10
10
|
from sqlalchemy.sql.expression import tuple_
|
|
11
|
-
from strawberry import ID, UNSET
|
|
11
|
+
from strawberry import ID, UNSET, Private
|
|
12
12
|
from strawberry.relay import Connection, Node, NodeID
|
|
13
13
|
from strawberry.types import Info
|
|
14
14
|
from typing_extensions import assert_never
|
|
@@ -33,7 +33,7 @@ from phoenix.server.api.types.pagination import (
|
|
|
33
33
|
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
|
|
34
34
|
from phoenix.server.api.types.SortDir import SortDir
|
|
35
35
|
from phoenix.server.api.types.Span import Span
|
|
36
|
-
from phoenix.server.api.types.Trace import Trace
|
|
36
|
+
from phoenix.server.api.types.Trace import Trace
|
|
37
37
|
from phoenix.server.api.types.ValidationResult import ValidationResult
|
|
38
38
|
from phoenix.trace.dsl import SpanFilter
|
|
39
39
|
|
|
@@ -41,10 +41,51 @@ from phoenix.trace.dsl import SpanFilter
|
|
|
41
41
|
@strawberry.type
|
|
42
42
|
class Project(Node):
|
|
43
43
|
_table: ClassVar[type[models.Base]] = models.Project
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
44
|
+
project_rowid: NodeID[int]
|
|
45
|
+
db_project: Private[models.Project] = UNSET
|
|
46
|
+
|
|
47
|
+
def __post_init__(self) -> None:
|
|
48
|
+
if self.db_project and self.project_rowid != self.db_project.id:
|
|
49
|
+
raise ValueError("Project ID mismatch")
|
|
50
|
+
|
|
51
|
+
@strawberry.field
|
|
52
|
+
async def name(
|
|
53
|
+
self,
|
|
54
|
+
info: Info[Context, None],
|
|
55
|
+
) -> str:
|
|
56
|
+
if self.db_project:
|
|
57
|
+
name = self.db_project.name
|
|
58
|
+
else:
|
|
59
|
+
name = await info.context.data_loaders.project_fields.load(
|
|
60
|
+
(self.project_rowid, models.Project.name),
|
|
61
|
+
)
|
|
62
|
+
return name
|
|
63
|
+
|
|
64
|
+
@strawberry.field
|
|
65
|
+
async def gradient_start_color(
|
|
66
|
+
self,
|
|
67
|
+
info: Info[Context, None],
|
|
68
|
+
) -> str:
|
|
69
|
+
if self.db_project:
|
|
70
|
+
gradient_start_color = self.db_project.gradient_start_color
|
|
71
|
+
else:
|
|
72
|
+
gradient_start_color = await info.context.data_loaders.project_fields.load(
|
|
73
|
+
(self.project_rowid, models.Project.gradient_start_color),
|
|
74
|
+
)
|
|
75
|
+
return gradient_start_color
|
|
76
|
+
|
|
77
|
+
@strawberry.field
|
|
78
|
+
async def gradient_end_color(
|
|
79
|
+
self,
|
|
80
|
+
info: Info[Context, None],
|
|
81
|
+
) -> str:
|
|
82
|
+
if self.db_project:
|
|
83
|
+
gradient_end_color = self.db_project.gradient_end_color
|
|
84
|
+
else:
|
|
85
|
+
gradient_end_color = await info.context.data_loaders.project_fields.load(
|
|
86
|
+
(self.project_rowid, models.Project.gradient_end_color),
|
|
87
|
+
)
|
|
88
|
+
return gradient_end_color
|
|
48
89
|
|
|
49
90
|
@strawberry.field
|
|
50
91
|
async def start_time(
|
|
@@ -52,7 +93,7 @@ class Project(Node):
|
|
|
52
93
|
info: Info[Context, None],
|
|
53
94
|
) -> Optional[datetime]:
|
|
54
95
|
start_time = await info.context.data_loaders.min_start_or_max_end_times.load(
|
|
55
|
-
(self.
|
|
96
|
+
(self.project_rowid, "start"),
|
|
56
97
|
)
|
|
57
98
|
start_time, _ = right_open_time_range(start_time, None)
|
|
58
99
|
return start_time
|
|
@@ -63,7 +104,7 @@ class Project(Node):
|
|
|
63
104
|
info: Info[Context, None],
|
|
64
105
|
) -> Optional[datetime]:
|
|
65
106
|
end_time = await info.context.data_loaders.min_start_or_max_end_times.load(
|
|
66
|
-
(self.
|
|
107
|
+
(self.project_rowid, "end"),
|
|
67
108
|
)
|
|
68
109
|
_, end_time = right_open_time_range(None, end_time)
|
|
69
110
|
return end_time
|
|
@@ -76,7 +117,7 @@ class Project(Node):
|
|
|
76
117
|
filter_condition: Optional[str] = UNSET,
|
|
77
118
|
) -> int:
|
|
78
119
|
return await info.context.data_loaders.record_counts.load(
|
|
79
|
-
("span", self.
|
|
120
|
+
("span", self.project_rowid, time_range, filter_condition),
|
|
80
121
|
)
|
|
81
122
|
|
|
82
123
|
@strawberry.field
|
|
@@ -86,7 +127,7 @@ class Project(Node):
|
|
|
86
127
|
time_range: Optional[TimeRange] = UNSET,
|
|
87
128
|
) -> int:
|
|
88
129
|
return await info.context.data_loaders.record_counts.load(
|
|
89
|
-
("trace", self.
|
|
130
|
+
("trace", self.project_rowid, time_range, None),
|
|
90
131
|
)
|
|
91
132
|
|
|
92
133
|
@strawberry.field
|
|
@@ -97,7 +138,7 @@ class Project(Node):
|
|
|
97
138
|
filter_condition: Optional[str] = UNSET,
|
|
98
139
|
) -> int:
|
|
99
140
|
return await info.context.data_loaders.token_counts.load(
|
|
100
|
-
("total", self.
|
|
141
|
+
("total", self.project_rowid, time_range, filter_condition),
|
|
101
142
|
)
|
|
102
143
|
|
|
103
144
|
@strawberry.field
|
|
@@ -108,7 +149,7 @@ class Project(Node):
|
|
|
108
149
|
filter_condition: Optional[str] = UNSET,
|
|
109
150
|
) -> int:
|
|
110
151
|
return await info.context.data_loaders.token_counts.load(
|
|
111
|
-
("prompt", self.
|
|
152
|
+
("prompt", self.project_rowid, time_range, filter_condition),
|
|
112
153
|
)
|
|
113
154
|
|
|
114
155
|
@strawberry.field
|
|
@@ -119,7 +160,7 @@ class Project(Node):
|
|
|
119
160
|
filter_condition: Optional[str] = UNSET,
|
|
120
161
|
) -> int:
|
|
121
162
|
return await info.context.data_loaders.token_counts.load(
|
|
122
|
-
("completion", self.
|
|
163
|
+
("completion", self.project_rowid, time_range, filter_condition),
|
|
123
164
|
)
|
|
124
165
|
|
|
125
166
|
@strawberry.field
|
|
@@ -132,7 +173,7 @@ class Project(Node):
|
|
|
132
173
|
return await info.context.data_loaders.latency_ms_quantile.load(
|
|
133
174
|
(
|
|
134
175
|
"trace",
|
|
135
|
-
self.
|
|
176
|
+
self.project_rowid,
|
|
136
177
|
time_range,
|
|
137
178
|
None,
|
|
138
179
|
probability,
|
|
@@ -150,7 +191,7 @@ class Project(Node):
|
|
|
150
191
|
return await info.context.data_loaders.latency_ms_quantile.load(
|
|
151
192
|
(
|
|
152
193
|
"span",
|
|
153
|
-
self.
|
|
194
|
+
self.project_rowid,
|
|
154
195
|
time_range,
|
|
155
196
|
filter_condition,
|
|
156
197
|
probability,
|
|
@@ -162,12 +203,12 @@ class Project(Node):
|
|
|
162
203
|
stmt = (
|
|
163
204
|
select(models.Trace)
|
|
164
205
|
.where(models.Trace.trace_id == str(trace_id))
|
|
165
|
-
.where(models.Trace.project_rowid == self.
|
|
206
|
+
.where(models.Trace.project_rowid == self.project_rowid)
|
|
166
207
|
)
|
|
167
208
|
async with info.context.db() as session:
|
|
168
209
|
if (trace := await session.scalar(stmt)) is None:
|
|
169
210
|
return None
|
|
170
|
-
return
|
|
211
|
+
return Trace(trace_rowid=trace.id, db_trace=trace)
|
|
171
212
|
|
|
172
213
|
@strawberry.field
|
|
173
214
|
async def spans(
|
|
@@ -185,7 +226,7 @@ class Project(Node):
|
|
|
185
226
|
stmt = (
|
|
186
227
|
select(models.Span.id)
|
|
187
228
|
.join(models.Trace)
|
|
188
|
-
.where(models.Trace.project_rowid == self.
|
|
229
|
+
.where(models.Trace.project_rowid == self.project_rowid)
|
|
189
230
|
)
|
|
190
231
|
if time_range:
|
|
191
232
|
if time_range.start:
|
|
@@ -264,7 +305,7 @@ class Project(Node):
|
|
|
264
305
|
filter_io_substring: Optional[str] = UNSET,
|
|
265
306
|
) -> Connection[ProjectSession]:
|
|
266
307
|
table = models.ProjectSession
|
|
267
|
-
stmt = select(table).filter_by(project_id=self.
|
|
308
|
+
stmt = select(table).filter_by(project_id=self.project_rowid)
|
|
268
309
|
if time_range:
|
|
269
310
|
if time_range.start:
|
|
270
311
|
stmt = stmt.where(time_range.start <= table.start_time)
|
|
@@ -382,7 +423,7 @@ class Project(Node):
|
|
|
382
423
|
stmt = (
|
|
383
424
|
select(distinct(models.TraceAnnotation.name))
|
|
384
425
|
.join(models.Trace)
|
|
385
|
-
.where(models.Trace.project_rowid == self.
|
|
426
|
+
.where(models.Trace.project_rowid == self.project_rowid)
|
|
386
427
|
)
|
|
387
428
|
async with info.context.db() as session:
|
|
388
429
|
return list(await session.scalars(stmt))
|
|
@@ -399,7 +440,7 @@ class Project(Node):
|
|
|
399
440
|
select(distinct(models.SpanAnnotation.name))
|
|
400
441
|
.join(models.Span)
|
|
401
442
|
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
|
|
402
|
-
.where(models.Trace.project_rowid == self.
|
|
443
|
+
.where(models.Trace.project_rowid == self.project_rowid)
|
|
403
444
|
)
|
|
404
445
|
async with info.context.db() as session:
|
|
405
446
|
return list(await session.scalars(stmt))
|
|
@@ -416,7 +457,7 @@ class Project(Node):
|
|
|
416
457
|
select(distinct(models.DocumentAnnotation.name))
|
|
417
458
|
.join(models.Span)
|
|
418
459
|
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
|
|
419
|
-
.where(models.Trace.project_rowid == self.
|
|
460
|
+
.where(models.Trace.project_rowid == self.project_rowid)
|
|
420
461
|
.where(models.DocumentAnnotation.annotator_kind == "LLM")
|
|
421
462
|
)
|
|
422
463
|
if span_id:
|
|
@@ -432,7 +473,7 @@ class Project(Node):
|
|
|
432
473
|
time_range: Optional[TimeRange] = UNSET,
|
|
433
474
|
) -> Optional[AnnotationSummary]:
|
|
434
475
|
return await info.context.data_loaders.annotation_summaries.load(
|
|
435
|
-
("trace", self.
|
|
476
|
+
("trace", self.project_rowid, time_range, None, annotation_name),
|
|
436
477
|
)
|
|
437
478
|
|
|
438
479
|
@strawberry.field
|
|
@@ -444,7 +485,7 @@ class Project(Node):
|
|
|
444
485
|
filter_condition: Optional[str] = UNSET,
|
|
445
486
|
) -> Optional[AnnotationSummary]:
|
|
446
487
|
return await info.context.data_loaders.annotation_summaries.load(
|
|
447
|
-
("span", self.
|
|
488
|
+
("span", self.project_rowid, time_range, filter_condition, annotation_name),
|
|
448
489
|
)
|
|
449
490
|
|
|
450
491
|
@strawberry.field
|
|
@@ -456,7 +497,7 @@ class Project(Node):
|
|
|
456
497
|
filter_condition: Optional[str] = UNSET,
|
|
457
498
|
) -> Optional[DocumentEvaluationSummary]:
|
|
458
499
|
return await info.context.data_loaders.document_evaluation_summaries.load(
|
|
459
|
-
(self.
|
|
500
|
+
(self.project_rowid, time_range, filter_condition, evaluation_name),
|
|
460
501
|
)
|
|
461
502
|
|
|
462
503
|
@strawberry.field
|
|
@@ -464,7 +505,7 @@ class Project(Node):
|
|
|
464
505
|
self,
|
|
465
506
|
info: Info[Context, None],
|
|
466
507
|
) -> Optional[datetime]:
|
|
467
|
-
return info.context.last_updated_at.get(self._table, self.
|
|
508
|
+
return info.context.last_updated_at.get(self._table, self.project_rowid)
|
|
468
509
|
|
|
469
510
|
@strawberry.field
|
|
470
511
|
async def validate_span_filter_condition(self, condition: str) -> ValidationResult:
|
|
@@ -483,17 +524,5 @@ class Project(Node):
|
|
|
483
524
|
)
|
|
484
525
|
|
|
485
526
|
|
|
486
|
-
def to_gql_project(project: models.Project) -> Project:
|
|
487
|
-
"""
|
|
488
|
-
Converts an ORM project to a GraphQL Project.
|
|
489
|
-
"""
|
|
490
|
-
return Project(
|
|
491
|
-
id_attr=project.id,
|
|
492
|
-
name=project.name,
|
|
493
|
-
gradient_start_color=project.gradient_start_color,
|
|
494
|
-
gradient_end_color=project.gradient_end_color,
|
|
495
|
-
)
|
|
496
|
-
|
|
497
|
-
|
|
498
527
|
INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
|
|
499
528
|
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
|
|
@@ -93,7 +93,7 @@ class ProjectSession(Node):
|
|
|
93
93
|
after: Optional[CursorString] = UNSET,
|
|
94
94
|
before: Optional[CursorString] = UNSET,
|
|
95
95
|
) -> Connection[Annotated["Trace", lazy(".Trace")]]:
|
|
96
|
-
from phoenix.server.api.types.Trace import
|
|
96
|
+
from phoenix.server.api.types.Trace import Trace
|
|
97
97
|
|
|
98
98
|
args = ConnectionArgs(
|
|
99
99
|
first=first,
|
|
@@ -109,7 +109,7 @@ class ProjectSession(Node):
|
|
|
109
109
|
)
|
|
110
110
|
async with info.context.db() as session:
|
|
111
111
|
traces = await session.stream_scalars(stmt)
|
|
112
|
-
data = [
|
|
112
|
+
data = [Trace(trace_rowid=trace.id, db_trace=trace) async for trace in traces]
|
|
113
113
|
return connection_from_list(data=data, args=args)
|
|
114
114
|
|
|
115
115
|
@strawberry.field
|
phoenix/server/api/types/Span.py
CHANGED
|
@@ -37,6 +37,7 @@ from phoenix.trace.attributes import get_attribute_value
|
|
|
37
37
|
|
|
38
38
|
if TYPE_CHECKING:
|
|
39
39
|
from phoenix.server.api.types.Project import Project
|
|
40
|
+
from phoenix.server.api.types.Trace import Trace
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
@strawberry.enum
|
|
@@ -216,6 +217,34 @@ class Span(Node):
|
|
|
216
217
|
)
|
|
217
218
|
return SpanKind(value)
|
|
218
219
|
|
|
220
|
+
@strawberry.field
|
|
221
|
+
async def span_id(
|
|
222
|
+
self,
|
|
223
|
+
info: Info[Context, None],
|
|
224
|
+
) -> ID:
|
|
225
|
+
if self.db_span:
|
|
226
|
+
span_id = self.db_span.span_id
|
|
227
|
+
else:
|
|
228
|
+
span_id = await info.context.data_loaders.span_fields.load(
|
|
229
|
+
(self.span_rowid, models.Span.span_id),
|
|
230
|
+
)
|
|
231
|
+
return ID(span_id)
|
|
232
|
+
|
|
233
|
+
@strawberry.field
|
|
234
|
+
async def trace(
|
|
235
|
+
self,
|
|
236
|
+
info: Info[Context, None],
|
|
237
|
+
) -> Annotated["Trace", strawberry.lazy(".Trace")]:
|
|
238
|
+
if self.db_span:
|
|
239
|
+
trace_rowid = self.db_span.trace_rowid
|
|
240
|
+
else:
|
|
241
|
+
trace_rowid = await info.context.data_loaders.span_fields.load(
|
|
242
|
+
(self.span_rowid, models.Span.trace_rowid),
|
|
243
|
+
)
|
|
244
|
+
from phoenix.server.api.types.Trace import Trace
|
|
245
|
+
|
|
246
|
+
return Trace(trace_rowid=trace_rowid)
|
|
247
|
+
|
|
219
248
|
@strawberry.field
|
|
220
249
|
async def context(
|
|
221
250
|
self,
|
|
@@ -508,6 +537,10 @@ class Span(Node):
|
|
|
508
537
|
(self.span_rowid, evaluation_name or None, num_documents),
|
|
509
538
|
)
|
|
510
539
|
|
|
540
|
+
@strawberry.field
|
|
541
|
+
async def num_child_spans(self, info: Info[Context, None]) -> int:
|
|
542
|
+
return await info.context.data_loaders.num_child_spans.load(self.span_rowid)
|
|
543
|
+
|
|
511
544
|
@strawberry.field(
|
|
512
545
|
description="All descendant spans (children, grandchildren, etc.)",
|
|
513
546
|
) # type: ignore
|
|
@@ -561,11 +594,11 @@ class Span(Node):
|
|
|
561
594
|
) -> Annotated[
|
|
562
595
|
"Project", strawberry.lazy("phoenix.server.api.types.Project")
|
|
563
596
|
]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
|
|
564
|
-
from phoenix.server.api.types.Project import
|
|
597
|
+
from phoenix.server.api.types.Project import Project
|
|
565
598
|
|
|
566
599
|
span_id = self.span_rowid
|
|
567
600
|
project = await info.context.data_loaders.span_projects.load(span_id)
|
|
568
|
-
return
|
|
601
|
+
return Project(project_rowid=project.id, db_project=project)
|
|
569
602
|
|
|
570
603
|
@strawberry.field(description="Indicates if the span is contained in any dataset") # type: ignore
|
|
571
604
|
async def contained_in_dataset(
|