arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (70) hide show
  1. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +65 -210
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/static/.vite/manifest.json +9 -9
  63. phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
  64. phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
  65. phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
  66. phoenix/version.py +1 -1
  67. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  68. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  69. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  70. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
1
1
  import operator
2
2
  from datetime import datetime, timezone
3
- from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, Optional, cast
3
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
4
4
 
5
5
  import strawberry
6
6
  from aioitertools.itertools import groupby, islice
@@ -9,7 +9,7 @@ from sqlalchemy import and_, case, desc, distinct, exists, func, or_, select
9
9
  from sqlalchemy.dialects import postgresql, sqlite
10
10
  from sqlalchemy.sql.expression import tuple_
11
11
  from sqlalchemy.sql.functions import percentile_cont
12
- from strawberry import ID, UNSET, Private, lazy
12
+ from strawberry import ID, UNSET, lazy
13
13
  from strawberry.relay import Connection, Edge, Node, NodeID, PageInfo
14
14
  from strawberry.types import Info
15
15
  from typing_extensions import assert_never
@@ -30,7 +30,7 @@ from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_a
30
30
  from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
31
31
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
32
32
  from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary
33
- from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
33
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel
34
34
  from phoenix.server.api.types.pagination import (
35
35
  ConnectionArgs,
36
36
  Cursor,
@@ -40,7 +40,7 @@ from phoenix.server.api.types.pagination import (
40
40
  connection_from_cursors_and_nodes,
41
41
  connection_from_list,
42
42
  )
43
- from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
43
+ from phoenix.server.api.types.ProjectSession import ProjectSession
44
44
  from phoenix.server.api.types.SortDir import SortDir
45
45
  from phoenix.server.api.types.Span import Span
46
46
  from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
@@ -58,12 +58,11 @@ if TYPE_CHECKING:
58
58
 
59
59
  @strawberry.type
60
60
  class Project(Node):
61
- _table: ClassVar[type[models.Base]] = models.Project
62
- project_rowid: NodeID[int]
63
- db_project: Private[models.Project] = UNSET
61
+ id: NodeID[int]
62
+ db_record: strawberry.Private[Optional[models.Project]] = None
64
63
 
65
64
  def __post_init__(self) -> None:
66
- if self.db_project and self.project_rowid != self.db_project.id:
65
+ if self.db_record and self.id != self.db_record.id:
67
66
  raise ValueError("Project ID mismatch")
68
67
 
69
68
  @strawberry.field
@@ -71,11 +70,11 @@ class Project(Node):
71
70
  self,
72
71
  info: Info[Context, None],
73
72
  ) -> str:
74
- if self.db_project:
75
- name = self.db_project.name
73
+ if self.db_record:
74
+ name = self.db_record.name
76
75
  else:
77
76
  name = await info.context.data_loaders.project_fields.load(
78
- (self.project_rowid, models.Project.name),
77
+ (self.id, models.Project.name),
79
78
  )
80
79
  return name
81
80
 
@@ -84,11 +83,11 @@ class Project(Node):
84
83
  self,
85
84
  info: Info[Context, None],
86
85
  ) -> str:
87
- if self.db_project:
88
- gradient_start_color = self.db_project.gradient_start_color
86
+ if self.db_record:
87
+ gradient_start_color = self.db_record.gradient_start_color
89
88
  else:
90
89
  gradient_start_color = await info.context.data_loaders.project_fields.load(
91
- (self.project_rowid, models.Project.gradient_start_color),
90
+ (self.id, models.Project.gradient_start_color),
92
91
  )
93
92
  return gradient_start_color
94
93
 
@@ -97,11 +96,11 @@ class Project(Node):
97
96
  self,
98
97
  info: Info[Context, None],
99
98
  ) -> str:
100
- if self.db_project:
101
- gradient_end_color = self.db_project.gradient_end_color
99
+ if self.db_record:
100
+ gradient_end_color = self.db_record.gradient_end_color
102
101
  else:
103
102
  gradient_end_color = await info.context.data_loaders.project_fields.load(
104
- (self.project_rowid, models.Project.gradient_end_color),
103
+ (self.id, models.Project.gradient_end_color),
105
104
  )
106
105
  return gradient_end_color
107
106
 
@@ -111,7 +110,7 @@ class Project(Node):
111
110
  info: Info[Context, None],
112
111
  ) -> Optional[datetime]:
113
112
  start_time = await info.context.data_loaders.min_start_or_max_end_times.load(
114
- (self.project_rowid, "start"),
113
+ (self.id, "start"),
115
114
  )
116
115
  start_time, _ = right_open_time_range(start_time, None)
117
116
  return start_time
@@ -122,7 +121,7 @@ class Project(Node):
122
121
  info: Info[Context, None],
123
122
  ) -> Optional[datetime]:
124
123
  end_time = await info.context.data_loaders.min_start_or_max_end_times.load(
125
- (self.project_rowid, "end"),
124
+ (self.id, "end"),
126
125
  )
127
126
  _, end_time = right_open_time_range(None, end_time)
128
127
  return end_time
@@ -143,7 +142,7 @@ class Project(Node):
143
142
  return await info.context.data_loaders.record_counts.load(
144
143
  (
145
144
  "span",
146
- self.project_rowid,
145
+ self.id,
147
146
  time_range or None,
148
147
  filter_condition or None,
149
148
  session_filter_condition or None,
@@ -166,7 +165,7 @@ class Project(Node):
166
165
  return await info.context.data_loaders.record_counts.load(
167
166
  (
168
167
  "trace",
169
- self.project_rowid,
168
+ self.id,
170
169
  time_range or None,
171
170
  filter_condition or None,
172
171
  session_filter_condition or None,
@@ -181,7 +180,7 @@ class Project(Node):
181
180
  filter_condition: Optional[str] = UNSET,
182
181
  ) -> float:
183
182
  return await info.context.data_loaders.token_counts.load(
184
- ("total", self.project_rowid, time_range, filter_condition),
183
+ ("total", self.id, time_range, filter_condition),
185
184
  )
186
185
 
187
186
  @strawberry.field
@@ -192,7 +191,7 @@ class Project(Node):
192
191
  filter_condition: Optional[str] = UNSET,
193
192
  ) -> float:
194
193
  return await info.context.data_loaders.token_counts.load(
195
- ("prompt", self.project_rowid, time_range, filter_condition),
194
+ ("prompt", self.id, time_range, filter_condition),
196
195
  )
197
196
 
198
197
  @strawberry.field
@@ -203,7 +202,7 @@ class Project(Node):
203
202
  filter_condition: Optional[str] = UNSET,
204
203
  ) -> float:
205
204
  return await info.context.data_loaders.token_counts.load(
206
- ("completion", self.project_rowid, time_range, filter_condition),
205
+ ("completion", self.id, time_range, filter_condition),
207
206
  )
208
207
 
209
208
  @strawberry.field
@@ -221,7 +220,7 @@ class Project(Node):
221
220
  )
222
221
  summary = await info.context.data_loaders.span_cost_summary_by_project.load(
223
222
  (
224
- self.project_rowid,
223
+ self.id,
225
224
  time_range or None,
226
225
  filter_condition or None,
227
226
  session_filter_condition or None,
@@ -259,7 +258,7 @@ class Project(Node):
259
258
  return await info.context.data_loaders.latency_ms_quantile.load(
260
259
  (
261
260
  "trace",
262
- self.project_rowid,
261
+ self.id,
263
262
  time_range or None,
264
263
  filter_condition or None,
265
264
  session_filter_condition or None,
@@ -284,7 +283,7 @@ class Project(Node):
284
283
  return await info.context.data_loaders.latency_ms_quantile.load(
285
284
  (
286
285
  "span",
287
- self.project_rowid,
286
+ self.id,
288
287
  time_range or None,
289
288
  filter_condition or None,
290
289
  session_filter_condition or None,
@@ -297,12 +296,12 @@ class Project(Node):
297
296
  stmt = (
298
297
  select(models.Trace)
299
298
  .where(models.Trace.trace_id == str(trace_id))
300
- .where(models.Trace.project_rowid == self.project_rowid)
299
+ .where(models.Trace.project_rowid == self.id)
301
300
  )
302
301
  async with info.context.db() as session:
303
302
  if (trace := await session.scalar(stmt)) is None:
304
303
  return None
305
- return Trace(trace_rowid=trace.id, db_trace=trace)
304
+ return Trace(id=trace.id, db_record=trace)
306
305
 
307
306
  @strawberry.field
308
307
  async def spans(
@@ -321,7 +320,7 @@ class Project(Node):
321
320
  if root_spans_only and not filter_condition and sort and sort.col is SpanColumn.startTime:
322
321
  return await _paginate_span_by_trace_start_time(
323
322
  db=info.context.db,
324
- project_rowid=self.project_rowid,
323
+ project_rowid=self.id,
325
324
  time_range=time_range,
326
325
  first=first,
327
326
  after=after,
@@ -332,7 +331,7 @@ class Project(Node):
332
331
  select(models.Span.id)
333
332
  .select_from(models.Span)
334
333
  .join(models.Trace)
335
- .where(models.Trace.project_rowid == self.project_rowid)
334
+ .where(models.Trace.project_rowid == self.id)
336
335
  )
337
336
  if time_range:
338
337
  if time_range.start:
@@ -401,7 +400,7 @@ class Project(Node):
401
400
  type=sort_config.column_data_type,
402
401
  value=span_record[1],
403
402
  )
404
- cursors_and_nodes.append((cursor, Span(span_rowid=span_rowid)))
403
+ cursors_and_nodes.append((cursor, Span(id=span_rowid)))
405
404
  has_next_page = True
406
405
  try:
407
406
  await span_records.__anext__()
@@ -431,12 +430,12 @@ class Project(Node):
431
430
  ans = await session.scalar(
432
431
  select(table).filter_by(
433
432
  session_id=session_id,
434
- project_id=self.project_rowid,
433
+ project_id=self.id,
435
434
  )
436
435
  )
437
436
  if ans:
438
437
  return connection_from_list(
439
- data=[to_gql_project_session(ans)],
438
+ data=[ProjectSession(id=ans.id, db_record=ans)],
440
439
  args=ConnectionArgs(),
441
440
  )
442
441
  elif not filter_io_substring:
@@ -444,7 +443,7 @@ class Project(Node):
444
443
  data=[],
445
444
  args=ConnectionArgs(),
446
445
  )
447
- stmt = select(table).filter_by(project_id=self.project_rowid)
446
+ stmt = select(table).filter_by(project_id=self.id)
448
447
  if time_range:
449
448
  if time_range.start:
450
449
  stmt = stmt.where(time_range.start <= table.start_time)
@@ -453,7 +452,7 @@ class Project(Node):
453
452
  if filter_io_substring:
454
453
  filtered_session_rowids = get_filtered_session_rowids_subquery(
455
454
  session_filter_condition=filter_io_substring,
456
- project_rowids=[self.project_rowid],
455
+ project_rowids=[self.id],
457
456
  start_time=time_range.start if time_range else None,
458
457
  end_time=time_range.end if time_range else None,
459
458
  )
@@ -499,7 +498,9 @@ class Project(Node):
499
498
  type=sort_config.column_data_type,
500
499
  value=record[1],
501
500
  )
502
- cursors_and_nodes.append((cursor, to_gql_project_session(project_session)))
501
+ cursors_and_nodes.append(
502
+ (cursor, ProjectSession(id=project_session.id, db_record=project_session))
503
+ )
503
504
  has_next_page = True
504
505
  try:
505
506
  await records.__anext__()
@@ -522,7 +523,7 @@ class Project(Node):
522
523
  stmt = (
523
524
  select(distinct(models.TraceAnnotation.name))
524
525
  .join(models.Trace)
525
- .where(models.Trace.project_rowid == self.project_rowid)
526
+ .where(models.Trace.project_rowid == self.id)
526
527
  )
527
528
  async with info.context.db() as session:
528
529
  return list(await session.scalars(stmt))
@@ -539,7 +540,7 @@ class Project(Node):
539
540
  select(distinct(models.SpanAnnotation.name))
540
541
  .join(models.Span)
541
542
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
542
- .where(models.Trace.project_rowid == self.project_rowid)
543
+ .where(models.Trace.project_rowid == self.id)
543
544
  )
544
545
  async with info.context.db() as session:
545
546
  return list(await session.scalars(stmt))
@@ -555,7 +556,7 @@ class Project(Node):
555
556
  stmt = (
556
557
  select(distinct(models.ProjectSessionAnnotation.name))
557
558
  .join(models.ProjectSession)
558
- .where(models.ProjectSession.project_id == self.project_rowid)
559
+ .where(models.ProjectSession.project_id == self.id)
559
560
  )
560
561
  async with info.context.db() as session:
561
562
  return list(await session.scalars(stmt))
@@ -572,7 +573,7 @@ class Project(Node):
572
573
  select(distinct(models.DocumentAnnotation.name))
573
574
  .join(models.Span)
574
575
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
575
- .where(models.Trace.project_rowid == self.project_rowid)
576
+ .where(models.Trace.project_rowid == self.id)
576
577
  .where(models.DocumentAnnotation.annotator_kind == "LLM")
577
578
  )
578
579
  if span_id:
@@ -597,7 +598,7 @@ class Project(Node):
597
598
  return await info.context.data_loaders.annotation_summaries.load(
598
599
  (
599
600
  "trace",
600
- self.project_rowid,
601
+ self.id,
601
602
  time_range or None,
602
603
  filter_condition or None,
603
604
  session_filter_condition or None,
@@ -622,7 +623,7 @@ class Project(Node):
622
623
  return await info.context.data_loaders.annotation_summaries.load(
623
624
  (
624
625
  "span",
625
- self.project_rowid,
626
+ self.id,
626
627
  time_range or None,
627
628
  filter_condition or None,
628
629
  session_filter_condition or None,
@@ -639,7 +640,7 @@ class Project(Node):
639
640
  filter_condition: Optional[str] = UNSET,
640
641
  ) -> Optional[DocumentEvaluationSummary]:
641
642
  return await info.context.data_loaders.document_evaluation_summaries.load(
642
- (self.project_rowid, time_range, filter_condition, evaluation_name),
643
+ (self.id, time_range, filter_condition, evaluation_name),
643
644
  )
644
645
 
645
646
  @strawberry.field
@@ -647,7 +648,7 @@ class Project(Node):
647
648
  self,
648
649
  info: Info[Context, None],
649
650
  ) -> Optional[datetime]:
650
- return info.context.last_updated_at.get(self._table, self.project_rowid)
651
+ return info.context.last_updated_at.get(models.Project, self.id)
651
652
 
652
653
  @strawberry.field
653
654
  async def validate_span_filter_condition(
@@ -708,7 +709,7 @@ class Project(Node):
708
709
  before=before if isinstance(before, CursorString) else None,
709
710
  )
710
711
  loader = info.context.data_loaders.annotation_configs_by_project
711
- configs = await loader.load(self.project_rowid)
712
+ configs = await loader.load(self.id)
712
713
  data = [to_gql_annotation_config(config) for config in configs]
713
714
  return connection_from_list(data=data, args=args)
714
715
 
@@ -719,9 +720,7 @@ class Project(Node):
719
720
  ) -> Annotated["ProjectTraceRetentionPolicy", lazy(".ProjectTraceRetentionPolicy")]:
720
721
  from .ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
721
722
 
722
- id_ = await info.context.data_loaders.trace_retention_policy_id_by_project_id.load(
723
- self.project_rowid
724
- )
723
+ id_ = await info.context.data_loaders.trace_retention_policy_id_by_project_id.load(self.id)
725
724
  return ProjectTraceRetentionPolicy(id=id_)
726
725
 
727
726
  @strawberry.field
@@ -729,11 +728,11 @@ class Project(Node):
729
728
  self,
730
729
  info: Info[Context, None],
731
730
  ) -> datetime:
732
- if self.db_project:
733
- created_at = self.db_project.created_at
731
+ if self.db_record:
732
+ created_at = self.db_record.created_at
734
733
  else:
735
734
  created_at = await info.context.data_loaders.project_fields.load(
736
- (self.project_rowid, models.Project.created_at),
735
+ (self.id, models.Project.created_at),
737
736
  )
738
737
  return created_at
739
738
 
@@ -742,11 +741,11 @@ class Project(Node):
742
741
  self,
743
742
  info: Info[Context, None],
744
743
  ) -> datetime:
745
- if self.db_project:
746
- updated_at = self.db_project.updated_at
744
+ if self.db_record:
745
+ updated_at = self.db_record.updated_at
747
746
  else:
748
747
  updated_at = await info.context.data_loaders.project_fields.load(
749
- (self.project_rowid, models.Project.updated_at),
748
+ (self.id, models.Project.updated_at),
750
749
  )
751
750
  return updated_at
752
751
 
@@ -792,7 +791,7 @@ class Project(Node):
792
791
  ),
793
792
  )
794
793
  .join_from(models.Span, models.Trace)
795
- .where(models.Trace.project_rowid == self.project_rowid)
794
+ .where(models.Trace.project_rowid == self.id)
796
795
  .group_by(bucket)
797
796
  .order_by(bucket)
798
797
  )
@@ -866,7 +865,7 @@ class Project(Node):
866
865
  bucket = date_trunc(dialect, field, models.Trace.start_time, utc_offset_minutes)
867
866
  stmt = (
868
867
  select(bucket, func.count(models.Trace.id))
869
- .where(models.Trace.project_rowid == self.project_rowid)
868
+ .where(models.Trace.project_rowid == self.id)
870
869
  .group_by(bucket)
871
870
  .order_by(bucket)
872
871
  )
@@ -949,7 +948,7 @@ class Project(Node):
949
948
  onclause=trace_error_status_counts.c.trace_rowid == models.Trace.id,
950
949
  isouter=True,
951
950
  )
952
- .where(models.Trace.project_rowid == self.project_rowid)
951
+ .where(models.Trace.project_rowid == self.id)
953
952
  .group_by(bucket)
954
953
  .order_by(bucket)
955
954
  )
@@ -1021,7 +1020,7 @@ class Project(Node):
1021
1020
  field = "year"
1022
1021
  bucket = date_trunc(dialect, field, models.Trace.start_time, utc_offset_minutes)
1023
1022
 
1024
- stmt = select(bucket).where(models.Trace.project_rowid == self.project_rowid)
1023
+ stmt = select(bucket).where(models.Trace.project_rowid == self.id)
1025
1024
  if time_range.start:
1026
1025
  stmt = stmt.where(time_range.start <= models.Trace.start_time)
1027
1026
  if time_range.end:
@@ -1136,7 +1135,7 @@ class Project(Node):
1136
1135
  models.SpanCost,
1137
1136
  onclause=models.SpanCost.trace_rowid == models.Trace.id,
1138
1137
  )
1139
- .where(models.Trace.project_rowid == self.project_rowid)
1138
+ .where(models.Trace.project_rowid == self.id)
1140
1139
  .group_by(bucket)
1141
1140
  .order_by(bucket)
1142
1141
  )
@@ -1219,7 +1218,7 @@ class Project(Node):
1219
1218
  models.SpanCost,
1220
1219
  onclause=models.SpanCost.trace_rowid == models.Trace.id,
1221
1220
  )
1222
- .where(models.Trace.project_rowid == self.project_rowid)
1221
+ .where(models.Trace.project_rowid == self.id)
1223
1222
  .group_by(bucket)
1224
1223
  .order_by(bucket)
1225
1224
  )
@@ -1306,7 +1305,7 @@ class Project(Node):
1306
1305
  models.Trace,
1307
1306
  onclause=models.Span.trace_rowid == models.Trace.id,
1308
1307
  )
1309
- .where(models.Trace.project_rowid == self.project_rowid)
1308
+ .where(models.Trace.project_rowid == self.id)
1310
1309
  .group_by(bucket, models.SpanAnnotation.name)
1311
1310
  .order_by(bucket)
1312
1311
  )
@@ -1393,7 +1392,7 @@ class Project(Node):
1393
1392
  models.Trace,
1394
1393
  models.SpanCost.trace_rowid == models.Trace.id,
1395
1394
  )
1396
- .where(models.Trace.project_rowid == self.project_rowid)
1395
+ .where(models.Trace.project_rowid == self.id)
1397
1396
  .where(models.SpanCost.model_id.isnot(None))
1398
1397
  .where(models.SpanCost.span_start_time >= time_range.start)
1399
1398
  .group_by(models.GenerativeModel.id)
@@ -1420,10 +1419,8 @@ class Project(Node):
1420
1419
  start=time_range.start,
1421
1420
  end=time_range.end,
1422
1421
  )
1423
- gql_model = to_gql_generative_model(model)
1424
- gql_model.add_cached_cost_summary(
1425
- self.project_rowid, cache_time_range, cost_summary
1426
- )
1422
+ gql_model = GenerativeModel(id=model.id, db_record=model)
1423
+ gql_model.add_cached_cost_summary(self.id, cache_time_range, cost_summary)
1427
1424
  results.append(gql_model)
1428
1425
  return results
1429
1426
 
@@ -1455,7 +1452,7 @@ class Project(Node):
1455
1452
  models.Trace,
1456
1453
  models.SpanCost.trace_rowid == models.Trace.id,
1457
1454
  )
1458
- .where(models.Trace.project_rowid == self.project_rowid)
1455
+ .where(models.Trace.project_rowid == self.id)
1459
1456
  .where(models.SpanCost.model_id.isnot(None))
1460
1457
  .where(models.SpanCost.span_start_time >= time_range.start)
1461
1458
  .group_by(models.GenerativeModel.id)
@@ -1482,10 +1479,8 @@ class Project(Node):
1482
1479
  start=time_range.start,
1483
1480
  end=time_range.end,
1484
1481
  )
1485
- gql_model = to_gql_generative_model(model)
1486
- gql_model.add_cached_cost_summary(
1487
- self.project_rowid, cache_time_range, cost_summary
1488
- )
1482
+ gql_model = GenerativeModel(id=model.id, db_record=model)
1483
+ gql_model.add_cached_cost_summary(self.id, cache_time_range, cost_summary)
1489
1484
  results.append(gql_model)
1490
1485
  return results
1491
1486
 
@@ -1755,7 +1750,7 @@ async def _paginate_span_by_trace_start_time(
1755
1750
  first_record = group[0]
1756
1751
  # Only create edge if trace has a root span
1757
1752
  if (span_rowid := first_record[2]) is not None:
1758
- edges.append(Edge(node=Span(span_rowid=span_rowid), cursor=str(cursor)))
1753
+ edges.append(Edge(node=Span(id=span_rowid), cursor=str(cursor)))
1759
1754
  has_next_page = True
1760
1755
  try:
1761
1756
  await records.__anext__()
@@ -1798,6 +1793,6 @@ def to_gql_project(project: models.Project) -> Project:
1798
1793
  Converts an ORM project to a GraphQL project.
1799
1794
  """
1800
1795
  return Project(
1801
- project_rowid=project.id,
1802
- db_project=project,
1796
+ id=project.id,
1797
+ db_record=project,
1803
1798
  )