arize-phoenix 6.2.0__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (46) hide show
  1. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.0.dist-info}/METADATA +4 -4
  2. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.0.dist-info}/RECORD +44 -31
  3. phoenix/db/insertion/span.py +65 -30
  4. phoenix/db/migrations/data_migration_scripts/__init__.py +0 -0
  5. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  6. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  7. phoenix/db/models.py +27 -0
  8. phoenix/server/api/context.py +15 -2
  9. phoenix/server/api/dataloaders/__init__.py +14 -2
  10. phoenix/server/api/dataloaders/session_io.py +75 -0
  11. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  12. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  13. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  14. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  15. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  16. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  17. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  18. phoenix/server/api/mutations/project_mutations.py +12 -2
  19. phoenix/server/api/queries.py +14 -9
  20. phoenix/server/api/types/ExperimentRun.py +3 -4
  21. phoenix/server/api/types/ExperimentRunAnnotation.py +3 -4
  22. phoenix/server/api/types/Project.py +150 -12
  23. phoenix/server/api/types/ProjectSession.py +139 -0
  24. phoenix/server/api/types/Span.py +6 -19
  25. phoenix/server/api/types/SpanIOValue.py +15 -0
  26. phoenix/server/api/types/TokenUsage.py +11 -0
  27. phoenix/server/api/types/Trace.py +59 -2
  28. phoenix/server/app.py +15 -2
  29. phoenix/server/static/.vite/manifest.json +36 -36
  30. phoenix/server/static/assets/{components-DrwPLMB6.js → components-DKH6AzJw.js} +276 -276
  31. phoenix/server/static/assets/index-DLV87qiO.js +93 -0
  32. phoenix/server/static/assets/{pages-Cmqh2i4E.js → pages-CVY3Nv4Z.js} +611 -290
  33. phoenix/server/static/assets/{vendor-Cdrqqth8.js → vendor-Cb3zlNNd.js} +45 -45
  34. phoenix/server/static/assets/{vendor-arizeai-BSCL03yQ.js → vendor-arizeai-Buo4e1A6.js} +2 -2
  35. phoenix/server/static/assets/{vendor-codemirror-Utqu7Snw.js → vendor-codemirror-BuAQiUVf.js} +1 -1
  36. phoenix/server/static/assets/{vendor-recharts-CNNUvc5T.js → vendor-recharts-Cl9dK5tC.js} +1 -1
  37. phoenix/server/static/assets/{vendor-shiki-B6bHerDK.js → vendor-shiki-CazYUixL.js} +1 -1
  38. phoenix/trace/fixtures.py +8 -0
  39. phoenix/trace/schemas.py +16 -0
  40. phoenix/version.py +1 -1
  41. phoenix/server/api/dataloaders/trace_row_ids.py +0 -33
  42. phoenix/server/static/assets/index-CTN-OfBU.js +0 -93
  43. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.0.dist-info}/WHEEL +0 -0
  44. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.0.dist-info}/entry_points.txt +0 -0
  45. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  46. {arize_phoenix-6.2.0.dist-info → arize_phoenix-7.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,23 +1,26 @@
1
1
  import operator
2
2
  from datetime import datetime
3
- from typing import (
4
- Any,
5
- ClassVar,
6
- Optional,
7
- )
3
+ from typing import Any, ClassVar, Optional
8
4
 
9
5
  import strawberry
10
6
  from aioitertools.itertools import islice
11
- from sqlalchemy import and_, desc, distinct, select
7
+ from openinference.semconv.trace import SpanAttributes
8
+ from sqlalchemy import and_, desc, distinct, func, or_, select
12
9
  from sqlalchemy.orm import contains_eager
10
+ from sqlalchemy.sql.elements import ColumnElement
13
11
  from sqlalchemy.sql.expression import tuple_
14
12
  from strawberry import ID, UNSET
15
13
  from strawberry.relay import Connection, Node, NodeID
16
14
  from strawberry.types import Info
15
+ from typing_extensions import assert_never
17
16
 
18
17
  from phoenix.datetime_utils import right_open_time_range
19
18
  from phoenix.db import models
20
19
  from phoenix.server.api.context import Context
20
+ from phoenix.server.api.input_types.ProjectSessionSort import (
21
+ ProjectSessionColumn,
22
+ ProjectSessionSort,
23
+ )
21
24
  from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortConfig
22
25
  from phoenix.server.api.input_types.TimeRange import TimeRange
23
26
  from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
@@ -28,9 +31,10 @@ from phoenix.server.api.types.pagination import (
28
31
  CursorString,
29
32
  connection_from_cursors_and_nodes,
30
33
  )
34
+ from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
31
35
  from phoenix.server.api.types.SortDir import SortDir
32
36
  from phoenix.server.api.types.Span import Span, to_gql_span
33
- from phoenix.server.api.types.Trace import Trace
37
+ from phoenix.server.api.types.Trace import Trace, to_gql_trace
34
38
  from phoenix.server.api.types.ValidationResult import ValidationResult
35
39
  from phoenix.trace.dsl import SpanFilter
36
40
 
@@ -127,7 +131,13 @@ class Project(Node):
127
131
  time_range: Optional[TimeRange] = UNSET,
128
132
  ) -> Optional[float]:
129
133
  return await info.context.data_loaders.latency_ms_quantile.load(
130
- ("trace", self.id_attr, time_range, None, probability),
134
+ (
135
+ "trace",
136
+ self.id_attr,
137
+ time_range,
138
+ None,
139
+ probability,
140
+ ),
131
141
  )
132
142
 
133
143
  @strawberry.field
@@ -139,20 +149,26 @@ class Project(Node):
139
149
  filter_condition: Optional[str] = UNSET,
140
150
  ) -> Optional[float]:
141
151
  return await info.context.data_loaders.latency_ms_quantile.load(
142
- ("span", self.id_attr, time_range, filter_condition, probability),
152
+ (
153
+ "span",
154
+ self.id_attr,
155
+ time_range,
156
+ filter_condition,
157
+ probability,
158
+ ),
143
159
  )
144
160
 
145
161
  @strawberry.field
146
162
  async def trace(self, trace_id: ID, info: Info[Context, None]) -> Optional[Trace]:
147
163
  stmt = (
148
- select(models.Trace.id)
164
+ select(models.Trace)
149
165
  .where(models.Trace.trace_id == str(trace_id))
150
166
  .where(models.Trace.project_rowid == self.id_attr)
151
167
  )
152
168
  async with info.context.db() as session:
153
- if (id_attr := await session.scalar(stmt)) is None:
169
+ if (trace := await session.scalar(stmt)) is None:
154
170
  return None
155
- return Trace(id_attr=id_attr, trace_id=trace_id, project_rowid=self.id_attr)
171
+ return to_gql_trace(trace)
156
172
 
157
173
  @strawberry.field
158
174
  async def spans(
@@ -241,6 +257,124 @@ class Project(Node):
241
257
  has_next_page=has_next_page,
242
258
  )
243
259
 
260
+ @strawberry.field
261
+ async def sessions(
262
+ self,
263
+ info: Info[Context, None],
264
+ time_range: Optional[TimeRange] = UNSET,
265
+ first: Optional[int] = 50,
266
+ after: Optional[CursorString] = UNSET,
267
+ sort: Optional[ProjectSessionSort] = UNSET,
268
+ filter_io_substring: Optional[str] = UNSET,
269
+ ) -> Connection[ProjectSession]:
270
+ table = models.ProjectSession
271
+ stmt = select(table).filter_by(project_id=self.id_attr)
272
+ if time_range:
273
+ if time_range.start:
274
+ stmt = stmt.where(time_range.start <= table.start_time)
275
+ if time_range.end:
276
+ stmt = stmt.where(table.start_time < time_range.end)
277
+ if filter_io_substring:
278
+ filter_subq = (
279
+ stmt.with_only_columns(distinct(table.id).label("id"))
280
+ .join_from(table, models.Trace)
281
+ .join_from(models.Trace, models.Span)
282
+ .where(models.Span.parent_id.is_(None))
283
+ .where(
284
+ or_(
285
+ models.TextContains(
286
+ models.Span.attributes[INPUT_VALUE].as_string(),
287
+ filter_io_substring,
288
+ ),
289
+ models.TextContains(
290
+ models.Span.attributes[OUTPUT_VALUE].as_string(),
291
+ filter_io_substring,
292
+ ),
293
+ )
294
+ )
295
+ ).subquery()
296
+ stmt = stmt.join(filter_subq, table.id == filter_subq.c.id)
297
+ if sort:
298
+ key: ColumnElement[Any]
299
+ if sort.col is ProjectSessionColumn.startTime:
300
+ key = table.start_time.label("key")
301
+ elif sort.col is ProjectSessionColumn.endTime:
302
+ key = table.end_time.label("key")
303
+ elif (
304
+ sort.col is ProjectSessionColumn.tokenCountTotal
305
+ or sort.col is ProjectSessionColumn.numTraces
306
+ ):
307
+ if sort.col is ProjectSessionColumn.tokenCountTotal:
308
+ sort_subq = (
309
+ select(
310
+ models.Trace.project_session_rowid.label("id"),
311
+ func.sum(models.Span.cumulative_llm_token_count_total).label("key"),
312
+ )
313
+ .join_from(models.Trace, models.Span)
314
+ .where(models.Span.parent_id.is_(None))
315
+ .group_by(models.Trace.project_session_rowid)
316
+ ).subquery()
317
+ elif sort.col is ProjectSessionColumn.numTraces:
318
+ sort_subq = (
319
+ select(
320
+ models.Trace.project_session_rowid.label("id"),
321
+ func.count(models.Trace.id).label("key"),
322
+ ).group_by(models.Trace.project_session_rowid)
323
+ ).subquery()
324
+ else:
325
+ assert_never(sort.col)
326
+ key = sort_subq.c.key
327
+ stmt = stmt.join(sort_subq, table.id == sort_subq.c.id)
328
+ else:
329
+ assert_never(sort.col)
330
+ stmt = stmt.add_columns(key)
331
+ if sort.dir is SortDir.asc:
332
+ stmt = stmt.order_by(key.asc(), table.id.asc())
333
+ else:
334
+ stmt = stmt.order_by(key.desc(), table.id.desc())
335
+ if after:
336
+ cursor = Cursor.from_string(after)
337
+ assert cursor.sort_column is not None
338
+ compare = operator.lt if sort.dir is SortDir.desc else operator.gt
339
+ stmt = stmt.where(
340
+ compare(
341
+ tuple_(key, table.id),
342
+ (cursor.sort_column.value, cursor.rowid),
343
+ )
344
+ )
345
+ else:
346
+ stmt = stmt.order_by(table.id.desc())
347
+ if after:
348
+ cursor = Cursor.from_string(after)
349
+ stmt = stmt.where(table.id < cursor.rowid)
350
+ if first:
351
+ stmt = stmt.limit(
352
+ first + 1 # over-fetch by one to determine whether there's a next page
353
+ )
354
+ cursors_and_nodes = []
355
+ async with info.context.db() as session:
356
+ records = await session.stream(stmt)
357
+ async for record in islice(records, first):
358
+ project_session = record[0]
359
+ cursor = Cursor(rowid=project_session.id)
360
+ if sort:
361
+ assert len(record) > 1
362
+ cursor.sort_column = CursorSortColumn(
363
+ type=sort.col.data_type,
364
+ value=record[1],
365
+ )
366
+ cursors_and_nodes.append((cursor, to_gql_project_session(project_session)))
367
+ has_next_page = True
368
+ try:
369
+ await records.__anext__()
370
+ except StopAsyncIteration:
371
+ has_next_page = False
372
+ return connection_from_cursors_and_nodes(
373
+ cursors_and_nodes,
374
+ has_previous_page=False,
375
+ has_next_page=has_next_page,
376
+ )
377
+
244
378
  @strawberry.field(
245
379
  description="Names of all available annotations for traces. "
246
380
  "(The list contains no duplicates.)"
@@ -363,3 +497,7 @@ def to_gql_project(project: models.Project) -> Project:
363
497
  gradient_start_color=project.gradient_start_color,
364
498
  gradient_end_color=project.gradient_end_color,
365
499
  )
500
+
501
+
502
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
503
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
@@ -0,0 +1,139 @@
1
+ from datetime import datetime
2
+ from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Type
3
+
4
+ import strawberry
5
+ from openinference.semconv.trace import SpanAttributes
6
+ from sqlalchemy import select
7
+ from strawberry import UNSET, Info, Private, lazy
8
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.api.context import Context
12
+ from phoenix.server.api.types.MimeType import MimeType
13
+ from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
14
+ from phoenix.server.api.types.SpanIOValue import SpanIOValue
15
+ from phoenix.server.api.types.TokenUsage import TokenUsage
16
+
17
+ if TYPE_CHECKING:
18
+ from phoenix.server.api.types.Trace import Trace
19
+
20
+
21
+ @strawberry.type
22
+ class ProjectSession(Node):
23
+ _table: ClassVar[Type[models.ProjectSession]] = models.ProjectSession
24
+ id_attr: NodeID[int]
25
+ project_rowid: Private[int]
26
+ session_id: str
27
+ start_time: datetime
28
+ end_time: datetime
29
+
30
+ @strawberry.field
31
+ async def project_id(self) -> GlobalID:
32
+ from phoenix.server.api.types.Project import Project
33
+
34
+ return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
35
+
36
+ @strawberry.field
37
+ async def num_traces(
38
+ self,
39
+ info: Info[Context, None],
40
+ ) -> int:
41
+ return await info.context.data_loaders.session_num_traces.load(self.id_attr)
42
+
43
+ @strawberry.field
44
+ async def num_traces_with_error(
45
+ self,
46
+ info: Info[Context, None],
47
+ ) -> int:
48
+ return await info.context.data_loaders.session_num_traces_with_error.load(self.id_attr)
49
+
50
+ @strawberry.field
51
+ async def first_input(
52
+ self,
53
+ info: Info[Context, None],
54
+ ) -> Optional[SpanIOValue]:
55
+ record = await info.context.data_loaders.session_first_inputs.load(self.id_attr)
56
+ if record is None:
57
+ return None
58
+ return SpanIOValue(
59
+ mime_type=MimeType(record.mime_type.value),
60
+ value=record.value,
61
+ )
62
+
63
+ @strawberry.field
64
+ async def last_output(
65
+ self,
66
+ info: Info[Context, None],
67
+ ) -> Optional[SpanIOValue]:
68
+ record = await info.context.data_loaders.session_last_outputs.load(self.id_attr)
69
+ if record is None:
70
+ return None
71
+ return SpanIOValue(
72
+ mime_type=MimeType(record.mime_type.value),
73
+ value=record.value,
74
+ )
75
+
76
+ @strawberry.field
77
+ async def token_usage(
78
+ self,
79
+ info: Info[Context, None],
80
+ ) -> TokenUsage:
81
+ usage = await info.context.data_loaders.session_token_usages.load(self.id_attr)
82
+ return TokenUsage(
83
+ prompt=usage.prompt,
84
+ completion=usage.completion,
85
+ )
86
+
87
+ @strawberry.field
88
+ async def traces(
89
+ self,
90
+ info: Info[Context, None],
91
+ first: Optional[int] = 50,
92
+ last: Optional[int] = UNSET,
93
+ after: Optional[CursorString] = UNSET,
94
+ before: Optional[CursorString] = UNSET,
95
+ ) -> Connection[Annotated["Trace", lazy(".Trace")]]:
96
+ from phoenix.server.api.types.Trace import to_gql_trace
97
+
98
+ args = ConnectionArgs(
99
+ first=first,
100
+ after=after if isinstance(after, CursorString) else None,
101
+ last=last,
102
+ before=before if isinstance(before, CursorString) else None,
103
+ )
104
+ stmt = (
105
+ select(models.Trace)
106
+ .filter_by(project_session_rowid=self.id_attr)
107
+ .order_by(models.Trace.start_time)
108
+ .limit(first)
109
+ )
110
+ async with info.context.db() as session:
111
+ traces = await session.stream_scalars(stmt)
112
+ data = [to_gql_trace(trace) async for trace in traces]
113
+ return connection_from_list(data=data, args=args)
114
+
115
+ @strawberry.field
116
+ async def trace_latency_ms_quantile(
117
+ self,
118
+ info: Info[Context, None],
119
+ probability: float,
120
+ ) -> Optional[float]:
121
+ return await info.context.data_loaders.session_trace_latency_ms_quantile.load(
122
+ (self.id_attr, probability)
123
+ )
124
+
125
+
126
+ def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
127
+ return ProjectSession(
128
+ id_attr=project_session.id,
129
+ session_id=project_session.session_id,
130
+ start_time=project_session.start_time,
131
+ project_rowid=project_session.project_id,
132
+ end_time=project_session.end_time,
133
+ )
134
+
135
+
136
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
137
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
138
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
139
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".")
@@ -24,17 +24,16 @@ from phoenix.server.api.input_types.SpanAnnotationSort import (
24
24
  SpanAnnotationColumn,
25
25
  SpanAnnotationSort,
26
26
  )
27
+ from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
28
+ from phoenix.server.api.types.Evaluation import DocumentEvaluation
29
+ from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
27
30
  from phoenix.server.api.types.GenerativeProvider import GenerativeProvider
31
+ from phoenix.server.api.types.MimeType import MimeType
28
32
  from phoenix.server.api.types.SortDir import SortDir
29
- from phoenix.server.api.types.SpanAnnotation import to_gql_span_annotation
33
+ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
34
+ from phoenix.server.api.types.SpanIOValue import SpanIOValue
30
35
  from phoenix.trace.attributes import get_attribute_value
31
36
 
32
- from .DocumentRetrievalMetrics import DocumentRetrievalMetrics
33
- from .Evaluation import DocumentEvaluation
34
- from .ExampleRevisionInterface import ExampleRevision
35
- from .MimeType import MimeType
36
- from .SpanAnnotation import SpanAnnotation
37
-
38
37
  if TYPE_CHECKING:
39
38
  from phoenix.server.api.types.Project import Project
40
39
 
@@ -71,18 +70,6 @@ class SpanContext:
71
70
  span_id: ID
72
71
 
73
72
 
74
- @strawberry.type
75
- class SpanIOValue:
76
- mime_type: MimeType
77
- value: str
78
-
79
- @strawberry.field(
80
- description="Truncate value up to `chars` characters, appending '...' if truncated.",
81
- ) # type: ignore
82
- def truncated_value(self, chars: int = 100) -> str:
83
- return f"{self.value[: max(0, chars - 3)]}..." if len(self.value) > chars else self.value
84
-
85
-
86
73
  @strawberry.enum
87
74
  class SpanStatusCode(Enum):
88
75
  OK = "OK"
@@ -0,0 +1,15 @@
1
+ import strawberry
2
+
3
+ from phoenix.server.api.types.MimeType import MimeType
4
+
5
+
6
+ @strawberry.type
7
+ class SpanIOValue:
8
+ mime_type: MimeType
9
+ value: str
10
+
11
+ @strawberry.field(
12
+ description="Truncate value up to `chars` characters, appending '...' if truncated.",
13
+ ) # type: ignore
14
+ def truncated_value(self, chars: int = 100) -> str:
15
+ return f"{self.value[: max(0, chars - 3)]}..." if len(self.value) > chars else self.value
@@ -0,0 +1,11 @@
1
+ import strawberry
2
+
3
+
4
+ @strawberry.type
5
+ class TokenUsage:
6
+ prompt: int = 0
7
+ completion: int = 0
8
+
9
+ @strawberry.field
10
+ async def total(self) -> int:
11
+ return self.prompt + self.completion
@@ -1,11 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional
3
+ from datetime import datetime
4
+ from typing import TYPE_CHECKING, Annotated, Optional, Union
4
5
 
5
6
  import strawberry
7
+ from openinference.semconv.trace import SpanAttributes
6
8
  from sqlalchemy import desc, select
7
9
  from sqlalchemy.orm import contains_eager
8
- from strawberry import UNSET, Private
10
+ from strawberry import UNSET, Private, lazy
9
11
  from strawberry.relay import Connection, GlobalID, Node, NodeID
10
12
  from strawberry.types import Info
11
13
 
@@ -21,12 +23,18 @@ from phoenix.server.api.types.SortDir import SortDir
21
23
  from phoenix.server.api.types.Span import Span, to_gql_span
22
24
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
23
25
 
26
+ if TYPE_CHECKING:
27
+ from phoenix.server.api.types.ProjectSession import ProjectSession
28
+
24
29
 
25
30
  @strawberry.type
26
31
  class Trace(Node):
27
32
  id_attr: NodeID[int]
28
33
  project_rowid: Private[int]
34
+ project_session_rowid: Private[Optional[int]]
29
35
  trace_id: str
36
+ start_time: datetime
37
+ end_time: datetime
30
38
 
31
39
  @strawberry.field
32
40
  async def project_id(self) -> GlobalID:
@@ -34,6 +42,40 @@ class Trace(Node):
34
42
 
35
43
  return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
36
44
 
45
+ @strawberry.field
46
+ async def project_session_id(self) -> Optional[GlobalID]:
47
+ if self.project_session_rowid is None:
48
+ return None
49
+ from phoenix.server.api.types.ProjectSession import ProjectSession
50
+
51
+ return GlobalID(type_name=ProjectSession.__name__, node_id=str(self.project_session_rowid))
52
+
53
+ @strawberry.field
54
+ async def session(
55
+ self,
56
+ info: Info[Context, None],
57
+ ) -> Union[Annotated["ProjectSession", lazy(".ProjectSession")], None]:
58
+ if self.project_session_rowid is None:
59
+ return None
60
+ from phoenix.server.api.types.ProjectSession import to_gql_project_session
61
+
62
+ stmt = select(models.ProjectSession).filter_by(id=self.project_session_rowid)
63
+ async with info.context.db() as session:
64
+ project_session = await session.scalar(stmt)
65
+ if project_session is None:
66
+ return None
67
+ return to_gql_project_session(project_session)
68
+
69
+ @strawberry.field
70
+ async def root_span(
71
+ self,
72
+ info: Info[Context, None],
73
+ ) -> Optional[Span]:
74
+ span = await info.context.data_loaders.trace_root_spans.load(self.id_attr)
75
+ if span is None:
76
+ return None
77
+ return to_gql_span(span)
78
+
37
79
  @strawberry.field
38
80
  async def spans(
39
81
  self,
@@ -82,3 +124,18 @@ class Trace(Node):
82
124
  stmt = stmt.order_by(models.TraceAnnotation.created_at.desc())
83
125
  annotations = await session.scalars(stmt)
84
126
  return [to_gql_trace_annotation(annotation) for annotation in annotations]
127
+
128
+
129
+ def to_gql_trace(trace: models.Trace) -> Trace:
130
+ return Trace(
131
+ id_attr=trace.id,
132
+ project_rowid=trace.project_rowid,
133
+ project_session_rowid=trace.project_session_rowid,
134
+ trace_id=trace.trace_id,
135
+ start_time=trace.start_time,
136
+ end_time=trace.end_time,
137
+ )
138
+
139
+
140
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
141
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
phoenix/server/app.py CHANGED
@@ -87,12 +87,18 @@ from phoenix.server.api.dataloaders import (
87
87
  MinStartOrMaxEndTimeDataLoader,
88
88
  ProjectByNameDataLoader,
89
89
  RecordCountDataLoader,
90
+ SessionIODataLoader,
91
+ SessionNumTracesDataLoader,
92
+ SessionNumTracesWithErrorDataLoader,
93
+ SessionTokenUsagesDataLoader,
94
+ SessionTraceLatencyMsQuantileDataLoader,
90
95
  SpanAnnotationsDataLoader,
91
96
  SpanDatasetExamplesDataLoader,
92
97
  SpanDescendantsDataLoader,
93
98
  SpanProjectsDataLoader,
94
99
  TokenCountDataLoader,
95
- TraceRowIdsDataLoader,
100
+ TraceByTraceIdsDataLoader,
101
+ TraceRootSpansDataLoader,
96
102
  UserRolesDataLoader,
97
103
  UsersDataLoader,
98
104
  )
@@ -609,6 +615,12 @@ def create_graphql_router(
609
615
  db,
610
616
  cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
611
617
  ),
618
+ session_first_inputs=SessionIODataLoader(db, "first_input"),
619
+ session_last_outputs=SessionIODataLoader(db, "last_output"),
620
+ session_num_traces=SessionNumTracesDataLoader(db),
621
+ session_num_traces_with_error=SessionNumTracesWithErrorDataLoader(db),
622
+ session_token_usages=SessionTokenUsagesDataLoader(db),
623
+ session_trace_latency_ms_quantile=SessionTraceLatencyMsQuantileDataLoader(db),
612
624
  span_annotations=SpanAnnotationsDataLoader(db),
613
625
  span_dataset_examples=SpanDatasetExamplesDataLoader(db),
614
626
  span_descendants=SpanDescendantsDataLoader(db),
@@ -617,7 +629,8 @@ def create_graphql_router(
617
629
  db,
618
630
  cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None,
619
631
  ),
620
- trace_row_ids=TraceRowIdsDataLoader(db),
632
+ trace_by_trace_ids=TraceByTraceIdsDataLoader(db),
633
+ trace_root_spans=TraceRootSpansDataLoader(db),
621
634
  project_by_name=ProjectByNameDataLoader(db),
622
635
  users=UsersDataLoader(db),
623
636
  user_roles=UserRolesDataLoader(db),