arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,48 +1,54 @@
1
- from __future__ import annotations
2
-
3
1
  import operator
4
- from datetime import datetime, timedelta
5
- from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Optional
2
+ from datetime import datetime, timezone
3
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
6
4
 
7
5
  import strawberry
8
- from aioitertools.itertools import islice
6
+ from aioitertools.itertools import groupby, islice
9
7
  from openinference.semconv.trace import SpanAttributes
10
- from sqlalchemy import desc, distinct, func, or_, select
8
+ from sqlalchemy import and_, case, desc, distinct, exists, func, or_, select
11
9
  from sqlalchemy.dialects import postgresql, sqlite
12
- from sqlalchemy.sql.elements import ColumnElement
13
10
  from sqlalchemy.sql.expression import tuple_
14
- from strawberry import ID, UNSET, Private, lazy
15
- from strawberry.relay import Connection, Node, NodeID
11
+ from sqlalchemy.sql.functions import percentile_cont
12
+ from strawberry import ID, UNSET, lazy
13
+ from strawberry.relay import Connection, Edge, Node, NodeID, PageInfo
16
14
  from strawberry.types import Info
17
15
  from typing_extensions import assert_never
18
16
 
19
- from phoenix.datetime_utils import right_open_time_range
17
+ from phoenix.datetime_utils import get_timestamp_range, normalize_datetime, right_open_time_range
20
18
  from phoenix.db import models
21
- from phoenix.db.helpers import SupportedSQLDialect
19
+ from phoenix.db.helpers import SupportedSQLDialect, date_trunc
22
20
  from phoenix.server.api.context import Context
21
+ from phoenix.server.api.exceptions import BadRequest
23
22
  from phoenix.server.api.input_types.ProjectSessionSort import (
24
- ProjectSessionColumn,
25
23
  ProjectSessionSort,
24
+ ProjectSessionSortConfig,
26
25
  )
27
- from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortConfig
26
+ from phoenix.server.api.input_types.SpanSort import SpanColumn, SpanSort, SpanSortConfig
27
+ from phoenix.server.api.input_types.TimeBinConfig import TimeBinConfig, TimeBinScale
28
28
  from phoenix.server.api.input_types.TimeRange import TimeRange
29
29
  from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_annotation_config
30
30
  from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
31
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
31
32
  from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary
33
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel
32
34
  from phoenix.server.api.types.pagination import (
33
35
  ConnectionArgs,
34
36
  Cursor,
35
37
  CursorSortColumn,
38
+ CursorSortColumnDataType,
36
39
  CursorString,
37
40
  connection_from_cursors_and_nodes,
38
41
  connection_from_list,
39
42
  )
40
- from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
43
+ from phoenix.server.api.types.ProjectSession import ProjectSession
41
44
  from phoenix.server.api.types.SortDir import SortDir
42
45
  from phoenix.server.api.types.Span import Span
46
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
43
47
  from phoenix.server.api.types.TimeSeries import TimeSeries, TimeSeriesDataPoint
44
48
  from phoenix.server.api.types.Trace import Trace
45
49
  from phoenix.server.api.types.ValidationResult import ValidationResult
50
+ from phoenix.server.session_filters import get_filtered_session_rowids_subquery
51
+ from phoenix.server.types import DbSessionFactory
46
52
  from phoenix.trace.dsl import SpanFilter
47
53
 
48
54
  DEFAULT_PAGE_SIZE = 30
@@ -52,12 +58,11 @@ if TYPE_CHECKING:
52
58
 
53
59
  @strawberry.type
54
60
  class Project(Node):
55
- _table: ClassVar[type[models.Base]] = models.Project
56
- project_rowid: NodeID[int]
57
- db_project: Private[models.Project] = UNSET
61
+ id: NodeID[int]
62
+ db_record: strawberry.Private[Optional[models.Project]] = None
58
63
 
59
64
  def __post_init__(self) -> None:
60
- if self.db_project and self.project_rowid != self.db_project.id:
65
+ if self.db_record and self.id != self.db_record.id:
61
66
  raise ValueError("Project ID mismatch")
62
67
 
63
68
  @strawberry.field
@@ -65,11 +70,11 @@ class Project(Node):
65
70
  self,
66
71
  info: Info[Context, None],
67
72
  ) -> str:
68
- if self.db_project:
69
- name = self.db_project.name
73
+ if self.db_record:
74
+ name = self.db_record.name
70
75
  else:
71
76
  name = await info.context.data_loaders.project_fields.load(
72
- (self.project_rowid, models.Project.name),
77
+ (self.id, models.Project.name),
73
78
  )
74
79
  return name
75
80
 
@@ -78,11 +83,11 @@ class Project(Node):
78
83
  self,
79
84
  info: Info[Context, None],
80
85
  ) -> str:
81
- if self.db_project:
82
- gradient_start_color = self.db_project.gradient_start_color
86
+ if self.db_record:
87
+ gradient_start_color = self.db_record.gradient_start_color
83
88
  else:
84
89
  gradient_start_color = await info.context.data_loaders.project_fields.load(
85
- (self.project_rowid, models.Project.gradient_start_color),
90
+ (self.id, models.Project.gradient_start_color),
86
91
  )
87
92
  return gradient_start_color
88
93
 
@@ -91,11 +96,11 @@ class Project(Node):
91
96
  self,
92
97
  info: Info[Context, None],
93
98
  ) -> str:
94
- if self.db_project:
95
- gradient_end_color = self.db_project.gradient_end_color
99
+ if self.db_record:
100
+ gradient_end_color = self.db_record.gradient_end_color
96
101
  else:
97
102
  gradient_end_color = await info.context.data_loaders.project_fields.load(
98
- (self.project_rowid, models.Project.gradient_end_color),
103
+ (self.id, models.Project.gradient_end_color),
99
104
  )
100
105
  return gradient_end_color
101
106
 
@@ -105,7 +110,7 @@ class Project(Node):
105
110
  info: Info[Context, None],
106
111
  ) -> Optional[datetime]:
107
112
  start_time = await info.context.data_loaders.min_start_or_max_end_times.load(
108
- (self.project_rowid, "start"),
113
+ (self.id, "start"),
109
114
  )
110
115
  start_time, _ = right_open_time_range(start_time, None)
111
116
  return start_time
@@ -116,7 +121,7 @@ class Project(Node):
116
121
  info: Info[Context, None],
117
122
  ) -> Optional[datetime]:
118
123
  end_time = await info.context.data_loaders.min_start_or_max_end_times.load(
119
- (self.project_rowid, "end"),
124
+ (self.id, "end"),
120
125
  )
121
126
  _, end_time = right_open_time_range(None, end_time)
122
127
  return end_time
@@ -127,9 +132,21 @@ class Project(Node):
127
132
  info: Info[Context, None],
128
133
  time_range: Optional[TimeRange] = UNSET,
129
134
  filter_condition: Optional[str] = UNSET,
135
+ session_filter_condition: Optional[str] = UNSET,
130
136
  ) -> int:
137
+ if filter_condition and session_filter_condition:
138
+ raise BadRequest(
139
+ "Both a filter condition and session filter condition "
140
+ "cannot be applied at the same time"
141
+ )
131
142
  return await info.context.data_loaders.record_counts.load(
132
- ("span", self.project_rowid, time_range, filter_condition),
143
+ (
144
+ "span",
145
+ self.id,
146
+ time_range or None,
147
+ filter_condition or None,
148
+ session_filter_condition or None,
149
+ ),
133
150
  )
134
151
 
135
152
  @strawberry.field
@@ -137,9 +154,22 @@ class Project(Node):
137
154
  self,
138
155
  info: Info[Context, None],
139
156
  time_range: Optional[TimeRange] = UNSET,
157
+ filter_condition: Optional[str] = UNSET,
158
+ session_filter_condition: Optional[str] = UNSET,
140
159
  ) -> int:
160
+ if filter_condition and session_filter_condition:
161
+ raise BadRequest(
162
+ "Both a filter condition and session filter condition "
163
+ "cannot be applied at the same time"
164
+ )
141
165
  return await info.context.data_loaders.record_counts.load(
142
- ("trace", self.project_rowid, time_range, None),
166
+ (
167
+ "trace",
168
+ self.id,
169
+ time_range or None,
170
+ filter_condition or None,
171
+ session_filter_condition or None,
172
+ ),
143
173
  )
144
174
 
145
175
  @strawberry.field
@@ -150,7 +180,7 @@ class Project(Node):
150
180
  filter_condition: Optional[str] = UNSET,
151
181
  ) -> float:
152
182
  return await info.context.data_loaders.token_counts.load(
153
- ("total", self.project_rowid, time_range, filter_condition),
183
+ ("total", self.id, time_range, filter_condition),
154
184
  )
155
185
 
156
186
  @strawberry.field
@@ -161,7 +191,7 @@ class Project(Node):
161
191
  filter_condition: Optional[str] = UNSET,
162
192
  ) -> float:
163
193
  return await info.context.data_loaders.token_counts.load(
164
- ("prompt", self.project_rowid, time_range, filter_condition),
194
+ ("prompt", self.id, time_range, filter_condition),
165
195
  )
166
196
 
167
197
  @strawberry.field
@@ -172,7 +202,43 @@ class Project(Node):
172
202
  filter_condition: Optional[str] = UNSET,
173
203
  ) -> float:
174
204
  return await info.context.data_loaders.token_counts.load(
175
- ("completion", self.project_rowid, time_range, filter_condition),
205
+ ("completion", self.id, time_range, filter_condition),
206
+ )
207
+
208
+ @strawberry.field
209
+ async def cost_summary(
210
+ self,
211
+ info: Info[Context, None],
212
+ time_range: Optional[TimeRange] = UNSET,
213
+ filter_condition: Optional[str] = UNSET,
214
+ session_filter_condition: Optional[str] = UNSET,
215
+ ) -> SpanCostSummary:
216
+ if filter_condition and session_filter_condition:
217
+ raise BadRequest(
218
+ "Both a filter condition and session filter condition "
219
+ "cannot be applied at the same time"
220
+ )
221
+ summary = await info.context.data_loaders.span_cost_summary_by_project.load(
222
+ (
223
+ self.id,
224
+ time_range or None,
225
+ filter_condition or None,
226
+ session_filter_condition or None,
227
+ )
228
+ )
229
+ return SpanCostSummary(
230
+ prompt=CostBreakdown(
231
+ tokens=summary.prompt.tokens,
232
+ cost=summary.prompt.cost,
233
+ ),
234
+ completion=CostBreakdown(
235
+ tokens=summary.completion.tokens,
236
+ cost=summary.completion.cost,
237
+ ),
238
+ total=CostBreakdown(
239
+ tokens=summary.total.tokens,
240
+ cost=summary.total.cost,
241
+ ),
176
242
  )
177
243
 
178
244
  @strawberry.field
@@ -181,13 +247,21 @@ class Project(Node):
181
247
  info: Info[Context, None],
182
248
  probability: float,
183
249
  time_range: Optional[TimeRange] = UNSET,
250
+ filter_condition: Optional[str] = UNSET,
251
+ session_filter_condition: Optional[str] = UNSET,
184
252
  ) -> Optional[float]:
253
+ if filter_condition and session_filter_condition:
254
+ raise BadRequest(
255
+ "Both a filter condition and session filter condition "
256
+ "cannot be applied at the same time"
257
+ )
185
258
  return await info.context.data_loaders.latency_ms_quantile.load(
186
259
  (
187
260
  "trace",
188
- self.project_rowid,
189
- time_range,
190
- None,
261
+ self.id,
262
+ time_range or None,
263
+ filter_condition or None,
264
+ session_filter_condition or None,
191
265
  probability,
192
266
  ),
193
267
  )
@@ -199,13 +273,20 @@ class Project(Node):
199
273
  probability: float,
200
274
  time_range: Optional[TimeRange] = UNSET,
201
275
  filter_condition: Optional[str] = UNSET,
276
+ session_filter_condition: Optional[str] = UNSET,
202
277
  ) -> Optional[float]:
278
+ if filter_condition and session_filter_condition:
279
+ raise BadRequest(
280
+ "Both a filter condition and session filter condition "
281
+ "cannot be applied at the same time"
282
+ )
203
283
  return await info.context.data_loaders.latency_ms_quantile.load(
204
284
  (
205
285
  "span",
206
- self.project_rowid,
207
- time_range,
208
- filter_condition,
286
+ self.id,
287
+ time_range or None,
288
+ filter_condition or None,
289
+ session_filter_condition or None,
209
290
  probability,
210
291
  ),
211
292
  )
@@ -215,12 +296,12 @@ class Project(Node):
215
296
  stmt = (
216
297
  select(models.Trace)
217
298
  .where(models.Trace.trace_id == str(trace_id))
218
- .where(models.Trace.project_rowid == self.project_rowid)
299
+ .where(models.Trace.project_rowid == self.id)
219
300
  )
220
301
  async with info.context.db() as session:
221
302
  if (trace := await session.scalar(stmt)) is None:
222
303
  return None
223
- return Trace(trace_rowid=trace.id, db_trace=trace)
304
+ return Trace(id=trace.id, db_record=trace)
224
305
 
225
306
  @strawberry.field
226
307
  async def spans(
@@ -236,10 +317,21 @@ class Project(Node):
236
317
  filter_condition: Optional[str] = UNSET,
237
318
  orphan_span_as_root_span: Optional[bool] = True,
238
319
  ) -> Connection[Span]:
320
+ if root_spans_only and not filter_condition and sort and sort.col is SpanColumn.startTime:
321
+ return await _paginate_span_by_trace_start_time(
322
+ db=info.context.db,
323
+ project_rowid=self.id,
324
+ time_range=time_range,
325
+ first=first,
326
+ after=after,
327
+ sort=sort,
328
+ orphan_span_as_root_span=orphan_span_as_root_span,
329
+ )
239
330
  stmt = (
240
331
  select(models.Span.id)
332
+ .select_from(models.Span)
241
333
  .join(models.Trace)
242
- .where(models.Trace.project_rowid == self.project_rowid)
334
+ .where(models.Trace.project_rowid == self.id)
243
335
  )
244
336
  if time_range:
245
337
  if time_range.start:
@@ -261,12 +353,16 @@ class Project(Node):
261
353
  if sort_config and cursor.sort_column:
262
354
  sort_column = cursor.sort_column
263
355
  compare = operator.lt if sort_config.dir is SortDir.desc else operator.gt
264
- stmt = stmt.where(
265
- compare(
266
- tuple_(sort_config.orm_expression, models.Span.id),
267
- (sort_column.value, cursor.rowid),
356
+ if sort_column.type is CursorSortColumnDataType.NULL:
357
+ stmt = stmt.where(sort_config.orm_expression.is_(None))
358
+ stmt = stmt.where(compare(models.Span.id, cursor.rowid))
359
+ else:
360
+ stmt = stmt.where(
361
+ compare(
362
+ tuple_(sort_config.orm_expression, models.Span.id),
363
+ (sort_column.value, cursor.rowid),
364
+ )
268
365
  )
269
- )
270
366
  else:
271
367
  stmt = stmt.where(models.Span.id > cursor.rowid)
272
368
  stmt = stmt.order_by(cursor_rowid_column)
@@ -304,7 +400,7 @@ class Project(Node):
304
400
  type=sort_config.column_data_type,
305
401
  value=span_record[1],
306
402
  )
307
- cursors_and_nodes.append((cursor, Span(span_rowid=span_rowid)))
403
+ cursors_and_nodes.append((cursor, Span(id=span_rowid)))
308
404
  has_next_page = True
309
405
  try:
310
406
  await span_records.__anext__()
@@ -326,87 +422,66 @@ class Project(Node):
326
422
  after: Optional[CursorString] = UNSET,
327
423
  sort: Optional[ProjectSessionSort] = UNSET,
328
424
  filter_io_substring: Optional[str] = UNSET,
425
+ session_id: Optional[str] = UNSET,
329
426
  ) -> Connection[ProjectSession]:
330
427
  table = models.ProjectSession
331
- stmt = select(table).filter_by(project_id=self.project_rowid)
428
+ if session_id:
429
+ async with info.context.db() as session:
430
+ ans = await session.scalar(
431
+ select(table).filter_by(
432
+ session_id=session_id,
433
+ project_id=self.id,
434
+ )
435
+ )
436
+ if ans:
437
+ return connection_from_list(
438
+ data=[ProjectSession(id=ans.id, db_record=ans)],
439
+ args=ConnectionArgs(),
440
+ )
441
+ elif not filter_io_substring:
442
+ return connection_from_list(
443
+ data=[],
444
+ args=ConnectionArgs(),
445
+ )
446
+ stmt = select(table).filter_by(project_id=self.id)
332
447
  if time_range:
333
448
  if time_range.start:
334
449
  stmt = stmt.where(time_range.start <= table.start_time)
335
450
  if time_range.end:
336
451
  stmt = stmt.where(table.start_time < time_range.end)
337
452
  if filter_io_substring:
338
- filter_subq = (
339
- stmt.with_only_columns(distinct(table.id).label("id"))
340
- .join_from(table, models.Trace)
341
- .join_from(models.Trace, models.Span)
342
- .where(models.Span.parent_id.is_(None))
343
- .where(
344
- or_(
345
- models.TextContains(
346
- models.Span.attributes[INPUT_VALUE].as_string(),
347
- filter_io_substring,
348
- ),
349
- models.TextContains(
350
- models.Span.attributes[OUTPUT_VALUE].as_string(),
351
- filter_io_substring,
352
- ),
353
- )
354
- )
355
- ).subquery()
356
- stmt = stmt.join(filter_subq, table.id == filter_subq.c.id)
453
+ filtered_session_rowids = get_filtered_session_rowids_subquery(
454
+ session_filter_condition=filter_io_substring,
455
+ project_rowids=[self.id],
456
+ start_time=time_range.start if time_range else None,
457
+ end_time=time_range.end if time_range else None,
458
+ )
459
+ stmt = stmt.where(table.id.in_(filtered_session_rowids))
460
+ sort_config: Optional[ProjectSessionSortConfig] = None
461
+ cursor_rowid_column: Any = table.id
357
462
  if sort:
358
- key: ColumnElement[Any]
359
- if sort.col is ProjectSessionColumn.startTime:
360
- key = table.start_time.label("key")
361
- elif sort.col is ProjectSessionColumn.endTime:
362
- key = table.end_time.label("key")
363
- elif (
364
- sort.col is ProjectSessionColumn.tokenCountTotal
365
- or sort.col is ProjectSessionColumn.numTraces
366
- ):
367
- if sort.col is ProjectSessionColumn.tokenCountTotal:
368
- sort_subq = (
369
- select(
370
- models.Trace.project_session_rowid.label("id"),
371
- func.sum(models.Span.cumulative_llm_token_count_total).label("key"),
372
- )
373
- .join_from(models.Trace, models.Span)
374
- .where(models.Span.parent_id.is_(None))
375
- .group_by(models.Trace.project_session_rowid)
376
- ).subquery()
377
- elif sort.col is ProjectSessionColumn.numTraces:
378
- sort_subq = (
379
- select(
380
- models.Trace.project_session_rowid.label("id"),
381
- func.count(models.Trace.id).label("key"),
382
- ).group_by(models.Trace.project_session_rowid)
383
- ).subquery()
463
+ sort_config = sort.update_orm_expr(stmt)
464
+ stmt = sort_config.stmt
465
+ if sort_config.dir is SortDir.desc:
466
+ cursor_rowid_column = desc(cursor_rowid_column)
467
+ if after:
468
+ cursor = Cursor.from_string(after)
469
+ if sort_config and cursor.sort_column:
470
+ sort_column = cursor.sort_column
471
+ compare = operator.lt if sort_config.dir is SortDir.desc else operator.gt
472
+ if sort_column.type is CursorSortColumnDataType.NULL:
473
+ stmt = stmt.where(sort_config.orm_expression.is_(None))
474
+ stmt = stmt.where(compare(table.id, cursor.rowid))
384
475
  else:
385
- assert_never(sort.col)
386
- key = sort_subq.c.key
387
- stmt = stmt.join(sort_subq, table.id == sort_subq.c.id)
388
- else:
389
- assert_never(sort.col)
390
- stmt = stmt.add_columns(key)
391
- if sort.dir is SortDir.asc:
392
- stmt = stmt.order_by(key.asc(), table.id.asc())
393
- else:
394
- stmt = stmt.order_by(key.desc(), table.id.desc())
395
- if after:
396
- cursor = Cursor.from_string(after)
397
- assert cursor.sort_column is not None
398
- compare = operator.lt if sort.dir is SortDir.desc else operator.gt
399
- stmt = stmt.where(
400
- compare(
401
- tuple_(key, table.id),
402
- (cursor.sort_column.value, cursor.rowid),
476
+ stmt = stmt.where(
477
+ compare(
478
+ tuple_(sort_config.orm_expression, table.id),
479
+ (sort_column.value, cursor.rowid),
480
+ )
403
481
  )
404
- )
405
- else:
406
- stmt = stmt.order_by(table.id.desc())
407
- if after:
408
- cursor = Cursor.from_string(after)
482
+ else:
409
483
  stmt = stmt.where(table.id < cursor.rowid)
484
+ stmt = stmt.order_by(cursor_rowid_column)
410
485
  if first:
411
486
  stmt = stmt.limit(
412
487
  first + 1 # over-fetch by one to determine whether there's a next page
@@ -417,13 +492,15 @@ class Project(Node):
417
492
  async for record in islice(records, first):
418
493
  project_session = record[0]
419
494
  cursor = Cursor(rowid=project_session.id)
420
- if sort:
495
+ if sort_config:
421
496
  assert len(record) > 1
422
497
  cursor.sort_column = CursorSortColumn(
423
- type=sort.col.data_type,
498
+ type=sort_config.column_data_type,
424
499
  value=record[1],
425
500
  )
426
- cursors_and_nodes.append((cursor, to_gql_project_session(project_session)))
501
+ cursors_and_nodes.append(
502
+ (cursor, ProjectSession(id=project_session.id, db_record=project_session))
503
+ )
427
504
  has_next_page = True
428
505
  try:
429
506
  await records.__anext__()
@@ -446,7 +523,7 @@ class Project(Node):
446
523
  stmt = (
447
524
  select(distinct(models.TraceAnnotation.name))
448
525
  .join(models.Trace)
449
- .where(models.Trace.project_rowid == self.project_rowid)
526
+ .where(models.Trace.project_rowid == self.id)
450
527
  )
451
528
  async with info.context.db() as session:
452
529
  return list(await session.scalars(stmt))
@@ -463,7 +540,23 @@ class Project(Node):
463
540
  select(distinct(models.SpanAnnotation.name))
464
541
  .join(models.Span)
465
542
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
466
- .where(models.Trace.project_rowid == self.project_rowid)
543
+ .where(models.Trace.project_rowid == self.id)
544
+ )
545
+ async with info.context.db() as session:
546
+ return list(await session.scalars(stmt))
547
+
548
+ @strawberry.field(
549
+ description="Names of all available annotations for sessions. "
550
+ "(The list contains no duplicates.)"
551
+ ) # type: ignore
552
+ async def session_annotation_names(
553
+ self,
554
+ info: Info[Context, None],
555
+ ) -> list[str]:
556
+ stmt = (
557
+ select(distinct(models.ProjectSessionAnnotation.name))
558
+ .join(models.ProjectSession)
559
+ .where(models.ProjectSession.project_id == self.id)
467
560
  )
468
561
  async with info.context.db() as session:
469
562
  return list(await session.scalars(stmt))
@@ -480,7 +573,7 @@ class Project(Node):
480
573
  select(distinct(models.DocumentAnnotation.name))
481
574
  .join(models.Span)
482
575
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
483
- .where(models.Trace.project_rowid == self.project_rowid)
576
+ .where(models.Trace.project_rowid == self.id)
484
577
  .where(models.DocumentAnnotation.annotator_kind == "LLM")
485
578
  )
486
579
  if span_id:
@@ -493,10 +586,24 @@ class Project(Node):
493
586
  self,
494
587
  info: Info[Context, None],
495
588
  annotation_name: str,
589
+ filter_condition: Optional[str] = UNSET,
590
+ session_filter_condition: Optional[str] = UNSET,
496
591
  time_range: Optional[TimeRange] = UNSET,
497
592
  ) -> Optional[AnnotationSummary]:
593
+ if filter_condition and session_filter_condition:
594
+ raise BadRequest(
595
+ "Both a filter condition and session filter condition "
596
+ "cannot be applied at the same time"
597
+ )
498
598
  return await info.context.data_loaders.annotation_summaries.load(
499
- ("trace", self.project_rowid, time_range, None, annotation_name),
599
+ (
600
+ "trace",
601
+ self.id,
602
+ time_range or None,
603
+ filter_condition or None,
604
+ session_filter_condition or None,
605
+ annotation_name,
606
+ ),
500
607
  )
501
608
 
502
609
  @strawberry.field
@@ -506,9 +613,22 @@ class Project(Node):
506
613
  annotation_name: str,
507
614
  time_range: Optional[TimeRange] = UNSET,
508
615
  filter_condition: Optional[str] = UNSET,
616
+ session_filter_condition: Optional[str] = UNSET,
509
617
  ) -> Optional[AnnotationSummary]:
618
+ if filter_condition and session_filter_condition:
619
+ raise BadRequest(
620
+ "Both a filter condition and session filter condition "
621
+ "cannot be applied at the same time"
622
+ )
510
623
  return await info.context.data_loaders.annotation_summaries.load(
511
- ("span", self.project_rowid, time_range, filter_condition, annotation_name),
624
+ (
625
+ "span",
626
+ self.id,
627
+ time_range or None,
628
+ filter_condition or None,
629
+ session_filter_condition or None,
630
+ annotation_name,
631
+ ),
512
632
  )
513
633
 
514
634
  @strawberry.field
@@ -520,7 +640,7 @@ class Project(Node):
520
640
  filter_condition: Optional[str] = UNSET,
521
641
  ) -> Optional[DocumentEvaluationSummary]:
522
642
  return await info.context.data_loaders.document_evaluation_summaries.load(
523
- (self.project_rowid, time_range, filter_condition, evaluation_name),
643
+ (self.id, time_range, filter_condition, evaluation_name),
524
644
  )
525
645
 
526
646
  @strawberry.field
@@ -528,7 +648,7 @@ class Project(Node):
528
648
  self,
529
649
  info: Info[Context, None],
530
650
  ) -> Optional[datetime]:
531
- return info.context.last_updated_at.get(self._table, self.project_rowid)
651
+ return info.context.last_updated_at.get(models.Project, self.id)
532
652
 
533
653
  @strawberry.field
534
654
  async def validate_span_filter_condition(
@@ -561,7 +681,7 @@ class Project(Node):
561
681
  stmt = span_filter(select(models.Span))
562
682
  dialect = info.context.db.dialect
563
683
  if dialect is SupportedSQLDialect.POSTGRESQL:
564
- str(stmt.compile(dialect=sqlite.dialect())) # type: ignore[no-untyped-call]
684
+ str(stmt.compile(dialect=sqlite.dialect()))
565
685
  elif dialect is SupportedSQLDialect.SQLITE:
566
686
  str(stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call]
567
687
  else:
@@ -588,30 +708,19 @@ class Project(Node):
588
708
  last=last,
589
709
  before=before if isinstance(before, CursorString) else None,
590
710
  )
591
- async with info.context.db() as session:
592
- annotation_configs = await session.stream_scalars(
593
- select(models.AnnotationConfig)
594
- .join(
595
- models.ProjectAnnotationConfig,
596
- models.AnnotationConfig.id
597
- == models.ProjectAnnotationConfig.annotation_config_id,
598
- )
599
- .where(models.ProjectAnnotationConfig.project_id == self.project_rowid)
600
- .order_by(models.AnnotationConfig.name)
601
- )
602
- data = [to_gql_annotation_config(config) async for config in annotation_configs]
711
+ loader = info.context.data_loaders.annotation_configs_by_project
712
+ configs = await loader.load(self.id)
713
+ data = [to_gql_annotation_config(config) for config in configs]
603
714
  return connection_from_list(data=data, args=args)
604
715
 
605
716
  @strawberry.field
606
717
  async def trace_retention_policy(
607
718
  self,
608
719
  info: Info[Context, None],
609
- ) -> Annotated[ProjectTraceRetentionPolicy, lazy(".ProjectTraceRetentionPolicy")]:
720
+ ) -> Annotated["ProjectTraceRetentionPolicy", lazy(".ProjectTraceRetentionPolicy")]:
610
721
  from .ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
611
722
 
612
- id_ = await info.context.data_loaders.trace_retention_policy_id_by_project_id.load(
613
- self.project_rowid
614
- )
723
+ id_ = await info.context.data_loaders.trace_retention_policy_id_by_project_id.load(self.id)
615
724
  return ProjectTraceRetentionPolicy(id=id_)
616
725
 
617
726
  @strawberry.field
@@ -619,11 +728,11 @@ class Project(Node):
619
728
  self,
620
729
  info: Info[Context, None],
621
730
  ) -> datetime:
622
- if self.db_project:
623
- created_at = self.db_project.created_at
731
+ if self.db_record:
732
+ created_at = self.db_record.created_at
624
733
  else:
625
734
  created_at = await info.context.data_loaders.project_fields.load(
626
- (self.project_rowid, models.Project.created_at),
735
+ (self.id, models.Project.created_at),
627
736
  )
628
737
  return created_at
629
738
 
@@ -632,96 +741,841 @@ class Project(Node):
632
741
  self,
633
742
  info: Info[Context, None],
634
743
  ) -> datetime:
635
- if self.db_project:
636
- updated_at = self.db_project.updated_at
744
+ if self.db_record:
745
+ updated_at = self.db_record.updated_at
637
746
  else:
638
747
  updated_at = await info.context.data_loaders.project_fields.load(
639
- (self.project_rowid, models.Project.updated_at),
748
+ (self.id, models.Project.updated_at),
640
749
  )
641
750
  return updated_at
642
751
 
643
- @strawberry.field(
644
- description="Hourly span count for the project.",
645
- ) # type: ignore
752
+ @strawberry.field
646
753
  async def span_count_time_series(
647
754
  self,
648
755
  info: Info[Context, None],
649
- time_range: Optional[TimeRange] = UNSET,
650
- ) -> SpanCountTimeSeries:
651
- """Returns a time series of span counts grouped by hour for the project.
756
+ time_range: TimeRange,
757
+ time_bin_config: Optional[TimeBinConfig] = UNSET,
758
+ filter_condition: Optional[str] = UNSET,
759
+ ) -> "SpanCountTimeSeries":
760
+ if time_range.start is None:
761
+ raise BadRequest("Start time is required")
652
762
 
653
- This field provides hourly aggregated span counts, which can be useful for
654
- visualizing span activity over time. The data points represent the number
655
- of spans that started in each hour.
763
+ dialect = info.context.db.dialect
764
+ utc_offset_minutes = 0
765
+ field: Literal["minute", "hour", "day", "week", "month", "year"] = "hour"
766
+ if time_bin_config:
767
+ utc_offset_minutes = time_bin_config.utc_offset_minutes
768
+ if time_bin_config.scale is TimeBinScale.MINUTE:
769
+ field = "minute"
770
+ elif time_bin_config.scale is TimeBinScale.HOUR:
771
+ field = "hour"
772
+ elif time_bin_config.scale is TimeBinScale.DAY:
773
+ field = "day"
774
+ elif time_bin_config.scale is TimeBinScale.WEEK:
775
+ field = "week"
776
+ elif time_bin_config.scale is TimeBinScale.MONTH:
777
+ field = "month"
778
+ elif time_bin_config.scale is TimeBinScale.YEAR:
779
+ field = "year"
780
+ bucket = date_trunc(dialect, field, models.Span.start_time, utc_offset_minutes)
781
+ stmt = (
782
+ select(
783
+ bucket,
784
+ func.count(models.Span.id).label("total_count"),
785
+ func.sum(case((models.Span.status_code == "OK", 1), else_=0)).label("ok_count"),
786
+ func.sum(case((models.Span.status_code == "ERROR", 1), else_=0)).label(
787
+ "error_count"
788
+ ),
789
+ func.sum(case((models.Span.status_code == "UNSET", 1), else_=0)).label(
790
+ "unset_count"
791
+ ),
792
+ )
793
+ .join_from(models.Span, models.Trace)
794
+ .where(models.Trace.project_rowid == self.id)
795
+ .group_by(bucket)
796
+ .order_by(bucket)
797
+ )
798
+ if time_range.start:
799
+ stmt = stmt.where(time_range.start <= models.Span.start_time)
800
+ if time_range.end:
801
+ stmt = stmt.where(models.Span.start_time < time_range.end)
802
+ if filter_condition:
803
+ span_filter = SpanFilter(condition=filter_condition)
804
+ stmt = span_filter(stmt)
656
805
 
657
- Args:
658
- info: The GraphQL info object containing context information.
659
- time_range: Optional time range to filter the spans. If provided, only
660
- spans that started within this range will be counted.
806
+ data = {}
807
+ async with info.context.db() as session:
808
+ async for t, total_count, ok_count, error_count, unset_count in await session.stream(
809
+ stmt
810
+ ):
811
+ timestamp = _as_datetime(t)
812
+ data[timestamp] = SpanCountTimeSeriesDataPoint(
813
+ timestamp=timestamp,
814
+ ok_count=ok_count,
815
+ error_count=error_count,
816
+ unset_count=unset_count,
817
+ total_count=total_count,
818
+ )
661
819
 
662
- Returns:
663
- A SpanCountTimeSeries object containing data points with timestamps
664
- (rounded to the nearest hour) and corresponding span counts.
665
-
666
- Notes:
667
- - The timestamps are rounded down to the nearest hour.
668
- - If a time range is provided, the start time is rounded down to the
669
- nearest hour, and the end time is rounded up to the nearest hour.
670
- - The SQL query is optimized for both PostgreSQL and SQLite databases.
671
- """
672
- # Determine the appropriate SQL function to truncate timestamps to hours
673
- # based on the database dialect
674
- if info.context.db.dialect is SupportedSQLDialect.POSTGRESQL:
675
- # PostgreSQL uses date_trunc for timestamp truncation
676
- hour = func.date_trunc("hour", models.Span.start_time)
677
- elif info.context.db.dialect is SupportedSQLDialect.SQLITE:
678
- # SQLite uses strftime for timestamp formatting
679
- hour = func.strftime("%Y-%m-%dT%H:00:00.000+00:00", models.Span.start_time)
820
+ data_timestamps: list[datetime] = [data_point.timestamp for data_point in data.values()]
821
+ min_time = min([*data_timestamps, time_range.start])
822
+ max_time = max(
823
+ [
824
+ *data_timestamps,
825
+ *([time_range.end] if time_range.end else [datetime.now(timezone.utc)]),
826
+ ],
827
+ )
828
+ for timestamp in get_timestamp_range(
829
+ start_time=min_time,
830
+ end_time=max_time,
831
+ stride=field,
832
+ utc_offset_minutes=utc_offset_minutes,
833
+ ):
834
+ if timestamp not in data:
835
+ data[timestamp] = SpanCountTimeSeriesDataPoint(timestamp=timestamp)
836
+ return SpanCountTimeSeries(data=sorted(data.values(), key=lambda x: x.timestamp))
837
+
838
+ @strawberry.field
839
+ async def trace_count_time_series(
840
+ self,
841
+ info: Info[Context, None],
842
+ time_range: TimeRange,
843
+ time_bin_config: Optional[TimeBinConfig] = UNSET,
844
+ ) -> "TraceCountTimeSeries":
845
+ if time_range.start is None:
846
+ raise BadRequest("Start time is required")
847
+
848
+ dialect = info.context.db.dialect
849
+ utc_offset_minutes = 0
850
+ field: Literal["minute", "hour", "day", "week", "month", "year"] = "hour"
851
+ if time_bin_config:
852
+ utc_offset_minutes = time_bin_config.utc_offset_minutes
853
+ if time_bin_config.scale is TimeBinScale.MINUTE:
854
+ field = "minute"
855
+ elif time_bin_config.scale is TimeBinScale.HOUR:
856
+ field = "hour"
857
+ elif time_bin_config.scale is TimeBinScale.DAY:
858
+ field = "day"
859
+ elif time_bin_config.scale is TimeBinScale.WEEK:
860
+ field = "week"
861
+ elif time_bin_config.scale is TimeBinScale.MONTH:
862
+ field = "month"
863
+ elif time_bin_config.scale is TimeBinScale.YEAR:
864
+ field = "year"
865
+ bucket = date_trunc(dialect, field, models.Trace.start_time, utc_offset_minutes)
866
+ stmt = (
867
+ select(bucket, func.count(models.Trace.id))
868
+ .where(models.Trace.project_rowid == self.id)
869
+ .group_by(bucket)
870
+ .order_by(bucket)
871
+ )
872
+ if time_range:
873
+ if time_range.start:
874
+ stmt = stmt.where(time_range.start <= models.Trace.start_time)
875
+ if time_range.end:
876
+ stmt = stmt.where(models.Trace.start_time < time_range.end)
877
+ data = {}
878
+ async with info.context.db() as session:
879
+ async for t, v in await session.stream(stmt):
880
+ timestamp = _as_datetime(t)
881
+ data[timestamp] = TimeSeriesDataPoint(timestamp=timestamp, value=v)
882
+
883
+ data_timestamps: list[datetime] = [data_point.timestamp for data_point in data.values()]
884
+ min_time = min([*data_timestamps, time_range.start])
885
+ max_time = max(
886
+ [
887
+ *data_timestamps,
888
+ *([time_range.end] if time_range.end else [datetime.now(timezone.utc)]),
889
+ ],
890
+ )
891
+ for timestamp in get_timestamp_range(
892
+ start_time=min_time,
893
+ end_time=max_time,
894
+ stride=field,
895
+ utc_offset_minutes=utc_offset_minutes,
896
+ ):
897
+ if timestamp not in data:
898
+ data[timestamp] = TimeSeriesDataPoint(timestamp=timestamp)
899
+ return TraceCountTimeSeries(data=sorted(data.values(), key=lambda x: x.timestamp))
900
+
901
+ @strawberry.field
902
+ async def trace_count_by_status_time_series(
903
+ self,
904
+ info: Info[Context, None],
905
+ time_range: TimeRange,
906
+ time_bin_config: Optional[TimeBinConfig] = UNSET,
907
+ ) -> "TraceCountByStatusTimeSeries":
908
+ if time_range.start is None:
909
+ raise BadRequest("Start time is required")
910
+
911
+ dialect = info.context.db.dialect
912
+ utc_offset_minutes = 0
913
+ field: Literal["minute", "hour", "day", "week", "month", "year"] = "hour"
914
+ if time_bin_config:
915
+ utc_offset_minutes = time_bin_config.utc_offset_minutes
916
+ if time_bin_config.scale is TimeBinScale.MINUTE:
917
+ field = "minute"
918
+ elif time_bin_config.scale is TimeBinScale.HOUR:
919
+ field = "hour"
920
+ elif time_bin_config.scale is TimeBinScale.DAY:
921
+ field = "day"
922
+ elif time_bin_config.scale is TimeBinScale.WEEK:
923
+ field = "week"
924
+ elif time_bin_config.scale is TimeBinScale.MONTH:
925
+ field = "month"
926
+ elif time_bin_config.scale is TimeBinScale.YEAR:
927
+ field = "year"
928
+ bucket = date_trunc(dialect, field, models.Trace.start_time, utc_offset_minutes)
929
+ trace_error_status_counts = (
930
+ select(
931
+ models.Span.trace_rowid,
932
+ )
933
+ .where(models.Span.parent_id.is_(None))
934
+ .group_by(models.Span.trace_rowid)
935
+ .having(func.max(models.Span.cumulative_error_count) > 0)
936
+ ).subquery()
937
+ stmt = (
938
+ select(
939
+ bucket,
940
+ func.count(models.Trace.id).label("total_count"),
941
+ func.coalesce(func.count(trace_error_status_counts.c.trace_rowid), 0).label(
942
+ "error_count"
943
+ ),
944
+ )
945
+ .join_from(
946
+ models.Trace,
947
+ trace_error_status_counts,
948
+ onclause=trace_error_status_counts.c.trace_rowid == models.Trace.id,
949
+ isouter=True,
950
+ )
951
+ .where(models.Trace.project_rowid == self.id)
952
+ .group_by(bucket)
953
+ .order_by(bucket)
954
+ )
955
+ if time_range:
956
+ if time_range.start:
957
+ stmt = stmt.where(time_range.start <= models.Trace.start_time)
958
+ if time_range.end:
959
+ stmt = stmt.where(models.Trace.start_time < time_range.end)
960
+ data: dict[datetime, TraceCountByStatusTimeSeriesDataPoint] = {}
961
+ async with info.context.db() as session:
962
+ async for t, total_count, error_count in await session.stream(stmt):
963
+ timestamp = _as_datetime(t)
964
+ data[timestamp] = TraceCountByStatusTimeSeriesDataPoint(
965
+ timestamp=timestamp,
966
+ ok_count=total_count - error_count,
967
+ error_count=error_count,
968
+ total_count=total_count,
969
+ )
970
+
971
+ data_timestamps: list[datetime] = [data_point.timestamp for data_point in data.values()]
972
+ min_time = min([*data_timestamps, time_range.start])
973
+ max_time = max(
974
+ [
975
+ *data_timestamps,
976
+ *([time_range.end] if time_range.end else [datetime.now(timezone.utc)]),
977
+ ],
978
+ )
979
+ for timestamp in get_timestamp_range(
980
+ start_time=min_time,
981
+ end_time=max_time,
982
+ stride=field,
983
+ utc_offset_minutes=utc_offset_minutes,
984
+ ):
985
+ if timestamp not in data:
986
+ data[timestamp] = TraceCountByStatusTimeSeriesDataPoint(
987
+ timestamp=timestamp,
988
+ ok_count=0,
989
+ error_count=0,
990
+ total_count=0,
991
+ )
992
+ return TraceCountByStatusTimeSeries(data=sorted(data.values(), key=lambda x: x.timestamp))
993
+
994
+ @strawberry.field
995
+ async def trace_latency_ms_percentile_time_series(
996
+ self,
997
+ info: Info[Context, None],
998
+ time_range: TimeRange,
999
+ time_bin_config: Optional[TimeBinConfig] = UNSET,
1000
+ ) -> "TraceLatencyPercentileTimeSeries":
1001
+ if time_range.start is None:
1002
+ raise BadRequest("Start time is required")
1003
+
1004
+ dialect = info.context.db.dialect
1005
+ utc_offset_minutes = 0
1006
+ field: Literal["minute", "hour", "day", "week", "month", "year"] = "hour"
1007
+ if time_bin_config:
1008
+ utc_offset_minutes = time_bin_config.utc_offset_minutes
1009
+ if time_bin_config.scale is TimeBinScale.MINUTE:
1010
+ field = "minute"
1011
+ elif time_bin_config.scale is TimeBinScale.HOUR:
1012
+ field = "hour"
1013
+ elif time_bin_config.scale is TimeBinScale.DAY:
1014
+ field = "day"
1015
+ elif time_bin_config.scale is TimeBinScale.WEEK:
1016
+ field = "week"
1017
+ elif time_bin_config.scale is TimeBinScale.MONTH:
1018
+ field = "month"
1019
+ elif time_bin_config.scale is TimeBinScale.YEAR:
1020
+ field = "year"
1021
+ bucket = date_trunc(dialect, field, models.Trace.start_time, utc_offset_minutes)
1022
+
1023
+ stmt = select(bucket).where(models.Trace.project_rowid == self.id)
1024
+ if time_range.start:
1025
+ stmt = stmt.where(time_range.start <= models.Trace.start_time)
1026
+ if time_range.end:
1027
+ stmt = stmt.where(models.Trace.start_time < time_range.end)
1028
+
1029
+ if dialect is SupportedSQLDialect.POSTGRESQL:
1030
+ stmt = stmt.add_columns(
1031
+ percentile_cont(0.50).within_group(models.Trace.latency_ms.asc()).label("p50"),
1032
+ percentile_cont(0.75).within_group(models.Trace.latency_ms.asc()).label("p75"),
1033
+ percentile_cont(0.90).within_group(models.Trace.latency_ms.asc()).label("p90"),
1034
+ percentile_cont(0.95).within_group(models.Trace.latency_ms.asc()).label("p95"),
1035
+ percentile_cont(0.99).within_group(models.Trace.latency_ms.asc()).label("p99"),
1036
+ percentile_cont(0.999).within_group(models.Trace.latency_ms.asc()).label("p999"),
1037
+ func.max(models.Trace.latency_ms).label("max"),
1038
+ )
1039
+ elif dialect is SupportedSQLDialect.SQLITE:
1040
+ stmt = stmt.add_columns(
1041
+ func.percentile(models.Trace.latency_ms, 50).label("p50"),
1042
+ func.percentile(models.Trace.latency_ms, 75).label("p75"),
1043
+ func.percentile(models.Trace.latency_ms, 90).label("p90"),
1044
+ func.percentile(models.Trace.latency_ms, 95).label("p95"),
1045
+ func.percentile(models.Trace.latency_ms, 99).label("p99"),
1046
+ func.percentile(models.Trace.latency_ms, 99.9).label("p999"),
1047
+ func.max(models.Trace.latency_ms).label("max"),
1048
+ )
680
1049
  else:
681
- assert_never(info.context.db.dialect)
1050
+ assert_never(dialect)
1051
+
1052
+ stmt = stmt.group_by(bucket).order_by(bucket)
1053
+
1054
+ data: dict[datetime, TraceLatencyMsPercentileTimeSeriesDataPoint] = {}
1055
+ async with info.context.db() as session:
1056
+ async for (
1057
+ bucket_time,
1058
+ p50,
1059
+ p75,
1060
+ p90,
1061
+ p95,
1062
+ p99,
1063
+ p999,
1064
+ max_latency,
1065
+ ) in await session.stream(stmt):
1066
+ timestamp = _as_datetime(bucket_time)
1067
+ data[timestamp] = TraceLatencyMsPercentileTimeSeriesDataPoint(
1068
+ timestamp=timestamp,
1069
+ p50=p50,
1070
+ p75=p75,
1071
+ p90=p90,
1072
+ p95=p95,
1073
+ p99=p99,
1074
+ p999=p999,
1075
+ max=max_latency,
1076
+ )
1077
+
1078
+ data_timestamps: list[datetime] = [data_point.timestamp for data_point in data.values()]
1079
+ min_time = min([*data_timestamps, time_range.start])
1080
+ max_time = max(
1081
+ [
1082
+ *data_timestamps,
1083
+ *([time_range.end] if time_range.end else [datetime.now(timezone.utc)]),
1084
+ ],
1085
+ )
1086
+ for timestamp in get_timestamp_range(
1087
+ start_time=min_time,
1088
+ end_time=max_time,
1089
+ stride=field,
1090
+ utc_offset_minutes=utc_offset_minutes,
1091
+ ):
1092
+ if timestamp not in data:
1093
+ data[timestamp] = TraceLatencyMsPercentileTimeSeriesDataPoint(timestamp=timestamp)
1094
+ return TraceLatencyPercentileTimeSeries(
1095
+ data=sorted(data.values(), key=lambda x: x.timestamp)
1096
+ )
682
1097
 
683
- # Build the base query to count spans grouped by hour
1098
+ @strawberry.field
1099
+ async def trace_token_count_time_series(
1100
+ self,
1101
+ info: Info[Context, None],
1102
+ time_range: TimeRange,
1103
+ time_bin_config: Optional[TimeBinConfig] = UNSET,
1104
+ ) -> "TraceTokenCountTimeSeries":
1105
+ if time_range.start is None:
1106
+ raise BadRequest("Start time is required")
1107
+
1108
+ dialect = info.context.db.dialect
1109
+ utc_offset_minutes = 0
1110
+ field: Literal["minute", "hour", "day", "week", "month", "year"] = "hour"
1111
+ if time_bin_config:
1112
+ utc_offset_minutes = time_bin_config.utc_offset_minutes
1113
+ if time_bin_config.scale is TimeBinScale.MINUTE:
1114
+ field = "minute"
1115
+ elif time_bin_config.scale is TimeBinScale.HOUR:
1116
+ field = "hour"
1117
+ elif time_bin_config.scale is TimeBinScale.DAY:
1118
+ field = "day"
1119
+ elif time_bin_config.scale is TimeBinScale.WEEK:
1120
+ field = "week"
1121
+ elif time_bin_config.scale is TimeBinScale.MONTH:
1122
+ field = "month"
1123
+ elif time_bin_config.scale is TimeBinScale.YEAR:
1124
+ field = "year"
1125
+ bucket = date_trunc(dialect, field, models.Trace.start_time, utc_offset_minutes)
684
1126
  stmt = (
685
- select(hour, func.count())
686
- .join(models.Trace)
687
- .where(models.Trace.project_rowid == self.project_rowid)
688
- .group_by(hour)
689
- .order_by(hour)
1127
+ select(
1128
+ bucket,
1129
+ func.sum(models.SpanCost.total_tokens),
1130
+ func.sum(models.SpanCost.prompt_tokens),
1131
+ func.sum(models.SpanCost.completion_tokens),
1132
+ )
1133
+ .join_from(
1134
+ models.Trace,
1135
+ models.SpanCost,
1136
+ onclause=models.SpanCost.trace_rowid == models.Trace.id,
1137
+ )
1138
+ .where(models.Trace.project_rowid == self.id)
1139
+ .group_by(bucket)
1140
+ .order_by(bucket)
690
1141
  )
1142
+ if time_range:
1143
+ if time_range.start:
1144
+ stmt = stmt.where(time_range.start <= models.Trace.start_time)
1145
+ if time_range.end:
1146
+ stmt = stmt.where(models.Trace.start_time < time_range.end)
1147
+ data: dict[datetime, TraceTokenCountTimeSeriesDataPoint] = {}
1148
+ async with info.context.db() as session:
1149
+ async for (
1150
+ t,
1151
+ total_tokens,
1152
+ prompt_tokens,
1153
+ completion_tokens,
1154
+ ) in await session.stream(stmt):
1155
+ timestamp = _as_datetime(t)
1156
+ data[timestamp] = TraceTokenCountTimeSeriesDataPoint(
1157
+ timestamp=timestamp,
1158
+ prompt_token_count=prompt_tokens,
1159
+ completion_token_count=completion_tokens,
1160
+ total_token_count=total_tokens,
1161
+ )
691
1162
 
692
- # Apply time range filtering if provided
1163
+ data_timestamps: list[datetime] = [data_point.timestamp for data_point in data.values()]
1164
+ min_time = min([*data_timestamps, time_range.start])
1165
+ max_time = max(
1166
+ [
1167
+ *data_timestamps,
1168
+ *([time_range.end] if time_range.end else [datetime.now(timezone.utc)]),
1169
+ ],
1170
+ )
1171
+ for timestamp in get_timestamp_range(
1172
+ start_time=min_time,
1173
+ end_time=max_time,
1174
+ stride=field,
1175
+ utc_offset_minutes=utc_offset_minutes,
1176
+ ):
1177
+ if timestamp not in data:
1178
+ data[timestamp] = TraceTokenCountTimeSeriesDataPoint(timestamp=timestamp)
1179
+ return TraceTokenCountTimeSeries(data=sorted(data.values(), key=lambda x: x.timestamp))
1180
+
1181
+ @strawberry.field
1182
+ async def trace_token_cost_time_series(
1183
+ self,
1184
+ info: Info[Context, None],
1185
+ time_range: TimeRange,
1186
+ time_bin_config: Optional[TimeBinConfig] = UNSET,
1187
+ ) -> "TraceTokenCostTimeSeries":
1188
+ if time_range.start is None:
1189
+ raise BadRequest("Start time is required")
1190
+
1191
+ dialect = info.context.db.dialect
1192
+ utc_offset_minutes = 0
1193
+ field: Literal["minute", "hour", "day", "week", "month", "year"] = "hour"
1194
+ if time_bin_config:
1195
+ utc_offset_minutes = time_bin_config.utc_offset_minutes
1196
+ if time_bin_config.scale is TimeBinScale.MINUTE:
1197
+ field = "minute"
1198
+ elif time_bin_config.scale is TimeBinScale.HOUR:
1199
+ field = "hour"
1200
+ elif time_bin_config.scale is TimeBinScale.DAY:
1201
+ field = "day"
1202
+ elif time_bin_config.scale is TimeBinScale.WEEK:
1203
+ field = "week"
1204
+ elif time_bin_config.scale is TimeBinScale.MONTH:
1205
+ field = "month"
1206
+ elif time_bin_config.scale is TimeBinScale.YEAR:
1207
+ field = "year"
1208
+ bucket = date_trunc(dialect, field, models.Trace.start_time, utc_offset_minutes)
1209
+ stmt = (
1210
+ select(
1211
+ bucket,
1212
+ func.sum(models.SpanCost.total_cost),
1213
+ func.sum(models.SpanCost.prompt_cost),
1214
+ func.sum(models.SpanCost.completion_cost),
1215
+ )
1216
+ .join_from(
1217
+ models.Trace,
1218
+ models.SpanCost,
1219
+ onclause=models.SpanCost.trace_rowid == models.Trace.id,
1220
+ )
1221
+ .where(models.Trace.project_rowid == self.id)
1222
+ .group_by(bucket)
1223
+ .order_by(bucket)
1224
+ )
693
1225
  if time_range:
694
- if t := time_range.start:
695
- # Round down to nearest hour for the start time
696
- start = t.replace(minute=0, second=0, microsecond=0)
697
- stmt = stmt.where(start <= models.Span.start_time)
698
- if t := time_range.end:
699
- # Round up to nearest hour for the end time
700
- # If the time is already at the start of an hour, use it as is
701
- if t.minute == 0 and t.second == 0 and t.microsecond == 0:
702
- end = t
703
- else:
704
- # Otherwise, round up to the next hour
705
- end = t.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
706
- stmt = stmt.where(models.Span.start_time < end)
1226
+ if time_range.start:
1227
+ stmt = stmt.where(time_range.start <= models.Trace.start_time)
1228
+ if time_range.end:
1229
+ stmt = stmt.where(models.Trace.start_time < time_range.end)
1230
+ data: dict[datetime, TraceTokenCostTimeSeriesDataPoint] = {}
1231
+ async with info.context.db() as session:
1232
+ async for (
1233
+ t,
1234
+ total_cost,
1235
+ prompt_cost,
1236
+ completion_cost,
1237
+ ) in await session.stream(stmt):
1238
+ timestamp = _as_datetime(t)
1239
+ data[timestamp] = TraceTokenCostTimeSeriesDataPoint(
1240
+ timestamp=timestamp,
1241
+ prompt_cost=prompt_cost,
1242
+ completion_cost=completion_cost,
1243
+ total_cost=total_cost,
1244
+ )
1245
+
1246
+ data_timestamps: list[datetime] = [data_point.timestamp for data_point in data.values()]
1247
+ min_time = min([*data_timestamps, time_range.start])
1248
+ max_time = max(
1249
+ [
1250
+ *data_timestamps,
1251
+ *([time_range.end] if time_range.end else [datetime.now(timezone.utc)]),
1252
+ ],
1253
+ )
1254
+ for timestamp in get_timestamp_range(
1255
+ start_time=min_time,
1256
+ end_time=max_time,
1257
+ stride=field,
1258
+ utc_offset_minutes=utc_offset_minutes,
1259
+ ):
1260
+ if timestamp not in data:
1261
+ data[timestamp] = TraceTokenCostTimeSeriesDataPoint(timestamp=timestamp)
1262
+ return TraceTokenCostTimeSeries(data=sorted(data.values(), key=lambda x: x.timestamp))
707
1263
 
708
- # Execute the query and convert the results to a time series
1264
+ @strawberry.field
1265
+ async def span_annotation_score_time_series(
1266
+ self,
1267
+ info: Info[Context, None],
1268
+ time_range: TimeRange,
1269
+ time_bin_config: Optional[TimeBinConfig] = UNSET,
1270
+ ) -> "SpanAnnotationScoreTimeSeries":
1271
+ if time_range.start is None:
1272
+ raise BadRequest("Start time is required")
1273
+
1274
+ dialect = info.context.db.dialect
1275
+ utc_offset_minutes = 0
1276
+ field: Literal["minute", "hour", "day", "week", "month", "year"] = "hour"
1277
+ if time_bin_config:
1278
+ utc_offset_minutes = time_bin_config.utc_offset_minutes
1279
+ if time_bin_config.scale is TimeBinScale.MINUTE:
1280
+ field = "minute"
1281
+ elif time_bin_config.scale is TimeBinScale.HOUR:
1282
+ field = "hour"
1283
+ elif time_bin_config.scale is TimeBinScale.DAY:
1284
+ field = "day"
1285
+ elif time_bin_config.scale is TimeBinScale.WEEK:
1286
+ field = "week"
1287
+ elif time_bin_config.scale is TimeBinScale.MONTH:
1288
+ field = "month"
1289
+ elif time_bin_config.scale is TimeBinScale.YEAR:
1290
+ field = "year"
1291
+ bucket = date_trunc(dialect, field, models.Trace.start_time, utc_offset_minutes)
1292
+ stmt = (
1293
+ select(
1294
+ bucket,
1295
+ models.SpanAnnotation.name,
1296
+ func.avg(models.SpanAnnotation.score).label("average_score"),
1297
+ )
1298
+ .join_from(
1299
+ models.SpanAnnotation,
1300
+ models.Span,
1301
+ onclause=models.SpanAnnotation.span_rowid == models.Span.id,
1302
+ )
1303
+ .join_from(
1304
+ models.Span,
1305
+ models.Trace,
1306
+ onclause=models.Span.trace_rowid == models.Trace.id,
1307
+ )
1308
+ .where(models.Trace.project_rowid == self.id)
1309
+ .group_by(bucket, models.SpanAnnotation.name)
1310
+ .order_by(bucket)
1311
+ )
1312
+ if time_range:
1313
+ if time_range.start:
1314
+ stmt = stmt.where(time_range.start <= models.Trace.start_time)
1315
+ if time_range.end:
1316
+ stmt = stmt.where(models.Trace.start_time < time_range.end)
1317
+ scores: dict[datetime, dict[str, float]] = {}
1318
+ unique_names: set[str] = set()
709
1319
  async with info.context.db() as session:
710
- data = await session.stream(stmt)
711
- return SpanCountTimeSeries(
712
- data=[
713
- TimeSeriesDataPoint(
714
- timestamp=_as_datetime(t),
715
- value=v,
716
- )
717
- async for t, v in data
718
- ]
1320
+ async for (
1321
+ t,
1322
+ name,
1323
+ average_score,
1324
+ ) in await session.stream(stmt):
1325
+ if average_score is None:
1326
+ continue
1327
+ timestamp = _as_datetime(t)
1328
+ if timestamp not in scores:
1329
+ scores[timestamp] = {}
1330
+ scores[timestamp][name] = average_score
1331
+ unique_names.add(name)
1332
+
1333
+ score_timestamps: list[datetime] = [timestamp for timestamp in scores]
1334
+ min_time = min([*score_timestamps, time_range.start])
1335
+ max_time = max(
1336
+ [
1337
+ *score_timestamps,
1338
+ *([time_range.end] if time_range.end else [datetime.now(timezone.utc)]),
1339
+ ],
1340
+ )
1341
+ data: dict[datetime, SpanAnnotationScoreTimeSeriesDataPoint] = {
1342
+ timestamp: SpanAnnotationScoreTimeSeriesDataPoint(
1343
+ timestamp=timestamp,
1344
+ scores_with_labels=[
1345
+ SpanAnnotationScoreWithLabel(label=label, score=scores[timestamp][label])
1346
+ for label in scores[timestamp]
1347
+ ],
719
1348
  )
1349
+ for timestamp in score_timestamps
1350
+ }
1351
+ for timestamp in get_timestamp_range(
1352
+ start_time=min_time,
1353
+ end_time=max_time,
1354
+ stride=field,
1355
+ utc_offset_minutes=utc_offset_minutes,
1356
+ ):
1357
+ if timestamp not in data:
1358
+ data[timestamp] = SpanAnnotationScoreTimeSeriesDataPoint(
1359
+ timestamp=timestamp,
1360
+ scores_with_labels=[],
1361
+ )
1362
+ return SpanAnnotationScoreTimeSeries(
1363
+ data=sorted(data.values(), key=lambda x: x.timestamp),
1364
+ names=sorted(list(unique_names)),
1365
+ )
1366
+
1367
+ @strawberry.field
1368
+ async def top_models_by_cost(
1369
+ self,
1370
+ info: Info[Context, None],
1371
+ time_range: TimeRange,
1372
+ ) -> list[GenerativeModel]:
1373
+ if time_range.start is None:
1374
+ raise BadRequest("Start time is required")
1375
+
1376
+ async with info.context.db() as session:
1377
+ stmt = (
1378
+ select(
1379
+ models.GenerativeModel,
1380
+ func.sum(models.SpanCost.total_tokens).label("total_tokens"),
1381
+ func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
1382
+ func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
1383
+ func.sum(models.SpanCost.total_cost).label("total_cost"),
1384
+ func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
1385
+ func.sum(models.SpanCost.completion_cost).label("completion_cost"),
1386
+ )
1387
+ .join(
1388
+ models.SpanCost,
1389
+ models.SpanCost.model_id == models.GenerativeModel.id,
1390
+ )
1391
+ .join(
1392
+ models.Trace,
1393
+ models.SpanCost.trace_rowid == models.Trace.id,
1394
+ )
1395
+ .where(models.Trace.project_rowid == self.id)
1396
+ .where(models.SpanCost.model_id.isnot(None))
1397
+ .where(models.SpanCost.span_start_time >= time_range.start)
1398
+ .group_by(models.GenerativeModel.id)
1399
+ .order_by(func.sum(models.SpanCost.total_cost).desc())
1400
+ )
1401
+ if time_range.end:
1402
+ stmt = stmt.where(models.SpanCost.span_start_time < time_range.end)
1403
+ results: list[GenerativeModel] = []
1404
+ async for (
1405
+ model,
1406
+ total_tokens,
1407
+ prompt_tokens,
1408
+ completion_tokens,
1409
+ total_cost,
1410
+ prompt_cost,
1411
+ completion_cost,
1412
+ ) in await session.stream(stmt):
1413
+ cost_summary = SpanCostSummary(
1414
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
1415
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
1416
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
1417
+ )
1418
+ cache_time_range = TimeRange(
1419
+ start=time_range.start,
1420
+ end=time_range.end,
1421
+ )
1422
+ gql_model = GenerativeModel(id=model.id, db_record=model)
1423
+ gql_model.add_cached_cost_summary(self.id, cache_time_range, cost_summary)
1424
+ results.append(gql_model)
1425
+ return results
1426
+
1427
+ @strawberry.field
1428
+ async def top_models_by_token_count(
1429
+ self,
1430
+ info: Info[Context, None],
1431
+ time_range: TimeRange,
1432
+ ) -> list[GenerativeModel]:
1433
+ if time_range.start is None:
1434
+ raise BadRequest("Start time is required")
1435
+
1436
+ async with info.context.db() as session:
1437
+ stmt = (
1438
+ select(
1439
+ models.GenerativeModel,
1440
+ func.sum(models.SpanCost.total_tokens).label("total_tokens"),
1441
+ func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
1442
+ func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
1443
+ func.sum(models.SpanCost.total_cost).label("total_cost"),
1444
+ func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
1445
+ func.sum(models.SpanCost.completion_cost).label("completion_cost"),
1446
+ )
1447
+ .join(
1448
+ models.SpanCost,
1449
+ models.SpanCost.model_id == models.GenerativeModel.id,
1450
+ )
1451
+ .join(
1452
+ models.Trace,
1453
+ models.SpanCost.trace_rowid == models.Trace.id,
1454
+ )
1455
+ .where(models.Trace.project_rowid == self.id)
1456
+ .where(models.SpanCost.model_id.isnot(None))
1457
+ .where(models.SpanCost.span_start_time >= time_range.start)
1458
+ .group_by(models.GenerativeModel.id)
1459
+ .order_by(func.sum(models.SpanCost.total_tokens).desc())
1460
+ )
1461
+ if time_range.end:
1462
+ stmt = stmt.where(models.SpanCost.span_start_time < time_range.end)
1463
+ results: list[GenerativeModel] = []
1464
+ async for (
1465
+ model,
1466
+ total_tokens,
1467
+ prompt_tokens,
1468
+ completion_tokens,
1469
+ total_cost,
1470
+ prompt_cost,
1471
+ completion_cost,
1472
+ ) in await session.stream(stmt):
1473
+ cost_summary = SpanCostSummary(
1474
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
1475
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
1476
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
1477
+ )
1478
+ cache_time_range = TimeRange(
1479
+ start=time_range.start,
1480
+ end=time_range.end,
1481
+ )
1482
+ gql_model = GenerativeModel(id=model.id, db_record=model)
1483
+ gql_model.add_cached_cost_summary(self.id, cache_time_range, cost_summary)
1484
+ results.append(gql_model)
1485
+ return results
1486
+
1487
+
1488
+ @strawberry.type
1489
+ class SpanCountTimeSeriesDataPoint:
1490
+ timestamp: datetime
1491
+ ok_count: Optional[int] = None
1492
+ error_count: Optional[int] = None
1493
+ unset_count: Optional[int] = None
1494
+ total_count: Optional[int] = None
1495
+
1496
+
1497
+ @strawberry.type
1498
+ class SpanCountTimeSeries:
1499
+ data: list[SpanCountTimeSeriesDataPoint]
1500
+
1501
+
1502
+ @strawberry.type
1503
+ class TraceCountTimeSeries(TimeSeries):
1504
+ """A time series of trace count"""
1505
+
1506
+
1507
+ @strawberry.type
1508
+ class TraceCountByStatusTimeSeriesDataPoint:
1509
+ timestamp: datetime
1510
+ ok_count: int
1511
+ error_count: int
1512
+ total_count: int
1513
+
1514
+
1515
+ @strawberry.type
1516
+ class TraceCountByStatusTimeSeries:
1517
+ data: list[TraceCountByStatusTimeSeriesDataPoint]
1518
+
1519
+
1520
+ @strawberry.type
1521
+ class TraceLatencyMsPercentileTimeSeriesDataPoint:
1522
+ timestamp: datetime
1523
+ p50: Optional[float] = None
1524
+ p75: Optional[float] = None
1525
+ p90: Optional[float] = None
1526
+ p95: Optional[float] = None
1527
+ p99: Optional[float] = None
1528
+ p999: Optional[float] = None
1529
+ max: Optional[float] = None
1530
+
1531
+
1532
+ @strawberry.type
1533
+ class TraceLatencyPercentileTimeSeries:
1534
+ data: list[TraceLatencyMsPercentileTimeSeriesDataPoint]
1535
+
1536
+
1537
+ @strawberry.type
1538
+ class TraceTokenCountTimeSeriesDataPoint:
1539
+ timestamp: datetime
1540
+ prompt_token_count: Optional[float] = None
1541
+ completion_token_count: Optional[float] = None
1542
+ total_token_count: Optional[float] = None
1543
+
1544
+
1545
+ @strawberry.type
1546
+ class TraceTokenCountTimeSeries:
1547
+ data: list[TraceTokenCountTimeSeriesDataPoint]
1548
+
1549
+
1550
+ @strawberry.type
1551
+ class TraceTokenCostTimeSeriesDataPoint:
1552
+ timestamp: datetime
1553
+ prompt_cost: Optional[float] = None
1554
+ completion_cost: Optional[float] = None
1555
+ total_cost: Optional[float] = None
1556
+
1557
+
1558
+ @strawberry.type
1559
+ class TraceTokenCostTimeSeries:
1560
+ data: list[TraceTokenCostTimeSeriesDataPoint]
1561
+
1562
+
1563
+ @strawberry.type
1564
+ class SpanAnnotationScoreWithLabel:
1565
+ label: str
1566
+ score: float
720
1567
 
721
1568
 
722
1569
  @strawberry.type
723
- class SpanCountTimeSeries(TimeSeries):
724
- """A time series of span count"""
1570
+ class SpanAnnotationScoreTimeSeriesDataPoint:
1571
+ timestamp: datetime
1572
+ scores_with_labels: list[SpanAnnotationScoreWithLabel]
1573
+
1574
+
1575
+ @strawberry.type
1576
+ class SpanAnnotationScoreTimeSeries:
1577
+ data: list[SpanAnnotationScoreTimeSeriesDataPoint]
1578
+ names: list[str]
725
1579
 
726
1580
 
727
1581
  INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
@@ -732,5 +1586,213 @@ def _as_datetime(value: Any) -> datetime:
732
1586
  if isinstance(value, datetime):
733
1587
  return value
734
1588
  if isinstance(value, str):
735
- return datetime.fromisoformat(value)
1589
+ return cast(datetime, normalize_datetime(datetime.fromisoformat(value), timezone.utc))
736
1590
  raise ValueError(f"Cannot convert {value} to datetime")
1591
+
1592
+
1593
+ async def _paginate_span_by_trace_start_time(
1594
+ db: DbSessionFactory,
1595
+ project_rowid: int,
1596
+ time_range: Optional[TimeRange] = None,
1597
+ first: Optional[int] = DEFAULT_PAGE_SIZE,
1598
+ after: Optional[CursorString] = None,
1599
+ sort: SpanSort = SpanSort(col=SpanColumn.startTime, dir=SortDir.desc),
1600
+ orphan_span_as_root_span: Optional[bool] = True,
1601
+ retries: int = 3,
1602
+ ) -> Connection[Span]:
1603
+ """Return one representative root span per trace, ordered by trace start time.
1604
+
1605
+ **Note**: Despite the function name, cursors are based on trace rowids, not span rowids.
1606
+ This is because we paginate by traces (one span per trace), not individual spans.
1607
+
1608
+ **Important**: The edges list can be empty while has_next_page=True. This happens
1609
+ when traces exist but have no matching root spans. Pagination continues because there
1610
+ may be more traces ahead with spans.
1611
+
1612
+ Args:
1613
+ db: Database session factory.
1614
+ project_rowid: Project ID to query spans from.
1615
+ time_range: Optional time range filter on trace start times.
1616
+ first: Maximum number of edges to return (default: DEFAULT_PAGE_SIZE).
1617
+ after: Cursor for pagination (points to trace position, not span).
1618
+ sort: Sort by trace start time (asc/desc only).
1619
+ orphan_span_as_root_span: Whether to include orphan spans as root spans.
1620
+ True: spans with parent_id=NULL OR pointing to non-existent spans.
1621
+ False: only spans with parent_id=NULL.
1622
+ retries: Maximum number of retry attempts when insufficient edges are found.
1623
+ When traces exist but lack root spans, the function retries pagination
1624
+ to find traces with spans. Set to 0 to disable retries.
1625
+
1626
+ Returns:
1627
+ Connection[Span] with:
1628
+ - edges: At most one Edge per trace (may be empty list).
1629
+ - page_info: Pagination info based on trace positions.
1630
+
1631
+ Key Points:
1632
+ - Traces without root spans produce NO edges
1633
+ - Spans ordered by trace start time, not span start time
1634
+ - Cursors track trace positions for efficient large-scale pagination
1635
+ """
1636
+ # Build base trace query ordered by start time
1637
+ traces = select(
1638
+ models.Trace.id,
1639
+ models.Trace.start_time,
1640
+ ).where(models.Trace.project_rowid == project_rowid)
1641
+ if sort.dir is SortDir.desc:
1642
+ traces = traces.order_by(
1643
+ models.Trace.start_time.desc(),
1644
+ models.Trace.id.desc(),
1645
+ )
1646
+ else:
1647
+ traces = traces.order_by(
1648
+ models.Trace.start_time.asc(),
1649
+ models.Trace.id.asc(),
1650
+ )
1651
+
1652
+ # Apply time range filters
1653
+ if time_range:
1654
+ if time_range.start:
1655
+ traces = traces.where(time_range.start <= models.Trace.start_time)
1656
+ if time_range.end:
1657
+ traces = traces.where(models.Trace.start_time < time_range.end)
1658
+
1659
+ # Apply cursor pagination
1660
+ if after:
1661
+ cursor = Cursor.from_string(after)
1662
+ assert cursor.sort_column
1663
+ compare = operator.lt if sort.dir is SortDir.desc else operator.gt
1664
+ traces = traces.where(
1665
+ compare(
1666
+ tuple_(models.Trace.start_time, models.Trace.id),
1667
+ (cursor.sort_column.value, cursor.rowid),
1668
+ )
1669
+ )
1670
+
1671
+ # Limit for pagination
1672
+ if first:
1673
+ traces = traces.limit(
1674
+ first + 1 # over-fetch by one to determine whether there's a next page
1675
+ )
1676
+ traces_cte = traces.cte()
1677
+
1678
+ # Define join condition for root spans
1679
+ if orphan_span_as_root_span:
1680
+ # Include both NULL parent_id and orphaned spans
1681
+ parent_spans = select(models.Span.span_id).alias("parent_spans")
1682
+ onclause = and_(
1683
+ models.Span.trace_rowid == traces_cte.c.id,
1684
+ or_(
1685
+ models.Span.parent_id.is_(None),
1686
+ ~exists().where(models.Span.parent_id == parent_spans.c.span_id),
1687
+ ),
1688
+ )
1689
+ else:
1690
+ # Only spans with no parent (parent_id is NULL, excludes orphaned spans)
1691
+ onclause = and_(
1692
+ models.Span.trace_rowid == traces_cte.c.id,
1693
+ models.Span.parent_id.is_(None),
1694
+ )
1695
+
1696
+ # Join traces with root spans (left join allows traces without spans)
1697
+ stmt = select(
1698
+ traces_cte.c.id,
1699
+ traces_cte.c.start_time,
1700
+ models.Span.id,
1701
+ ).join_from(
1702
+ traces_cte,
1703
+ models.Span,
1704
+ onclause=onclause,
1705
+ isouter=True,
1706
+ )
1707
+
1708
+ # Order by trace time, then pick earliest span per trace
1709
+ if sort.dir is SortDir.desc:
1710
+ stmt = stmt.order_by(
1711
+ traces_cte.c.start_time.desc(),
1712
+ traces_cte.c.id.desc(),
1713
+ models.Span.start_time.asc(), # earliest span
1714
+ models.Span.id.desc(),
1715
+ )
1716
+ else:
1717
+ stmt = stmt.order_by(
1718
+ traces_cte.c.start_time.asc(),
1719
+ traces_cte.c.id.asc(),
1720
+ models.Span.start_time.asc(), # earliest span
1721
+ models.Span.id.desc(),
1722
+ )
1723
+
1724
+ # Use DISTINCT for PostgreSQL, manual grouping for SQLite
1725
+ if db.dialect is SupportedSQLDialect.POSTGRESQL:
1726
+ stmt = stmt.distinct(traces_cte.c.start_time, traces_cte.c.id)
1727
+ elif db.dialect is SupportedSQLDialect.SQLITE:
1728
+ # too complicated for SQLite, so we rely on groupby() below
1729
+ pass
1730
+ else:
1731
+ assert_never(db.dialect)
1732
+
1733
+ # Process results and build edges
1734
+ edges: list[Edge[Span]] = []
1735
+ start_cursor: Optional[str] = None
1736
+ end_cursor: Optional[str] = None
1737
+ async with db() as session:
1738
+ records = groupby(await session.stream(stmt), key=lambda record: record[:2])
1739
+ async for (trace_rowid, trace_start_time), group in islice(records, first):
1740
+ cursor = Cursor(
1741
+ rowid=trace_rowid,
1742
+ sort_column=CursorSortColumn(
1743
+ type=CursorSortColumnDataType.DATETIME,
1744
+ value=trace_start_time,
1745
+ ),
1746
+ )
1747
+ if start_cursor is None:
1748
+ start_cursor = str(cursor)
1749
+ end_cursor = str(cursor)
1750
+ first_record = group[0]
1751
+ # Only create edge if trace has a root span
1752
+ if (span_rowid := first_record[2]) is not None:
1753
+ edges.append(Edge(node=Span(id=span_rowid), cursor=str(cursor)))
1754
+ has_next_page = True
1755
+ try:
1756
+ await records.__anext__()
1757
+ except StopAsyncIteration:
1758
+ has_next_page = False
1759
+
1760
+ # Retry if we need more edges and more traces exist
1761
+ if first and len(edges) < first and has_next_page:
1762
+ while retries and (num_needed := first - len(edges)) and has_next_page:
1763
+ retries -= 1
1764
+ batch_size = max(first, 1000)
1765
+ more = await _paginate_span_by_trace_start_time(
1766
+ db=db,
1767
+ project_rowid=project_rowid,
1768
+ time_range=time_range,
1769
+ first=batch_size,
1770
+ after=end_cursor,
1771
+ sort=sort,
1772
+ orphan_span_as_root_span=orphan_span_as_root_span,
1773
+ retries=0,
1774
+ )
1775
+ edges.extend(more.edges[:num_needed])
1776
+ start_cursor = start_cursor or more.page_info.start_cursor
1777
+ end_cursor = more.page_info.end_cursor if len(edges) < first else edges[-1].cursor
1778
+ has_next_page = len(more.edges) > num_needed or more.page_info.has_next_page
1779
+
1780
+ return Connection(
1781
+ edges=edges,
1782
+ page_info=PageInfo(
1783
+ start_cursor=start_cursor,
1784
+ end_cursor=end_cursor,
1785
+ has_previous_page=False,
1786
+ has_next_page=has_next_page,
1787
+ ),
1788
+ )
1789
+
1790
+
1791
+ def to_gql_project(project: models.Project) -> Project:
1792
+ """
1793
+ Converts an ORM project to a GraphQL project.
1794
+ """
1795
+ return Project(
1796
+ id=project.id,
1797
+ db_record=project,
1798
+ )