arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
phoenix/db/helpers.py CHANGED
@@ -1,22 +1,28 @@
1
1
  from collections.abc import Callable, Hashable, Iterable
2
+ from datetime import datetime
2
3
  from enum import Enum
3
- from typing import Any, Optional, TypeVar
4
+ from typing import Any, Literal, Optional, Sequence, TypeVar, Union
4
5
 
5
- from openinference.semconv.trace import (
6
- OpenInferenceSpanKindValues,
7
- RerankerAttributes,
8
- SpanAttributes,
9
- )
6
+ import sqlalchemy as sa
10
7
  from sqlalchemy import (
11
- Integer,
8
+ Insert,
12
9
  Select,
13
10
  SQLColumnExpression,
14
11
  and_,
15
12
  case,
16
13
  distinct,
14
+ exists,
17
15
  func,
16
+ insert,
17
+ literal,
18
+ literal_column,
19
+ or_,
18
20
  select,
21
+ util,
19
22
  )
23
+ from sqlalchemy.ext.asyncio import AsyncSession
24
+ from sqlalchemy.orm import QueryableAttribute
25
+ from sqlalchemy.sql.roles import InElementRole
20
26
  from typing_extensions import assert_never
21
27
 
22
28
  from phoenix.config import PLAYGROUND_PROJECT_NAME
@@ -34,30 +40,6 @@ class SupportedSQLDialect(Enum):
34
40
  raise ValueError(f"`{v}` is not a supported SQL backend/dialect.")
35
41
 
36
42
 
37
- def num_docs_col(dialect: SupportedSQLDialect) -> SQLColumnExpression[Integer]:
38
- if dialect is SupportedSQLDialect.POSTGRESQL:
39
- array_length = func.jsonb_array_length
40
- elif dialect is SupportedSQLDialect.SQLITE:
41
- array_length = func.json_array_length
42
- else:
43
- assert_never(dialect)
44
- retrieval_docs = models.Span.attributes[_RETRIEVAL_DOCUMENTS]
45
- num_retrieval_docs = array_length(retrieval_docs)
46
- reranker_docs = models.Span.attributes[_RERANKER_OUTPUT_DOCUMENTS]
47
- num_reranker_docs = array_length(reranker_docs)
48
- return case(
49
- (
50
- func.upper(models.Span.span_kind) == OpenInferenceSpanKindValues.RERANKER.value.upper(),
51
- num_reranker_docs,
52
- ),
53
- else_=num_retrieval_docs,
54
- ).label("num_docs")
55
-
56
-
57
- _RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS.split(".")
58
- _RERANKER_OUTPUT_DOCUMENTS = RerankerAttributes.RERANKER_OUTPUT_DOCUMENTS.split(".")
59
-
60
-
61
43
  def get_eval_trace_ids_for_datasets(*dataset_ids: int) -> Select[tuple[Optional[str]]]:
62
44
  return (
63
45
  select(distinct(models.ExperimentRunAnnotation.trace_id))
@@ -115,51 +97,205 @@ def dedup(
115
97
  return ans
116
98
 
117
99
 
118
- def get_dataset_example_revisions(
100
+ def _build_ranked_revisions_query(
119
101
  dataset_version_id: int,
120
- ) -> Select[tuple[models.DatasetExampleRevision]]:
121
- version = (
102
+ /,
103
+ *,
104
+ dataset_id: Optional[int] = None,
105
+ example_ids: Optional[Union[Sequence[int], InElementRole]] = None,
106
+ ) -> Select[tuple[int]]:
107
+ """
108
+ Build a query that ranks revisions per example within a dataset version.
109
+
110
+ This performs the core ranking logic using ROW_NUMBER() to find the latest
111
+ revision for each example within the specified dataset version.
112
+
113
+ Args:
114
+ dataset_version_id: Maximum dataset version to consider
115
+ dataset_id: Optional dataset ID - if provided, avoids subquery lookup
116
+
117
+ Returns:
118
+ SQLAlchemy SELECT query with revision ranking and basic dataset filtering
119
+ """
120
+ stmt = (
122
121
  select(
123
- models.DatasetVersion.id,
124
- models.DatasetVersion.dataset_id,
122
+ func.row_number()
123
+ .over(
124
+ partition_by=models.DatasetExampleRevision.dataset_example_id,
125
+ order_by=models.DatasetExampleRevision.dataset_version_id.desc(),
126
+ )
127
+ .label("rn"),
125
128
  )
126
- .filter_by(id=dataset_version_id)
127
- .subquery()
129
+ .join(models.DatasetExample)
130
+ .where(models.DatasetExampleRevision.dataset_version_id <= dataset_version_id)
128
131
  )
129
- table = models.DatasetExampleRevision
130
- revision = (
131
- select(
132
- table.dataset_example_id,
133
- func.max(table.dataset_version_id).label("dataset_version_id"),
134
- )
135
- .join_from(
136
- table,
137
- models.DatasetExample,
138
- table.dataset_example_id == models.DatasetExample.id,
139
- )
140
- .join_from(
141
- models.DatasetExample,
142
- version,
143
- models.DatasetExample.dataset_id == version.c.dataset_id,
144
- )
145
- .where(models.DatasetExample.dataset_id == version.c.dataset_id)
146
- .where(table.dataset_version_id <= version.c.id)
147
- .group_by(table.dataset_example_id)
148
- .subquery()
132
+
133
+ if dataset_id is None:
134
+ version_subquery = (
135
+ select(models.DatasetVersion.dataset_id)
136
+ .filter_by(id=dataset_version_id)
137
+ .scalar_subquery()
138
+ )
139
+ stmt = stmt.where(models.DatasetExample.dataset_id == version_subquery)
140
+ else:
141
+ stmt = stmt.where(models.DatasetExample.dataset_id == dataset_id)
142
+
143
+ if example_ids is not None:
144
+ stmt = stmt.where(models.DatasetExampleRevision.dataset_example_id.in_(example_ids))
145
+
146
+ return stmt
147
+
148
+
149
+ def get_dataset_example_revisions(
150
+ dataset_version_id: int,
151
+ /,
152
+ *,
153
+ dataset_id: Optional[int] = None,
154
+ example_ids: Optional[Union[Sequence[int], InElementRole]] = None,
155
+ split_ids: Optional[Union[Sequence[int], InElementRole]] = None,
156
+ split_names: Optional[Union[Sequence[str], InElementRole]] = None,
157
+ ) -> Select[tuple[models.DatasetExampleRevision]]:
158
+ """
159
+ Get the latest revisions for all dataset examples within a specific dataset version.
160
+
161
+ Excludes examples where the latest revision is a DELETE.
162
+
163
+ Args:
164
+ dataset_version_id: The dataset version to get revisions for
165
+ dataset_id: Optional dataset ID - if provided, avoids extra subquery lookup
166
+ example_ids: Optional filter by specific example IDs (subquery or list of IDs).
167
+ - None = no filtering
168
+ - Empty sequences/subqueries = no matches (strict filtering)
169
+ split_ids: Optional filter by split IDs (subquery or list of split IDs).
170
+ - None = no filtering
171
+ - Empty sequences/subqueries = no matches (strict filtering)
172
+ split_names: Optional filter by split names (subquery or list of split names).
173
+ - None = no filtering
174
+ - Empty sequences/subqueries = no matches (strict filtering)
175
+
176
+ Note:
177
+ - split_ids and split_names are mutually exclusive
178
+ - Use split_ids for better performance when IDs are available (avoids JOIN)
179
+ - Empty filters use strict behavior: empty inputs return zero results
180
+ """
181
+ if split_ids is not None and split_names is not None:
182
+ raise ValueError(
183
+ "Cannot specify both split_ids and split_names - they are mutually exclusive"
184
+ )
185
+
186
+ stmt = _build_ranked_revisions_query(
187
+ dataset_version_id,
188
+ dataset_id=dataset_id,
189
+ example_ids=example_ids,
190
+ ).add_columns(
191
+ models.DatasetExampleRevision.id,
192
+ models.DatasetExampleRevision.revision_kind,
149
193
  )
194
+
195
+ if split_ids is not None or split_names is not None:
196
+ if split_names is not None:
197
+ split_example_ids_subquery = (
198
+ select(models.DatasetSplitDatasetExample.dataset_example_id)
199
+ .join(
200
+ models.DatasetSplit,
201
+ models.DatasetSplit.id == models.DatasetSplitDatasetExample.dataset_split_id,
202
+ )
203
+ .where(models.DatasetSplit.name.in_(split_names))
204
+ )
205
+ stmt = stmt.where(models.DatasetExample.id.in_(split_example_ids_subquery))
206
+ else:
207
+ assert split_ids is not None
208
+ split_example_ids_subquery = select(
209
+ models.DatasetSplitDatasetExample.dataset_example_id
210
+ ).where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_ids))
211
+ stmt = stmt.where(models.DatasetExample.id.in_(split_example_ids_subquery))
212
+
213
+ ranked_subquery = stmt.subquery()
150
214
  return (
151
- select(table)
152
- .where(table.revision_kind != "DELETE")
215
+ select(models.DatasetExampleRevision)
153
216
  .join(
154
- revision,
155
- onclause=and_(
156
- revision.c.dataset_example_id == table.dataset_example_id,
157
- revision.c.dataset_version_id == table.dataset_version_id,
158
- ),
217
+ ranked_subquery,
218
+ models.DatasetExampleRevision.id == ranked_subquery.c.id,
219
+ )
220
+ .where(
221
+ ranked_subquery.c.rn == 1,
222
+ ranked_subquery.c.revision_kind != "DELETE",
159
223
  )
160
224
  )
161
225
 
162
226
 
227
+ def create_experiment_examples_snapshot_insert(
228
+ experiment: models.Experiment,
229
+ ) -> Insert:
230
+ """
231
+ Create an INSERT statement to snapshot dataset examples for an experiment.
232
+
233
+ This captures which examples belong to the experiment at the time of creation,
234
+ respecting any dataset splits assigned to the experiment.
235
+
236
+ Args:
237
+ experiment: The experiment to create the snapshot for
238
+
239
+ Returns:
240
+ SQLAlchemy INSERT statement ready for execution
241
+ """
242
+ stmt = _build_ranked_revisions_query(
243
+ experiment.dataset_version_id,
244
+ dataset_id=experiment.dataset_id,
245
+ ).add_columns(
246
+ models.DatasetExampleRevision.id,
247
+ models.DatasetExampleRevision.dataset_example_id,
248
+ models.DatasetExampleRevision.revision_kind,
249
+ )
250
+
251
+ experiment_splits_subquery = select(models.ExperimentDatasetSplit.dataset_split_id).where(
252
+ models.ExperimentDatasetSplit.experiment_id == experiment.id
253
+ )
254
+ has_splits_condition = exists(experiment_splits_subquery)
255
+ split_filtered_example_ids = select(models.DatasetSplitDatasetExample.dataset_example_id).where(
256
+ models.DatasetSplitDatasetExample.dataset_split_id.in_(experiment_splits_subquery)
257
+ )
258
+
259
+ stmt = stmt.where(
260
+ or_(
261
+ ~has_splits_condition, # No splits = include all examples
262
+ models.DatasetExampleRevision.dataset_example_id.in_(
263
+ split_filtered_example_ids
264
+ ), # Has splits = filter by splits
265
+ )
266
+ )
267
+
268
+ ranked_subquery = stmt.subquery()
269
+ return insert(models.ExperimentDatasetExample).from_select(
270
+ [
271
+ models.ExperimentDatasetExample.experiment_id,
272
+ models.ExperimentDatasetExample.dataset_example_id,
273
+ models.ExperimentDatasetExample.dataset_example_revision_id,
274
+ ],
275
+ select(
276
+ literal(experiment.id),
277
+ ranked_subquery.c.dataset_example_id,
278
+ ranked_subquery.c.id,
279
+ ).where(
280
+ ranked_subquery.c.rn == 1,
281
+ ranked_subquery.c.revision_kind != "DELETE",
282
+ ),
283
+ )
284
+
285
+
286
+ async def insert_experiment_with_examples_snapshot(
287
+ session: AsyncSession,
288
+ experiment: models.Experiment,
289
+ ) -> None:
290
+ """
291
+ Insert an experiment with its snapshot of dataset examples.
292
+ """
293
+ session.add(experiment)
294
+ await session.flush()
295
+ insert_stmt = create_experiment_examples_snapshot_insert(experiment)
296
+ await session.execute(insert_stmt)
297
+
298
+
163
299
  _AnyTuple = TypeVar("_AnyTuple", bound=tuple[Any, ...])
164
300
 
165
301
 
@@ -173,3 +309,802 @@ def exclude_experiment_projects(
173
309
  models.Experiment.project_name != PLAYGROUND_PROJECT_NAME,
174
310
  ),
175
311
  ).where(models.Experiment.project_name.is_(None))
312
+
313
+
314
+ def date_trunc(
315
+ dialect: SupportedSQLDialect,
316
+ field: Literal["minute", "hour", "day", "week", "month", "year"],
317
+ source: Union[QueryableAttribute[datetime], sa.TextClause],
318
+ utc_offset_minutes: int = 0,
319
+ ) -> SQLColumnExpression[datetime]:
320
+ """
321
+ Truncate a datetime to the specified field with optional UTC offset adjustment.
322
+
323
+ This function provides a cross-dialect way to truncate datetime values to a specific
324
+ time unit (minute, hour, day, week, month, or year). It handles UTC offset conversion
325
+ by applying the offset before truncation and then converting back to UTC.
326
+
327
+ Args:
328
+ dialect: The SQL dialect to use (PostgreSQL or SQLite).
329
+ field: The time unit to truncate to. Valid values are:
330
+ - "minute": Truncate to the start of the minute (seconds set to 0)
331
+ - "hour": Truncate to the start of the hour (minutes and seconds set to 0)
332
+ - "day": Truncate to the start of the day (time set to 00:00:00)
333
+ - "week": Truncate to the start of the week (Monday at 00:00:00)
334
+ - "month": Truncate to the first day of the month (day set to 1, time to 00:00:00)
335
+ - "year": Truncate to the first day of the year (date set to Jan 1, time to 00:00:00)
336
+ source: The datetime column or expression to truncate.
337
+ utc_offset_minutes: UTC offset in minutes to apply before truncation.
338
+ Positive values represent time zones ahead of UTC (e.g., +60 for UTC+1).
339
+ Negative values represent time zones behind UTC (e.g., -300 for UTC-5).
340
+ Defaults to 0 (no offset).
341
+
342
+ Returns:
343
+ A SQL column expression representing the truncated datetime in UTC.
344
+
345
+ Note:
346
+ - For PostgreSQL, uses the native `date_trunc` function with timezone support.
347
+ - For SQLite, implements custom truncation logic using datetime functions.
348
+ - Week truncation starts on Monday (ISO 8601 standard).
349
+ - The result is always returned in UTC, regardless of the input offset.
350
+
351
+ Examples:
352
+ >>> # Truncate to hour with no offset
353
+ >>> date_trunc(SupportedSQLDialect.POSTGRESQL, "hour", Span.start_time)
354
+
355
+ >>> # Truncate to day with UTC-5 offset (Eastern Time)
356
+ >>> date_trunc(SupportedSQLDialect.SQLITE, "day", Span.start_time, -300)
357
+ """
358
+ if dialect is SupportedSQLDialect.POSTGRESQL:
359
+ # Note: the usage of the timezone parameter in the form of e.g. "+05:00"
360
+ # appears to be an undocumented feature of PostgreSQL's date_trunc function.
361
+ # Below is an example query and its output executed on PostgreSQL v12 and v17.
362
+ # SELECT date_trunc('day', TIMESTAMP WITH TIME ZONE '2001-02-16 15:38:40-05'),
363
+ # date_trunc('day', TIMESTAMP WITH TIME ZONE '2001-02-16 20:38:40+00', '+05:00'),
364
+ # date_trunc('day', TIMESTAMP WITH TIME ZONE '2001-02-16 20:38:40+00', '-05:00');
365
+ # ┌────────────────────────┬────────────────────────┬────────────────────────┐
366
+ # │ date_trunc │ date_trunc │ date_trunc │
367
+ # ├────────────────────────┼────────────────────────┼────────────────────────┤
368
+ # │ 2001-02-16 00:00:00+00 │ 2001-02-16 05:00:00+00 │ 2001-02-16 19:00:00+00 │
369
+ # └────────────────────────┴────────────────────────┴────────────────────────┘
370
+ # (1 row)
371
+ sign = "-" if utc_offset_minutes >= 0 else "+"
372
+ timezone = f"{sign}{abs(utc_offset_minutes) // 60}:{abs(utc_offset_minutes) % 60:02d}"
373
+ return sa.func.date_trunc(field, source, timezone)
374
+ elif dialect is SupportedSQLDialect.SQLITE:
375
+ return _date_trunc_for_sqlite(field, source, utc_offset_minutes)
376
+ else:
377
+ assert_never(dialect)
378
+
379
+
380
+ def _date_trunc_for_sqlite(
381
+ field: Literal["minute", "hour", "day", "week", "month", "year"],
382
+ source: Union[QueryableAttribute[datetime], sa.TextClause],
383
+ utc_offset_minutes: int = 0,
384
+ ) -> SQLColumnExpression[datetime]:
385
+ """
386
+ SQLite-specific implementation of datetime truncation with UTC offset handling.
387
+
388
+ This private helper function implements date truncation for SQLite databases, which
389
+ lack a native date_trunc function. It uses SQLite's datetime and strftime functions
390
+ to achieve the same result as PostgreSQL's date_trunc function.
391
+
392
+ Args:
393
+ field: The time unit to truncate to. Valid values are:
394
+ - "minute": Truncate to the start of the minute (seconds set to 0)
395
+ - "hour": Truncate to the start of the hour (minutes and seconds set to 0)
396
+ - "day": Truncate to the start of the day (time set to 00:00:00)
397
+ - "week": Truncate to the start of the week (Monday at 00:00:00)
398
+ - "month": Truncate to the first day of the month (day set to 1, time to 00:00:00)
399
+ - "year": Truncate to the first day of the year (date set to Jan 1, time to 00:00:00)
400
+ source: The datetime column or expression to truncate.
401
+ utc_offset_minutes: UTC offset in minutes to apply before truncation.
402
+ Positive values represent time zones ahead of UTC (e.g., +60 for UTC+1).
403
+ Negative values represent time zones behind UTC (e.g., -300 for UTC-5).
404
+
405
+ Returns:
406
+ A SQL column expression representing the truncated datetime in UTC.
407
+
408
+ Implementation Details:
409
+ - Uses SQLite's strftime() function to format and extract date components
410
+ - Applies UTC offset before truncation using datetime(source, "N minutes")
411
+ - Converts result back to UTC by subtracting the offset
412
+ - Week truncation uses day-of-week calculations where:
413
+ * strftime('%w') returns 0=Sunday, 1=Monday, ..., 6=Saturday
414
+ * Truncates to Monday (start of week) using case-based day adjustments
415
+ - Month/year truncation reconstructs dates using extracted components
416
+
417
+ Raises:
418
+ ValueError: If the field parameter is not one of the supported values.
419
+
420
+ Note:
421
+ This is a private helper function intended only for use by the date_trunc function
422
+ when the dialect is SupportedSQLDialect.SQLITE.
423
+ """
424
+ # SQLite does not have a built-in date truncation function, so we use datetime functions
425
+ # First apply UTC offset, then truncate
426
+ offset_source = func.datetime(source, f"{utc_offset_minutes} minutes")
427
+
428
+ if field == "minute":
429
+ t = func.datetime(func.strftime("%Y-%m-%d %H:%M:00", offset_source))
430
+ elif field == "hour":
431
+ t = func.datetime(func.strftime("%Y-%m-%d %H:00:00", offset_source))
432
+ elif field == "day":
433
+ t = func.datetime(func.strftime("%Y-%m-%d 00:00:00", offset_source))
434
+ elif field == "week":
435
+ # Truncate to Monday (start of week)
436
+ # SQLite strftime('%w') returns: 0=Sunday, 1=Monday, ..., 6=Saturday
437
+ dow = func.strftime("%w", offset_source)
438
+ t = func.datetime(
439
+ case(
440
+ (dow == "0", func.date(offset_source, "-6 days")), # Sunday -> go back 6 days
441
+ (dow == "1", func.date(offset_source, "+0 days")), # Monday -> stay
442
+ (dow == "2", func.date(offset_source, "-1 days")), # Tuesday -> go back 1 day
443
+ (dow == "3", func.date(offset_source, "-2 days")), # Wednesday -> go back 2 days
444
+ (dow == "4", func.date(offset_source, "-3 days")), # Thursday -> go back 3 days
445
+ (dow == "5", func.date(offset_source, "-4 days")), # Friday -> go back 4 days
446
+ (dow == "6", func.date(offset_source, "-5 days")), # Saturday -> go back 5 days
447
+ ),
448
+ "00:00:00",
449
+ )
450
+ elif field == "month":
451
+ # Extract year and month, then construct first day of month
452
+ year = func.strftime("%Y", offset_source)
453
+ month = func.strftime("%m", offset_source)
454
+ t = func.datetime(year + "-" + month + "-01 00:00:00")
455
+ elif field == "year":
456
+ # Extract year, then construct first day of year
457
+ year = func.strftime("%Y", offset_source)
458
+ t = func.datetime(year + "-01-01 00:00:00")
459
+ else:
460
+ raise ValueError(f"Unsupported field for date truncation: {field}")
461
+
462
+ # Convert back to UTC by subtracting the offset
463
+ return func.datetime(t, f"{-utc_offset_minutes} minutes")
464
+
465
+
466
+ def get_ancestor_span_rowids(parent_id: str) -> Select[tuple[int]]:
467
+ """
468
+ Get all ancestor span IDs for a given parent_id using recursive CTE.
469
+
470
+ This function returns a query that finds all ancestors of a span with the given parent_id.
471
+ It uses a recursive Common Table Expression (CTE) to traverse up the span hierarchy.
472
+
473
+ Args:
474
+ parent_id: The span_id of the parent span to start the ancestor search from.
475
+
476
+ Returns:
477
+ A Select query that returns tuples of (span_id,) for all ancestor spans.
478
+ """
479
+ ancestors = (
480
+ select(models.Span.id, models.Span.parent_id)
481
+ .where(models.Span.span_id == parent_id)
482
+ .cte(recursive=True)
483
+ )
484
+ child = ancestors.alias()
485
+ ancestors = ancestors.union_all(
486
+ select(models.Span.id, models.Span.parent_id).join(
487
+ child, models.Span.span_id == child.c.parent_id
488
+ )
489
+ )
490
+ return select(ancestors.c.id)
491
+
492
+
493
+ def truncate_name(name: str, max_len: int = 63) -> str:
494
+ # https://github.com/sqlalchemy/sqlalchemy/blob/e263825e3c5060bf4f47eed0e833c6660a31658e/lib/sqlalchemy/sql/compiler.py#L7844-L7845
495
+ if len(name) > max_len:
496
+ return name[0 : max_len - 8] + "_" + util.md5_hex(name)[-4:]
497
+ return name
498
+
499
+
500
+ def get_successful_run_counts_subquery(
501
+ experiment_id: int,
502
+ repetitions: int,
503
+ ) -> Any:
504
+ """
505
+ Build a subquery that counts successful runs per dataset example for an experiment.
506
+
507
+ This subquery outer joins experiment dataset examples with their runs, counting only
508
+ successful runs (runs that exist and have no error). The HAVING clause filters to only
509
+ include examples with fewer successful runs than the total repetitions required.
510
+
511
+ Args:
512
+ experiment_id: The experiment ID to query runs for
513
+ repetitions: The number of repetitions required per example
514
+
515
+ Returns:
516
+ SQLAlchemy subquery with columns:
517
+ - dataset_example_revision_id: ID of the example revision
518
+ - dataset_example_id: ID of the dataset example
519
+ - successful_count: Count of successful runs for this example
520
+ """
521
+ # Use CASE to count only successful runs (run exists AND error IS NULL)
522
+ # Important: Must check that run exists (id IS NOT NULL) to distinguish
523
+ # "no run" from "successful run" in the outer join
524
+ successful_run_case = case(
525
+ (
526
+ and_(
527
+ models.ExperimentRun.id.is_not(None), # Run exists
528
+ models.ExperimentRun.error.is_(None), # No error (successful)
529
+ ),
530
+ 1,
531
+ ),
532
+ else_=0,
533
+ )
534
+
535
+ return (
536
+ select(
537
+ models.ExperimentDatasetExample.dataset_example_revision_id,
538
+ models.ExperimentDatasetExample.dataset_example_id,
539
+ func.sum(successful_run_case).label("successful_count"),
540
+ )
541
+ .select_from(models.ExperimentDatasetExample)
542
+ .outerjoin(
543
+ models.ExperimentRun,
544
+ and_(
545
+ models.ExperimentRun.experiment_id == experiment_id,
546
+ models.ExperimentRun.dataset_example_id
547
+ == models.ExperimentDatasetExample.dataset_example_id,
548
+ ),
549
+ )
550
+ .where(models.ExperimentDatasetExample.experiment_id == experiment_id)
551
+ .group_by(
552
+ models.ExperimentDatasetExample.dataset_example_revision_id,
553
+ models.ExperimentDatasetExample.dataset_example_id,
554
+ )
555
+ .having(
556
+ # Only include incomplete examples (successful_count < repetitions)
557
+ func.coalesce(func.sum(successful_run_case), 0) < repetitions
558
+ )
559
+ .subquery()
560
+ )
561
+
562
+
563
+ def generate_expected_repetitions_cte(
564
+ dialect: SupportedSQLDialect,
565
+ run_counts_subquery: Any,
566
+ repetitions: int,
567
+ ) -> Any:
568
+ """
569
+ Generate a CTE that produces all expected repetition numbers for partially complete examples.
570
+
571
+ This generates a sequence of repetition numbers [1..repetitions] for each example that has
572
+ at least one successful run (0 < successful_count < repetitions). The implementation varies
573
+ by SQL dialect.
574
+
575
+ Args:
576
+ dialect: The SQL dialect to use (PostgreSQL or SQLite)
577
+ run_counts_subquery: Subquery from get_successful_run_counts_subquery containing
578
+ dataset_example_revision_id, dataset_example_id, and successful_count columns
579
+ repetitions: The total number of repetitions required
580
+
581
+ Returns:
582
+ SQLAlchemy CTE with columns:
583
+ - dataset_example_revision_id: ID of the example revision
584
+ - dataset_example_id: ID of the dataset example
585
+ - successful_count: Count of successful runs for this example
586
+ - repetition_number: Expected repetition number (1..repetitions)
587
+
588
+ Note:
589
+ - For PostgreSQL: Uses generate_series function
590
+ - For SQLite: Uses recursive CTE to generate the sequence
591
+ """
592
+ if dialect is SupportedSQLDialect.POSTGRESQL:
593
+ # Generate expected repetition numbers only for partially complete examples
594
+ # Use func.generate_series with direct parameter - SQLAlchemy handles this safely
595
+ return (
596
+ select(
597
+ run_counts_subquery.c.dataset_example_revision_id,
598
+ run_counts_subquery.c.dataset_example_id,
599
+ run_counts_subquery.c.successful_count,
600
+ func.generate_series(1, repetitions).label("repetition_number"),
601
+ )
602
+ .select_from(run_counts_subquery)
603
+ .where(run_counts_subquery.c.successful_count > 0) # Only partially complete!
604
+ .cte("expected_runs")
605
+ )
606
+ elif dialect is SupportedSQLDialect.SQLITE:
607
+ # Recursive CTE only for partially complete examples
608
+ expected_runs_cte = (
609
+ select(
610
+ run_counts_subquery.c.dataset_example_revision_id,
611
+ run_counts_subquery.c.dataset_example_id,
612
+ run_counts_subquery.c.successful_count,
613
+ literal_column("1").label("repetition_number"),
614
+ )
615
+ .select_from(run_counts_subquery)
616
+ .where(run_counts_subquery.c.successful_count > 0) # Only partially complete!
617
+ .cte("expected_runs", recursive=True)
618
+ )
619
+
620
+ # Recursive part: increment repetition_number up to repetitions
621
+ expected_runs_recursive = expected_runs_cte.union_all(
622
+ select(
623
+ expected_runs_cte.c.dataset_example_revision_id,
624
+ expected_runs_cte.c.dataset_example_id,
625
+ expected_runs_cte.c.successful_count,
626
+ (expected_runs_cte.c.repetition_number + 1).label("repetition_number"),
627
+ ).where(expected_runs_cte.c.repetition_number < repetitions)
628
+ )
629
+
630
+ return expected_runs_recursive
631
+ else:
632
+ assert_never(dialect)
633
+
634
+
635
+ def get_incomplete_repetitions_query(
636
+ dialect: SupportedSQLDialect,
637
+ expected_runs_cte: Any,
638
+ experiment_id: int,
639
+ ) -> Select[tuple[Any, Any, Any]]:
640
+ """
641
+ Build a query that finds incomplete repetitions for partially complete examples.
642
+
643
+ This query outer joins the expected repetition numbers with actual successful runs to find
644
+ which repetitions are missing or failed. It aggregates the incomplete repetitions into an
645
+ array or JSON array depending on the dialect.
646
+
647
+ Args:
648
+ dialect: The SQL dialect to use (PostgreSQL or SQLite)
649
+ expected_runs_cte: CTE from generate_expected_repetitions_cte containing expected
650
+ repetition numbers for partially complete examples
651
+ experiment_id: The experiment ID to query runs for
652
+
653
+ Returns:
654
+ SQLAlchemy SELECT query with columns:
655
+ - dataset_example_revision_id: ID of the example revision
656
+ - successful_count: Count of successful runs for this example
657
+ - incomplete_reps: Array/JSON array of incomplete repetition numbers
658
+
659
+ Note:
660
+ - For PostgreSQL: Returns an array using array_agg
661
+ - For SQLite: Returns a JSON string using json_group_array
662
+ """
663
+ if dialect is SupportedSQLDialect.POSTGRESQL:
664
+ agg_func = func.coalesce(
665
+ func.array_agg(expected_runs_cte.c.repetition_number),
666
+ literal_column("ARRAY[]::int[]"),
667
+ )
668
+ elif dialect is SupportedSQLDialect.SQLITE:
669
+ agg_func = func.coalesce(
670
+ func.json_group_array(expected_runs_cte.c.repetition_number),
671
+ literal_column("'[]'"),
672
+ )
673
+ else:
674
+ assert_never(dialect)
675
+
676
+ # Find incomplete runs for partially complete examples
677
+ return (
678
+ select(
679
+ expected_runs_cte.c.dataset_example_revision_id,
680
+ expected_runs_cte.c.successful_count,
681
+ agg_func.label("incomplete_reps"),
682
+ )
683
+ .select_from(expected_runs_cte)
684
+ .outerjoin(
685
+ models.ExperimentRun,
686
+ and_(
687
+ models.ExperimentRun.experiment_id == experiment_id,
688
+ models.ExperimentRun.dataset_example_id == expected_runs_cte.c.dataset_example_id,
689
+ models.ExperimentRun.repetition_number == expected_runs_cte.c.repetition_number,
690
+ # Only join successful runs
691
+ models.ExperimentRun.error.is_(None),
692
+ ),
693
+ )
694
+ .where(
695
+ # Incomplete = no matching run (NULL)
696
+ models.ExperimentRun.id.is_(None)
697
+ )
698
+ .group_by(
699
+ expected_runs_cte.c.dataset_example_revision_id,
700
+ expected_runs_cte.c.successful_count,
701
+ )
702
+ )
703
+
704
+
705
+ def get_incomplete_runs_with_revisions_query(
706
+ incomplete_runs_subquery: Any,
707
+ *,
708
+ cursor_example_rowid: Optional[int] = None,
709
+ limit: Optional[int] = None,
710
+ ) -> Select[tuple[models.DatasetExampleRevision, Any, Any]]:
711
+ """
712
+ Build the main query that joins incomplete runs with dataset example revisions.
713
+
714
+ This query takes a subquery containing incomplete run information and joins it with
715
+ the DatasetExampleRevision table to get the full example data. It also applies
716
+ cursor-based pagination for efficient retrieval of large result sets.
717
+
718
+ Args:
719
+ incomplete_runs_subquery: Subquery with columns:
720
+ - dataset_example_revision_id: ID of the example revision
721
+ - successful_count: Count of successful runs for this example
722
+ - incomplete_reps: Array/JSON array of incomplete repetition numbers
723
+ cursor_example_rowid: Optional cursor position (dataset_example_id) for pagination.
724
+ When provided, only returns examples with ID >= cursor_example_rowid
725
+ limit: Optional maximum number of results to return. If provided, the query
726
+ will fetch limit+1 rows to enable next-page detection
727
+
728
+ Returns:
729
+ SQLAlchemy SELECT query with columns:
730
+ - DatasetExampleRevision: The full revision object
731
+ - successful_count: Count of successful runs
732
+ - incomplete_reps: Array/JSON array of incomplete repetition numbers
733
+
734
+ Note:
735
+ Results are ordered by dataset_example_id ascending for consistent pagination.
736
+ When using limit, fetch one extra row to check if there's a next page.
737
+ """
738
+ stmt = (
739
+ select(
740
+ models.DatasetExampleRevision,
741
+ incomplete_runs_subquery.c.successful_count,
742
+ incomplete_runs_subquery.c.incomplete_reps,
743
+ )
744
+ .select_from(incomplete_runs_subquery)
745
+ .join(
746
+ models.DatasetExampleRevision,
747
+ models.DatasetExampleRevision.id
748
+ == incomplete_runs_subquery.c.dataset_example_revision_id,
749
+ )
750
+ .order_by(models.DatasetExampleRevision.dataset_example_id.asc())
751
+ )
752
+
753
+ # Apply cursor filter in SQL for efficiency with large datasets
754
+ if cursor_example_rowid is not None:
755
+ stmt = stmt.where(models.DatasetExampleRevision.dataset_example_id >= cursor_example_rowid)
756
+
757
+ # Fetch limit+1 to check if there's a next page
758
+ if limit is not None:
759
+ stmt = stmt.limit(limit + 1)
760
+
761
+ return stmt
762
+
763
+
764
+ def get_successful_experiment_runs_query(
765
+ experiment_id: int,
766
+ *,
767
+ cursor_run_rowid: Optional[int] = None,
768
+ limit: Optional[int] = None,
769
+ ) -> Select[tuple[models.ExperimentRun, int]]:
770
+ """
771
+ Build a query for successful experiment runs with their dataset example revision IDs.
772
+
773
+ This query retrieves all experiment runs that completed successfully (error IS NULL)
774
+ and joins them with the ExperimentDatasetExample table to get the revision IDs.
775
+ Results are ordered by run ID ascending for consistent pagination.
776
+
777
+ Args:
778
+ experiment_id: The experiment ID to query runs for
779
+ cursor_run_rowid: Optional cursor position (experiment_run_id) for pagination.
780
+ When provided, only returns runs with ID >= cursor_run_rowid
781
+ limit: Optional maximum number of results to return. If provided, the query
782
+ will fetch limit+1 rows to enable next-page detection
783
+
784
+ Returns:
785
+ SQLAlchemy SELECT query with columns:
786
+ - ExperimentRun: The full experiment run object
787
+ - dataset_example_revision_id: ID of the dataset example revision (int)
788
+
789
+ Note:
790
+ - Only includes successful runs (error IS NULL)
791
+ - Results ordered by run ID ascending for consistent pagination
792
+ - When using limit, fetch one extra row to check if there's a next page
793
+ """
794
+ stmt = (
795
+ select(
796
+ models.ExperimentRun,
797
+ models.ExperimentDatasetExample.dataset_example_revision_id,
798
+ )
799
+ .join(
800
+ models.ExperimentDatasetExample,
801
+ and_(
802
+ models.ExperimentDatasetExample.experiment_id == experiment_id,
803
+ models.ExperimentDatasetExample.dataset_example_id
804
+ == models.ExperimentRun.dataset_example_id,
805
+ ),
806
+ )
807
+ .where(
808
+ and_(
809
+ models.ExperimentRun.experiment_id == experiment_id,
810
+ models.ExperimentRun.error.is_(None), # Only successful task runs
811
+ )
812
+ )
813
+ .order_by(models.ExperimentRun.id.asc())
814
+ )
815
+
816
+ if cursor_run_rowid is not None:
817
+ stmt = stmt.where(models.ExperimentRun.id >= cursor_run_rowid)
818
+
819
+ if limit is not None:
820
+ stmt = stmt.limit(limit + 1)
821
+
822
+ return stmt
823
+
824
+
825
+ def get_experiment_run_annotations_query(
826
+ run_ids: Sequence[int],
827
+ evaluation_names: Sequence[str],
828
+ ) -> Select[tuple[int, str, Optional[str]]]:
829
+ """
830
+ Build a query to get annotations for specific runs and evaluation names.
831
+
832
+ This query retrieves annotations (evaluations) for a set of experiment runs,
833
+ filtered by specific evaluation names. It returns only the essential fields
834
+ needed to determine if an evaluation is complete or has errors.
835
+
836
+ Args:
837
+ run_ids: List of experiment run IDs to query annotations for
838
+ evaluation_names: List of evaluation names to filter by
839
+
840
+ Returns:
841
+ SQLAlchemy SELECT query with columns:
842
+ - experiment_run_id: ID of the experiment run (int)
843
+ - name: Name of the evaluation (str)
844
+ - error: Error message if evaluation failed, None if successful (Optional[str])
845
+
846
+ Example:
847
+ >>> run_ids = [1, 2, 3]
848
+ >>> eval_names = ["relevance", "coherence"]
849
+ >>> query = get_experiment_run_annotations_query(run_ids, eval_names)
850
+ >>> results = await session.execute(query)
851
+ >>> for run_id, name, error in results:
852
+ ... # Process annotations...
853
+ """
854
+ return (
855
+ select(
856
+ models.ExperimentRunAnnotation.experiment_run_id,
857
+ models.ExperimentRunAnnotation.name,
858
+ models.ExperimentRunAnnotation.error,
859
+ )
860
+ .where(models.ExperimentRunAnnotation.experiment_run_id.in_(run_ids))
861
+ .where(models.ExperimentRunAnnotation.name.in_(evaluation_names))
862
+ )
863
+
864
+
865
+ def get_runs_with_incomplete_evaluations_query(
866
+ experiment_id: int,
867
+ evaluation_names: Sequence[str],
868
+ dialect: SupportedSQLDialect,
869
+ *,
870
+ cursor_run_rowid: Optional[int] = None,
871
+ limit: Optional[int] = None,
872
+ include_annotations_and_revisions: bool = False,
873
+ ) -> Select[Any]:
874
+ """
875
+ Get experiment runs that have incomplete evaluations.
876
+
877
+ A run has incomplete evaluations if it's missing successful annotations for any of
878
+ the requested evaluation names. Both missing (no annotation) and failed (error != NULL)
879
+ evaluations are considered incomplete.
880
+
881
+ Args:
882
+ experiment_id: The experiment ID to query
883
+ evaluation_names: Evaluation names to check for completeness
884
+ dialect: SQL dialect (PostgreSQL or SQLite)
885
+ cursor_run_rowid: Optional run ID for cursor-based pagination
886
+ limit: Optional limit (fetches limit+1 for next-page detection)
887
+ include_annotations_and_revisions: If True, also fetch revision and successful
888
+ annotation names as JSON array
889
+
890
+ Returns:
891
+ Query returning (ExperimentRun, revision_id, [revision, annotations_json])
892
+ Results ordered by run ID ascending
893
+ """
894
+ # Subquery: Count successful annotations per run
895
+ successful_annotations_count = (
896
+ select(
897
+ models.ExperimentRunAnnotation.experiment_run_id,
898
+ func.count().label("successful_count"),
899
+ )
900
+ .where(
901
+ models.ExperimentRunAnnotation.name.in_(evaluation_names),
902
+ models.ExperimentRunAnnotation.error.is_(None),
903
+ )
904
+ .group_by(models.ExperimentRunAnnotation.experiment_run_id)
905
+ .subquery()
906
+ )
907
+
908
+ # Base query: Find runs where successful_count < required evaluations
909
+ stmt = (
910
+ select(
911
+ models.ExperimentRun,
912
+ models.ExperimentDatasetExample.dataset_example_revision_id,
913
+ )
914
+ .join(
915
+ models.ExperimentDatasetExample,
916
+ and_(
917
+ models.ExperimentDatasetExample.experiment_id == experiment_id,
918
+ models.ExperimentDatasetExample.dataset_example_id
919
+ == models.ExperimentRun.dataset_example_id,
920
+ ),
921
+ )
922
+ .outerjoin(
923
+ successful_annotations_count,
924
+ successful_annotations_count.c.experiment_run_id == models.ExperimentRun.id,
925
+ )
926
+ .where(
927
+ models.ExperimentRun.experiment_id == experiment_id,
928
+ models.ExperimentRun.error.is_(None), # Only successful task runs
929
+ func.coalesce(successful_annotations_count.c.successful_count, 0)
930
+ < len(evaluation_names),
931
+ )
932
+ )
933
+
934
+ # Optionally include revisions and successful annotation names
935
+ if include_annotations_and_revisions:
936
+ # Subquery: Aggregate successful annotation names as JSON array
937
+ if dialect is SupportedSQLDialect.POSTGRESQL:
938
+ json_agg_expr = func.cast(
939
+ func.coalesce(
940
+ func.json_agg(models.ExperimentRunAnnotation.name),
941
+ literal_column("'[]'::json"),
942
+ ),
943
+ sa.String,
944
+ )
945
+ else: # SQLite
946
+ json_agg_expr = func.cast(
947
+ func.coalesce(
948
+ func.json_group_array(models.ExperimentRunAnnotation.name),
949
+ literal_column("'[]'"),
950
+ ),
951
+ sa.String,
952
+ )
953
+
954
+ successful_annotations_json = (
955
+ select(
956
+ models.ExperimentRunAnnotation.experiment_run_id,
957
+ json_agg_expr.label("annotations_json"),
958
+ )
959
+ .where(
960
+ models.ExperimentRunAnnotation.name.in_(evaluation_names),
961
+ models.ExperimentRunAnnotation.error.is_(None),
962
+ )
963
+ .group_by(models.ExperimentRunAnnotation.experiment_run_id)
964
+ .subquery()
965
+ )
966
+
967
+ stmt = (
968
+ stmt.add_columns(
969
+ models.DatasetExampleRevision,
970
+ successful_annotations_json.c.annotations_json,
971
+ )
972
+ .join(
973
+ models.DatasetExampleRevision,
974
+ models.DatasetExampleRevision.id
975
+ == models.ExperimentDatasetExample.dataset_example_revision_id,
976
+ )
977
+ .outerjoin(
978
+ successful_annotations_json,
979
+ successful_annotations_json.c.experiment_run_id == models.ExperimentRun.id,
980
+ )
981
+ )
982
+
983
+ # Apply ordering, cursor, and limit
984
+ stmt = stmt.order_by(models.ExperimentRun.id.asc())
985
+
986
+ if cursor_run_rowid is not None:
987
+ stmt = stmt.where(models.ExperimentRun.id >= cursor_run_rowid)
988
+
989
+ if limit is not None:
990
+ stmt = stmt.limit(limit + 1)
991
+
992
+ return stmt
993
+
994
+
995
+ def get_experiment_incomplete_runs_query(
996
+ experiment: models.Experiment,
997
+ dialect: SupportedSQLDialect,
998
+ *,
999
+ cursor_example_rowid: Optional[int] = None,
1000
+ limit: Optional[int] = None,
1001
+ ) -> Select[tuple[models.DatasetExampleRevision, Any, Any]]:
1002
+ """
1003
+ High-level helper to build a complete query for incomplete runs in an experiment.
1004
+
1005
+ This is the main entry point for querying incomplete runs. It encapsulates all the
1006
+ logic for finding runs that need to be completed, including both missing runs
1007
+ (not yet attempted) and failed runs (attempted but have errors).
1008
+
1009
+ The function automatically chooses the optimal query strategy:
1010
+ - For repetitions=1: Simple fast path (no CTE needed)
1011
+ - For repetitions>1: Two-path optimization separating completely missing examples
1012
+ from partially complete examples
1013
+
1014
+ Args:
1015
+ experiment: The Experiment model instance to query incomplete runs for
1016
+ dialect: The SQL dialect to use (PostgreSQL or SQLite)
1017
+ cursor_example_rowid: Optional cursor position (dataset_example_id) for pagination.
1018
+ When provided, only returns examples with ID >= cursor_example_rowid
1019
+ limit: Optional maximum number of results to return. If provided, the query
1020
+ will fetch limit+1 rows to enable next-page detection
1021
+
1022
+ Returns:
1023
+ SQLAlchemy SELECT query with columns:
1024
+ - DatasetExampleRevision: The full revision object with example data
1025
+ - successful_count: Count of successful runs for this example (int)
1026
+ - incomplete_reps: Incomplete repetition numbers as:
1027
+ * PostgreSQL: Array of ints (or empty array for completely missing)
1028
+ * SQLite: JSON string array (or '[]' for completely missing)
1029
+
1030
+ Note:
1031
+ For completely missing examples (successful_count=0), the incomplete_reps
1032
+ column will be an empty array/JSON. Callers should generate the full
1033
+ [1..repetitions] list when successful_count=0.
1034
+
1035
+ Example:
1036
+ >>> experiment = session.get(models.Experiment, experiment_id)
1037
+ >>> dialect = SupportedSQLDialect(session.bind.dialect.name)
1038
+ >>> query = get_experiment_incomplete_runs_query(
1039
+ ... experiment, dialect, cursor_example_rowid=100, limit=50
1040
+ ... )
1041
+ >>> results = await session.execute(query)
1042
+ >>> for revision, success_count, incomplete_reps in results:
1043
+ ... # Process incomplete runs...
1044
+ """
1045
+ # Step 1: Get successful run counts for incomplete examples
1046
+ run_counts_subquery = get_successful_run_counts_subquery(experiment.id, experiment.repetitions)
1047
+
1048
+ # Step 2: Build the combined incomplete runs subquery
1049
+ # The strategy depends on whether repetitions=1 or >1
1050
+ if experiment.repetitions == 1:
1051
+ # Fast path optimization for repetitions=1:
1052
+ # All incomplete examples have successful_count=0, so we can skip the expensive CTE
1053
+ empty_array: Any
1054
+ if dialect is SupportedSQLDialect.POSTGRESQL:
1055
+ empty_array = literal_column("ARRAY[]::int[]")
1056
+ elif dialect is SupportedSQLDialect.SQLITE:
1057
+ empty_array = literal_column("'[]'")
1058
+ else:
1059
+ assert_never(dialect)
1060
+
1061
+ combined_incomplete = (
1062
+ select(
1063
+ run_counts_subquery.c.dataset_example_revision_id,
1064
+ run_counts_subquery.c.successful_count,
1065
+ empty_array.label("incomplete_reps"),
1066
+ ).select_from(run_counts_subquery)
1067
+ ).subquery()
1068
+ else:
1069
+ # Two-path optimization for repetitions > 1:
1070
+ # Path 1: Completely missing examples (successful_count = 0) - no CTE needed
1071
+ # Path 2: Partially complete examples (0 < successful_count < R) - use CTE
1072
+
1073
+ # Path 1: Completely missing examples
1074
+ empty_array_inner: Any
1075
+ if dialect is SupportedSQLDialect.POSTGRESQL:
1076
+ empty_array_inner = literal_column("ARRAY[]::int[]")
1077
+ elif dialect is SupportedSQLDialect.SQLITE:
1078
+ empty_array_inner = literal_column("'[]'")
1079
+ else:
1080
+ assert_never(dialect)
1081
+
1082
+ completely_missing_stmt = (
1083
+ select(
1084
+ run_counts_subquery.c.dataset_example_revision_id,
1085
+ run_counts_subquery.c.successful_count,
1086
+ empty_array_inner.label("incomplete_reps"),
1087
+ )
1088
+ .select_from(run_counts_subquery)
1089
+ .where(run_counts_subquery.c.successful_count == 0)
1090
+ )
1091
+
1092
+ # Path 2: Partially complete examples
1093
+ expected_runs_cte = generate_expected_repetitions_cte(
1094
+ dialect, run_counts_subquery, experiment.repetitions
1095
+ )
1096
+ partially_complete_stmt = get_incomplete_repetitions_query(
1097
+ dialect, expected_runs_cte, experiment.id
1098
+ )
1099
+
1100
+ # Combine both paths
1101
+ from sqlalchemy import union_all
1102
+
1103
+ combined_incomplete = union_all(completely_missing_stmt, partially_complete_stmt).subquery()
1104
+
1105
+ # Step 3: Join with revisions and apply pagination
1106
+ return get_incomplete_runs_with_revisions_query(
1107
+ combined_incomplete,
1108
+ cursor_example_rowid=cursor_example_rowid,
1109
+ limit=limit,
1110
+ )