arize-phoenix 6.2.0__py3-none-any.whl → 7.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/METADATA +4 -4
- {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/RECORD +45 -32
- phoenix/db/engines.py +14 -2
- phoenix/db/insertion/span.py +65 -30
- phoenix/db/migrations/data_migration_scripts/__init__.py +0 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/models.py +27 -0
- phoenix/server/api/context.py +15 -2
- phoenix/server/api/dataloaders/__init__.py +14 -2
- phoenix/server/api/dataloaders/session_io.py +75 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/mutations/project_mutations.py +12 -2
- phoenix/server/api/queries.py +14 -9
- phoenix/server/api/types/ExperimentRun.py +3 -4
- phoenix/server/api/types/ExperimentRunAnnotation.py +3 -4
- phoenix/server/api/types/Project.py +150 -12
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Span.py +6 -19
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +59 -2
- phoenix/server/app.py +15 -2
- phoenix/server/static/.vite/manifest.json +36 -36
- phoenix/server/static/assets/{components-DrwPLMB6.js → components-DKH6AzJw.js} +276 -276
- phoenix/server/static/assets/index-DLV87qiO.js +93 -0
- phoenix/server/static/assets/{pages-Cmqh2i4E.js → pages-CVY3Nv4Z.js} +611 -290
- phoenix/server/static/assets/{vendor-Cdrqqth8.js → vendor-Cb3zlNNd.js} +45 -45
- phoenix/server/static/assets/{vendor-arizeai-BSCL03yQ.js → vendor-arizeai-Buo4e1A6.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-Utqu7Snw.js → vendor-codemirror-BuAQiUVf.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-CNNUvc5T.js → vendor-recharts-Cl9dK5tC.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-B6bHerDK.js → vendor-shiki-CazYUixL.js} +1 -1
- phoenix/trace/fixtures.py +8 -0
- phoenix/trace/schemas.py +16 -0
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/trace_row_ids.py +0 -33
- phoenix/server/static/assets/index-CTN-OfBU.js +0 -93
- {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/WHEEL +0 -0
- {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/licenses/LICENSE +0 -0
phoenix/db/models.py
CHANGED
|
@@ -156,14 +156,37 @@ class Project(Base):
|
|
|
156
156
|
)
|
|
157
157
|
|
|
158
158
|
|
|
159
|
+
class ProjectSession(Base):
|
|
160
|
+
__tablename__ = "project_sessions"
|
|
161
|
+
id: Mapped[int] = mapped_column(primary_key=True)
|
|
162
|
+
session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True)
|
|
163
|
+
project_id: Mapped[int] = mapped_column(
|
|
164
|
+
ForeignKey("projects.id", ondelete="CASCADE"),
|
|
165
|
+
nullable=False,
|
|
166
|
+
index=True,
|
|
167
|
+
)
|
|
168
|
+
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False)
|
|
169
|
+
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False)
|
|
170
|
+
traces: Mapped[list["Trace"]] = relationship(
|
|
171
|
+
"Trace",
|
|
172
|
+
back_populates="project_session",
|
|
173
|
+
uselist=True,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
159
177
|
class Trace(Base):
|
|
160
178
|
__tablename__ = "traces"
|
|
161
179
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
162
180
|
project_rowid: Mapped[int] = mapped_column(
|
|
163
181
|
ForeignKey("projects.id", ondelete="CASCADE"),
|
|
182
|
+
nullable=False,
|
|
164
183
|
index=True,
|
|
165
184
|
)
|
|
166
185
|
trace_id: Mapped[str]
|
|
186
|
+
project_session_rowid: Mapped[Optional[int]] = mapped_column(
|
|
187
|
+
ForeignKey("project_sessions.id", ondelete="CASCADE"),
|
|
188
|
+
index=True,
|
|
189
|
+
)
|
|
167
190
|
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
|
|
168
191
|
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
|
|
169
192
|
|
|
@@ -188,6 +211,10 @@ class Trace(Base):
|
|
|
188
211
|
cascade="all, delete-orphan",
|
|
189
212
|
uselist=True,
|
|
190
213
|
)
|
|
214
|
+
project_session: Mapped[ProjectSession] = relationship(
|
|
215
|
+
"ProjectSession",
|
|
216
|
+
back_populates="traces",
|
|
217
|
+
)
|
|
191
218
|
experiment_runs: Mapped[list["ExperimentRun"]] = relationship(
|
|
192
219
|
primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
|
|
193
220
|
back_populates="trace",
|
phoenix/server/api/context.py
CHANGED
|
@@ -31,12 +31,18 @@ from phoenix.server.api.dataloaders import (
|
|
|
31
31
|
MinStartOrMaxEndTimeDataLoader,
|
|
32
32
|
ProjectByNameDataLoader,
|
|
33
33
|
RecordCountDataLoader,
|
|
34
|
+
SessionIODataLoader,
|
|
35
|
+
SessionNumTracesDataLoader,
|
|
36
|
+
SessionNumTracesWithErrorDataLoader,
|
|
37
|
+
SessionTokenUsagesDataLoader,
|
|
38
|
+
SessionTraceLatencyMsQuantileDataLoader,
|
|
34
39
|
SpanAnnotationsDataLoader,
|
|
35
40
|
SpanDatasetExamplesDataLoader,
|
|
36
41
|
SpanDescendantsDataLoader,
|
|
37
42
|
SpanProjectsDataLoader,
|
|
38
43
|
TokenCountDataLoader,
|
|
39
|
-
|
|
44
|
+
TraceByTraceIdsDataLoader,
|
|
45
|
+
TraceRootSpansDataLoader,
|
|
40
46
|
UserRolesDataLoader,
|
|
41
47
|
UsersDataLoader,
|
|
42
48
|
)
|
|
@@ -68,12 +74,19 @@ class DataLoaders:
|
|
|
68
74
|
latency_ms_quantile: LatencyMsQuantileDataLoader
|
|
69
75
|
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
|
|
70
76
|
record_counts: RecordCountDataLoader
|
|
77
|
+
session_first_inputs: SessionIODataLoader
|
|
78
|
+
session_last_outputs: SessionIODataLoader
|
|
79
|
+
session_num_traces: SessionNumTracesDataLoader
|
|
80
|
+
session_num_traces_with_error: SessionNumTracesWithErrorDataLoader
|
|
81
|
+
session_token_usages: SessionTokenUsagesDataLoader
|
|
82
|
+
session_trace_latency_ms_quantile: SessionTraceLatencyMsQuantileDataLoader
|
|
71
83
|
span_annotations: SpanAnnotationsDataLoader
|
|
72
84
|
span_dataset_examples: SpanDatasetExamplesDataLoader
|
|
73
85
|
span_descendants: SpanDescendantsDataLoader
|
|
74
86
|
span_projects: SpanProjectsDataLoader
|
|
75
87
|
token_counts: TokenCountDataLoader
|
|
76
|
-
|
|
88
|
+
trace_by_trace_ids: TraceByTraceIdsDataLoader
|
|
89
|
+
trace_root_spans: TraceRootSpansDataLoader
|
|
77
90
|
project_by_name: ProjectByNameDataLoader
|
|
78
91
|
users: UsersDataLoader
|
|
79
92
|
user_roles: UserRolesDataLoader
|
|
@@ -19,12 +19,18 @@ from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLo
|
|
|
19
19
|
from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
|
|
20
20
|
from .project_by_name import ProjectByNameDataLoader
|
|
21
21
|
from .record_counts import RecordCountCache, RecordCountDataLoader
|
|
22
|
+
from .session_io import SessionIODataLoader
|
|
23
|
+
from .session_num_traces import SessionNumTracesDataLoader
|
|
24
|
+
from .session_num_traces_with_error import SessionNumTracesWithErrorDataLoader
|
|
25
|
+
from .session_token_usages import SessionTokenUsagesDataLoader
|
|
26
|
+
from .session_trace_latency_ms_quantile import SessionTraceLatencyMsQuantileDataLoader
|
|
22
27
|
from .span_annotations import SpanAnnotationsDataLoader
|
|
23
28
|
from .span_dataset_examples import SpanDatasetExamplesDataLoader
|
|
24
29
|
from .span_descendants import SpanDescendantsDataLoader
|
|
25
30
|
from .span_projects import SpanProjectsDataLoader
|
|
26
31
|
from .token_counts import TokenCountCache, TokenCountDataLoader
|
|
27
|
-
from .
|
|
32
|
+
from .trace_by_trace_ids import TraceByTraceIdsDataLoader
|
|
33
|
+
from .trace_root_spans import TraceRootSpansDataLoader
|
|
28
34
|
from .user_roles import UserRolesDataLoader
|
|
29
35
|
from .users import UsersDataLoader
|
|
30
36
|
|
|
@@ -45,11 +51,17 @@ __all__ = [
|
|
|
45
51
|
"LatencyMsQuantileDataLoader",
|
|
46
52
|
"MinStartOrMaxEndTimeDataLoader",
|
|
47
53
|
"RecordCountDataLoader",
|
|
54
|
+
"SessionIODataLoader",
|
|
55
|
+
"SessionNumTracesDataLoader",
|
|
56
|
+
"SessionNumTracesWithErrorDataLoader",
|
|
57
|
+
"SessionTokenUsagesDataLoader",
|
|
58
|
+
"SessionTraceLatencyMsQuantileDataLoader",
|
|
48
59
|
"SpanDatasetExamplesDataLoader",
|
|
49
60
|
"SpanDescendantsDataLoader",
|
|
50
61
|
"SpanProjectsDataLoader",
|
|
51
62
|
"TokenCountDataLoader",
|
|
52
|
-
"
|
|
63
|
+
"TraceByTraceIdsDataLoader",
|
|
64
|
+
"TraceRootSpansDataLoader",
|
|
53
65
|
"ProjectByNameDataLoader",
|
|
54
66
|
"SpanAnnotationsDataLoader",
|
|
55
67
|
"UsersDataLoader",
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import Literal, Optional, cast
|
|
3
|
+
|
|
4
|
+
from openinference.semconv.trace import SpanAttributes
|
|
5
|
+
from sqlalchemy import Select, func, select
|
|
6
|
+
from strawberry.dataloader import DataLoader
|
|
7
|
+
from typing_extensions import TypeAlias, assert_never
|
|
8
|
+
|
|
9
|
+
from phoenix.db import models
|
|
10
|
+
from phoenix.server.types import DbSessionFactory
|
|
11
|
+
from phoenix.trace.schemas import MimeType, SpanIOValue
|
|
12
|
+
|
|
13
|
+
Key: TypeAlias = int
|
|
14
|
+
Result: TypeAlias = Optional[SpanIOValue]
|
|
15
|
+
|
|
16
|
+
Kind = Literal["first_input", "last_output"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SessionIODataLoader(DataLoader[Key, Result]):
|
|
20
|
+
def __init__(self, db: DbSessionFactory, kind: Kind) -> None:
|
|
21
|
+
super().__init__(load_fn=self._load_fn)
|
|
22
|
+
self._db = db
|
|
23
|
+
self._kind = kind
|
|
24
|
+
|
|
25
|
+
@cached_property
|
|
26
|
+
def _subq(self) -> Select[tuple[Optional[int], str, str, int]]:
|
|
27
|
+
stmt = (
|
|
28
|
+
select(models.Trace.project_session_rowid.label("id_"))
|
|
29
|
+
.join_from(models.Span, models.Trace)
|
|
30
|
+
.where(models.Span.parent_id.is_(None))
|
|
31
|
+
)
|
|
32
|
+
if self._kind == "first_input":
|
|
33
|
+
stmt = stmt.add_columns(
|
|
34
|
+
models.Span.attributes[INPUT_VALUE].label("value"),
|
|
35
|
+
models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"),
|
|
36
|
+
func.row_number()
|
|
37
|
+
.over(
|
|
38
|
+
partition_by=models.Trace.project_session_rowid,
|
|
39
|
+
order_by=[models.Trace.start_time.asc(), models.Trace.id.asc()],
|
|
40
|
+
)
|
|
41
|
+
.label("rank"),
|
|
42
|
+
)
|
|
43
|
+
elif self._kind == "last_output":
|
|
44
|
+
stmt = stmt.add_columns(
|
|
45
|
+
models.Span.attributes[OUTPUT_VALUE].label("value"),
|
|
46
|
+
models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"),
|
|
47
|
+
func.row_number()
|
|
48
|
+
.over(
|
|
49
|
+
partition_by=models.Trace.project_session_rowid,
|
|
50
|
+
order_by=[models.Trace.start_time.desc(), models.Trace.id.desc()],
|
|
51
|
+
)
|
|
52
|
+
.label("rank"),
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
assert_never(self._kind)
|
|
56
|
+
return cast(Select[tuple[Optional[int], str, str, int]], stmt)
|
|
57
|
+
|
|
58
|
+
def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]:
|
|
59
|
+
subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery()
|
|
60
|
+
return select(subq.c.id_, subq.c.value, subq.c.mime_type).filter_by(rank=1)
|
|
61
|
+
|
|
62
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
63
|
+
async with self._db() as session:
|
|
64
|
+
result: dict[Key, SpanIOValue] = {
|
|
65
|
+
id_: SpanIOValue(value=value, mime_type=MimeType(mime_type))
|
|
66
|
+
async for id_, value, mime_type in await session.stream(self._stmt(*keys))
|
|
67
|
+
if id_ is not None
|
|
68
|
+
}
|
|
69
|
+
return [result.get(key) for key in keys]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
|
|
73
|
+
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
|
|
74
|
+
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
|
|
75
|
+
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".")
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from sqlalchemy import func, select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
Key: TypeAlias = int
|
|
9
|
+
Result: TypeAlias = int
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SessionNumTracesDataLoader(DataLoader[Key, Result]):
|
|
13
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
14
|
+
super().__init__(load_fn=self._load_fn)
|
|
15
|
+
self._db = db
|
|
16
|
+
|
|
17
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
18
|
+
stmt = (
|
|
19
|
+
select(
|
|
20
|
+
models.Trace.project_session_rowid.label("id_"),
|
|
21
|
+
func.count(models.Trace.id).label("value"),
|
|
22
|
+
)
|
|
23
|
+
.group_by(models.Trace.project_session_rowid)
|
|
24
|
+
.where(models.Trace.project_session_rowid.in_(keys))
|
|
25
|
+
)
|
|
26
|
+
async with self._db() as session:
|
|
27
|
+
result: dict[Key, int] = {
|
|
28
|
+
id_: value async for id_, value in await session.stream(stmt) if id_ is not None
|
|
29
|
+
}
|
|
30
|
+
return [result.get(key, 0) for key in keys]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from sqlalchemy import func, select
|
|
2
|
+
from strawberry.dataloader import DataLoader
|
|
3
|
+
from typing_extensions import TypeAlias
|
|
4
|
+
|
|
5
|
+
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
|
+
|
|
8
|
+
Key: TypeAlias = int
|
|
9
|
+
Result: TypeAlias = int
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SessionNumTracesWithErrorDataLoader(DataLoader[Key, Result]):
|
|
13
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
14
|
+
super().__init__(load_fn=self._load_fn)
|
|
15
|
+
self._db = db
|
|
16
|
+
|
|
17
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
18
|
+
stmt = (
|
|
19
|
+
select(
|
|
20
|
+
models.Trace.project_session_rowid.label("id_"),
|
|
21
|
+
func.count(models.Trace.id).label("value"),
|
|
22
|
+
)
|
|
23
|
+
.join(models.Span)
|
|
24
|
+
.group_by(models.Trace.project_session_rowid)
|
|
25
|
+
.where(models.Span.cumulative_error_count > 0)
|
|
26
|
+
.where(models.Trace.project_session_rowid.in_(keys))
|
|
27
|
+
)
|
|
28
|
+
async with self._db() as session:
|
|
29
|
+
result: dict[Key, int] = {
|
|
30
|
+
id_: value async for id_, value in await session.stream(stmt) if id_ is not None
|
|
31
|
+
}
|
|
32
|
+
return [result.get(key, 0) for key in keys]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from sqlalchemy import func, select
|
|
2
|
+
from sqlalchemy.sql.functions import coalesce
|
|
3
|
+
from strawberry.dataloader import DataLoader
|
|
4
|
+
from typing_extensions import TypeAlias
|
|
5
|
+
|
|
6
|
+
from phoenix.db import models
|
|
7
|
+
from phoenix.server.types import DbSessionFactory
|
|
8
|
+
from phoenix.trace.schemas import TokenUsage
|
|
9
|
+
|
|
10
|
+
Key: TypeAlias = int
|
|
11
|
+
Result: TypeAlias = TokenUsage
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SessionTokenUsagesDataLoader(DataLoader[Key, Result]):
|
|
15
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
16
|
+
super().__init__(load_fn=self._load_fn)
|
|
17
|
+
self._db = db
|
|
18
|
+
|
|
19
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
20
|
+
stmt = (
|
|
21
|
+
select(
|
|
22
|
+
models.Trace.project_session_rowid.label("id_"),
|
|
23
|
+
func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label(
|
|
24
|
+
"prompt"
|
|
25
|
+
),
|
|
26
|
+
func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label(
|
|
27
|
+
"completion"
|
|
28
|
+
),
|
|
29
|
+
)
|
|
30
|
+
.join_from(models.Span, models.Trace)
|
|
31
|
+
.where(models.Span.parent_id.is_(None))
|
|
32
|
+
.where(models.Trace.project_session_rowid.in_(keys))
|
|
33
|
+
.group_by(models.Trace.project_session_rowid)
|
|
34
|
+
)
|
|
35
|
+
async with self._db() as session:
|
|
36
|
+
result: dict[Key, TokenUsage] = {
|
|
37
|
+
id_: TokenUsage(prompt=prompt, completion=completion)
|
|
38
|
+
async for id_, prompt, completion in await session.stream(stmt)
|
|
39
|
+
if id_ is not None
|
|
40
|
+
}
|
|
41
|
+
return [result.get(key, TokenUsage()) for key in keys]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from aioitertools.itertools import groupby
|
|
6
|
+
from sqlalchemy import select
|
|
7
|
+
from strawberry.dataloader import DataLoader
|
|
8
|
+
from typing_extensions import TypeAlias
|
|
9
|
+
|
|
10
|
+
from phoenix.db import models
|
|
11
|
+
from phoenix.server.types import DbSessionFactory
|
|
12
|
+
|
|
13
|
+
SessionId: TypeAlias = int
|
|
14
|
+
Probability: TypeAlias = float
|
|
15
|
+
QuantileValue: TypeAlias = float
|
|
16
|
+
|
|
17
|
+
Key: TypeAlias = tuple[SessionId, Probability]
|
|
18
|
+
Result: TypeAlias = Optional[QuantileValue]
|
|
19
|
+
ResultPosition: TypeAlias = int
|
|
20
|
+
|
|
21
|
+
DEFAULT_VALUE: Result = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SessionTraceLatencyMsQuantileDataLoader(DataLoader[Key, Result]):
|
|
25
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
26
|
+
super().__init__(load_fn=self._load_fn)
|
|
27
|
+
self._db = db
|
|
28
|
+
|
|
29
|
+
async def _load_fn(self, keys: list[Key]) -> list[Result]:
|
|
30
|
+
results: list[Result] = [DEFAULT_VALUE] * len(keys)
|
|
31
|
+
argument_position_map: defaultdict[
|
|
32
|
+
SessionId, defaultdict[Probability, list[ResultPosition]]
|
|
33
|
+
] = defaultdict(lambda: defaultdict(list))
|
|
34
|
+
session_rowids = {session_id for session_id, _ in keys}
|
|
35
|
+
for position, (session_id, probability) in enumerate(keys):
|
|
36
|
+
argument_position_map[session_id][probability].append(position)
|
|
37
|
+
stmt = (
|
|
38
|
+
select(
|
|
39
|
+
models.Trace.project_session_rowid,
|
|
40
|
+
models.Trace.latency_ms,
|
|
41
|
+
)
|
|
42
|
+
.where(models.Trace.project_session_rowid.in_(session_rowids))
|
|
43
|
+
.order_by(models.Trace.project_session_rowid)
|
|
44
|
+
)
|
|
45
|
+
async with self._db() as session:
|
|
46
|
+
data = await session.stream(stmt)
|
|
47
|
+
async for project_session_rowid, group in groupby(
|
|
48
|
+
data, lambda row: row.project_session_rowid
|
|
49
|
+
):
|
|
50
|
+
session_latencies = [row.latency_ms for row in group]
|
|
51
|
+
for probability, positions in argument_position_map[project_session_rowid].items():
|
|
52
|
+
quantile_value = np.quantile(session_latencies, probability)
|
|
53
|
+
for position in positions:
|
|
54
|
+
results[position] = quantile_value
|
|
55
|
+
return results
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from strawberry.dataloader import DataLoader
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.db import models
|
|
8
|
+
from phoenix.server.types import DbSessionFactory
|
|
9
|
+
|
|
10
|
+
Key: TypeAlias = str
|
|
11
|
+
Result: TypeAlias = Optional[models.Trace]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TraceByTraceIdsDataLoader(DataLoader[Key, Result]):
|
|
15
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
16
|
+
super().__init__(load_fn=self._load_fn)
|
|
17
|
+
self._db = db
|
|
18
|
+
|
|
19
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
20
|
+
stmt = select(models.Trace).where(models.Trace.trace_id.in_(keys))
|
|
21
|
+
async with self._db() as session:
|
|
22
|
+
result: dict[Key, models.Trace] = {
|
|
23
|
+
trace.trace_id: trace async for trace in await session.stream_scalars(stmt)
|
|
24
|
+
}
|
|
25
|
+
return [result.get(trace_id) for trace_id in keys]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from sqlalchemy.orm import contains_eager
|
|
5
|
+
from strawberry.dataloader import DataLoader
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.types import DbSessionFactory
|
|
10
|
+
|
|
11
|
+
Key: TypeAlias = int
|
|
12
|
+
Result: TypeAlias = Optional[models.Span]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TraceRootSpansDataLoader(DataLoader[Key, Result]):
|
|
16
|
+
def __init__(self, db: DbSessionFactory) -> None:
|
|
17
|
+
super().__init__(load_fn=self._load_fn)
|
|
18
|
+
self._db = db
|
|
19
|
+
|
|
20
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
21
|
+
stmt = (
|
|
22
|
+
select(models.Span)
|
|
23
|
+
.join(models.Trace)
|
|
24
|
+
.where(models.Span.parent_id.is_(None))
|
|
25
|
+
.where(models.Trace.id.in_(keys))
|
|
26
|
+
.options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
|
|
27
|
+
)
|
|
28
|
+
async with self._db() as session:
|
|
29
|
+
result: dict[Key, models.Span] = {
|
|
30
|
+
span.trace_rowid: span async for span in await session.stream_scalars(stmt)
|
|
31
|
+
}
|
|
32
|
+
return [result.get(key) for key in keys]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from enum import Enum, auto
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from typing_extensions import assert_never
|
|
5
|
+
|
|
6
|
+
from phoenix.server.api.types.pagination import CursorSortColumnDataType
|
|
7
|
+
from phoenix.server.api.types.SortDir import SortDir
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@strawberry.enum
|
|
11
|
+
class ProjectSessionColumn(Enum):
|
|
12
|
+
startTime = auto()
|
|
13
|
+
endTime = auto()
|
|
14
|
+
tokenCountTotal = auto()
|
|
15
|
+
numTraces = auto()
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def data_type(self) -> CursorSortColumnDataType:
|
|
19
|
+
if self is ProjectSessionColumn.tokenCountTotal or self is ProjectSessionColumn.numTraces:
|
|
20
|
+
return CursorSortColumnDataType.INT
|
|
21
|
+
if self is ProjectSessionColumn.startTime or self is ProjectSessionColumn.endTime:
|
|
22
|
+
return CursorSortColumnDataType.DATETIME
|
|
23
|
+
assert_never(self)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@strawberry.input(description="The sort key and direction for ProjectSession connections.")
|
|
27
|
+
class ProjectSessionSort:
|
|
28
|
+
col: ProjectSessionColumn
|
|
29
|
+
dir: SortDir
|
|
@@ -38,10 +38,20 @@ class ProjectMutationMixin:
|
|
|
38
38
|
project_id = from_global_id_with_expected_type(
|
|
39
39
|
global_id=input.id, expected_type_name="Project"
|
|
40
40
|
)
|
|
41
|
-
delete_statement =
|
|
41
|
+
delete_statement = (
|
|
42
|
+
delete(models.Trace)
|
|
43
|
+
.where(models.Trace.project_rowid == project_id)
|
|
44
|
+
.returning(models.Trace.project_session_rowid)
|
|
45
|
+
)
|
|
42
46
|
if input.end_time:
|
|
43
47
|
delete_statement = delete_statement.where(models.Trace.start_time < input.end_time)
|
|
44
48
|
async with info.context.db() as session:
|
|
45
|
-
await session.
|
|
49
|
+
deleted_trace_project_session_ids = await session.scalars(delete_statement)
|
|
50
|
+
if deleted_trace_project_session_ids:
|
|
51
|
+
await session.execute(
|
|
52
|
+
delete(models.ProjectSession).where(
|
|
53
|
+
models.ProjectSession.id.in_(set(deleted_trace_project_session_ids))
|
|
54
|
+
)
|
|
55
|
+
)
|
|
46
56
|
info.context.event_queue.put(SpanDeleteEvent((project_id,)))
|
|
47
57
|
return Query()
|
phoenix/server/api/queries.py
CHANGED
|
@@ -78,10 +78,11 @@ from phoenix.server.api.types.pagination import (
|
|
|
78
78
|
connection_from_list,
|
|
79
79
|
)
|
|
80
80
|
from phoenix.server.api.types.Project import Project
|
|
81
|
+
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
|
|
81
82
|
from phoenix.server.api.types.SortDir import SortDir
|
|
82
83
|
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
83
84
|
from phoenix.server.api.types.SystemApiKey import SystemApiKey
|
|
84
|
-
from phoenix.server.api.types.Trace import
|
|
85
|
+
from phoenix.server.api.types.Trace import to_gql_trace
|
|
85
86
|
from phoenix.server.api.types.User import User, to_gql_user
|
|
86
87
|
from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
|
|
87
88
|
from phoenix.server.api.types.UserRole import UserRole
|
|
@@ -445,17 +446,12 @@ class Query:
|
|
|
445
446
|
gradient_end_color=project.gradient_end_color,
|
|
446
447
|
)
|
|
447
448
|
elif type_name == "Trace":
|
|
448
|
-
trace_stmt = select(
|
|
449
|
-
models.Trace.id,
|
|
450
|
-
models.Trace.project_rowid,
|
|
451
|
-
).where(models.Trace.id == node_id)
|
|
449
|
+
trace_stmt = select(models.Trace).filter_by(id=node_id)
|
|
452
450
|
async with info.context.db() as session:
|
|
453
|
-
trace =
|
|
451
|
+
trace = await session.scalar(trace_stmt)
|
|
454
452
|
if trace is None:
|
|
455
453
|
raise NotFound(f"Unknown trace: {id}")
|
|
456
|
-
return
|
|
457
|
-
id_attr=trace.id, trace_id=trace.trace_id, project_rowid=trace.project_rowid
|
|
458
|
-
)
|
|
454
|
+
return to_gql_trace(trace)
|
|
459
455
|
elif type_name == Span.__name__:
|
|
460
456
|
span_stmt = (
|
|
461
457
|
select(models.Span)
|
|
@@ -544,6 +540,15 @@ class Query:
|
|
|
544
540
|
):
|
|
545
541
|
raise NotFound(f"Unknown user: {id}")
|
|
546
542
|
return to_gql_user(user)
|
|
543
|
+
elif type_name == ProjectSession.__name__:
|
|
544
|
+
async with info.context.db() as session:
|
|
545
|
+
if not (
|
|
546
|
+
project_session := await session.scalar(
|
|
547
|
+
select(models.ProjectSession).filter_by(id=node_id)
|
|
548
|
+
)
|
|
549
|
+
):
|
|
550
|
+
raise NotFound(f"Unknown user: {id}")
|
|
551
|
+
return to_gql_project_session(project_session)
|
|
547
552
|
raise NotFound(f"Unknown node type: {type_name}")
|
|
548
553
|
|
|
549
554
|
@strawberry.field
|
|
@@ -20,7 +20,7 @@ from phoenix.server.api.types.pagination import (
|
|
|
20
20
|
CursorString,
|
|
21
21
|
connection_from_list,
|
|
22
22
|
)
|
|
23
|
-
from phoenix.server.api.types.Trace import Trace
|
|
23
|
+
from phoenix.server.api.types.Trace import Trace, to_gql_trace
|
|
24
24
|
|
|
25
25
|
if TYPE_CHECKING:
|
|
26
26
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
@@ -61,11 +61,10 @@ class ExperimentRun(Node):
|
|
|
61
61
|
async def trace(self, info: Info) -> Optional[Trace]:
|
|
62
62
|
if not self.trace_id:
|
|
63
63
|
return None
|
|
64
|
-
dataloader = info.context.data_loaders.
|
|
64
|
+
dataloader = info.context.data_loaders.trace_by_trace_ids
|
|
65
65
|
if (trace := await dataloader.load(self.trace_id)) is None:
|
|
66
66
|
return None
|
|
67
|
-
|
|
68
|
-
return Trace(id_attr=trace_rowid, trace_id=self.trace_id, project_rowid=project_rowid)
|
|
67
|
+
return to_gql_trace(trace)
|
|
69
68
|
|
|
70
69
|
@strawberry.field
|
|
71
70
|
async def example(
|
|
@@ -8,7 +8,7 @@ from strawberry.scalars import JSON
|
|
|
8
8
|
|
|
9
9
|
from phoenix.db import models
|
|
10
10
|
from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
|
|
11
|
-
from phoenix.server.api.types.Trace import Trace
|
|
11
|
+
from phoenix.server.api.types.Trace import Trace, to_gql_trace
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
@strawberry.type
|
|
@@ -29,11 +29,10 @@ class ExperimentRunAnnotation(Node):
|
|
|
29
29
|
async def trace(self, info: Info) -> Optional[Trace]:
|
|
30
30
|
if not self.trace_id:
|
|
31
31
|
return None
|
|
32
|
-
dataloader = info.context.data_loaders.
|
|
32
|
+
dataloader = info.context.data_loaders.trace_by_trace_ids
|
|
33
33
|
if (trace := await dataloader.load(self.trace_id)) is None:
|
|
34
34
|
return None
|
|
35
|
-
|
|
36
|
-
return Trace(id_attr=trace_row_id, trace_id=self.trace_id, project_rowid=project_row_id)
|
|
35
|
+
return to_gql_trace(trace)
|
|
37
36
|
|
|
38
37
|
|
|
39
38
|
def to_gql_experiment_run_annotation(
|