arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (70) hide show
  1. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +65 -210
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/static/.vite/manifest.json +9 -9
  63. phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
  64. phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
  65. phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
  66. phoenix/version.py +1 -1
  67. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  68. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  69. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  70. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,8 @@
1
- from typing import Optional
1
+ from datetime import datetime
2
+ from math import isfinite
3
+ from typing import TYPE_CHECKING, Annotated, Optional
2
4
 
3
5
  import strawberry
4
- from strawberry import Private
5
6
  from strawberry.relay import GlobalID, Node, NodeID
6
7
  from strawberry.scalars import JSON
7
8
  from strawberry.types import Info
@@ -12,56 +13,202 @@ from phoenix.server.api.context import Context
12
13
  from .Annotation import Annotation
13
14
  from .AnnotationSource import AnnotationSource
14
15
  from .AnnotatorKind import AnnotatorKind
15
- from .User import User, to_gql_user
16
+
17
+ if TYPE_CHECKING:
18
+ from .Span import Span
19
+ from .User import User
16
20
 
17
21
 
18
22
  @strawberry.type
19
23
  class SpanAnnotation(Node, Annotation):
20
- id_attr: NodeID[int]
21
- user_id: Private[Optional[int]]
22
- annotator_kind: AnnotatorKind
23
- metadata: JSON
24
- span_rowid: Private[Optional[int]]
25
- source: AnnotationSource
26
- identifier: str
24
+ id: NodeID[int]
25
+ db_record: strawberry.Private[Optional[models.SpanAnnotation]] = None
26
+
27
+ def __post_init__(self) -> None:
28
+ if self.db_record and self.id != self.db_record.id:
29
+ raise ValueError("SpanAnnotation ID mismatch")
30
+
31
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
32
+ async def name(
33
+ self,
34
+ info: Info[Context, None],
35
+ ) -> str:
36
+ if self.db_record:
37
+ val = self.db_record.name
38
+ else:
39
+ val = await info.context.data_loaders.span_annotation_fields.load(
40
+ (self.id, models.SpanAnnotation.name),
41
+ )
42
+ return val
43
+
44
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
45
+ async def annotator_kind(
46
+ self,
47
+ info: Info[Context, None],
48
+ ) -> AnnotatorKind:
49
+ if self.db_record:
50
+ val = self.db_record.annotator_kind
51
+ else:
52
+ val = await info.context.data_loaders.span_annotation_fields.load(
53
+ (self.id, models.SpanAnnotation.annotator_kind),
54
+ )
55
+ return AnnotatorKind(val)
56
+
57
+ @strawberry.field(
58
+ description="Value of the annotation in the form of a string, e.g. "
59
+ "'helpful' or 'not helpful'. Note that the label is not necessarily binary."
60
+ ) # type: ignore
61
+ async def label(
62
+ self,
63
+ info: Info[Context, None],
64
+ ) -> Optional[str]:
65
+ if self.db_record:
66
+ val = self.db_record.label
67
+ else:
68
+ val = await info.context.data_loaders.span_annotation_fields.load(
69
+ (self.id, models.SpanAnnotation.label),
70
+ )
71
+ return val
72
+
73
+ @strawberry.field(
74
+ description="Value of the annotation in the form of a numeric score.",
75
+ ) # type: ignore
76
+ async def score(
77
+ self,
78
+ info: Info[Context, None],
79
+ ) -> Optional[float]:
80
+ if self.db_record:
81
+ val = self.db_record.score
82
+ else:
83
+ val = await info.context.data_loaders.span_annotation_fields.load(
84
+ (self.id, models.SpanAnnotation.score),
85
+ )
86
+ return val if val is not None and isfinite(val) else None
87
+
88
+ @strawberry.field(
89
+ description="The annotator's explanation for the annotation result (i.e. "
90
+ "score or label, or both) given to the subject."
91
+ ) # type: ignore
92
+ async def explanation(
93
+ self,
94
+ info: Info[Context, None],
95
+ ) -> Optional[str]:
96
+ if self.db_record:
97
+ val = self.db_record.explanation
98
+ else:
99
+ val = await info.context.data_loaders.span_annotation_fields.load(
100
+ (self.id, models.SpanAnnotation.explanation),
101
+ )
102
+ return val
103
+
104
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
105
+ async def metadata(
106
+ self,
107
+ info: Info[Context, None],
108
+ ) -> JSON:
109
+ if self.db_record:
110
+ val = self.db_record.metadata_
111
+ else:
112
+ val = await info.context.data_loaders.span_annotation_fields.load(
113
+ (self.id, models.SpanAnnotation.metadata_),
114
+ )
115
+ return val
116
+
117
+ @strawberry.field(description="The source of the annotation.") # type: ignore
118
+ async def source(
119
+ self,
120
+ info: Info[Context, None],
121
+ ) -> AnnotationSource:
122
+ if self.db_record:
123
+ val = self.db_record.source
124
+ else:
125
+ val = await info.context.data_loaders.span_annotation_fields.load(
126
+ (self.id, models.SpanAnnotation.source),
127
+ )
128
+ return AnnotationSource(val)
129
+
130
+ @strawberry.field(description="The identifier of the annotation.") # type: ignore
131
+ async def identifier(
132
+ self,
133
+ info: Info[Context, None],
134
+ ) -> str:
135
+ if self.db_record:
136
+ val = self.db_record.identifier
137
+ else:
138
+ val = await info.context.data_loaders.span_annotation_fields.load(
139
+ (self.id, models.SpanAnnotation.identifier),
140
+ )
141
+ return val
142
+
143
+ @strawberry.field(description="The date and time when the annotation was created.") # type: ignore
144
+ async def created_at(
145
+ self,
146
+ info: Info[Context, None],
147
+ ) -> datetime:
148
+ if self.db_record:
149
+ val = self.db_record.created_at
150
+ else:
151
+ val = await info.context.data_loaders.span_annotation_fields.load(
152
+ (self.id, models.SpanAnnotation.created_at),
153
+ )
154
+ return val
155
+
156
+ @strawberry.field(description="The date and time when the annotation was last updated.") # type: ignore
157
+ async def updated_at(
158
+ self,
159
+ info: Info[Context, None],
160
+ ) -> datetime:
161
+ if self.db_record:
162
+ val = self.db_record.updated_at
163
+ else:
164
+ val = await info.context.data_loaders.span_annotation_fields.load(
165
+ (self.id, models.SpanAnnotation.updated_at),
166
+ )
167
+ return val
27
168
 
28
169
  @strawberry.field
29
- async def span_id(self) -> GlobalID:
170
+ async def span_id(
171
+ self,
172
+ info: Info[Context, None],
173
+ ) -> GlobalID:
30
174
  from phoenix.server.api.types.Span import Span
31
175
 
32
- return GlobalID(type_name=Span.__name__, node_id=str(self.span_rowid))
176
+ if self.db_record:
177
+ span_rowid = self.db_record.span_rowid
178
+ else:
179
+ span_rowid = await info.context.data_loaders.span_annotation_fields.load(
180
+ (self.id, models.SpanAnnotation.span_rowid),
181
+ )
182
+ return GlobalID(type_name=Span.__name__, node_id=str(span_rowid))
33
183
 
34
- @strawberry.field
184
+ @strawberry.field(description="The span associated with the annotation.") # type: ignore
185
+ async def span(
186
+ self,
187
+ info: Info[Context, None],
188
+ ) -> Annotated["Span", strawberry.lazy(".Span")]:
189
+ if self.db_record:
190
+ span_rowid = self.db_record.span_rowid
191
+ else:
192
+ span_rowid = await info.context.data_loaders.span_annotation_fields.load(
193
+ (self.id, models.SpanAnnotation.span_rowid),
194
+ )
195
+ from .Span import Span
196
+
197
+ return Span(id=span_rowid)
198
+
199
+ @strawberry.field(description="The user that produced the annotation.") # type: ignore
35
200
  async def user(
36
201
  self,
37
202
  info: Info[Context, None],
38
- ) -> Optional[User]:
39
- if self.user_id is None:
203
+ ) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
204
+ if self.db_record:
205
+ user_id = self.db_record.user_id
206
+ else:
207
+ user_id = await info.context.data_loaders.span_annotation_fields.load(
208
+ (self.id, models.SpanAnnotation.user_id),
209
+ )
210
+ if user_id is None:
40
211
  return None
41
- user = await info.context.data_loaders.users.load(self.user_id)
42
- if user is None:
43
- return None
44
- return to_gql_user(user)
45
-
46
-
47
- def to_gql_span_annotation(
48
- annotation: models.SpanAnnotation,
49
- ) -> SpanAnnotation:
50
- """
51
- Converts an ORM span annotation to a GraphQL SpanAnnotation.
52
- """
53
- return SpanAnnotation(
54
- id_attr=annotation.id,
55
- user_id=annotation.user_id,
56
- span_rowid=annotation.span_rowid,
57
- name=annotation.name,
58
- annotator_kind=AnnotatorKind(annotation.annotator_kind),
59
- label=annotation.label,
60
- score=annotation.score,
61
- explanation=annotation.explanation,
62
- metadata=annotation.metadata_,
63
- source=AnnotationSource(annotation.source),
64
- identifier=annotation.identifier,
65
- created_at=annotation.created_at,
66
- updated_at=annotation.updated_at,
67
- )
212
+ from .User import User
213
+
214
+ return User(id=user_id)
@@ -1,9 +1,73 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
1
4
  import strawberry
2
5
  from strawberry.relay import Node, NodeID
6
+ from strawberry.types import Info
7
+
8
+ from phoenix.db.models import ApiKey as OrmApiKey
9
+ from phoenix.server.api.context import Context
3
10
 
4
11
  from .ApiKey import ApiKey
5
12
 
6
13
 
7
14
  @strawberry.type
8
15
  class SystemApiKey(ApiKey, Node):
9
- id_attr: NodeID[int]
16
+ id: NodeID[int]
17
+ db_record: strawberry.Private[Optional[OrmApiKey]] = None
18
+
19
+ def __post_init__(self) -> None:
20
+ if self.db_record and self.id != self.db_record.id:
21
+ raise ValueError("SystemApiKey ID mismatch")
22
+
23
+ @strawberry.field(description="Name of the API key.") # type: ignore
24
+ async def name(
25
+ self,
26
+ info: Info[Context, None],
27
+ ) -> str:
28
+ if self.db_record:
29
+ val = self.db_record.name
30
+ else:
31
+ val = await info.context.data_loaders.user_api_key_fields.load(
32
+ (self.id, OrmApiKey.name),
33
+ )
34
+ return val
35
+
36
+ @strawberry.field(description="Description of the API key.") # type: ignore
37
+ async def description(
38
+ self,
39
+ info: Info[Context, None],
40
+ ) -> Optional[str]:
41
+ if self.db_record:
42
+ val = self.db_record.description
43
+ else:
44
+ val = await info.context.data_loaders.user_api_key_fields.load(
45
+ (self.id, OrmApiKey.description),
46
+ )
47
+ return val
48
+
49
+ @strawberry.field(description="The date and time the API key was created.") # type: ignore
50
+ async def created_at(
51
+ self,
52
+ info: Info[Context, None],
53
+ ) -> datetime:
54
+ if self.db_record:
55
+ val = self.db_record.created_at
56
+ else:
57
+ val = await info.context.data_loaders.user_api_key_fields.load(
58
+ (self.id, OrmApiKey.created_at),
59
+ )
60
+ return val
61
+
62
+ @strawberry.field(description="The date and time the API key will expire.") # type: ignore
63
+ async def expires_at(
64
+ self,
65
+ info: Info[Context, None],
66
+ ) -> Optional[datetime]:
67
+ if self.db_record:
68
+ val = self.db_record.expires_at
69
+ else:
70
+ val = await info.context.data_loaders.user_api_key_fields.load(
71
+ (self.id, OrmApiKey.expires_at),
72
+ )
73
+ return val
@@ -9,7 +9,7 @@ import pandas as pd
9
9
  import strawberry
10
10
  from openinference.semconv.trace import SpanAttributes
11
11
  from sqlalchemy import desc, select
12
- from strawberry import ID, UNSET, Private, lazy
12
+ from strawberry import ID, UNSET, lazy
13
13
  from strawberry.relay import Connection, GlobalID, Node, NodeID
14
14
  from strawberry.types import Info
15
15
  from typing_extensions import TypeAlias
@@ -29,7 +29,7 @@ from phoenix.server.api.types.SortDir import SortDir
29
29
  from phoenix.server.api.types.Span import Span
30
30
  from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
31
31
  from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
32
- from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
32
+ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
33
33
 
34
34
  if TYPE_CHECKING:
35
35
  from phoenix.server.api.types.Project import Project
@@ -41,11 +41,11 @@ TraceRowId: TypeAlias = int
41
41
 
42
42
  @strawberry.type
43
43
  class Trace(Node):
44
- trace_rowid: NodeID[TraceRowId]
45
- db_trace: Private[models.Trace] = UNSET
44
+ id: NodeID[TraceRowId]
45
+ db_record: strawberry.Private[Optional[models.Trace]] = None
46
46
 
47
47
  def __post_init__(self) -> None:
48
- if self.db_trace and self.trace_rowid != self.db_trace.id:
48
+ if self.db_record and self.id != self.db_record.id:
49
49
  raise ValueError("Trace ID mismatch")
50
50
 
51
51
  @strawberry.field
@@ -53,11 +53,11 @@ class Trace(Node):
53
53
  self,
54
54
  info: Info[Context, None],
55
55
  ) -> ID:
56
- if self.db_trace:
57
- trace_id = self.db_trace.trace_id
56
+ if self.db_record:
57
+ trace_id = self.db_record.trace_id
58
58
  else:
59
59
  trace_id = await info.context.data_loaders.trace_fields.load(
60
- (self.trace_rowid, models.Trace.trace_id),
60
+ (self.id, models.Trace.trace_id),
61
61
  )
62
62
  return ID(trace_id)
63
63
 
@@ -66,11 +66,11 @@ class Trace(Node):
66
66
  self,
67
67
  info: Info[Context, None],
68
68
  ) -> datetime:
69
- if self.db_trace:
70
- start_time = self.db_trace.start_time
69
+ if self.db_record:
70
+ start_time = self.db_record.start_time
71
71
  else:
72
72
  start_time = await info.context.data_loaders.trace_fields.load(
73
- (self.trace_rowid, models.Trace.start_time),
73
+ (self.id, models.Trace.start_time),
74
74
  )
75
75
  return start_time
76
76
 
@@ -79,11 +79,11 @@ class Trace(Node):
79
79
  self,
80
80
  info: Info[Context, None],
81
81
  ) -> datetime:
82
- if self.db_trace:
83
- end_time = self.db_trace.end_time
82
+ if self.db_record:
83
+ end_time = self.db_record.end_time
84
84
  else:
85
85
  end_time = await info.context.data_loaders.trace_fields.load(
86
- (self.trace_rowid, models.Trace.end_time),
86
+ (self.id, models.Trace.end_time),
87
87
  )
88
88
  return end_time
89
89
 
@@ -92,11 +92,11 @@ class Trace(Node):
92
92
  self,
93
93
  info: Info[Context, None],
94
94
  ) -> Optional[float]:
95
- if self.db_trace:
96
- latency_ms = self.db_trace.latency_ms
95
+ if self.db_record:
96
+ latency_ms = self.db_record.latency_ms
97
97
  else:
98
98
  latency_ms = await info.context.data_loaders.trace_fields.load(
99
- (self.trace_rowid, models.Trace.latency_ms),
99
+ (self.id, models.Trace.latency_ms),
100
100
  )
101
101
  return latency_ms
102
102
 
@@ -105,26 +105,26 @@ class Trace(Node):
105
105
  self,
106
106
  info: Info[Context, None],
107
107
  ) -> Annotated["Project", strawberry.lazy(".Project")]:
108
- if self.db_trace:
109
- project_rowid = self.db_trace.project_rowid
108
+ if self.db_record:
109
+ project_rowid = self.db_record.project_rowid
110
110
  else:
111
111
  project_rowid = await info.context.data_loaders.trace_fields.load(
112
- (self.trace_rowid, models.Trace.project_rowid),
112
+ (self.id, models.Trace.project_rowid),
113
113
  )
114
114
  from phoenix.server.api.types.Project import Project
115
115
 
116
- return Project(project_rowid=project_rowid)
116
+ return Project(id=project_rowid)
117
117
 
118
118
  @strawberry.field
119
119
  async def project_id(
120
120
  self,
121
121
  info: Info[Context, None],
122
122
  ) -> GlobalID:
123
- if self.db_trace:
124
- project_rowid = self.db_trace.project_rowid
123
+ if self.db_record:
124
+ project_rowid = self.db_record.project_rowid
125
125
  else:
126
126
  project_rowid = await info.context.data_loaders.trace_fields.load(
127
- (self.trace_rowid, models.Trace.project_rowid),
127
+ (self.id, models.Trace.project_rowid),
128
128
  )
129
129
  from phoenix.server.api.types.Project import Project
130
130
 
@@ -135,11 +135,11 @@ class Trace(Node):
135
135
  self,
136
136
  info: Info[Context, None],
137
137
  ) -> Optional[GlobalID]:
138
- if self.db_trace:
139
- project_session_rowid = self.db_trace.project_session_rowid
138
+ if self.db_record:
139
+ project_session_rowid = self.db_record.project_session_rowid
140
140
  else:
141
141
  project_session_rowid = await info.context.data_loaders.trace_fields.load(
142
- (self.trace_rowid, models.Trace.project_session_rowid),
142
+ (self.id, models.Trace.project_session_rowid),
143
143
  )
144
144
  if project_session_rowid is None:
145
145
  return None
@@ -152,39 +152,40 @@ class Trace(Node):
152
152
  self,
153
153
  info: Info[Context, None],
154
154
  ) -> Union[Annotated["ProjectSession", lazy(".ProjectSession")], None]:
155
- if self.db_trace:
156
- project_session_rowid = self.db_trace.project_session_rowid
155
+ if self.db_record:
156
+ project_session_rowid = self.db_record.project_session_rowid
157
157
  else:
158
158
  project_session_rowid = await info.context.data_loaders.trace_fields.load(
159
- (self.trace_rowid, models.Trace.project_session_rowid),
159
+ (self.id, models.Trace.project_session_rowid),
160
160
  )
161
161
  if project_session_rowid is None:
162
162
  return None
163
- from phoenix.server.api.types.ProjectSession import to_gql_project_session
164
163
 
165
164
  stmt = select(models.ProjectSession).filter_by(id=project_session_rowid)
166
165
  async with info.context.db() as session:
167
166
  project_session = await session.scalar(stmt)
168
167
  if project_session is None:
169
168
  return None
170
- return to_gql_project_session(project_session)
169
+ from .ProjectSession import ProjectSession
170
+
171
+ return ProjectSession(id=project_session.id, db_record=project_session)
171
172
 
172
173
  @strawberry.field
173
174
  async def root_span(
174
175
  self,
175
176
  info: Info[Context, None],
176
177
  ) -> Optional[Span]:
177
- span_rowid = await info.context.data_loaders.trace_root_spans.load(self.trace_rowid)
178
+ span_rowid = await info.context.data_loaders.trace_root_spans.load(self.id)
178
179
  if span_rowid is None:
179
180
  return None
180
- return Span(span_rowid=span_rowid)
181
+ return Span(id=span_rowid)
181
182
 
182
183
  @strawberry.field
183
184
  async def num_spans(
184
185
  self,
185
186
  info: Info[Context, None],
186
187
  ) -> int:
187
- return await info.context.data_loaders.num_spans_per_trace.load(self.trace_rowid)
188
+ return await info.context.data_loaders.num_spans_per_trace.load(self.id)
188
189
 
189
190
  @strawberry.field
190
191
  async def spans(
@@ -204,7 +205,7 @@ class Trace(Node):
204
205
  stmt = (
205
206
  select(models.Span.id)
206
207
  .join(models.Trace)
207
- .where(models.Trace.id == self.trace_rowid)
208
+ .where(models.Trace.id == self.id)
208
209
  # Sort descending because the root span tends to show up later
209
210
  # in the ingestion process.
210
211
  .order_by(desc(models.Span.id))
@@ -212,7 +213,7 @@ class Trace(Node):
212
213
  )
213
214
  async with info.context.db() as session:
214
215
  span_rowids = await session.stream_scalars(stmt)
215
- data = [Span(span_rowid=span_rowid) async for span_rowid in span_rowids]
216
+ data = [Span(id=span_rowid) async for span_rowid in span_rowids]
216
217
  return connection_from_list(data=data, args=args)
217
218
 
218
219
  @strawberry.field(description="Annotations associated with the trace.") # type: ignore
@@ -222,7 +223,7 @@ class Trace(Node):
222
223
  sort: Optional[TraceAnnotationSort] = None,
223
224
  ) -> list[TraceAnnotation]:
224
225
  async with info.context.db() as session:
225
- stmt = select(models.TraceAnnotation).filter_by(trace_rowid=self.trace_rowid)
226
+ stmt = select(models.TraceAnnotation).filter_by(trace_rowid=self.id)
226
227
  if sort:
227
228
  sort_col = getattr(models.TraceAnnotation, sort.col.value)
228
229
  if sort.dir is SortDir.desc:
@@ -232,7 +233,9 @@ class Trace(Node):
232
233
  else:
233
234
  stmt = stmt.order_by(models.TraceAnnotation.created_at.desc())
234
235
  annotations = await session.scalars(stmt)
235
- return [to_gql_trace_annotation(annotation) for annotation in annotations]
236
+ return [
237
+ TraceAnnotation(id=annotation.id, db_record=annotation) for annotation in annotations
238
+ ]
236
239
 
237
240
  @strawberry.field(description="Summarizes each annotation (by name) associated with the trace") # type: ignore
238
241
  async def trace_annotation_summaries(
@@ -257,9 +260,7 @@ class Trace(Node):
257
260
  - data: A list of dictionaries with label statistics
258
261
  """
259
262
  # Load all annotations for this span from the data loader
260
- annotations = await info.context.data_loaders.trace_annotations_by_trace.load(
261
- self.trace_rowid
262
- )
263
+ annotations = await info.context.data_loaders.trace_annotations_by_trace.load(self.id)
263
264
 
264
265
  # Apply filter if provided to narrow down the annotations
265
266
  if filter:
@@ -296,7 +297,7 @@ class Trace(Node):
296
297
  info: Info[Context, None],
297
298
  ) -> SpanCostSummary:
298
299
  loader = info.context.data_loaders.span_cost_summary_by_trace
299
- summary = await loader.load(self.trace_rowid)
300
+ summary = await loader.load(self.id)
300
301
  return SpanCostSummary(
301
302
  prompt=CostBreakdown(
302
303
  tokens=summary.prompt.tokens,
@@ -318,7 +319,7 @@ class Trace(Node):
318
319
  info: Info[Context, None],
319
320
  ) -> list[SpanCostDetailSummaryEntry]:
320
321
  loader = info.context.data_loaders.span_cost_detail_summary_entries_by_trace
321
- entries = await loader.load(self.trace_rowid)
322
+ entries = await loader.load(self.id)
322
323
  return [
323
324
  SpanCostDetailSummaryEntry(
324
325
  token_type=entry.token_type,