arize-phoenix 8.2.2__py3-none-any.whl → 8.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (33) hide show
  1. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/METADATA +4 -3
  2. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/RECORD +33 -32
  3. phoenix/config.py +32 -5
  4. phoenix/db/models.py +1 -25
  5. phoenix/server/api/context.py +6 -2
  6. phoenix/server/api/dataloaders/__init__.py +4 -2
  7. phoenix/server/api/dataloaders/num_child_spans.py +35 -0
  8. phoenix/server/api/dataloaders/{span_fields.py → table_fields.py} +21 -19
  9. phoenix/server/api/helpers/playground_clients.py +4 -0
  10. phoenix/server/api/helpers/prompts/models.py +1 -0
  11. phoenix/server/api/queries.py +80 -20
  12. phoenix/server/api/types/Experiment.py +2 -4
  13. phoenix/server/api/types/ExperimentRun.py +2 -2
  14. phoenix/server/api/types/ExperimentRunAnnotation.py +2 -2
  15. phoenix/server/api/types/Project.py +67 -38
  16. phoenix/server/api/types/ProjectSession.py +2 -2
  17. phoenix/server/api/types/Span.py +35 -2
  18. phoenix/server/api/types/Trace.py +98 -30
  19. phoenix/server/app.py +6 -2
  20. phoenix/server/static/.vite/manifest.json +40 -40
  21. phoenix/server/static/assets/{components-MeFAEc1z.js → components-BgFPI6sn.js} +166 -167
  22. phoenix/server/static/assets/{index-BSRuZ-_J.js → index-CIkk8uHr.js} +9 -9
  23. phoenix/server/static/assets/{pages-NrL4hb9q.js → pages-CmDiPH1A.js} +660 -623
  24. phoenix/server/static/assets/{vendor-Cqfydjep.js → vendor-CRRq3WgM.js} +116 -116
  25. phoenix/server/static/assets/{vendor-arizeai-WnerlUPN.js → vendor-arizeai-Dq64X_0o.js} +1 -1
  26. phoenix/server/static/assets/{vendor-codemirror-D-ZZKLFq.js → vendor-codemirror-C1oevlym.js} +1 -1
  27. phoenix/server/static/assets/{vendor-recharts-KY97ZPfK.js → vendor-recharts-CPj01S89.js} +1 -1
  28. phoenix/server/static/assets/{vendor-shiki-D5K9GnFn.js → vendor-shiki-aY7rz1pm.js} +1 -1
  29. phoenix/version.py +1 -1
  30. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/WHEEL +0 -0
  31. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/entry_points.txt +0 -0
  32. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/licenses/IP_NOTICE +0 -0
  33. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -8,35 +8,36 @@ from typing_extensions import TypeAlias
8
8
  from phoenix.db import models
9
9
  from phoenix.server.types import DbSessionFactory
10
10
 
11
- SpanRowId: TypeAlias = int
11
+ RowId: TypeAlias = int
12
12
 
13
- Key: TypeAlias = tuple[SpanRowId, QueryableAttribute[Any]]
13
+ Key: TypeAlias = tuple[RowId, QueryableAttribute[Any]]
14
14
  Result: TypeAlias = Any
15
15
 
16
-
17
16
  _ResultColumnPosition: TypeAlias = int
18
17
  _AttrStrIdentifier: TypeAlias = str
19
18
 
20
19
 
21
- class SpanFieldsDataLoader(DataLoader[Key, Result]):
22
- def __init__(self, db: DbSessionFactory) -> None:
20
+ class TableFieldsDataLoader(DataLoader[Key, Result]):
21
+ def __init__(self, db: DbSessionFactory, table: type[models.Base]) -> None:
23
22
  super().__init__(load_fn=self._load_fn)
24
23
  self._db = db
24
+ self._table = table
25
25
 
26
26
  async def _load_fn(self, keys: Iterable[Key]) -> list[Union[Result, ValueError]]:
27
- result: dict[tuple[SpanRowId, _AttrStrIdentifier], Result] = {}
28
- stmt, attr_strs = _get_stmt(keys)
27
+ result: dict[tuple[RowId, _AttrStrIdentifier], Result] = {}
28
+ stmt, attr_strs = _get_stmt(keys, self._table)
29
29
  async with self._db() as session:
30
30
  data = await session.stream(stmt)
31
31
  async for row in data:
32
- span_rowid: SpanRowId = row[0] # models.Span's primary key
32
+ rowid: RowId = row[0] # models.Span's primary key
33
33
  for i, value in enumerate(row[1:]):
34
- result[span_rowid, attr_strs[i]] = value
35
- return [result.get((span_rowid, str(attr))) for span_rowid, attr in keys]
34
+ result[rowid, attr_strs[i]] = value
35
+ return [result.get((rowid, str(attr))) for rowid, attr in keys]
36
36
 
37
37
 
38
38
  def _get_stmt(
39
- keys: Iterable[Key],
39
+ keys: Iterable[tuple[RowId, QueryableAttribute[Any]]],
40
+ table: type[models.Base],
40
41
  ) -> tuple[
41
42
  Select[Any],
42
43
  dict[_ResultColumnPosition, _AttrStrIdentifier],
@@ -52,7 +53,8 @@ def _get_stmt(
52
53
 
53
54
  Args:
54
55
  keys (list[Key]): A list of tuples, where each tuple contains an integer ID, i.e. the
55
- primary key of models.Span, and a QueryableAttribute.
56
+ primary key of table, and a QueryableAttribute.
57
+ table (models.Base): The table to query.
56
58
 
57
59
  Returns:
58
60
  tuple: A tuple containing:
@@ -61,16 +63,16 @@ def _get_stmt(
61
63
  at the second column (because the first column is the span's primary key)--in the
62
64
  result to the attribute's string identifier.
63
65
  """
64
- span_rowids: set[SpanRowId] = set()
66
+ rowids: set[RowId] = set()
65
67
  attrs: dict[_AttrStrIdentifier, QueryableAttribute[Any]] = {}
66
68
  joins = set()
67
- for span_rowid, attr in keys:
68
- span_rowids.add(span_rowid)
69
+ for rowid, attr in keys:
70
+ rowids.add(rowid)
69
71
  attrs[str(attr)] = attr
70
- if (entity := attr.parent.entity) is not models.Span:
72
+ if (entity := attr.parent.entity) is not table:
71
73
  joins.add(entity)
72
- stmt = select(models.Span.id).where(models.Span.id.in_(span_rowids))
73
- for table in joins:
74
- stmt = stmt.join(table)
74
+ stmt = select(table.id).where(table.id.in_(rowids))
75
+ for other_table in joins:
76
+ stmt = stmt.join(other_table)
75
77
  identifiers, columns = zip(*attrs.items())
76
78
  return stmt.add_columns(*columns), dict(enumerate(identifiers))
@@ -806,6 +806,10 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
806
806
  raise NotImplementedError
807
807
  elif isinstance(event, anthropic_streaming._types.CitationEvent):
808
808
  raise NotImplementedError
809
+ elif isinstance(event, anthropic_streaming._types.ThinkingEvent):
810
+ raise NotImplementedError
811
+ elif isinstance(event, anthropic_streaming._types.SignatureEvent):
812
+ raise NotImplementedError
809
813
  else:
810
814
  assert_never(event)
811
815
 
@@ -348,6 +348,7 @@ class AnthropicToolDefinition(PromptModel):
348
348
  class PromptOpenAIInvocationParametersContent(PromptModel):
349
349
  temperature: float = UNDEFINED
350
350
  max_tokens: int = UNDEFINED
351
+ max_completion_tokens: int = UNDEFINED
351
352
  frequency_penalty: float = UNDEFINED
352
353
  presence_penalty: float = UNDEFINED
353
354
  top_p: float = UNDEFINED
@@ -1,19 +1,21 @@
1
1
  from collections import defaultdict
2
2
  from datetime import datetime
3
- from typing import Optional, Union
3
+ from typing import Iterable, Iterator, Optional, Union, cast
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
7
  import strawberry
8
- from sqlalchemy import and_, distinct, func, select
8
+ from sqlalchemy import and_, distinct, func, select, text
9
9
  from sqlalchemy.orm import joinedload
10
10
  from starlette.authentication import UnauthenticatedUser
11
11
  from strawberry import ID, UNSET
12
12
  from strawberry.relay import Connection, GlobalID, Node
13
13
  from strawberry.types import Info
14
- from typing_extensions import Annotated, TypeAlias
14
+ from typing_extensions import Annotated, TypeAlias, assert_never
15
15
 
16
+ from phoenix.config import ENV_PHOENIX_SQL_DATABASE_SCHEMA, getenv
16
17
  from phoenix.db import enums, models
18
+ from phoenix.db.helpers import SupportedSQLDialect
17
19
  from phoenix.db.models import DatasetExample as OrmExample
18
20
  from phoenix.db.models import DatasetExampleRevision as OrmRevision
19
21
  from phoenix.db.models import DatasetVersion as OrmVersion
@@ -66,7 +68,7 @@ from phoenix.server.api.types.PromptVersion import PromptVersion, to_gql_prompt_
66
68
  from phoenix.server.api.types.SortDir import SortDir
67
69
  from phoenix.server.api.types.Span import Span
68
70
  from phoenix.server.api.types.SystemApiKey import SystemApiKey
69
- from phoenix.server.api.types.Trace import to_gql_trace
71
+ from phoenix.server.api.types.Trace import Trace
70
72
  from phoenix.server.api.types.User import User, to_gql_user
71
73
  from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
72
74
  from phoenix.server.api.types.UserRole import UserRole
@@ -81,6 +83,12 @@ class ModelsInput:
81
83
  model_name: Optional[str] = None
82
84
 
83
85
 
86
+ @strawberry.type
87
+ class DbTableStats:
88
+ table_name: str
89
+ num_bytes: int
90
+
91
+
84
92
  @strawberry.type
85
93
  class Query:
86
94
  @strawberry.field
@@ -236,10 +244,8 @@ class Query:
236
244
  projects = await session.stream_scalars(stmt)
237
245
  data = [
238
246
  Project(
239
- id_attr=project.id,
240
- name=project.name,
241
- gradient_start_color=project.gradient_start_color,
242
- gradient_end_color=project.gradient_end_color,
247
+ project_rowid=project.id,
248
+ db_project=project,
243
249
  )
244
250
  async for project in projects
245
251
  ]
@@ -448,21 +454,14 @@ class Query:
448
454
  embedding_dimension = info.context.model.embedding_dimensions[node_id]
449
455
  return to_gql_embedding_dimension(node_id, embedding_dimension)
450
456
  elif type_name == "Project":
451
- project_stmt = select(
452
- models.Project.id,
453
- models.Project.name,
454
- models.Project.gradient_start_color,
455
- models.Project.gradient_end_color,
456
- ).where(models.Project.id == node_id)
457
+ project_stmt = select(models.Project).filter_by(id=node_id)
457
458
  async with info.context.db() as session:
458
- project = (await session.execute(project_stmt)).first()
459
+ project = await session.scalar(project_stmt)
459
460
  if project is None:
460
461
  raise NotFound(f"Unknown project: {id}")
461
462
  return Project(
462
- id_attr=project.id,
463
- name=project.name,
464
- gradient_start_color=project.gradient_start_color,
465
- gradient_end_color=project.gradient_end_color,
463
+ project_rowid=project.id,
464
+ db_project=project,
466
465
  )
467
466
  elif type_name == "Trace":
468
467
  trace_stmt = select(models.Trace).filter_by(id=node_id)
@@ -470,7 +469,7 @@ class Query:
470
469
  trace = await session.scalar(trace_stmt)
471
470
  if trace is None:
472
471
  raise NotFound(f"Unknown trace: {id}")
473
- return to_gql_trace(trace)
472
+ return Trace(trace_rowid=trace.id, db_trace=trace)
474
473
  elif type_name == Span.__name__:
475
474
  span_stmt = (
476
475
  select(models.Span)
@@ -789,3 +788,64 @@ class Query:
789
788
  return to_gql_clusters(
790
789
  clustered_events=clustered_events,
791
790
  )
791
+
792
+ @strawberry.field
793
+ async def db_table_stats(
794
+ self,
795
+ info: Info[Context, None],
796
+ ) -> list[DbTableStats]:
797
+ if info.context.db.dialect is SupportedSQLDialect.SQLITE:
798
+ stmt = text("SELECT name, sum(pgsize) FROM dbstat group by name;")
799
+ async with info.context.db() as session:
800
+ stats = cast(Iterable[tuple[str, int]], await session.execute(stmt))
801
+ stats = _consolidate_sqlite_db_table_stats(stats)
802
+ elif info.context.db.dialect is SupportedSQLDialect.POSTGRESQL:
803
+ stmt = text(f"""\
804
+ SELECT c.relname, pg_total_relation_size(c.oid)
805
+ FROM pg_class as c
806
+ INNER JOIN pg_namespace as n ON n.oid = c.relnamespace
807
+ WHERE c.relkind = 'r'
808
+ AND n.nspname = '{getenv(ENV_PHOENIX_SQL_DATABASE_SCHEMA) or "public"}';
809
+ """)
810
+ async with info.context.db() as session:
811
+ stats = cast(Iterable[tuple[str, int]], await session.execute(stmt))
812
+ else:
813
+ assert_never(info.context.db.dialect)
814
+ return [
815
+ DbTableStats(table_name=table_name, num_bytes=num_bytes)
816
+ for table_name, num_bytes in stats
817
+ ]
818
+
819
+
820
+ def _consolidate_sqlite_db_table_stats(
821
+ stats: Iterable[tuple[str, int]],
822
+ ) -> Iterator[tuple[str, int]]:
823
+ """
824
+ Consolidate SQLite database stats by combining indexes with their respective tables.
825
+ """
826
+ aggregate: dict[str, int] = {}
827
+ for name, num_bytes in stats:
828
+ # Skip internal SQLite tables and indexes.
829
+ if name.startswith("ix_") or name.startswith("sqlite_"):
830
+ continue
831
+ aggregate[name] = num_bytes
832
+ for name, num_bytes in stats:
833
+ # Combine indexes with their respective tables.
834
+ for flag in ["sqlite_autoindex_", "ix_"]:
835
+ if not name.startswith(flag):
836
+ continue
837
+ if parent := _longest_matching_prefix(name[len(flag) :], aggregate.keys()):
838
+ aggregate[parent] += num_bytes
839
+ break
840
+ yield from aggregate.items()
841
+
842
+
843
+ def _longest_matching_prefix(s: str, prefixes: Iterable[str]) -> str:
844
+ """
845
+ Return the longest prefix of s that matches any of the given prefixes.
846
+ """
847
+ longest = ""
848
+ for prefix in prefixes:
849
+ if s.startswith(prefix) and len(prefix) > len(longest):
850
+ longest = prefix
851
+ return longest
@@ -122,10 +122,8 @@ class Experiment(Node):
122
122
  return None
123
123
 
124
124
  return Project(
125
- id_attr=db_project.id,
126
- name=db_project.name,
127
- gradient_start_color=db_project.gradient_start_color,
128
- gradient_end_color=db_project.gradient_end_color,
125
+ project_rowid=db_project.id,
126
+ db_project=db_project,
129
127
  )
130
128
 
131
129
  @strawberry.field
@@ -20,7 +20,7 @@ from phoenix.server.api.types.pagination import (
20
20
  CursorString,
21
21
  connection_from_list,
22
22
  )
23
- from phoenix.server.api.types.Trace import Trace, to_gql_trace
23
+ from phoenix.server.api.types.Trace import Trace
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from phoenix.server.api.types.DatasetExample import DatasetExample
@@ -64,7 +64,7 @@ class ExperimentRun(Node):
64
64
  dataloader = info.context.data_loaders.trace_by_trace_ids
65
65
  if (trace := await dataloader.load(self.trace_id)) is None:
66
66
  return None
67
- return to_gql_trace(trace)
67
+ return Trace(trace_rowid=trace.id, db_trace=trace)
68
68
 
69
69
  @strawberry.field
70
70
  async def example(
@@ -8,7 +8,7 @@ from strawberry.scalars import JSON
8
8
 
9
9
  from phoenix.db import models
10
10
  from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
11
- from phoenix.server.api.types.Trace import Trace, to_gql_trace
11
+ from phoenix.server.api.types.Trace import Trace
12
12
 
13
13
 
14
14
  @strawberry.type
@@ -32,7 +32,7 @@ class ExperimentRunAnnotation(Node):
32
32
  dataloader = info.context.data_loaders.trace_by_trace_ids
33
33
  if (trace := await dataloader.load(self.trace_id)) is None:
34
34
  return None
35
- return to_gql_trace(trace)
35
+ return Trace(trace_rowid=trace.id, db_trace=trace)
36
36
 
37
37
 
38
38
  def to_gql_experiment_run_annotation(
@@ -8,7 +8,7 @@ from openinference.semconv.trace import SpanAttributes
8
8
  from sqlalchemy import desc, distinct, func, or_, select
9
9
  from sqlalchemy.sql.elements import ColumnElement
10
10
  from sqlalchemy.sql.expression import tuple_
11
- from strawberry import ID, UNSET
11
+ from strawberry import ID, UNSET, Private
12
12
  from strawberry.relay import Connection, Node, NodeID
13
13
  from strawberry.types import Info
14
14
  from typing_extensions import assert_never
@@ -33,7 +33,7 @@ from phoenix.server.api.types.pagination import (
33
33
  from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
34
34
  from phoenix.server.api.types.SortDir import SortDir
35
35
  from phoenix.server.api.types.Span import Span
36
- from phoenix.server.api.types.Trace import Trace, to_gql_trace
36
+ from phoenix.server.api.types.Trace import Trace
37
37
  from phoenix.server.api.types.ValidationResult import ValidationResult
38
38
  from phoenix.trace.dsl import SpanFilter
39
39
 
@@ -41,10 +41,51 @@ from phoenix.trace.dsl import SpanFilter
41
41
  @strawberry.type
42
42
  class Project(Node):
43
43
  _table: ClassVar[type[models.Base]] = models.Project
44
- id_attr: NodeID[int]
45
- name: str
46
- gradient_start_color: str
47
- gradient_end_color: str
44
+ project_rowid: NodeID[int]
45
+ db_project: Private[models.Project] = UNSET
46
+
47
+ def __post_init__(self) -> None:
48
+ if self.db_project and self.project_rowid != self.db_project.id:
49
+ raise ValueError("Project ID mismatch")
50
+
51
+ @strawberry.field
52
+ async def name(
53
+ self,
54
+ info: Info[Context, None],
55
+ ) -> str:
56
+ if self.db_project:
57
+ name = self.db_project.name
58
+ else:
59
+ name = await info.context.data_loaders.project_fields.load(
60
+ (self.project_rowid, models.Project.name),
61
+ )
62
+ return name
63
+
64
+ @strawberry.field
65
+ async def gradient_start_color(
66
+ self,
67
+ info: Info[Context, None],
68
+ ) -> str:
69
+ if self.db_project:
70
+ gradient_start_color = self.db_project.gradient_start_color
71
+ else:
72
+ gradient_start_color = await info.context.data_loaders.project_fields.load(
73
+ (self.project_rowid, models.Project.gradient_start_color),
74
+ )
75
+ return gradient_start_color
76
+
77
+ @strawberry.field
78
+ async def gradient_end_color(
79
+ self,
80
+ info: Info[Context, None],
81
+ ) -> str:
82
+ if self.db_project:
83
+ gradient_end_color = self.db_project.gradient_end_color
84
+ else:
85
+ gradient_end_color = await info.context.data_loaders.project_fields.load(
86
+ (self.project_rowid, models.Project.gradient_end_color),
87
+ )
88
+ return gradient_end_color
48
89
 
49
90
  @strawberry.field
50
91
  async def start_time(
@@ -52,7 +93,7 @@ class Project(Node):
52
93
  info: Info[Context, None],
53
94
  ) -> Optional[datetime]:
54
95
  start_time = await info.context.data_loaders.min_start_or_max_end_times.load(
55
- (self.id_attr, "start"),
96
+ (self.project_rowid, "start"),
56
97
  )
57
98
  start_time, _ = right_open_time_range(start_time, None)
58
99
  return start_time
@@ -63,7 +104,7 @@ class Project(Node):
63
104
  info: Info[Context, None],
64
105
  ) -> Optional[datetime]:
65
106
  end_time = await info.context.data_loaders.min_start_or_max_end_times.load(
66
- (self.id_attr, "end"),
107
+ (self.project_rowid, "end"),
67
108
  )
68
109
  _, end_time = right_open_time_range(None, end_time)
69
110
  return end_time
@@ -76,7 +117,7 @@ class Project(Node):
76
117
  filter_condition: Optional[str] = UNSET,
77
118
  ) -> int:
78
119
  return await info.context.data_loaders.record_counts.load(
79
- ("span", self.id_attr, time_range, filter_condition),
120
+ ("span", self.project_rowid, time_range, filter_condition),
80
121
  )
81
122
 
82
123
  @strawberry.field
@@ -86,7 +127,7 @@ class Project(Node):
86
127
  time_range: Optional[TimeRange] = UNSET,
87
128
  ) -> int:
88
129
  return await info.context.data_loaders.record_counts.load(
89
- ("trace", self.id_attr, time_range, None),
130
+ ("trace", self.project_rowid, time_range, None),
90
131
  )
91
132
 
92
133
  @strawberry.field
@@ -97,7 +138,7 @@ class Project(Node):
97
138
  filter_condition: Optional[str] = UNSET,
98
139
  ) -> int:
99
140
  return await info.context.data_loaders.token_counts.load(
100
- ("total", self.id_attr, time_range, filter_condition),
141
+ ("total", self.project_rowid, time_range, filter_condition),
101
142
  )
102
143
 
103
144
  @strawberry.field
@@ -108,7 +149,7 @@ class Project(Node):
108
149
  filter_condition: Optional[str] = UNSET,
109
150
  ) -> int:
110
151
  return await info.context.data_loaders.token_counts.load(
111
- ("prompt", self.id_attr, time_range, filter_condition),
152
+ ("prompt", self.project_rowid, time_range, filter_condition),
112
153
  )
113
154
 
114
155
  @strawberry.field
@@ -119,7 +160,7 @@ class Project(Node):
119
160
  filter_condition: Optional[str] = UNSET,
120
161
  ) -> int:
121
162
  return await info.context.data_loaders.token_counts.load(
122
- ("completion", self.id_attr, time_range, filter_condition),
163
+ ("completion", self.project_rowid, time_range, filter_condition),
123
164
  )
124
165
 
125
166
  @strawberry.field
@@ -132,7 +173,7 @@ class Project(Node):
132
173
  return await info.context.data_loaders.latency_ms_quantile.load(
133
174
  (
134
175
  "trace",
135
- self.id_attr,
176
+ self.project_rowid,
136
177
  time_range,
137
178
  None,
138
179
  probability,
@@ -150,7 +191,7 @@ class Project(Node):
150
191
  return await info.context.data_loaders.latency_ms_quantile.load(
151
192
  (
152
193
  "span",
153
- self.id_attr,
194
+ self.project_rowid,
154
195
  time_range,
155
196
  filter_condition,
156
197
  probability,
@@ -162,12 +203,12 @@ class Project(Node):
162
203
  stmt = (
163
204
  select(models.Trace)
164
205
  .where(models.Trace.trace_id == str(trace_id))
165
- .where(models.Trace.project_rowid == self.id_attr)
206
+ .where(models.Trace.project_rowid == self.project_rowid)
166
207
  )
167
208
  async with info.context.db() as session:
168
209
  if (trace := await session.scalar(stmt)) is None:
169
210
  return None
170
- return to_gql_trace(trace)
211
+ return Trace(trace_rowid=trace.id, db_trace=trace)
171
212
 
172
213
  @strawberry.field
173
214
  async def spans(
@@ -185,7 +226,7 @@ class Project(Node):
185
226
  stmt = (
186
227
  select(models.Span.id)
187
228
  .join(models.Trace)
188
- .where(models.Trace.project_rowid == self.id_attr)
229
+ .where(models.Trace.project_rowid == self.project_rowid)
189
230
  )
190
231
  if time_range:
191
232
  if time_range.start:
@@ -264,7 +305,7 @@ class Project(Node):
264
305
  filter_io_substring: Optional[str] = UNSET,
265
306
  ) -> Connection[ProjectSession]:
266
307
  table = models.ProjectSession
267
- stmt = select(table).filter_by(project_id=self.id_attr)
308
+ stmt = select(table).filter_by(project_id=self.project_rowid)
268
309
  if time_range:
269
310
  if time_range.start:
270
311
  stmt = stmt.where(time_range.start <= table.start_time)
@@ -382,7 +423,7 @@ class Project(Node):
382
423
  stmt = (
383
424
  select(distinct(models.TraceAnnotation.name))
384
425
  .join(models.Trace)
385
- .where(models.Trace.project_rowid == self.id_attr)
426
+ .where(models.Trace.project_rowid == self.project_rowid)
386
427
  )
387
428
  async with info.context.db() as session:
388
429
  return list(await session.scalars(stmt))
@@ -399,7 +440,7 @@ class Project(Node):
399
440
  select(distinct(models.SpanAnnotation.name))
400
441
  .join(models.Span)
401
442
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
402
- .where(models.Trace.project_rowid == self.id_attr)
443
+ .where(models.Trace.project_rowid == self.project_rowid)
403
444
  )
404
445
  async with info.context.db() as session:
405
446
  return list(await session.scalars(stmt))
@@ -416,7 +457,7 @@ class Project(Node):
416
457
  select(distinct(models.DocumentAnnotation.name))
417
458
  .join(models.Span)
418
459
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
419
- .where(models.Trace.project_rowid == self.id_attr)
460
+ .where(models.Trace.project_rowid == self.project_rowid)
420
461
  .where(models.DocumentAnnotation.annotator_kind == "LLM")
421
462
  )
422
463
  if span_id:
@@ -432,7 +473,7 @@ class Project(Node):
432
473
  time_range: Optional[TimeRange] = UNSET,
433
474
  ) -> Optional[AnnotationSummary]:
434
475
  return await info.context.data_loaders.annotation_summaries.load(
435
- ("trace", self.id_attr, time_range, None, annotation_name),
476
+ ("trace", self.project_rowid, time_range, None, annotation_name),
436
477
  )
437
478
 
438
479
  @strawberry.field
@@ -444,7 +485,7 @@ class Project(Node):
444
485
  filter_condition: Optional[str] = UNSET,
445
486
  ) -> Optional[AnnotationSummary]:
446
487
  return await info.context.data_loaders.annotation_summaries.load(
447
- ("span", self.id_attr, time_range, filter_condition, annotation_name),
488
+ ("span", self.project_rowid, time_range, filter_condition, annotation_name),
448
489
  )
449
490
 
450
491
  @strawberry.field
@@ -456,7 +497,7 @@ class Project(Node):
456
497
  filter_condition: Optional[str] = UNSET,
457
498
  ) -> Optional[DocumentEvaluationSummary]:
458
499
  return await info.context.data_loaders.document_evaluation_summaries.load(
459
- (self.id_attr, time_range, filter_condition, evaluation_name),
500
+ (self.project_rowid, time_range, filter_condition, evaluation_name),
460
501
  )
461
502
 
462
503
  @strawberry.field
@@ -464,7 +505,7 @@ class Project(Node):
464
505
  self,
465
506
  info: Info[Context, None],
466
507
  ) -> Optional[datetime]:
467
- return info.context.last_updated_at.get(self._table, self.id_attr)
508
+ return info.context.last_updated_at.get(self._table, self.project_rowid)
468
509
 
469
510
  @strawberry.field
470
511
  async def validate_span_filter_condition(self, condition: str) -> ValidationResult:
@@ -483,17 +524,5 @@ class Project(Node):
483
524
  )
484
525
 
485
526
 
486
- def to_gql_project(project: models.Project) -> Project:
487
- """
488
- Converts an ORM project to a GraphQL Project.
489
- """
490
- return Project(
491
- id_attr=project.id,
492
- name=project.name,
493
- gradient_start_color=project.gradient_start_color,
494
- gradient_end_color=project.gradient_end_color,
495
- )
496
-
497
-
498
527
  INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
499
528
  OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
@@ -93,7 +93,7 @@ class ProjectSession(Node):
93
93
  after: Optional[CursorString] = UNSET,
94
94
  before: Optional[CursorString] = UNSET,
95
95
  ) -> Connection[Annotated["Trace", lazy(".Trace")]]:
96
- from phoenix.server.api.types.Trace import to_gql_trace
96
+ from phoenix.server.api.types.Trace import Trace
97
97
 
98
98
  args = ConnectionArgs(
99
99
  first=first,
@@ -109,7 +109,7 @@ class ProjectSession(Node):
109
109
  )
110
110
  async with info.context.db() as session:
111
111
  traces = await session.stream_scalars(stmt)
112
- data = [to_gql_trace(trace) async for trace in traces]
112
+ data = [Trace(trace_rowid=trace.id, db_trace=trace) async for trace in traces]
113
113
  return connection_from_list(data=data, args=args)
114
114
 
115
115
  @strawberry.field
@@ -37,6 +37,7 @@ from phoenix.trace.attributes import get_attribute_value
37
37
 
38
38
  if TYPE_CHECKING:
39
39
  from phoenix.server.api.types.Project import Project
40
+ from phoenix.server.api.types.Trace import Trace
40
41
 
41
42
 
42
43
  @strawberry.enum
@@ -216,6 +217,34 @@ class Span(Node):
216
217
  )
217
218
  return SpanKind(value)
218
219
 
220
+ @strawberry.field
221
+ async def span_id(
222
+ self,
223
+ info: Info[Context, None],
224
+ ) -> ID:
225
+ if self.db_span:
226
+ span_id = self.db_span.span_id
227
+ else:
228
+ span_id = await info.context.data_loaders.span_fields.load(
229
+ (self.span_rowid, models.Span.span_id),
230
+ )
231
+ return ID(span_id)
232
+
233
+ @strawberry.field
234
+ async def trace(
235
+ self,
236
+ info: Info[Context, None],
237
+ ) -> Annotated["Trace", strawberry.lazy(".Trace")]:
238
+ if self.db_span:
239
+ trace_rowid = self.db_span.trace_rowid
240
+ else:
241
+ trace_rowid = await info.context.data_loaders.span_fields.load(
242
+ (self.span_rowid, models.Span.trace_rowid),
243
+ )
244
+ from phoenix.server.api.types.Trace import Trace
245
+
246
+ return Trace(trace_rowid=trace_rowid)
247
+
219
248
  @strawberry.field
220
249
  async def context(
221
250
  self,
@@ -508,6 +537,10 @@ class Span(Node):
508
537
  (self.span_rowid, evaluation_name or None, num_documents),
509
538
  )
510
539
 
540
+ @strawberry.field
541
+ async def num_child_spans(self, info: Info[Context, None]) -> int:
542
+ return await info.context.data_loaders.num_child_spans.load(self.span_rowid)
543
+
511
544
  @strawberry.field(
512
545
  description="All descendant spans (children, grandchildren, etc.)",
513
546
  ) # type: ignore
@@ -561,11 +594,11 @@ class Span(Node):
561
594
  ) -> Annotated[
562
595
  "Project", strawberry.lazy("phoenix.server.api.types.Project")
563
596
  ]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
564
- from phoenix.server.api.types.Project import to_gql_project
597
+ from phoenix.server.api.types.Project import Project
565
598
 
566
599
  span_id = self.span_rowid
567
600
  project = await info.context.data_loaders.span_projects.load(span_id)
568
- return to_gql_project(project)
601
+ return Project(project_rowid=project.id, db_project=project)
569
602
 
570
603
  @strawberry.field(description="Indicates if the span is contained in any dataset") # type: ignore
571
604
  async def contained_in_dataset(