arize-phoenix 6.2.0__py3-none-any.whl → 7.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (47) hide show
  1. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/METADATA +4 -4
  2. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/RECORD +45 -32
  3. phoenix/db/engines.py +14 -2
  4. phoenix/db/insertion/span.py +65 -30
  5. phoenix/db/migrations/data_migration_scripts/__init__.py +0 -0
  6. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  7. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  8. phoenix/db/models.py +27 -0
  9. phoenix/server/api/context.py +15 -2
  10. phoenix/server/api/dataloaders/__init__.py +14 -2
  11. phoenix/server/api/dataloaders/session_io.py +75 -0
  12. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  13. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  14. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  15. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  16. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  17. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  18. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  19. phoenix/server/api/mutations/project_mutations.py +12 -2
  20. phoenix/server/api/queries.py +14 -9
  21. phoenix/server/api/types/ExperimentRun.py +3 -4
  22. phoenix/server/api/types/ExperimentRunAnnotation.py +3 -4
  23. phoenix/server/api/types/Project.py +150 -12
  24. phoenix/server/api/types/ProjectSession.py +139 -0
  25. phoenix/server/api/types/Span.py +6 -19
  26. phoenix/server/api/types/SpanIOValue.py +15 -0
  27. phoenix/server/api/types/TokenUsage.py +11 -0
  28. phoenix/server/api/types/Trace.py +59 -2
  29. phoenix/server/app.py +15 -2
  30. phoenix/server/static/.vite/manifest.json +36 -36
  31. phoenix/server/static/assets/{components-DrwPLMB6.js → components-DKH6AzJw.js} +276 -276
  32. phoenix/server/static/assets/index-DLV87qiO.js +93 -0
  33. phoenix/server/static/assets/{pages-Cmqh2i4E.js → pages-CVY3Nv4Z.js} +611 -290
  34. phoenix/server/static/assets/{vendor-Cdrqqth8.js → vendor-Cb3zlNNd.js} +45 -45
  35. phoenix/server/static/assets/{vendor-arizeai-BSCL03yQ.js → vendor-arizeai-Buo4e1A6.js} +2 -2
  36. phoenix/server/static/assets/{vendor-codemirror-Utqu7Snw.js → vendor-codemirror-BuAQiUVf.js} +1 -1
  37. phoenix/server/static/assets/{vendor-recharts-CNNUvc5T.js → vendor-recharts-Cl9dK5tC.js} +1 -1
  38. phoenix/server/static/assets/{vendor-shiki-B6bHerDK.js → vendor-shiki-CazYUixL.js} +1 -1
  39. phoenix/trace/fixtures.py +8 -0
  40. phoenix/trace/schemas.py +16 -0
  41. phoenix/version.py +1 -1
  42. phoenix/server/api/dataloaders/trace_row_ids.py +0 -33
  43. phoenix/server/static/assets/index-CTN-OfBU.js +0 -93
  44. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/WHEEL +0 -0
  45. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/entry_points.txt +0 -0
  46. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  47. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.1.dist-info}/licenses/LICENSE +0 -0
phoenix/db/models.py CHANGED
@@ -156,14 +156,37 @@ class Project(Base):
156
156
  )
157
157
 
158
158
 
159
+ class ProjectSession(Base):
160
+ __tablename__ = "project_sessions"
161
+ id: Mapped[int] = mapped_column(primary_key=True)
162
+ session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True)
163
+ project_id: Mapped[int] = mapped_column(
164
+ ForeignKey("projects.id", ondelete="CASCADE"),
165
+ nullable=False,
166
+ index=True,
167
+ )
168
+ start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False)
169
+ end_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False)
170
+ traces: Mapped[list["Trace"]] = relationship(
171
+ "Trace",
172
+ back_populates="project_session",
173
+ uselist=True,
174
+ )
175
+
176
+
159
177
  class Trace(Base):
160
178
  __tablename__ = "traces"
161
179
  id: Mapped[int] = mapped_column(primary_key=True)
162
180
  project_rowid: Mapped[int] = mapped_column(
163
181
  ForeignKey("projects.id", ondelete="CASCADE"),
182
+ nullable=False,
164
183
  index=True,
165
184
  )
166
185
  trace_id: Mapped[str]
186
+ project_session_rowid: Mapped[Optional[int]] = mapped_column(
187
+ ForeignKey("project_sessions.id", ondelete="CASCADE"),
188
+ index=True,
189
+ )
167
190
  start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
168
191
  end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
169
192
 
@@ -188,6 +211,10 @@ class Trace(Base):
188
211
  cascade="all, delete-orphan",
189
212
  uselist=True,
190
213
  )
214
+ project_session: Mapped[ProjectSession] = relationship(
215
+ "ProjectSession",
216
+ back_populates="traces",
217
+ )
191
218
  experiment_runs: Mapped[list["ExperimentRun"]] = relationship(
192
219
  primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
193
220
  back_populates="trace",
@@ -31,12 +31,18 @@ from phoenix.server.api.dataloaders import (
31
31
  MinStartOrMaxEndTimeDataLoader,
32
32
  ProjectByNameDataLoader,
33
33
  RecordCountDataLoader,
34
+ SessionIODataLoader,
35
+ SessionNumTracesDataLoader,
36
+ SessionNumTracesWithErrorDataLoader,
37
+ SessionTokenUsagesDataLoader,
38
+ SessionTraceLatencyMsQuantileDataLoader,
34
39
  SpanAnnotationsDataLoader,
35
40
  SpanDatasetExamplesDataLoader,
36
41
  SpanDescendantsDataLoader,
37
42
  SpanProjectsDataLoader,
38
43
  TokenCountDataLoader,
39
- TraceRowIdsDataLoader,
44
+ TraceByTraceIdsDataLoader,
45
+ TraceRootSpansDataLoader,
40
46
  UserRolesDataLoader,
41
47
  UsersDataLoader,
42
48
  )
@@ -68,12 +74,19 @@ class DataLoaders:
68
74
  latency_ms_quantile: LatencyMsQuantileDataLoader
69
75
  min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
70
76
  record_counts: RecordCountDataLoader
77
+ session_first_inputs: SessionIODataLoader
78
+ session_last_outputs: SessionIODataLoader
79
+ session_num_traces: SessionNumTracesDataLoader
80
+ session_num_traces_with_error: SessionNumTracesWithErrorDataLoader
81
+ session_token_usages: SessionTokenUsagesDataLoader
82
+ session_trace_latency_ms_quantile: SessionTraceLatencyMsQuantileDataLoader
71
83
  span_annotations: SpanAnnotationsDataLoader
72
84
  span_dataset_examples: SpanDatasetExamplesDataLoader
73
85
  span_descendants: SpanDescendantsDataLoader
74
86
  span_projects: SpanProjectsDataLoader
75
87
  token_counts: TokenCountDataLoader
76
- trace_row_ids: TraceRowIdsDataLoader
88
+ trace_by_trace_ids: TraceByTraceIdsDataLoader
89
+ trace_root_spans: TraceRootSpansDataLoader
77
90
  project_by_name: ProjectByNameDataLoader
78
91
  users: UsersDataLoader
79
92
  user_roles: UserRolesDataLoader
@@ -19,12 +19,18 @@ from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLo
19
19
  from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
20
20
  from .project_by_name import ProjectByNameDataLoader
21
21
  from .record_counts import RecordCountCache, RecordCountDataLoader
22
+ from .session_io import SessionIODataLoader
23
+ from .session_num_traces import SessionNumTracesDataLoader
24
+ from .session_num_traces_with_error import SessionNumTracesWithErrorDataLoader
25
+ from .session_token_usages import SessionTokenUsagesDataLoader
26
+ from .session_trace_latency_ms_quantile import SessionTraceLatencyMsQuantileDataLoader
22
27
  from .span_annotations import SpanAnnotationsDataLoader
23
28
  from .span_dataset_examples import SpanDatasetExamplesDataLoader
24
29
  from .span_descendants import SpanDescendantsDataLoader
25
30
  from .span_projects import SpanProjectsDataLoader
26
31
  from .token_counts import TokenCountCache, TokenCountDataLoader
27
- from .trace_row_ids import TraceRowIdsDataLoader
32
+ from .trace_by_trace_ids import TraceByTraceIdsDataLoader
33
+ from .trace_root_spans import TraceRootSpansDataLoader
28
34
  from .user_roles import UserRolesDataLoader
29
35
  from .users import UsersDataLoader
30
36
 
@@ -45,11 +51,17 @@ __all__ = [
45
51
  "LatencyMsQuantileDataLoader",
46
52
  "MinStartOrMaxEndTimeDataLoader",
47
53
  "RecordCountDataLoader",
54
+ "SessionIODataLoader",
55
+ "SessionNumTracesDataLoader",
56
+ "SessionNumTracesWithErrorDataLoader",
57
+ "SessionTokenUsagesDataLoader",
58
+ "SessionTraceLatencyMsQuantileDataLoader",
48
59
  "SpanDatasetExamplesDataLoader",
49
60
  "SpanDescendantsDataLoader",
50
61
  "SpanProjectsDataLoader",
51
62
  "TokenCountDataLoader",
52
- "TraceRowIdsDataLoader",
63
+ "TraceByTraceIdsDataLoader",
64
+ "TraceRootSpansDataLoader",
53
65
  "ProjectByNameDataLoader",
54
66
  "SpanAnnotationsDataLoader",
55
67
  "UsersDataLoader",
@@ -0,0 +1,75 @@
1
+ from functools import cached_property
2
+ from typing import Literal, Optional, cast
3
+
4
+ from openinference.semconv.trace import SpanAttributes
5
+ from sqlalchemy import Select, func, select
6
+ from strawberry.dataloader import DataLoader
7
+ from typing_extensions import TypeAlias, assert_never
8
+
9
+ from phoenix.db import models
10
+ from phoenix.server.types import DbSessionFactory
11
+ from phoenix.trace.schemas import MimeType, SpanIOValue
12
+
13
+ Key: TypeAlias = int
14
+ Result: TypeAlias = Optional[SpanIOValue]
15
+
16
+ Kind = Literal["first_input", "last_output"]
17
+
18
+
19
+ class SessionIODataLoader(DataLoader[Key, Result]):
20
+ def __init__(self, db: DbSessionFactory, kind: Kind) -> None:
21
+ super().__init__(load_fn=self._load_fn)
22
+ self._db = db
23
+ self._kind = kind
24
+
25
+ @cached_property
26
+ def _subq(self) -> Select[tuple[Optional[int], str, str, int]]:
27
+ stmt = (
28
+ select(models.Trace.project_session_rowid.label("id_"))
29
+ .join_from(models.Span, models.Trace)
30
+ .where(models.Span.parent_id.is_(None))
31
+ )
32
+ if self._kind == "first_input":
33
+ stmt = stmt.add_columns(
34
+ models.Span.attributes[INPUT_VALUE].label("value"),
35
+ models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"),
36
+ func.row_number()
37
+ .over(
38
+ partition_by=models.Trace.project_session_rowid,
39
+ order_by=[models.Trace.start_time.asc(), models.Trace.id.asc()],
40
+ )
41
+ .label("rank"),
42
+ )
43
+ elif self._kind == "last_output":
44
+ stmt = stmt.add_columns(
45
+ models.Span.attributes[OUTPUT_VALUE].label("value"),
46
+ models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"),
47
+ func.row_number()
48
+ .over(
49
+ partition_by=models.Trace.project_session_rowid,
50
+ order_by=[models.Trace.start_time.desc(), models.Trace.id.desc()],
51
+ )
52
+ .label("rank"),
53
+ )
54
+ else:
55
+ assert_never(self._kind)
56
+ return cast(Select[tuple[Optional[int], str, str, int]], stmt)
57
+
58
+ def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]:
59
+ subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery()
60
+ return select(subq.c.id_, subq.c.value, subq.c.mime_type).filter_by(rank=1)
61
+
62
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
63
+ async with self._db() as session:
64
+ result: dict[Key, SpanIOValue] = {
65
+ id_: SpanIOValue(value=value, mime_type=MimeType(mime_type))
66
+ async for id_, value, mime_type in await session.stream(self._stmt(*keys))
67
+ if id_ is not None
68
+ }
69
+ return [result.get(key) for key in keys]
70
+
71
+
72
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
73
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
74
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
75
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".")
@@ -0,0 +1,30 @@
1
+ from sqlalchemy import func, select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ Key: TypeAlias = int
9
+ Result: TypeAlias = int
10
+
11
+
12
+ class SessionNumTracesDataLoader(DataLoader[Key, Result]):
13
+ def __init__(self, db: DbSessionFactory) -> None:
14
+ super().__init__(load_fn=self._load_fn)
15
+ self._db = db
16
+
17
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
18
+ stmt = (
19
+ select(
20
+ models.Trace.project_session_rowid.label("id_"),
21
+ func.count(models.Trace.id).label("value"),
22
+ )
23
+ .group_by(models.Trace.project_session_rowid)
24
+ .where(models.Trace.project_session_rowid.in_(keys))
25
+ )
26
+ async with self._db() as session:
27
+ result: dict[Key, int] = {
28
+ id_: value async for id_, value in await session.stream(stmt) if id_ is not None
29
+ }
30
+ return [result.get(key, 0) for key in keys]
@@ -0,0 +1,32 @@
1
+ from sqlalchemy import func, select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ Key: TypeAlias = int
9
+ Result: TypeAlias = int
10
+
11
+
12
+ class SessionNumTracesWithErrorDataLoader(DataLoader[Key, Result]):
13
+ def __init__(self, db: DbSessionFactory) -> None:
14
+ super().__init__(load_fn=self._load_fn)
15
+ self._db = db
16
+
17
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
18
+ stmt = (
19
+ select(
20
+ models.Trace.project_session_rowid.label("id_"),
21
+ func.count(models.Trace.id).label("value"),
22
+ )
23
+ .join(models.Span)
24
+ .group_by(models.Trace.project_session_rowid)
25
+ .where(models.Span.cumulative_error_count > 0)
26
+ .where(models.Trace.project_session_rowid.in_(keys))
27
+ )
28
+ async with self._db() as session:
29
+ result: dict[Key, int] = {
30
+ id_: value async for id_, value in await session.stream(stmt) if id_ is not None
31
+ }
32
+ return [result.get(key, 0) for key in keys]
@@ -0,0 +1,41 @@
1
+ from sqlalchemy import func, select
2
+ from sqlalchemy.sql.functions import coalesce
3
+ from strawberry.dataloader import DataLoader
4
+ from typing_extensions import TypeAlias
5
+
6
+ from phoenix.db import models
7
+ from phoenix.server.types import DbSessionFactory
8
+ from phoenix.trace.schemas import TokenUsage
9
+
10
+ Key: TypeAlias = int
11
+ Result: TypeAlias = TokenUsage
12
+
13
+
14
+ class SessionTokenUsagesDataLoader(DataLoader[Key, Result]):
15
+ def __init__(self, db: DbSessionFactory) -> None:
16
+ super().__init__(load_fn=self._load_fn)
17
+ self._db = db
18
+
19
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
20
+ stmt = (
21
+ select(
22
+ models.Trace.project_session_rowid.label("id_"),
23
+ func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label(
24
+ "prompt"
25
+ ),
26
+ func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label(
27
+ "completion"
28
+ ),
29
+ )
30
+ .join_from(models.Span, models.Trace)
31
+ .where(models.Span.parent_id.is_(None))
32
+ .where(models.Trace.project_session_rowid.in_(keys))
33
+ .group_by(models.Trace.project_session_rowid)
34
+ )
35
+ async with self._db() as session:
36
+ result: dict[Key, TokenUsage] = {
37
+ id_: TokenUsage(prompt=prompt, completion=completion)
38
+ async for id_, prompt, completion in await session.stream(stmt)
39
+ if id_ is not None
40
+ }
41
+ return [result.get(key, TokenUsage()) for key in keys]
@@ -0,0 +1,55 @@
1
+ from collections import defaultdict
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ from aioitertools.itertools import groupby
6
+ from sqlalchemy import select
7
+ from strawberry.dataloader import DataLoader
8
+ from typing_extensions import TypeAlias
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.types import DbSessionFactory
12
+
13
+ SessionId: TypeAlias = int
14
+ Probability: TypeAlias = float
15
+ QuantileValue: TypeAlias = float
16
+
17
+ Key: TypeAlias = tuple[SessionId, Probability]
18
+ Result: TypeAlias = Optional[QuantileValue]
19
+ ResultPosition: TypeAlias = int
20
+
21
+ DEFAULT_VALUE: Result = None
22
+
23
+
24
+ class SessionTraceLatencyMsQuantileDataLoader(DataLoader[Key, Result]):
25
+ def __init__(self, db: DbSessionFactory) -> None:
26
+ super().__init__(load_fn=self._load_fn)
27
+ self._db = db
28
+
29
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
30
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
31
+ argument_position_map: defaultdict[
32
+ SessionId, defaultdict[Probability, list[ResultPosition]]
33
+ ] = defaultdict(lambda: defaultdict(list))
34
+ session_rowids = {session_id for session_id, _ in keys}
35
+ for position, (session_id, probability) in enumerate(keys):
36
+ argument_position_map[session_id][probability].append(position)
37
+ stmt = (
38
+ select(
39
+ models.Trace.project_session_rowid,
40
+ models.Trace.latency_ms,
41
+ )
42
+ .where(models.Trace.project_session_rowid.in_(session_rowids))
43
+ .order_by(models.Trace.project_session_rowid)
44
+ )
45
+ async with self._db() as session:
46
+ data = await session.stream(stmt)
47
+ async for project_session_rowid, group in groupby(
48
+ data, lambda row: row.project_session_rowid
49
+ ):
50
+ session_latencies = [row.latency_ms for row in group]
51
+ for probability, positions in argument_position_map[project_session_rowid].items():
52
+ quantile_value = np.quantile(session_latencies, probability)
53
+ for position in positions:
54
+ results[position] = quantile_value
55
+ return results
@@ -0,0 +1,25 @@
1
+ from typing import List, Optional
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ Key: TypeAlias = str
11
+ Result: TypeAlias = Optional[models.Trace]
12
+
13
+
14
+ class TraceByTraceIdsDataLoader(DataLoader[Key, Result]):
15
+ def __init__(self, db: DbSessionFactory) -> None:
16
+ super().__init__(load_fn=self._load_fn)
17
+ self._db = db
18
+
19
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
20
+ stmt = select(models.Trace).where(models.Trace.trace_id.in_(keys))
21
+ async with self._db() as session:
22
+ result: dict[Key, models.Trace] = {
23
+ trace.trace_id: trace async for trace in await session.stream_scalars(stmt)
24
+ }
25
+ return [result.get(trace_id) for trace_id in keys]
@@ -0,0 +1,32 @@
1
+ from typing import List, Optional
2
+
3
+ from sqlalchemy import select
4
+ from sqlalchemy.orm import contains_eager
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ Key: TypeAlias = int
12
+ Result: TypeAlias = Optional[models.Span]
13
+
14
+
15
+ class TraceRootSpansDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
21
+ stmt = (
22
+ select(models.Span)
23
+ .join(models.Trace)
24
+ .where(models.Span.parent_id.is_(None))
25
+ .where(models.Trace.id.in_(keys))
26
+ .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
27
+ )
28
+ async with self._db() as session:
29
+ result: dict[Key, models.Span] = {
30
+ span.trace_rowid: span async for span in await session.stream_scalars(stmt)
31
+ }
32
+ return [result.get(key) for key in keys]
@@ -0,0 +1,29 @@
1
+ from enum import Enum, auto
2
+
3
+ import strawberry
4
+ from typing_extensions import assert_never
5
+
6
+ from phoenix.server.api.types.pagination import CursorSortColumnDataType
7
+ from phoenix.server.api.types.SortDir import SortDir
8
+
9
+
10
+ @strawberry.enum
11
+ class ProjectSessionColumn(Enum):
12
+ startTime = auto()
13
+ endTime = auto()
14
+ tokenCountTotal = auto()
15
+ numTraces = auto()
16
+
17
+ @property
18
+ def data_type(self) -> CursorSortColumnDataType:
19
+ if self is ProjectSessionColumn.tokenCountTotal or self is ProjectSessionColumn.numTraces:
20
+ return CursorSortColumnDataType.INT
21
+ if self is ProjectSessionColumn.startTime or self is ProjectSessionColumn.endTime:
22
+ return CursorSortColumnDataType.DATETIME
23
+ assert_never(self)
24
+
25
+
26
+ @strawberry.input(description="The sort key and direction for ProjectSession connections.")
27
+ class ProjectSessionSort:
28
+ col: ProjectSessionColumn
29
+ dir: SortDir
@@ -38,10 +38,20 @@ class ProjectMutationMixin:
38
38
  project_id = from_global_id_with_expected_type(
39
39
  global_id=input.id, expected_type_name="Project"
40
40
  )
41
- delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
41
+ delete_statement = (
42
+ delete(models.Trace)
43
+ .where(models.Trace.project_rowid == project_id)
44
+ .returning(models.Trace.project_session_rowid)
45
+ )
42
46
  if input.end_time:
43
47
  delete_statement = delete_statement.where(models.Trace.start_time < input.end_time)
44
48
  async with info.context.db() as session:
45
- await session.execute(delete_statement)
49
+ deleted_trace_project_session_ids = await session.scalars(delete_statement)
50
+ if deleted_trace_project_session_ids:
51
+ await session.execute(
52
+ delete(models.ProjectSession).where(
53
+ models.ProjectSession.id.in_(set(deleted_trace_project_session_ids))
54
+ )
55
+ )
46
56
  info.context.event_queue.put(SpanDeleteEvent((project_id,)))
47
57
  return Query()
@@ -78,10 +78,11 @@ from phoenix.server.api.types.pagination import (
78
78
  connection_from_list,
79
79
  )
80
80
  from phoenix.server.api.types.Project import Project
81
+ from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
81
82
  from phoenix.server.api.types.SortDir import SortDir
82
83
  from phoenix.server.api.types.Span import Span, to_gql_span
83
84
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
84
- from phoenix.server.api.types.Trace import Trace
85
+ from phoenix.server.api.types.Trace import to_gql_trace
85
86
  from phoenix.server.api.types.User import User, to_gql_user
86
87
  from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
87
88
  from phoenix.server.api.types.UserRole import UserRole
@@ -445,17 +446,12 @@ class Query:
445
446
  gradient_end_color=project.gradient_end_color,
446
447
  )
447
448
  elif type_name == "Trace":
448
- trace_stmt = select(
449
- models.Trace.id,
450
- models.Trace.project_rowid,
451
- ).where(models.Trace.id == node_id)
449
+ trace_stmt = select(models.Trace).filter_by(id=node_id)
452
450
  async with info.context.db() as session:
453
- trace = (await session.execute(trace_stmt)).first()
451
+ trace = await session.scalar(trace_stmt)
454
452
  if trace is None:
455
453
  raise NotFound(f"Unknown trace: {id}")
456
- return Trace(
457
- id_attr=trace.id, trace_id=trace.trace_id, project_rowid=trace.project_rowid
458
- )
454
+ return to_gql_trace(trace)
459
455
  elif type_name == Span.__name__:
460
456
  span_stmt = (
461
457
  select(models.Span)
@@ -544,6 +540,15 @@ class Query:
544
540
  ):
545
541
  raise NotFound(f"Unknown user: {id}")
546
542
  return to_gql_user(user)
543
+ elif type_name == ProjectSession.__name__:
544
+ async with info.context.db() as session:
545
+ if not (
546
+ project_session := await session.scalar(
547
+ select(models.ProjectSession).filter_by(id=node_id)
548
+ )
549
+ ):
550
+ raise NotFound(f"Unknown user: {id}")
551
+ return to_gql_project_session(project_session)
547
552
  raise NotFound(f"Unknown node type: {type_name}")
548
553
 
549
554
  @strawberry.field
@@ -20,7 +20,7 @@ from phoenix.server.api.types.pagination import (
20
20
  CursorString,
21
21
  connection_from_list,
22
22
  )
23
- from phoenix.server.api.types.Trace import Trace
23
+ from phoenix.server.api.types.Trace import Trace, to_gql_trace
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from phoenix.server.api.types.DatasetExample import DatasetExample
@@ -61,11 +61,10 @@ class ExperimentRun(Node):
61
61
  async def trace(self, info: Info) -> Optional[Trace]:
62
62
  if not self.trace_id:
63
63
  return None
64
- dataloader = info.context.data_loaders.trace_row_ids
64
+ dataloader = info.context.data_loaders.trace_by_trace_ids
65
65
  if (trace := await dataloader.load(self.trace_id)) is None:
66
66
  return None
67
- trace_rowid, project_rowid = trace
68
- return Trace(id_attr=trace_rowid, trace_id=self.trace_id, project_rowid=project_rowid)
67
+ return to_gql_trace(trace)
69
68
 
70
69
  @strawberry.field
71
70
  async def example(
@@ -8,7 +8,7 @@ from strawberry.scalars import JSON
8
8
 
9
9
  from phoenix.db import models
10
10
  from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
11
- from phoenix.server.api.types.Trace import Trace
11
+ from phoenix.server.api.types.Trace import Trace, to_gql_trace
12
12
 
13
13
 
14
14
  @strawberry.type
@@ -29,11 +29,10 @@ class ExperimentRunAnnotation(Node):
29
29
  async def trace(self, info: Info) -> Optional[Trace]:
30
30
  if not self.trace_id:
31
31
  return None
32
- dataloader = info.context.data_loaders.trace_row_ids
32
+ dataloader = info.context.data_loaders.trace_by_trace_ids
33
33
  if (trace := await dataloader.load(self.trace_id)) is None:
34
34
  return None
35
- trace_row_id, project_row_id = trace
36
- return Trace(id_attr=trace_row_id, trace_id=self.trace_id, project_rowid=project_row_id)
35
+ return to_gql_trace(trace)
37
36
 
38
37
 
39
38
  def to_gql_experiment_run_annotation(