arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (70) hide show
  1. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +65 -210
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/static/.vite/manifest.json +9 -9
  63. phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
  64. phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
  65. phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
  66. phoenix/version.py +1 -1
  67. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  68. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  69. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  70. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,13 +1,13 @@
1
1
  from collections import defaultdict
2
2
  from dataclasses import asdict, dataclass
3
3
  from datetime import datetime
4
- from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Type
4
+ from typing import TYPE_CHECKING, Annotated, Optional
5
5
 
6
6
  import pandas as pd
7
7
  import strawberry
8
8
  from openinference.semconv.trace import SpanAttributes
9
9
  from sqlalchemy import select
10
- from strawberry import UNSET, Info, Private, lazy
10
+ from strawberry import UNSET, Info, lazy
11
11
  from strawberry.relay import Connection, Node, NodeID
12
12
 
13
13
  from phoenix.db import models
@@ -30,12 +30,51 @@ if TYPE_CHECKING:
30
30
 
31
31
  @strawberry.type
32
32
  class ProjectSession(Node):
33
- _table: ClassVar[Type[models.ProjectSession]] = models.ProjectSession
34
- id_attr: NodeID[int]
35
- project_rowid: Private[int]
36
- session_id: str
37
- start_time: datetime
38
- end_time: datetime
33
+ id: NodeID[int]
34
+ db_record: strawberry.Private[Optional[models.ProjectSession]] = None
35
+
36
+ def __post_init__(self) -> None:
37
+ if self.db_record and self.id != self.db_record.id:
38
+ raise ValueError("ProjectSession ID mismatch")
39
+
40
+ @strawberry.field
41
+ async def session_id(
42
+ self,
43
+ info: Info[Context, None],
44
+ ) -> str:
45
+ if self.db_record:
46
+ val = self.db_record.session_id
47
+ else:
48
+ val = await info.context.data_loaders.project_session_fields.load(
49
+ (self.id, models.ProjectSession.session_id),
50
+ )
51
+ return val
52
+
53
+ @strawberry.field
54
+ async def start_time(
55
+ self,
56
+ info: Info[Context, None],
57
+ ) -> datetime:
58
+ if self.db_record:
59
+ val = self.db_record.start_time
60
+ else:
61
+ val = await info.context.data_loaders.project_session_fields.load(
62
+ (self.id, models.ProjectSession.start_time),
63
+ )
64
+ return val
65
+
66
+ @strawberry.field
67
+ async def end_time(
68
+ self,
69
+ info: Info[Context, None],
70
+ ) -> datetime:
71
+ if self.db_record:
72
+ val = self.db_record.end_time
73
+ else:
74
+ val = await info.context.data_loaders.project_session_fields.load(
75
+ (self.id, models.ProjectSession.end_time),
76
+ )
77
+ return val
39
78
 
40
79
  @strawberry.field
41
80
  async def project(
@@ -44,28 +83,34 @@ class ProjectSession(Node):
44
83
  ) -> Annotated["Project", lazy(".Project")]:
45
84
  from phoenix.server.api.types.Project import Project
46
85
 
47
- return Project(project_rowid=self.project_rowid)
86
+ if self.db_record:
87
+ project_rowid = self.db_record.project_id
88
+ else:
89
+ project_rowid = await info.context.data_loaders.project_session_fields.load(
90
+ (self.id, models.ProjectSession.project_id),
91
+ )
92
+ return Project(id=project_rowid)
48
93
 
49
94
  @strawberry.field
50
95
  async def num_traces(
51
96
  self,
52
97
  info: Info[Context, None],
53
98
  ) -> int:
54
- return await info.context.data_loaders.session_num_traces.load(self.id_attr)
99
+ return await info.context.data_loaders.session_num_traces.load(self.id)
55
100
 
56
101
  @strawberry.field
57
102
  async def num_traces_with_error(
58
103
  self,
59
104
  info: Info[Context, None],
60
105
  ) -> int:
61
- return await info.context.data_loaders.session_num_traces_with_error.load(self.id_attr)
106
+ return await info.context.data_loaders.session_num_traces_with_error.load(self.id)
62
107
 
63
108
  @strawberry.field
64
109
  async def first_input(
65
110
  self,
66
111
  info: Info[Context, None],
67
112
  ) -> Optional[SpanIOValue]:
68
- record = await info.context.data_loaders.session_first_inputs.load(self.id_attr)
113
+ record = await info.context.data_loaders.session_first_inputs.load(self.id)
69
114
  if record is None:
70
115
  return None
71
116
  return SpanIOValue(
@@ -78,7 +123,7 @@ class ProjectSession(Node):
78
123
  self,
79
124
  info: Info[Context, None],
80
125
  ) -> Optional[SpanIOValue]:
81
- record = await info.context.data_loaders.session_last_outputs.load(self.id_attr)
126
+ record = await info.context.data_loaders.session_last_outputs.load(self.id)
82
127
  if record is None:
83
128
  return None
84
129
  return SpanIOValue(
@@ -91,7 +136,7 @@ class ProjectSession(Node):
91
136
  self,
92
137
  info: Info[Context, None],
93
138
  ) -> TokenUsage:
94
- usage = await info.context.data_loaders.session_token_usages.load(self.id_attr)
139
+ usage = await info.context.data_loaders.session_token_usages.load(self.id)
95
140
  return TokenUsage(
96
141
  prompt=usage.prompt,
97
142
  completion=usage.completion,
@@ -116,12 +161,12 @@ class ProjectSession(Node):
116
161
  )
117
162
  stmt = (
118
163
  select(models.Trace)
119
- .filter_by(project_session_rowid=self.id_attr)
164
+ .filter_by(project_session_rowid=self.id)
120
165
  .order_by(models.Trace.start_time)
121
166
  )
122
167
  async with info.context.db() as session:
123
168
  traces = await session.stream_scalars(stmt)
124
- data = [Trace(trace_rowid=trace.id, db_trace=trace) async for trace in traces]
169
+ data = [Trace(id=trace.id, db_record=trace) async for trace in traces]
125
170
  return connection_from_list(data=data, args=args)
126
171
 
127
172
  @strawberry.field
@@ -131,7 +176,7 @@ class ProjectSession(Node):
131
176
  probability: float,
132
177
  ) -> Optional[float]:
133
178
  return await info.context.data_loaders.session_trace_latency_ms_quantile.load(
134
- (self.id_attr, probability)
179
+ (self.id, probability)
135
180
  )
136
181
 
137
182
  @strawberry.field
@@ -140,7 +185,7 @@ class ProjectSession(Node):
140
185
  info: Info[Context, None],
141
186
  ) -> SpanCostSummary:
142
187
  loader = info.context.data_loaders.span_cost_summary_by_project_session
143
- summary = await loader.load(self.id_attr)
188
+ summary = await loader.load(self.id)
144
189
  return SpanCostSummary(
145
190
  prompt=CostBreakdown(
146
191
  tokens=summary.prompt.tokens,
@@ -162,7 +207,7 @@ class ProjectSession(Node):
162
207
  info: Info[Context, None],
163
208
  ) -> list[SpanCostDetailSummaryEntry]:
164
209
  loader = info.context.data_loaders.span_cost_detail_summary_entries_by_project_session
165
- summary = await loader.load(self.id_attr)
210
+ summary = await loader.load(self.id)
166
211
  return [
167
212
  SpanCostDetailSummaryEntry(
168
213
  token_type=entry.token_type,
@@ -181,15 +226,14 @@ class ProjectSession(Node):
181
226
  info: Info[Context, None],
182
227
  ) -> list[Annotated["ProjectSessionAnnotation", lazy(".ProjectSessionAnnotation")]]:
183
228
  """Get all annotations for this session."""
184
- from phoenix.server.api.types.ProjectSessionAnnotation import (
185
- to_gql_project_session_annotation,
186
- )
229
+ from .ProjectSessionAnnotation import ProjectSessionAnnotation
187
230
 
188
- stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.id_attr)
231
+ stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.id)
189
232
  async with info.context.db() as session:
190
233
  annotations = await session.stream_scalars(stmt)
191
234
  return [
192
- to_gql_project_session_annotation(annotation) async for annotation in annotations
235
+ ProjectSessionAnnotation(id=annotation.id, db_record=annotation)
236
+ async for annotation in annotations
193
237
  ]
194
238
 
195
239
  @strawberry.field(
@@ -217,9 +261,7 @@ class ProjectSession(Node):
217
261
  - data: A list of dictionaries with label statistics
218
262
  """
219
263
  # Load all annotations for this span from the data loader
220
- annotations = await info.context.data_loaders.session_annotations_by_session.load(
221
- self.id_attr
222
- )
264
+ annotations = await info.context.data_loaders.session_annotations_by_session.load(self.id)
223
265
 
224
266
  # Apply filter if provided to narrow down the annotations
225
267
  if filter:
@@ -251,16 +293,6 @@ class ProjectSession(Node):
251
293
  return result
252
294
 
253
295
 
254
- def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
255
- return ProjectSession(
256
- id_attr=project_session.id,
257
- session_id=project_session.session_id,
258
- start_time=project_session.start_time,
259
- project_rowid=project_session.project_id,
260
- end_time=project_session.end_time,
261
- )
262
-
263
-
264
296
  INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
265
297
  INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
266
298
  OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
@@ -1,7 +1,7 @@
1
- from typing import Optional
1
+ from math import isfinite
2
+ from typing import TYPE_CHECKING, Annotated, Optional
2
3
 
3
4
  import strawberry
4
- from strawberry import Private
5
5
  from strawberry.relay import GlobalID, Node, NodeID
6
6
  from strawberry.scalars import JSON
7
7
  from strawberry.types import Info
@@ -10,59 +10,178 @@ from phoenix.db import models
10
10
  from phoenix.server.api.context import Context
11
11
  from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
12
12
 
13
+ from .Annotation import Annotation
13
14
  from .AnnotationSource import AnnotationSource
14
- from .User import User, to_gql_user
15
+
16
+ if TYPE_CHECKING:
17
+ from .ProjectSession import ProjectSession
18
+ from .User import User
15
19
 
16
20
 
17
21
  @strawberry.type
18
- class ProjectSessionAnnotation(Node):
19
- id_attr: NodeID[int]
20
- user_id: Private[Optional[int]]
21
- name: str
22
- annotator_kind: AnnotatorKind
23
- label: Optional[str]
24
- score: Optional[float]
25
- explanation: Optional[str]
26
- metadata: JSON
27
- _project_session_id: Private[Optional[int]]
28
- identifier: str
29
- source: AnnotationSource
30
-
31
- @strawberry.field
32
- async def project_session_id(self) -> GlobalID:
22
+ class ProjectSessionAnnotation(Node, Annotation):
23
+ id: NodeID[int]
24
+ db_record: strawberry.Private[Optional[models.ProjectSessionAnnotation]] = None
25
+
26
+ def __post_init__(self) -> None:
27
+ if self.db_record and self.id != self.db_record.id:
28
+ raise ValueError("ProjectSessionAnnotation ID mismatch")
29
+
30
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
31
+ async def name(
32
+ self,
33
+ info: Info[Context, None],
34
+ ) -> str:
35
+ if self.db_record:
36
+ val = self.db_record.name
37
+ else:
38
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
39
+ (self.id, models.ProjectSessionAnnotation.name),
40
+ )
41
+ return val
42
+
43
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
44
+ async def annotator_kind(
45
+ self,
46
+ info: Info[Context, None],
47
+ ) -> AnnotatorKind:
48
+ if self.db_record:
49
+ val = self.db_record.annotator_kind
50
+ else:
51
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
52
+ (self.id, models.ProjectSessionAnnotation.annotator_kind),
53
+ )
54
+ return AnnotatorKind(val)
55
+
56
+ @strawberry.field(
57
+ description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
58
+ ) # type: ignore
59
+ async def label(
60
+ self,
61
+ info: Info[Context, None],
62
+ ) -> Optional[str]:
63
+ if self.db_record:
64
+ val = self.db_record.label
65
+ else:
66
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
67
+ (self.id, models.ProjectSessionAnnotation.label),
68
+ )
69
+ return val
70
+
71
+ @strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
72
+ async def score(
73
+ self,
74
+ info: Info[Context, None],
75
+ ) -> Optional[float]:
76
+ if self.db_record:
77
+ val = self.db_record.score
78
+ else:
79
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
80
+ (self.id, models.ProjectSessionAnnotation.score),
81
+ )
82
+ return val if val is not None and isfinite(val) else None
83
+
84
+ @strawberry.field(
85
+ description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
86
+ ) # type: ignore
87
+ async def explanation(
88
+ self,
89
+ info: Info[Context, None],
90
+ ) -> Optional[str]:
91
+ if self.db_record:
92
+ val = self.db_record.explanation
93
+ else:
94
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
95
+ (self.id, models.ProjectSessionAnnotation.explanation),
96
+ )
97
+ return val
98
+
99
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
100
+ async def metadata(
101
+ self,
102
+ info: Info[Context, None],
103
+ ) -> JSON:
104
+ if self.db_record:
105
+ val = self.db_record.metadata_
106
+ else:
107
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
108
+ (self.id, models.ProjectSessionAnnotation.metadata_),
109
+ )
110
+ return val
111
+
112
+ @strawberry.field(description="The identifier of the annotation.") # type: ignore
113
+ async def identifier(
114
+ self,
115
+ info: Info[Context, None],
116
+ ) -> str:
117
+ if self.db_record:
118
+ val = self.db_record.identifier
119
+ else:
120
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
121
+ (self.id, models.ProjectSessionAnnotation.identifier),
122
+ )
123
+ return val
124
+
125
+ @strawberry.field(description="The source of the annotation.") # type: ignore
126
+ async def source(
127
+ self,
128
+ info: Info[Context, None],
129
+ ) -> AnnotationSource:
130
+ if self.db_record:
131
+ val = self.db_record.source
132
+ else:
133
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
134
+ (self.id, models.ProjectSessionAnnotation.source),
135
+ )
136
+ return AnnotationSource(val)
137
+
138
+ @strawberry.field(description="The project session associated with the annotation.") # type: ignore
139
+ async def project_session_id(
140
+ self,
141
+ info: Info[Context, None],
142
+ ) -> GlobalID:
33
143
  from phoenix.server.api.types.ProjectSession import ProjectSession
34
144
 
35
- return GlobalID(type_name=ProjectSession.__name__, node_id=str(self._project_session_id))
145
+ if self.db_record:
146
+ project_session_id = self.db_record.project_session_id
147
+ else:
148
+ project_session_id = (
149
+ await info.context.data_loaders.project_session_annotation_fields.load(
150
+ (self.id, models.ProjectSessionAnnotation.project_session_id),
151
+ )
152
+ )
153
+ return GlobalID(type_name=ProjectSession.__name__, node_id=str(project_session_id))
154
+
155
+ @strawberry.field(description="The project session associated with the annotation.") # type: ignore
156
+ async def project_session(
157
+ self,
158
+ info: Info[Context, None],
159
+ ) -> Annotated["ProjectSession", strawberry.lazy(".ProjectSession")]:
160
+ if self.db_record:
161
+ project_session_id = self.db_record.project_session_id
162
+ else:
163
+ project_session_id = (
164
+ await info.context.data_loaders.project_session_annotation_fields.load(
165
+ (self.id, models.ProjectSessionAnnotation.project_session_id),
166
+ )
167
+ )
168
+ from .ProjectSession import ProjectSession
169
+
170
+ return ProjectSession(id=project_session_id)
36
171
 
37
- @strawberry.field
172
+ @strawberry.field(description="The user that produced the annotation.") # type: ignore
38
173
  async def user(
39
174
  self,
40
175
  info: Info[Context, None],
41
- ) -> Optional[User]:
42
- if self.user_id is None:
176
+ ) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
177
+ if self.db_record:
178
+ user_id = self.db_record.user_id
179
+ else:
180
+ user_id = await info.context.data_loaders.project_session_annotation_fields.load(
181
+ (self.id, models.ProjectSessionAnnotation.user_id),
182
+ )
183
+ if user_id is None:
43
184
  return None
44
- user = await info.context.data_loaders.users.load(self.user_id)
45
- if user is None:
46
- return None
47
- return to_gql_user(user)
48
-
49
-
50
- def to_gql_project_session_annotation(
51
- annotation: models.ProjectSessionAnnotation,
52
- ) -> ProjectSessionAnnotation:
53
- """
54
- Converts an ORM projectSession annotation to a GraphQL ProjectSessionAnnotation.
55
- """
56
- return ProjectSessionAnnotation(
57
- id_attr=annotation.id,
58
- user_id=annotation.user_id,
59
- _project_session_id=annotation.project_session_id,
60
- name=annotation.name,
61
- annotator_kind=AnnotatorKind(annotation.annotator_kind),
62
- label=annotation.label,
63
- score=annotation.score,
64
- explanation=annotation.explanation,
65
- metadata=JSON(annotation.metadata_),
66
- identifier=annotation.identifier,
67
- source=AnnotationSource(annotation.source),
68
- )
185
+ from .User import User
186
+
187
+ return User(id=user_id)
@@ -106,5 +106,5 @@ class ProjectTraceRetentionPolicy(Node):
106
106
  project_rowids = await info.context.data_loaders.projects_by_trace_retention_policy_id.load(
107
107
  self.id
108
108
  )
109
- data = [Project(project_rowid=project_rowid) for project_rowid in project_rowids]
109
+ data = [Project(id=project_rowid) for project_rowid in project_rowids]
110
110
  return connection_from_list(data=data, args=args)
@@ -20,21 +20,76 @@ from phoenix.server.api.types.pagination import (
20
20
  connection_from_list,
21
21
  )
22
22
 
23
- from .PromptLabel import PromptLabel, to_gql_prompt_label
23
+ from .PromptLabel import PromptLabel
24
24
  from .PromptVersion import (
25
25
  PromptVersion,
26
26
  to_gql_prompt_version,
27
27
  )
28
- from .PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
28
+ from .PromptVersionTag import PromptVersionTag
29
29
 
30
30
 
31
31
  @strawberry.type
32
32
  class Prompt(Node):
33
- id_attr: NodeID[int]
34
- source_prompt_id: Optional[GlobalID]
35
- name: Identifier
36
- description: Optional[str]
37
- created_at: datetime
33
+ id: NodeID[int]
34
+ db_record: strawberry.Private[Optional[models.Prompt]] = None
35
+
36
+ def __post_init__(self) -> None:
37
+ if self.db_record and self.id != self.db_record.id:
38
+ raise ValueError("Prompt ID mismatch")
39
+
40
+ @strawberry.field
41
+ async def source_prompt_id(
42
+ self,
43
+ info: Info[Context, None],
44
+ ) -> Optional[GlobalID]:
45
+ if self.db_record:
46
+ source_id = self.db_record.source_prompt_id
47
+ else:
48
+ source_id = await info.context.data_loaders.prompt_fields.load(
49
+ (self.id, models.Prompt.source_prompt_id),
50
+ )
51
+ if not source_id:
52
+ return None
53
+ return GlobalID(Prompt.__name__, str(source_id))
54
+
55
+ @strawberry.field
56
+ async def name(
57
+ self,
58
+ info: Info[Context, None],
59
+ ) -> Identifier:
60
+ if self.db_record:
61
+ val = self.db_record.name
62
+ else:
63
+ val = await info.context.data_loaders.prompt_fields.load(
64
+ (self.id, models.Prompt.name),
65
+ )
66
+ return Identifier(val.root)
67
+
68
+ @strawberry.field
69
+ async def description(
70
+ self,
71
+ info: Info[Context, None],
72
+ ) -> Optional[str]:
73
+ if self.db_record:
74
+ val = self.db_record.description
75
+ else:
76
+ val = await info.context.data_loaders.prompt_fields.load(
77
+ (self.id, models.Prompt.description),
78
+ )
79
+ return val
80
+
81
+ @strawberry.field
82
+ async def created_at(
83
+ self,
84
+ info: Info[Context, None],
85
+ ) -> datetime:
86
+ if self.db_record:
87
+ val = self.db_record.created_at
88
+ else:
89
+ val = await info.context.data_loaders.prompt_fields.load(
90
+ (self.id, models.Prompt.created_at),
91
+ )
92
+ return val
38
93
 
39
94
  @strawberry.field
40
95
  async def version(
@@ -49,7 +104,7 @@ class Prompt(Node):
49
104
  version = await session.scalar(
50
105
  select(models.PromptVersion).where(
51
106
  models.PromptVersion.id == v_id,
52
- models.PromptVersion.prompt_id == self.id_attr,
107
+ models.PromptVersion.prompt_id == self.id,
53
108
  )
54
109
  )
55
110
  if not version:
@@ -61,7 +116,7 @@ class Prompt(Node):
61
116
  raise NotFound(f"Prompt version tag not found: {tag_name}")
62
117
  version = await session.scalar(
63
118
  select(models.PromptVersion)
64
- .where(models.PromptVersion.prompt_id == self.id_attr)
119
+ .where(models.PromptVersion.prompt_id == self.id)
65
120
  .join_from(models.PromptVersion, models.PromptVersionTag)
66
121
  .where(models.PromptVersionTag.name == name)
67
122
  )
@@ -70,7 +125,7 @@ class Prompt(Node):
70
125
  else:
71
126
  stmt = (
72
127
  select(models.PromptVersion)
73
- .where(models.PromptVersion.prompt_id == self.id_attr)
128
+ .where(models.PromptVersion.prompt_id == self.id)
74
129
  .order_by(models.PromptVersion.id.desc())
75
130
  .limit(1)
76
131
  )
@@ -83,10 +138,11 @@ class Prompt(Node):
83
138
  async def version_tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
84
139
  async with info.context.db() as session:
85
140
  stmt = select(models.PromptVersionTag).where(
86
- models.PromptVersionTag.prompt_id == self.id_attr
141
+ models.PromptVersionTag.prompt_id == self.id
87
142
  )
88
143
  return [
89
- to_gql_prompt_version_tag(tag) async for tag in await session.stream_scalars(stmt)
144
+ PromptVersionTag(id=tag.id, db_record=tag)
145
+ async for tag in await session.stream_scalars(stmt)
90
146
  ]
91
147
 
92
148
  @strawberry.field
@@ -107,7 +163,7 @@ class Prompt(Node):
107
163
  row_number = func.row_number().over(order_by=models.PromptVersion.id).label("row_number")
108
164
  stmt = (
109
165
  select(models.PromptVersion, row_number)
110
- .where(models.PromptVersion.prompt_id == self.id_attr)
166
+ .where(models.PromptVersion.prompt_id == self.id)
111
167
  .order_by(models.PromptVersion.id.desc())
112
168
  )
113
169
  async with info.context.db() as session:
@@ -119,20 +175,19 @@ class Prompt(Node):
119
175
 
120
176
  @strawberry.field
121
177
  async def source_prompt(self, info: Info[Context, None]) -> Optional["Prompt"]:
122
- if not self.source_prompt_id:
178
+ if self.db_record:
179
+ id_ = self.db_record.source_prompt_id
180
+ else:
181
+ id_ = await info.context.data_loaders.prompt_fields.load(
182
+ (self.id, models.Prompt.source_prompt_id),
183
+ )
184
+ if not id_:
123
185
  return None
124
-
125
- source_prompt_id = from_global_id_with_expected_type(
126
- global_id=self.source_prompt_id, expected_type_name=Prompt.__name__
127
- )
128
-
129
186
  async with info.context.db() as session:
130
- source_prompt = await session.scalar(
131
- select(models.Prompt).where(models.Prompt.id == source_prompt_id)
132
- )
133
- if not source_prompt:
134
- raise NotFound(f"Source prompt not found: {self.source_prompt_id}")
135
- return to_gql_prompt_from_orm(source_prompt)
187
+ source_prompt = await session.get(models.Prompt, id_)
188
+ if not source_prompt:
189
+ raise NotFound(f"Source prompt not found: {id_}")
190
+ return Prompt(id=source_prompt.id, db_record=source_prompt)
136
191
 
137
192
  @strawberry.field
138
193
  async def labels(self, info: Info[Context, None]) -> list["PromptLabel"]:
@@ -140,23 +195,6 @@ class Prompt(Node):
140
195
  labels = await session.scalars(
141
196
  select(models.PromptLabel)
142
197
  .join(models.PromptPromptLabel)
143
- .where(models.PromptPromptLabel.prompt_id == self.id_attr)
198
+ .where(models.PromptPromptLabel.prompt_id == self.id)
144
199
  )
145
- return [to_gql_prompt_label(label) for label in labels]
146
-
147
-
148
- def to_gql_prompt_from_orm(orm_model: "models.Prompt") -> Prompt:
149
- if not orm_model.source_prompt_id:
150
- source_prompt_gid = None
151
- else:
152
- source_prompt_gid = GlobalID(
153
- Prompt.__name__,
154
- str(orm_model.source_prompt_id),
155
- )
156
- return Prompt(
157
- id_attr=orm_model.id,
158
- source_prompt_id=source_prompt_gid,
159
- name=Identifier(orm_model.name.root),
160
- description=orm_model.description,
161
- created_at=orm_model.created_at,
162
- )
200
+ return [PromptLabel(id=label.id, db_record=label) for label in labels]