arize-phoenix 8.2.2__py3-none-any.whl → 8.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (27) hide show
  1. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.3.0.dist-info}/METADATA +3 -3
  2. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.3.0.dist-info}/RECORD +27 -27
  3. phoenix/config.py +32 -5
  4. phoenix/db/models.py +1 -25
  5. phoenix/server/api/context.py +4 -2
  6. phoenix/server/api/dataloaders/__init__.py +2 -2
  7. phoenix/server/api/dataloaders/{span_fields.py → table_fields.py} +21 -19
  8. phoenix/server/api/helpers/playground_clients.py +4 -0
  9. phoenix/server/api/helpers/prompts/models.py +1 -0
  10. phoenix/server/api/queries.py +8 -17
  11. phoenix/server/api/types/Experiment.py +2 -4
  12. phoenix/server/api/types/ExperimentRun.py +2 -2
  13. phoenix/server/api/types/ExperimentRunAnnotation.py +2 -2
  14. phoenix/server/api/types/Project.py +67 -38
  15. phoenix/server/api/types/ProjectSession.py +2 -2
  16. phoenix/server/api/types/Span.py +31 -2
  17. phoenix/server/api/types/Trace.py +98 -30
  18. phoenix/server/app.py +4 -2
  19. phoenix/server/static/.vite/manifest.json +9 -9
  20. phoenix/server/static/assets/{components-MeFAEc1z.js → components-T5K9z49d.js} +3 -3
  21. phoenix/server/static/assets/{index-BSRuZ-_J.js → index-DvHwFF8e.js} +2 -2
  22. phoenix/server/static/assets/{pages-NrL4hb9q.js → pages-CY3ZXSHj.js} +375 -356
  23. phoenix/version.py +1 -1
  24. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.3.0.dist-info}/WHEEL +0 -0
  25. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.3.0.dist-info}/entry_points.txt +0 -0
  26. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.3.0.dist-info}/licenses/IP_NOTICE +0 -0
  27. {arize_phoenix-8.2.2.dist-info → arize_phoenix-8.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -122,10 +122,8 @@ class Experiment(Node):
122
122
  return None
123
123
 
124
124
  return Project(
125
- id_attr=db_project.id,
126
- name=db_project.name,
127
- gradient_start_color=db_project.gradient_start_color,
128
- gradient_end_color=db_project.gradient_end_color,
125
+ project_rowid=db_project.id,
126
+ db_project=db_project,
129
127
  )
130
128
 
131
129
  @strawberry.field
@@ -20,7 +20,7 @@ from phoenix.server.api.types.pagination import (
20
20
  CursorString,
21
21
  connection_from_list,
22
22
  )
23
- from phoenix.server.api.types.Trace import Trace, to_gql_trace
23
+ from phoenix.server.api.types.Trace import Trace
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from phoenix.server.api.types.DatasetExample import DatasetExample
@@ -64,7 +64,7 @@ class ExperimentRun(Node):
64
64
  dataloader = info.context.data_loaders.trace_by_trace_ids
65
65
  if (trace := await dataloader.load(self.trace_id)) is None:
66
66
  return None
67
- return to_gql_trace(trace)
67
+ return Trace(trace_rowid=trace.id, db_trace=trace)
68
68
 
69
69
  @strawberry.field
70
70
  async def example(
@@ -8,7 +8,7 @@ from strawberry.scalars import JSON
8
8
 
9
9
  from phoenix.db import models
10
10
  from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
11
- from phoenix.server.api.types.Trace import Trace, to_gql_trace
11
+ from phoenix.server.api.types.Trace import Trace
12
12
 
13
13
 
14
14
  @strawberry.type
@@ -32,7 +32,7 @@ class ExperimentRunAnnotation(Node):
32
32
  dataloader = info.context.data_loaders.trace_by_trace_ids
33
33
  if (trace := await dataloader.load(self.trace_id)) is None:
34
34
  return None
35
- return to_gql_trace(trace)
35
+ return Trace(trace_rowid=trace.id, db_trace=trace)
36
36
 
37
37
 
38
38
  def to_gql_experiment_run_annotation(
@@ -8,7 +8,7 @@ from openinference.semconv.trace import SpanAttributes
8
8
  from sqlalchemy import desc, distinct, func, or_, select
9
9
  from sqlalchemy.sql.elements import ColumnElement
10
10
  from sqlalchemy.sql.expression import tuple_
11
- from strawberry import ID, UNSET
11
+ from strawberry import ID, UNSET, Private
12
12
  from strawberry.relay import Connection, Node, NodeID
13
13
  from strawberry.types import Info
14
14
  from typing_extensions import assert_never
@@ -33,7 +33,7 @@ from phoenix.server.api.types.pagination import (
33
33
  from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
34
34
  from phoenix.server.api.types.SortDir import SortDir
35
35
  from phoenix.server.api.types.Span import Span
36
- from phoenix.server.api.types.Trace import Trace, to_gql_trace
36
+ from phoenix.server.api.types.Trace import Trace
37
37
  from phoenix.server.api.types.ValidationResult import ValidationResult
38
38
  from phoenix.trace.dsl import SpanFilter
39
39
 
@@ -41,10 +41,51 @@ from phoenix.trace.dsl import SpanFilter
41
41
  @strawberry.type
42
42
  class Project(Node):
43
43
  _table: ClassVar[type[models.Base]] = models.Project
44
- id_attr: NodeID[int]
45
- name: str
46
- gradient_start_color: str
47
- gradient_end_color: str
44
+ project_rowid: NodeID[int]
45
+ db_project: Private[models.Project] = UNSET
46
+
47
+ def __post_init__(self) -> None:
48
+ if self.db_project and self.project_rowid != self.db_project.id:
49
+ raise ValueError("Project ID mismatch")
50
+
51
+ @strawberry.field
52
+ async def name(
53
+ self,
54
+ info: Info[Context, None],
55
+ ) -> str:
56
+ if self.db_project:
57
+ name = self.db_project.name
58
+ else:
59
+ name = await info.context.data_loaders.project_fields.load(
60
+ (self.project_rowid, models.Project.name),
61
+ )
62
+ return name
63
+
64
+ @strawberry.field
65
+ async def gradient_start_color(
66
+ self,
67
+ info: Info[Context, None],
68
+ ) -> str:
69
+ if self.db_project:
70
+ gradient_start_color = self.db_project.gradient_start_color
71
+ else:
72
+ gradient_start_color = await info.context.data_loaders.project_fields.load(
73
+ (self.project_rowid, models.Project.gradient_start_color),
74
+ )
75
+ return gradient_start_color
76
+
77
+ @strawberry.field
78
+ async def gradient_end_color(
79
+ self,
80
+ info: Info[Context, None],
81
+ ) -> str:
82
+ if self.db_project:
83
+ gradient_end_color = self.db_project.gradient_end_color
84
+ else:
85
+ gradient_end_color = await info.context.data_loaders.project_fields.load(
86
+ (self.project_rowid, models.Project.gradient_end_color),
87
+ )
88
+ return gradient_end_color
48
89
 
49
90
  @strawberry.field
50
91
  async def start_time(
@@ -52,7 +93,7 @@ class Project(Node):
52
93
  info: Info[Context, None],
53
94
  ) -> Optional[datetime]:
54
95
  start_time = await info.context.data_loaders.min_start_or_max_end_times.load(
55
- (self.id_attr, "start"),
96
+ (self.project_rowid, "start"),
56
97
  )
57
98
  start_time, _ = right_open_time_range(start_time, None)
58
99
  return start_time
@@ -63,7 +104,7 @@ class Project(Node):
63
104
  info: Info[Context, None],
64
105
  ) -> Optional[datetime]:
65
106
  end_time = await info.context.data_loaders.min_start_or_max_end_times.load(
66
- (self.id_attr, "end"),
107
+ (self.project_rowid, "end"),
67
108
  )
68
109
  _, end_time = right_open_time_range(None, end_time)
69
110
  return end_time
@@ -76,7 +117,7 @@ class Project(Node):
76
117
  filter_condition: Optional[str] = UNSET,
77
118
  ) -> int:
78
119
  return await info.context.data_loaders.record_counts.load(
79
- ("span", self.id_attr, time_range, filter_condition),
120
+ ("span", self.project_rowid, time_range, filter_condition),
80
121
  )
81
122
 
82
123
  @strawberry.field
@@ -86,7 +127,7 @@ class Project(Node):
86
127
  time_range: Optional[TimeRange] = UNSET,
87
128
  ) -> int:
88
129
  return await info.context.data_loaders.record_counts.load(
89
- ("trace", self.id_attr, time_range, None),
130
+ ("trace", self.project_rowid, time_range, None),
90
131
  )
91
132
 
92
133
  @strawberry.field
@@ -97,7 +138,7 @@ class Project(Node):
97
138
  filter_condition: Optional[str] = UNSET,
98
139
  ) -> int:
99
140
  return await info.context.data_loaders.token_counts.load(
100
- ("total", self.id_attr, time_range, filter_condition),
141
+ ("total", self.project_rowid, time_range, filter_condition),
101
142
  )
102
143
 
103
144
  @strawberry.field
@@ -108,7 +149,7 @@ class Project(Node):
108
149
  filter_condition: Optional[str] = UNSET,
109
150
  ) -> int:
110
151
  return await info.context.data_loaders.token_counts.load(
111
- ("prompt", self.id_attr, time_range, filter_condition),
152
+ ("prompt", self.project_rowid, time_range, filter_condition),
112
153
  )
113
154
 
114
155
  @strawberry.field
@@ -119,7 +160,7 @@ class Project(Node):
119
160
  filter_condition: Optional[str] = UNSET,
120
161
  ) -> int:
121
162
  return await info.context.data_loaders.token_counts.load(
122
- ("completion", self.id_attr, time_range, filter_condition),
163
+ ("completion", self.project_rowid, time_range, filter_condition),
123
164
  )
124
165
 
125
166
  @strawberry.field
@@ -132,7 +173,7 @@ class Project(Node):
132
173
  return await info.context.data_loaders.latency_ms_quantile.load(
133
174
  (
134
175
  "trace",
135
- self.id_attr,
176
+ self.project_rowid,
136
177
  time_range,
137
178
  None,
138
179
  probability,
@@ -150,7 +191,7 @@ class Project(Node):
150
191
  return await info.context.data_loaders.latency_ms_quantile.load(
151
192
  (
152
193
  "span",
153
- self.id_attr,
194
+ self.project_rowid,
154
195
  time_range,
155
196
  filter_condition,
156
197
  probability,
@@ -162,12 +203,12 @@ class Project(Node):
162
203
  stmt = (
163
204
  select(models.Trace)
164
205
  .where(models.Trace.trace_id == str(trace_id))
165
- .where(models.Trace.project_rowid == self.id_attr)
206
+ .where(models.Trace.project_rowid == self.project_rowid)
166
207
  )
167
208
  async with info.context.db() as session:
168
209
  if (trace := await session.scalar(stmt)) is None:
169
210
  return None
170
- return to_gql_trace(trace)
211
+ return Trace(trace_rowid=trace.id, db_trace=trace)
171
212
 
172
213
  @strawberry.field
173
214
  async def spans(
@@ -185,7 +226,7 @@ class Project(Node):
185
226
  stmt = (
186
227
  select(models.Span.id)
187
228
  .join(models.Trace)
188
- .where(models.Trace.project_rowid == self.id_attr)
229
+ .where(models.Trace.project_rowid == self.project_rowid)
189
230
  )
190
231
  if time_range:
191
232
  if time_range.start:
@@ -264,7 +305,7 @@ class Project(Node):
264
305
  filter_io_substring: Optional[str] = UNSET,
265
306
  ) -> Connection[ProjectSession]:
266
307
  table = models.ProjectSession
267
- stmt = select(table).filter_by(project_id=self.id_attr)
308
+ stmt = select(table).filter_by(project_id=self.project_rowid)
268
309
  if time_range:
269
310
  if time_range.start:
270
311
  stmt = stmt.where(time_range.start <= table.start_time)
@@ -382,7 +423,7 @@ class Project(Node):
382
423
  stmt = (
383
424
  select(distinct(models.TraceAnnotation.name))
384
425
  .join(models.Trace)
385
- .where(models.Trace.project_rowid == self.id_attr)
426
+ .where(models.Trace.project_rowid == self.project_rowid)
386
427
  )
387
428
  async with info.context.db() as session:
388
429
  return list(await session.scalars(stmt))
@@ -399,7 +440,7 @@ class Project(Node):
399
440
  select(distinct(models.SpanAnnotation.name))
400
441
  .join(models.Span)
401
442
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
402
- .where(models.Trace.project_rowid == self.id_attr)
443
+ .where(models.Trace.project_rowid == self.project_rowid)
403
444
  )
404
445
  async with info.context.db() as session:
405
446
  return list(await session.scalars(stmt))
@@ -416,7 +457,7 @@ class Project(Node):
416
457
  select(distinct(models.DocumentAnnotation.name))
417
458
  .join(models.Span)
418
459
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
419
- .where(models.Trace.project_rowid == self.id_attr)
460
+ .where(models.Trace.project_rowid == self.project_rowid)
420
461
  .where(models.DocumentAnnotation.annotator_kind == "LLM")
421
462
  )
422
463
  if span_id:
@@ -432,7 +473,7 @@ class Project(Node):
432
473
  time_range: Optional[TimeRange] = UNSET,
433
474
  ) -> Optional[AnnotationSummary]:
434
475
  return await info.context.data_loaders.annotation_summaries.load(
435
- ("trace", self.id_attr, time_range, None, annotation_name),
476
+ ("trace", self.project_rowid, time_range, None, annotation_name),
436
477
  )
437
478
 
438
479
  @strawberry.field
@@ -444,7 +485,7 @@ class Project(Node):
444
485
  filter_condition: Optional[str] = UNSET,
445
486
  ) -> Optional[AnnotationSummary]:
446
487
  return await info.context.data_loaders.annotation_summaries.load(
447
- ("span", self.id_attr, time_range, filter_condition, annotation_name),
488
+ ("span", self.project_rowid, time_range, filter_condition, annotation_name),
448
489
  )
449
490
 
450
491
  @strawberry.field
@@ -456,7 +497,7 @@ class Project(Node):
456
497
  filter_condition: Optional[str] = UNSET,
457
498
  ) -> Optional[DocumentEvaluationSummary]:
458
499
  return await info.context.data_loaders.document_evaluation_summaries.load(
459
- (self.id_attr, time_range, filter_condition, evaluation_name),
500
+ (self.project_rowid, time_range, filter_condition, evaluation_name),
460
501
  )
461
502
 
462
503
  @strawberry.field
@@ -464,7 +505,7 @@ class Project(Node):
464
505
  self,
465
506
  info: Info[Context, None],
466
507
  ) -> Optional[datetime]:
467
- return info.context.last_updated_at.get(self._table, self.id_attr)
508
+ return info.context.last_updated_at.get(self._table, self.project_rowid)
468
509
 
469
510
  @strawberry.field
470
511
  async def validate_span_filter_condition(self, condition: str) -> ValidationResult:
@@ -483,17 +524,5 @@ class Project(Node):
483
524
  )
484
525
 
485
526
 
486
- def to_gql_project(project: models.Project) -> Project:
487
- """
488
- Converts an ORM project to a GraphQL Project.
489
- """
490
- return Project(
491
- id_attr=project.id,
492
- name=project.name,
493
- gradient_start_color=project.gradient_start_color,
494
- gradient_end_color=project.gradient_end_color,
495
- )
496
-
497
-
498
527
  INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
499
528
  OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
@@ -93,7 +93,7 @@ class ProjectSession(Node):
93
93
  after: Optional[CursorString] = UNSET,
94
94
  before: Optional[CursorString] = UNSET,
95
95
  ) -> Connection[Annotated["Trace", lazy(".Trace")]]:
96
- from phoenix.server.api.types.Trace import to_gql_trace
96
+ from phoenix.server.api.types.Trace import Trace
97
97
 
98
98
  args = ConnectionArgs(
99
99
  first=first,
@@ -109,7 +109,7 @@ class ProjectSession(Node):
109
109
  )
110
110
  async with info.context.db() as session:
111
111
  traces = await session.stream_scalars(stmt)
112
- data = [to_gql_trace(trace) async for trace in traces]
112
+ data = [Trace(trace_rowid=trace.id, db_trace=trace) async for trace in traces]
113
113
  return connection_from_list(data=data, args=args)
114
114
 
115
115
  @strawberry.field
@@ -37,6 +37,7 @@ from phoenix.trace.attributes import get_attribute_value
37
37
 
38
38
  if TYPE_CHECKING:
39
39
  from phoenix.server.api.types.Project import Project
40
+ from phoenix.server.api.types.Trace import Trace
40
41
 
41
42
 
42
43
  @strawberry.enum
@@ -216,6 +217,34 @@ class Span(Node):
216
217
  )
217
218
  return SpanKind(value)
218
219
 
220
+ @strawberry.field
221
+ async def span_id(
222
+ self,
223
+ info: Info[Context, None],
224
+ ) -> ID:
225
+ if self.db_span:
226
+ span_id = self.db_span.span_id
227
+ else:
228
+ span_id = await info.context.data_loaders.span_fields.load(
229
+ (self.span_rowid, models.Span.span_id),
230
+ )
231
+ return ID(span_id)
232
+
233
+ @strawberry.field
234
+ async def trace(
235
+ self,
236
+ info: Info[Context, None],
237
+ ) -> Annotated["Trace", strawberry.lazy(".Trace")]:
238
+ if self.db_span:
239
+ trace_rowid = self.db_span.trace_rowid
240
+ else:
241
+ trace_rowid = await info.context.data_loaders.span_fields.load(
242
+ (self.span_rowid, models.Span.trace_rowid),
243
+ )
244
+ from phoenix.server.api.types.Trace import Trace
245
+
246
+ return Trace(trace_rowid=trace_rowid)
247
+
219
248
  @strawberry.field
220
249
  async def context(
221
250
  self,
@@ -561,11 +590,11 @@ class Span(Node):
561
590
  ) -> Annotated[
562
591
  "Project", strawberry.lazy("phoenix.server.api.types.Project")
563
592
  ]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
564
- from phoenix.server.api.types.Project import to_gql_project
593
+ from phoenix.server.api.types.Project import Project
565
594
 
566
595
  span_id = self.span_rowid
567
596
  project = await info.context.data_loaders.span_projects.load(span_id)
568
- return to_gql_project(project)
597
+ return Project(project_rowid=project.id, db_project=project)
569
598
 
570
599
  @strawberry.field(description="Indicates if the span is contained in any dataset") # type: ignore
571
600
  async def contained_in_dataset(
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Annotated, Optional, Union
6
6
  import strawberry
7
7
  from openinference.semconv.trace import SpanAttributes
8
8
  from sqlalchemy import desc, select
9
- from strawberry import UNSET, Private, lazy
9
+ from strawberry import ID, UNSET, Private, lazy
10
10
  from strawberry.relay import Connection, GlobalID, Node, NodeID
11
11
  from strawberry.types import Info
12
12
  from typing_extensions import TypeAlias
@@ -24,6 +24,7 @@ from phoenix.server.api.types.Span import Span
24
24
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
25
25
 
26
26
  if TYPE_CHECKING:
27
+ from phoenix.server.api.types.Project import Project
27
28
  from phoenix.server.api.types.ProjectSession import ProjectSession
28
29
 
29
30
  ProjectRowId: TypeAlias = int
@@ -33,49 +34,127 @@ TraceRowId: TypeAlias = int
33
34
  @strawberry.type
34
35
  class Trace(Node):
35
36
  trace_rowid: NodeID[TraceRowId]
36
- project_rowid: Private[ProjectRowId]
37
- project_session_rowid: Private[Optional[int]]
38
- trace_id: str
39
- start_time: datetime
40
- end_time: datetime
37
+ db_trace: Private[models.Trace] = UNSET
38
+
39
+ def __post_init__(self) -> None:
40
+ if self.db_trace and self.trace_rowid != self.db_trace.id:
41
+ raise ValueError("Trace ID mismatch")
42
+
43
+ @strawberry.field
44
+ async def trace_id(
45
+ self,
46
+ info: Info[Context, None],
47
+ ) -> ID:
48
+ if self.db_trace:
49
+ trace_id = self.db_trace.trace_id
50
+ else:
51
+ trace_id = await info.context.data_loaders.trace_fields.load(
52
+ (self.trace_rowid, models.Trace.trace_id),
53
+ )
54
+ return ID(trace_id)
55
+
56
+ @strawberry.field
57
+ async def start_time(
58
+ self,
59
+ info: Info[Context, None],
60
+ ) -> datetime:
61
+ if self.db_trace:
62
+ start_time = self.db_trace.start_time
63
+ else:
64
+ start_time = await info.context.data_loaders.trace_fields.load(
65
+ (self.trace_rowid, models.Trace.start_time),
66
+ )
67
+ return start_time
68
+
69
+ @strawberry.field
70
+ async def end_time(
71
+ self,
72
+ info: Info[Context, None],
73
+ ) -> datetime:
74
+ if self.db_trace:
75
+ end_time = self.db_trace.end_time
76
+ else:
77
+ end_time = await info.context.data_loaders.trace_fields.load(
78
+ (self.trace_rowid, models.Trace.end_time),
79
+ )
80
+ return end_time
41
81
 
42
82
  @strawberry.field
43
83
  async def latency_ms(
44
84
  self,
45
85
  info: Info[Context, None],
46
86
  ) -> Optional[float]:
47
- async with info.context.db() as session:
48
- latency = await session.scalar(
49
- select(
50
- models.Trace.latency_ms,
51
- ).where(models.Trace.id == self.trace_rowid)
87
+ if self.db_trace:
88
+ latency_ms = self.db_trace.latency_ms
89
+ else:
90
+ latency_ms = await info.context.data_loaders.trace_fields.load(
91
+ (self.trace_rowid, models.Trace.latency_ms),
52
92
  )
53
- return latency
93
+ return latency_ms
54
94
 
55
95
  @strawberry.field
56
- async def project_id(self) -> GlobalID:
96
+ async def project(
97
+ self,
98
+ info: Info[Context, None],
99
+ ) -> Annotated["Project", strawberry.lazy(".Project")]:
100
+ if self.db_trace:
101
+ project_rowid = self.db_trace.project_rowid
102
+ else:
103
+ project_rowid = await info.context.data_loaders.trace_fields.load(
104
+ (self.trace_rowid, models.Trace.project_rowid),
105
+ )
106
+ from phoenix.server.api.types.Project import Project
107
+
108
+ return Project(project_rowid=project_rowid)
109
+
110
+ @strawberry.field
111
+ async def project_id(
112
+ self,
113
+ info: Info[Context, None],
114
+ ) -> GlobalID:
115
+ if self.db_trace:
116
+ project_rowid = self.db_trace.project_rowid
117
+ else:
118
+ project_rowid = await info.context.data_loaders.trace_fields.load(
119
+ (self.trace_rowid, models.Trace.project_rowid),
120
+ )
57
121
  from phoenix.server.api.types.Project import Project
58
122
 
59
- return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
123
+ return GlobalID(type_name=Project.__name__, node_id=str(project_rowid))
60
124
 
61
125
  @strawberry.field
62
- async def project_session_id(self) -> Optional[GlobalID]:
63
- if self.project_session_rowid is None:
126
+ async def project_session_id(
127
+ self,
128
+ info: Info[Context, None],
129
+ ) -> Optional[GlobalID]:
130
+ if self.db_trace:
131
+ project_session_rowid = self.db_trace.project_session_rowid
132
+ else:
133
+ project_session_rowid = await info.context.data_loaders.trace_fields.load(
134
+ (self.trace_rowid, models.Trace.project_session_rowid),
135
+ )
136
+ if project_session_rowid is None:
64
137
  return None
65
138
  from phoenix.server.api.types.ProjectSession import ProjectSession
66
139
 
67
- return GlobalID(type_name=ProjectSession.__name__, node_id=str(self.project_session_rowid))
140
+ return GlobalID(type_name=ProjectSession.__name__, node_id=str(project_session_rowid))
68
141
 
69
142
  @strawberry.field
70
143
  async def session(
71
144
  self,
72
145
  info: Info[Context, None],
73
146
  ) -> Union[Annotated["ProjectSession", lazy(".ProjectSession")], None]:
74
- if self.project_session_rowid is None:
147
+ if self.db_trace:
148
+ project_session_rowid = self.db_trace.project_session_rowid
149
+ else:
150
+ project_session_rowid = await info.context.data_loaders.trace_fields.load(
151
+ (self.trace_rowid, models.Trace.project_session_rowid),
152
+ )
153
+ if project_session_rowid is None:
75
154
  return None
76
155
  from phoenix.server.api.types.ProjectSession import to_gql_project_session
77
156
 
78
- stmt = select(models.ProjectSession).filter_by(id=self.project_session_rowid)
157
+ stmt = select(models.ProjectSession).filter_by(id=project_session_rowid)
79
158
  async with info.context.db() as session:
80
159
  project_session = await session.scalar(stmt)
81
160
  if project_session is None:
@@ -141,16 +220,5 @@ class Trace(Node):
141
220
  return [to_gql_trace_annotation(annotation) for annotation in annotations]
142
221
 
143
222
 
144
- def to_gql_trace(trace: models.Trace) -> Trace:
145
- return Trace(
146
- trace_rowid=trace.id,
147
- project_rowid=trace.project_rowid,
148
- project_session_rowid=trace.project_session_rowid,
149
- trace_id=trace.trace_id,
150
- start_time=trace.start_time,
151
- end_time=trace.end_time,
152
- )
153
-
154
-
155
223
  INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
156
224
  OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
phoenix/server/app.py CHANGED
@@ -97,8 +97,8 @@ from phoenix.server.api.dataloaders import (
97
97
  SpanByIdDataLoader,
98
98
  SpanDatasetExamplesDataLoader,
99
99
  SpanDescendantsDataLoader,
100
- SpanFieldsDataLoader,
101
100
  SpanProjectsDataLoader,
101
+ TableFieldsDataLoader,
102
102
  TokenCountDataLoader,
103
103
  TraceByTraceIdsDataLoader,
104
104
  TraceRootSpansDataLoader,
@@ -614,6 +614,7 @@ def create_graphql_router(
614
614
  else None
615
615
  ),
616
616
  ),
617
+ project_fields=TableFieldsDataLoader(db, models.Project),
617
618
  prompt_version_sequence_number=PromptVersionSequenceNumberDataLoader(db),
618
619
  record_counts=RecordCountDataLoader(
619
620
  db,
@@ -626,7 +627,7 @@ def create_graphql_router(
626
627
  session_token_usages=SessionTokenUsagesDataLoader(db),
627
628
  session_trace_latency_ms_quantile=SessionTraceLatencyMsQuantileDataLoader(db),
628
629
  span_annotations=SpanAnnotationsDataLoader(db),
629
- span_fields=SpanFieldsDataLoader(db),
630
+ span_fields=TableFieldsDataLoader(db, models.Span),
630
631
  span_by_id=SpanByIdDataLoader(db),
631
632
  span_dataset_examples=SpanDatasetExamplesDataLoader(db),
632
633
  span_descendants=SpanDescendantsDataLoader(db),
@@ -636,6 +637,7 @@ def create_graphql_router(
636
637
  cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None,
637
638
  ),
638
639
  trace_by_trace_ids=TraceByTraceIdsDataLoader(db),
640
+ trace_fields=TableFieldsDataLoader(db, models.Trace),
639
641
  trace_root_spans=TraceRootSpansDataLoader(db),
640
642
  project_by_name=ProjectByNameDataLoader(db),
641
643
  users=UsersDataLoader(db),
@@ -1,22 +1,22 @@
1
1
  {
2
- "_components-MeFAEc1z.js": {
3
- "file": "assets/components-MeFAEc1z.js",
2
+ "_components-T5K9z49d.js": {
3
+ "file": "assets/components-T5K9z49d.js",
4
4
  "name": "components",
5
5
  "imports": [
6
6
  "_vendor-Cqfydjep.js",
7
- "_pages-NrL4hb9q.js",
7
+ "_pages-CY3ZXSHj.js",
8
8
  "_vendor-arizeai-WnerlUPN.js",
9
9
  "_vendor-codemirror-D-ZZKLFq.js",
10
10
  "_vendor-three-C-AGeJYv.js"
11
11
  ]
12
12
  },
13
- "_pages-NrL4hb9q.js": {
14
- "file": "assets/pages-NrL4hb9q.js",
13
+ "_pages-CY3ZXSHj.js": {
14
+ "file": "assets/pages-CY3ZXSHj.js",
15
15
  "name": "pages",
16
16
  "imports": [
17
17
  "_vendor-Cqfydjep.js",
18
18
  "_vendor-arizeai-WnerlUPN.js",
19
- "_components-MeFAEc1z.js",
19
+ "_components-T5K9z49d.js",
20
20
  "_vendor-codemirror-D-ZZKLFq.js",
21
21
  "_vendor-recharts-KY97ZPfK.js"
22
22
  ]
@@ -69,15 +69,15 @@
69
69
  "name": "vendor-three"
70
70
  },
71
71
  "index.tsx": {
72
- "file": "assets/index-BSRuZ-_J.js",
72
+ "file": "assets/index-DvHwFF8e.js",
73
73
  "name": "index",
74
74
  "src": "index.tsx",
75
75
  "isEntry": true,
76
76
  "imports": [
77
77
  "_vendor-Cqfydjep.js",
78
78
  "_vendor-arizeai-WnerlUPN.js",
79
- "_pages-NrL4hb9q.js",
80
- "_components-MeFAEc1z.js",
79
+ "_pages-CY3ZXSHj.js",
80
+ "_components-T5K9z49d.js",
81
81
  "_vendor-three-C-AGeJYv.js",
82
82
  "_vendor-codemirror-D-ZZKLFq.js",
83
83
  "_vendor-shiki-D5K9GnFn.js",