arize-phoenix 3.16.0__py3-none-any.whl → 7.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.0.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.0.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -247
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +13 -107
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.0.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.0.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -617
  295. phoenix/core/traces.py +0 -100
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.0.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -0,0 +1,763 @@
1
+ import ast
2
+ import operator
3
+ from abc import ABC, abstractmethod
4
+ from copy import deepcopy
5
+ from dataclasses import dataclass, field
6
+ from hashlib import sha256
7
+ from typing import Any, Callable, Literal, Optional, Union, get_args
8
+
9
+ from sqlalchemy import (
10
+ BinaryExpression,
11
+ Boolean,
12
+ Float,
13
+ Integer,
14
+ Null,
15
+ Select,
16
+ String,
17
+ and_,
18
+ cast,
19
+ literal,
20
+ or_,
21
+ )
22
+ from sqlalchemy.orm import aliased
23
+ from sqlalchemy.sql import operators as sqlalchemy_operators
24
+ from typing_extensions import TypeAlias, TypeGuard, assert_never
25
+
26
+ from phoenix.db import models
27
+
28
+ SupportedComparisonOperator: TypeAlias = Union[
29
+ ast.Is,
30
+ ast.IsNot,
31
+ ast.In,
32
+ ast.NotIn,
33
+ ast.Eq,
34
+ ast.NotEq,
35
+ ast.Lt,
36
+ ast.LtE,
37
+ ast.Gt,
38
+ ast.GtE,
39
+ ]
40
+ SupportedConstantType: TypeAlias = Union[bool, int, float, str, None]
41
+ SQLAlchemyDataType: TypeAlias = Union[Boolean, Integer, Float[float], String]
42
+ ExperimentID: TypeAlias = int
43
+ SupportedUnaryBooleanOperator: TypeAlias = ast.Not
44
+ SupportedUnaryTermOperator: TypeAlias = ast.USub
45
+ SupportedDatasetExampleAttributeName: TypeAlias = Literal["input", "reference_output", "metadata"]
46
+ SupportedExperimentRunAttributeName: TypeAlias = Literal["output", "error", "latency_ms", "evals"]
47
+ SupportedExperimentRunEvalAttributeName: TypeAlias = Literal["score", "explanation", "label"]
48
+ EvalName: TypeAlias = str
49
+
50
+
51
+ def update_examples_query_with_filter_condition(
52
+ query: Select[Any], filter_condition: str, experiment_ids: list[int]
53
+ ) -> Select[Any]:
54
+ orm_filter_condition, transformer = compile_sqlalchemy_filter_condition(
55
+ filter_condition=filter_condition, experiment_ids=experiment_ids
56
+ )
57
+ for experiment_id in experiment_ids:
58
+ experiment_runs = transformer.get_experiment_runs_alias(experiment_id)
59
+ if experiment_runs is None:
60
+ continue
61
+ query = query.join(
62
+ experiment_runs,
63
+ onclause=and_(
64
+ experiment_runs.dataset_example_id == models.DatasetExample.id,
65
+ experiment_runs.experiment_id == experiment_id,
66
+ ),
67
+ isouter=True,
68
+ )
69
+ experiment_run_annotations_aliases = transformer.get_experiment_run_annotations_aliases(
70
+ experiment_id
71
+ )
72
+ for eval_name, experiment_run_annotations in experiment_run_annotations_aliases.items():
73
+ query = query.join(
74
+ experiment_run_annotations,
75
+ onclause=(
76
+ and_(
77
+ experiment_run_annotations.experiment_run_id == experiment_runs.id,
78
+ experiment_run_annotations.name == eval_name,
79
+ )
80
+ ),
81
+ isouter=True,
82
+ )
83
+ query = query.where(orm_filter_condition)
84
+ return query
85
+
86
+
87
+ def compile_sqlalchemy_filter_condition(
88
+ filter_condition: str, experiment_ids: list[int]
89
+ ) -> tuple[Any, "SQLAlchemyTransformer"]:
90
+ try:
91
+ original_tree = ast.parse(filter_condition, mode="eval")
92
+ except SyntaxError as error:
93
+ raise ExperimentRunFilterConditionSyntaxError(str(error))
94
+
95
+ trees_with_bound_attribute_names = _bind_free_attribute_names(original_tree, experiment_ids)
96
+ has_free_attribute_names = bool(trees_with_bound_attribute_names)
97
+ if has_free_attribute_names:
98
+ # compile the filter condition once for each experiment and return the disjunction
99
+ sqlalchemy_transformer = SQLAlchemyTransformer(experiment_ids=experiment_ids)
100
+ compiled_filter_conditions: dict[ExperimentID, BinaryExpression[Any]] = {}
101
+ for experiment_id, tree in trees_with_bound_attribute_names.items():
102
+ sqlalchemy_tree = sqlalchemy_transformer.visit(tree)
103
+ node = sqlalchemy_tree.body
104
+ if not isinstance(node, BooleanExpression):
105
+ raise ExperimentRunFilterConditionSyntaxError(
106
+ "Filter condition must be a boolean expression"
107
+ )
108
+ compiled_filter_conditions[experiment_id] = node.compile()
109
+ return or_(*compiled_filter_conditions.values()), sqlalchemy_transformer
110
+
111
+ # compile the filter condition once for all experiments
112
+ sqlalchemy_transformer = SQLAlchemyTransformer(experiment_ids)
113
+ sqlalchemy_tree = sqlalchemy_transformer.visit(original_tree)
114
+ node = sqlalchemy_tree.body
115
+ if not isinstance(node, BooleanExpression):
116
+ raise ExperimentRunFilterConditionSyntaxError(
117
+ "Filter condition must be a boolean expression"
118
+ )
119
+ compiled_filter_condition = node.compile()
120
+ return compiled_filter_condition, sqlalchemy_transformer
121
+
122
+
123
+ def _bind_free_attribute_names(
124
+ tree: ast.AST, experiment_ids: list[ExperimentID]
125
+ ) -> dict[ExperimentID, ast.AST]:
126
+ trees_with_bound_attribute_names: dict[ExperimentID, ast.AST] = {}
127
+ for experiment_index, experiment_id in enumerate(experiment_ids):
128
+ binder = FreeAttributeNameBinder(experiment_index=experiment_index)
129
+ trees_with_bound_attribute_names[experiment_id] = binder.visit(deepcopy(tree))
130
+ has_free_attribute_names = binder.binds_free_attribute_name
131
+ if not has_free_attribute_names:
132
+ return {} # return early since there are no free attribute names
133
+ return trees_with_bound_attribute_names
134
+
135
+
136
+ class FreeAttributeNameBinder(ast.NodeTransformer):
137
+ def __init__(self, *, experiment_index: int) -> None:
138
+ super().__init__()
139
+ self._experiment_index = experiment_index
140
+ self._binds_free_attribute_name = False
141
+
142
+ def visit_Name(self, node: ast.Name) -> Any:
143
+ name = node.id
144
+ if _is_supported_experiment_run_attribute_name(name):
145
+ self._binds_free_attribute_name = True
146
+ return ast.Attribute(
147
+ value=ast.Subscript(
148
+ value=ast.Name(id="experiments", ctx=ast.Load()),
149
+ slice=ast.Constant(value=self._experiment_index),
150
+ ctx=ast.Load(),
151
+ ),
152
+ attr=name,
153
+ ctx=node.ctx,
154
+ )
155
+ return node
156
+
157
+ @property
158
+ def binds_free_attribute_name(self) -> bool:
159
+ return self._binds_free_attribute_name
160
+
161
+
162
+ class ExperimentRunFilterConditionSyntaxError(Exception):
163
+ pass
164
+
165
+
166
+ @dataclass(frozen=True)
167
+ class ExperimentRunFilterConditionNode(ABC):
168
+ """
169
+ A node in a tree representing a SQLAlchemy expression.
170
+ """
171
+
172
+ ast_node: ast.AST
173
+
174
+ @abstractmethod
175
+ def compile(self) -> Any:
176
+ """
177
+ Compiles the node into a SQLAlchemy expression.
178
+ """
179
+ raise NotImplementedError
180
+
181
+
182
+ @dataclass(frozen=True)
183
+ class Term(ExperimentRunFilterConditionNode):
184
+ @property
185
+ def data_type(self) -> Optional[SQLAlchemyDataType]:
186
+ return None
187
+
188
+
189
+ @dataclass(frozen=True)
190
+ class Constant(Term):
191
+ value: SupportedConstantType
192
+
193
+ def compile(self) -> Any:
194
+ value = self.value
195
+ if value is None:
196
+ return Null()
197
+ return literal(value)
198
+
199
+ @property
200
+ def data_type(self) -> Optional[SQLAlchemyDataType]:
201
+ value = self.value
202
+ if isinstance(value, bool):
203
+ return Boolean()
204
+ elif isinstance(value, int):
205
+ return Integer()
206
+ elif isinstance(value, float):
207
+ return Float()
208
+ elif isinstance(value, str):
209
+ return String()
210
+ elif value is None:
211
+ return None
212
+ assert_never(value)
213
+
214
+
215
+ class ExperimentsName(ExperimentRunFilterConditionNode):
216
+ def compile(self) -> Any:
217
+ raise ExperimentRunFilterConditionSyntaxError("Select an experiment with [<index>]")
218
+
219
+
220
+ @dataclass(frozen=True)
221
+ class ExperimentRun(ExperimentRunFilterConditionNode):
222
+ slice: Constant
223
+ experiment_ids: list[int]
224
+ experiment_id: int = field(init=False)
225
+
226
+ def __post_init__(self) -> None:
227
+ experiment_index = self.slice.value
228
+ if not isinstance(experiment_index, int):
229
+ raise ExperimentRunFilterConditionSyntaxError("Index to experiments must be an integer")
230
+ if not (0 <= experiment_index < len(self.experiment_ids)):
231
+ raise ExperimentRunFilterConditionSyntaxError("Select an experiment with [<index>]")
232
+ object.__setattr__(self, "experiment_id", self.experiment_ids[experiment_index])
233
+
234
+ def compile(self) -> Any:
235
+ raise ExperimentRunFilterConditionSyntaxError("Add an attribute")
236
+
237
+
238
+ @dataclass(frozen=True)
239
+ class Attribute(Term):
240
+ pass
241
+
242
+
243
+ @dataclass(frozen=True)
244
+ class HasAliasedTables:
245
+ transformer: "SQLAlchemyTransformer"
246
+
247
+ def experiment_run_alias(self, experiment_id: ExperimentID) -> Any:
248
+ return self.transformer.get_experiment_runs_alias(
249
+ experiment_id
250
+ ) or self.transformer.create_experiment_runs_alias(experiment_id)
251
+
252
+ def experiment_run_annotation_alias(
253
+ self, experiment_id: ExperimentID, eval_name: EvalName
254
+ ) -> Any:
255
+ return self.transformer.get_experiment_run_annotations_alias(
256
+ experiment_id, eval_name
257
+ ) or self.transformer.create_experiment_run_annotations_alias(experiment_id, eval_name)
258
+
259
+
260
+ @dataclass(frozen=True)
261
+ class DatasetExampleAttribute(HasAliasedTables, Attribute):
262
+ attribute_name: str
263
+ _attribute_name: SupportedDatasetExampleAttributeName = field(init=False)
264
+
265
+ def __post_init__(self) -> None:
266
+ if not _is_supported_dataset_example_attribute(self.attribute_name):
267
+ raise ExperimentRunFilterConditionSyntaxError("Unknown name")
268
+ object.__setattr__(self, "_attribute_name", self.attribute_name)
269
+
270
+ def compile(self) -> Any:
271
+ attribute_name = self._attribute_name
272
+ if attribute_name == "input":
273
+ return models.DatasetExampleRevision.input
274
+ elif attribute_name == "reference_output":
275
+ return models.DatasetExampleRevision.output
276
+ elif attribute_name == "metadata":
277
+ return models.DatasetExampleRevision.metadata_
278
+ assert_never(attribute_name)
279
+
280
+
281
+ @dataclass(frozen=True)
282
+ class ExperimentRunAttribute(HasAliasedTables, Attribute):
283
+ attribute_name: str
284
+ experiment_id: int
285
+ _attribute_name: SupportedExperimentRunAttributeName = field(init=False)
286
+
287
+ def __post_init__(self) -> None:
288
+ if not _is_supported_experiment_run_attribute_name(self.attribute_name):
289
+ raise ExperimentRunFilterConditionSyntaxError("Unknown name")
290
+ object.__setattr__(self, "_attribute_name", self.attribute_name)
291
+
292
+ def compile(self) -> Any:
293
+ attribute_name = self._attribute_name
294
+ experiment_id = self.experiment_id
295
+ if attribute_name == "evals":
296
+ raise ExperimentRunFilterConditionSyntaxError("Select an eval with [<eval-name>]")
297
+ elif attribute_name == "output":
298
+ aliased_experiment_run = self.experiment_run_alias(experiment_id)
299
+ return aliased_experiment_run.output["task_output"]
300
+ elif attribute_name == "error":
301
+ aliased_experiment_run = self.experiment_run_alias(experiment_id)
302
+ return aliased_experiment_run.error
303
+ elif attribute_name == "latency_ms":
304
+ aliased_experiment_run = self.experiment_run_alias(experiment_id)
305
+ return aliased_experiment_run.latency_ms
306
+ assert_never(attribute_name)
307
+
308
+ @property
309
+ def is_eval_attribute(self) -> bool:
310
+ return self.attribute_name == "evals"
311
+
312
+ @property
313
+ def is_json_attribute(self) -> bool:
314
+ return self.attribute_name in ("input", "reference_output", "output")
315
+
316
+ @property
317
+ def data_type(self) -> Optional[SQLAlchemyDataType]:
318
+ attribute_name = self._attribute_name
319
+ if attribute_name == "evals":
320
+ return None
321
+ elif attribute_name == "output":
322
+ return None
323
+ elif attribute_name == "error":
324
+ return String()
325
+ elif attribute_name == "latency_ms":
326
+ return Float()
327
+ assert_never(attribute_name)
328
+
329
+
330
+ @dataclass(frozen=True)
331
+ class JSONAttribute(Attribute):
332
+ attribute: Attribute
333
+ index_constant: Constant
334
+ _index_value: Union[int, str] = field(init=False)
335
+
336
+ def __post_init__(self) -> None:
337
+ index_value = self.index_constant.value
338
+ if not isinstance(index_value, (int, str)):
339
+ raise ExperimentRunFilterConditionSyntaxError("Index must be an integer or string")
340
+ object.__setattr__(self, "_index_value", index_value)
341
+
342
+ def compile(self) -> Any:
343
+ compiled_attribute = self.attribute.compile()
344
+ return compiled_attribute[self._index_value]
345
+
346
+
347
+ @dataclass(frozen=True)
348
+ class ExperimentRunEval(ExperimentRunFilterConditionNode):
349
+ experiment_run_attribute: ExperimentRunAttribute
350
+ eval_name: str
351
+ experiment_id: int = field(init=False)
352
+
353
+ def __post_init__(self) -> None:
354
+ if not isinstance(self.eval_name, str):
355
+ raise ExperimentRunFilterConditionSyntaxError("Eval must be indexed by string")
356
+ object.__setattr__(self, "experiment_id", self.experiment_run_attribute.experiment_id)
357
+
358
+ def compile(self) -> Any:
359
+ raise ExperimentRunFilterConditionSyntaxError(
360
+ "Choose an attribute for your eval (label, score, etc.)"
361
+ )
362
+
363
+
364
+ @dataclass(frozen=True)
365
+ class ExperimentRunEvalAttribute(HasAliasedTables, Attribute):
366
+ experiment_run_eval: ExperimentRunEval
367
+ attribute_name: str
368
+ experiment_id: int = field(init=False)
369
+ _attribute_name: SupportedExperimentRunEvalAttributeName = field(init=False)
370
+ _eval_name: str = field(init=False)
371
+
372
+ def __post_init__(self) -> None:
373
+ if not _is_supported_experiment_run_eval_attribute_name(self.attribute_name):
374
+ raise ExperimentRunFilterConditionSyntaxError("Unknown eval attribute")
375
+ object.__setattr__(self, "experiment_id", self.experiment_run_eval.experiment_id)
376
+ object.__setattr__(self, "_attribute_name", self.attribute_name)
377
+ object.__setattr__(self, "_eval_name", self.experiment_run_eval.eval_name)
378
+
379
+ def compile(self) -> Any:
380
+ experiment_id = self.experiment_id
381
+ eval_name = self._eval_name
382
+ attribute_name = self._attribute_name
383
+ experiment_run_annotations = self.experiment_run_annotation_alias(experiment_id, eval_name)
384
+ return getattr(experiment_run_annotations, attribute_name)
385
+
386
+ @property
387
+ def data_type(self) -> Optional[SQLAlchemyDataType]:
388
+ attribute_name = self._attribute_name
389
+ if attribute_name == "label":
390
+ return String()
391
+ elif attribute_name == "score":
392
+ return Float()
393
+ elif attribute_name == "explanation":
394
+ return String()
395
+ assert_never(attribute_name)
396
+
397
+
398
+ @dataclass(frozen=True)
399
+ class UnaryTermOperation(Term):
400
+ operand: Term
401
+ operator: SupportedUnaryTermOperator
402
+
403
+ def compile(self) -> Any:
404
+ operator = self.operator
405
+ operand = self.operand
406
+ sqlalchemy_operator: Callable[[Any], Any]
407
+ if isinstance(operator, ast.USub):
408
+ sqlalchemy_operator = sqlalchemy_operators.neg
409
+ else:
410
+ assert_never(operator)
411
+ compiled_operand = operand.compile()
412
+ return sqlalchemy_operator(compiled_operand)
413
+
414
+
415
+ @dataclass(frozen=True)
416
+ class BooleanExpression(ExperimentRunFilterConditionNode):
417
+ pass
418
+
419
+
420
+ @dataclass(frozen=True)
421
+ class ComparisonOperation(BooleanExpression):
422
+ left_operand: Term
423
+ right_operand: Term
424
+ operator: ast.cmpop
425
+ _operator: SupportedComparisonOperator = field(init=False)
426
+
427
+ def __post_init__(self) -> None:
428
+ operator = self.operator
429
+ if not _is_supported_comparison_operator(operator):
430
+ raise ExperimentRunFilterConditionSyntaxError("Unsupported comparison operator")
431
+ object.__setattr__(self, "_operator", operator)
432
+
433
+ def compile(self) -> Any:
434
+ left_operand = self.left_operand
435
+ right_operand = self.right_operand
436
+ operator = self._operator
437
+ compiled_left_operand = left_operand.compile()
438
+ compiled_right_operand = right_operand.compile()
439
+ cast_type = _get_cast_type_for_comparison(
440
+ operator=operator,
441
+ left_operand=left_operand,
442
+ right_operand=right_operand,
443
+ )
444
+ if cast_type is not None:
445
+ if left_operand.data_type is None:
446
+ compiled_left_operand = cast(compiled_left_operand, cast_type)
447
+ if right_operand.data_type is None:
448
+ compiled_right_operand = cast(compiled_right_operand, cast_type)
449
+ sqlalchemy_operator = _get_sqlalchemy_comparison_operator(operator)
450
+ return sqlalchemy_operator(compiled_left_operand, compiled_right_operand)
451
+
452
+
453
+ @dataclass(frozen=True)
454
+ class UnaryBooleanOperation(BooleanExpression):
455
+ operand: ExperimentRunFilterConditionNode
456
+ operator: SupportedUnaryBooleanOperator
457
+
458
+ def __post_init__(self) -> None:
459
+ if not isinstance(self.operand, BooleanExpression):
460
+ raise ExperimentRunFilterConditionSyntaxError("Operand must be a boolean expression")
461
+
462
+ def compile(self) -> Any:
463
+ operator = self.operator
464
+ sqlalchemy_operator: Callable[[Any], Any]
465
+ if isinstance(operator, ast.Not):
466
+ sqlalchemy_operator = sqlalchemy_operators.inv
467
+ else:
468
+ assert_never(operator)
469
+ compiled_operand = self.operand.compile()
470
+ return sqlalchemy_operator(compiled_operand)
471
+
472
+
473
+ @dataclass(frozen=True)
474
+ class BooleanOperation(BooleanExpression):
475
+ operator: ast.boolop
476
+ operands: list[BooleanExpression]
477
+
478
+ def __post_init__(self) -> None:
479
+ if len(self.operands) < 2:
480
+ raise ExperimentRunFilterConditionSyntaxError(
481
+ "Boolean operators require at least two operands"
482
+ )
483
+
484
+ def compile(self) -> Any:
485
+ ast_operator = self.operator
486
+ operands = [operand.compile() for operand in self.operands]
487
+ if isinstance(ast_operator, ast.And):
488
+ return and_(*operands)
489
+ elif isinstance(ast_operator, ast.Or):
490
+ return or_(*operands)
491
+ raise ExperimentRunFilterConditionSyntaxError("Unsupported boolean operator")
492
+
493
+
494
+ class SQLAlchemyTransformer(ast.NodeTransformer):
495
+ def __init__(self, experiment_ids: list[int]) -> None:
496
+ if not experiment_ids:
497
+ raise ValueError("Must provide one or more experiments")
498
+ self._experiment_ids = experiment_ids
499
+ self._aliased_experiment_runs: dict[ExperimentID, Any] = {}
500
+ self._aliased_experiment_run_annotations: dict[ExperimentID, dict[EvalName, Any]] = {}
501
+
502
+ def visit_Constant(self, node: ast.Constant) -> Constant:
503
+ return Constant(value=node.value, ast_node=node)
504
+
505
+ def visit_Name(self, node: ast.Name) -> ExperimentRunFilterConditionNode:
506
+ name = node.id
507
+ if name == "experiments":
508
+ return ExperimentsName(ast_node=node)
509
+ elif _is_supported_dataset_example_attribute(name):
510
+ return DatasetExampleAttribute(
511
+ attribute_name=name,
512
+ transformer=self,
513
+ ast_node=node,
514
+ )
515
+ raise ExperimentRunFilterConditionSyntaxError("Unknown name")
516
+
517
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> Union[UnaryBooleanOperation, UnaryTermOperation]:
518
+ operator = node.op
519
+ operand = self.visit(node.operand)
520
+ if _is_supported_unary_boolean_operator(operator):
521
+ return UnaryBooleanOperation(operand=operand, operator=operator, ast_node=node)
522
+ if _is_supported_unary_term_operator(operator):
523
+ return UnaryTermOperation(operand=operand, operator=operator, ast_node=node)
524
+ raise ExperimentRunFilterConditionSyntaxError("Unsupported unary operator")
525
+
526
+ def visit_BoolOp(self, node: ast.BoolOp) -> BooleanOperation:
527
+ operator = node.op
528
+ operands = [self.visit(value) for value in node.values]
529
+ return BooleanOperation(operator=operator, operands=operands, ast_node=node)
530
+
531
+ def visit_Compare(self, node: ast.Compare) -> ExperimentRunFilterConditionNode:
532
+ if not (len(node.ops) == 1 and len(node.comparators) == 1):
533
+ raise ExperimentRunFilterConditionSyntaxError("Only binary comparisons are supported")
534
+ left_operand = self.visit(node.left)
535
+ right_operand = self.visit(node.comparators[0])
536
+ operator = node.ops[0]
537
+ return ComparisonOperation(
538
+ left_operand=left_operand,
539
+ right_operand=right_operand,
540
+ operator=operator,
541
+ ast_node=node,
542
+ )
543
+
544
+ def visit_Subscript(self, node: ast.Subscript) -> ExperimentRunFilterConditionNode:
545
+ container = self.visit(node.value)
546
+ key = self.visit(node.slice)
547
+ if isinstance(container, ExperimentsName):
548
+ if not isinstance(key, Constant):
549
+ raise ExperimentRunFilterConditionSyntaxError("Index must be a constant")
550
+ return ExperimentRun(
551
+ slice=key,
552
+ experiment_ids=self._experiment_ids,
553
+ ast_node=node,
554
+ )
555
+ if isinstance(container, ExperimentRunAttribute):
556
+ if container.is_eval_attribute:
557
+ return ExperimentRunEval(
558
+ experiment_run_attribute=container,
559
+ eval_name=key.value,
560
+ ast_node=node,
561
+ )
562
+ if isinstance(container, (JSONAttribute, DatasetExampleAttribute)) or (
563
+ isinstance(container, ExperimentRunAttribute) and container.is_json_attribute
564
+ ):
565
+ return JSONAttribute(
566
+ attribute=container,
567
+ index_constant=key,
568
+ ast_node=node,
569
+ )
570
+ raise ExperimentRunFilterConditionSyntaxError("Invalid subscript")
571
+
572
+ def visit_Attribute(self, node: ast.Attribute) -> ExperimentRunFilterConditionNode:
573
+ parent = self.visit(node.value)
574
+ attribute_name = node.attr
575
+ if isinstance(parent, ExperimentRun):
576
+ if _is_supported_experiment_run_attribute_name(attribute_name):
577
+ return ExperimentRunAttribute(
578
+ attribute_name=attribute_name,
579
+ experiment_id=parent.experiment_id,
580
+ transformer=self,
581
+ ast_node=node,
582
+ )
583
+ elif _is_supported_dataset_example_attribute(attribute_name):
584
+ return DatasetExampleAttribute(
585
+ attribute_name=attribute_name,
586
+ transformer=self,
587
+ ast_node=node,
588
+ )
589
+ raise ExperimentRunFilterConditionSyntaxError("Unknown attribute")
590
+ if isinstance(parent, ExperimentRunEval):
591
+ return ExperimentRunEvalAttribute(
592
+ attribute_name=attribute_name,
593
+ experiment_run_eval=parent,
594
+ transformer=self,
595
+ ast_node=node,
596
+ )
597
+ raise ExperimentRunFilterConditionSyntaxError("Unknown attribute")
598
+
599
+ def create_experiment_runs_alias(self, experiment_id: ExperimentID) -> Any:
600
+ if self.get_experiment_runs_alias(experiment_id) is not None:
601
+ raise ValueError(f"Alias already exists for experiment ID: {experiment_id}")
602
+ experiment_index = self.get_experiment_index(experiment_id)
603
+ alias_name = f"experiment_runs_{experiment_index}"
604
+ aliased_table = aliased(models.ExperimentRun, name=alias_name)
605
+ self._aliased_experiment_runs[experiment_id] = aliased_table
606
+ return aliased_table
607
+
608
+ def get_experiment_runs_alias(self, experiment_id: ExperimentID) -> Any:
609
+ return self._aliased_experiment_runs.get(experiment_id)
610
+
611
+ def create_experiment_run_annotations_alias(
612
+ self, experiment_id: ExperimentID, eval_name: EvalName
613
+ ) -> Any:
614
+ if self.get_experiment_run_annotations_alias(experiment_id, eval_name) is not None:
615
+ raise ValueError(
616
+ f"Alias exists for experiment ID and eval name: {(experiment_id, eval_name)}"
617
+ )
618
+ self._ensure_experiment_runs_alias_exists(
619
+ experiment_id
620
+ ) # experiment_runs are needed so we have something to join experiment_run_annotations to
621
+ experiment_index = self.get_experiment_index(experiment_id)
622
+ eval_name_hash = sha256(eval_name.encode()).hexdigest()[:9]
623
+ alias_name = ( # postgres truncates identifiers at 63 chars, so cap the length
624
+ f"experiment_run_annotations_{experiment_index}_{eval_name_hash}"
625
+ )
626
+ aliased_table = aliased(models.ExperimentRunAnnotation, name=alias_name)
627
+ if experiment_id not in self._aliased_experiment_run_annotations:
628
+ self._aliased_experiment_run_annotations[experiment_id] = {}
629
+ self._aliased_experiment_run_annotations[experiment_id][eval_name] = aliased_table
630
+ return aliased_table
631
+
632
+ def get_experiment_run_annotations_alias(
633
+ self, experiment_id: ExperimentID, eval_name: EvalName
634
+ ) -> Any:
635
+ return self._aliased_experiment_run_annotations.get(experiment_id, {}).get(eval_name)
636
+
637
+ def get_experiment_run_annotations_aliases(
638
+ self, experiment_id: ExperimentID
639
+ ) -> dict[EvalName, Any]:
640
+ return self._aliased_experiment_run_annotations.get(experiment_id, {})
641
+
642
+ def get_experiment_index(self, experiment_id: ExperimentID) -> int:
643
+ return self._experiment_ids.index(experiment_id)
644
+
645
+ def _ensure_experiment_runs_alias_exists(self, experiment_id: ExperimentID) -> None:
646
+ if self.get_experiment_runs_alias(experiment_id) is None:
647
+ self.create_experiment_runs_alias(experiment_id)
648
+
649
+
650
+ def _get_sqlalchemy_comparison_operator(
651
+ ast_operator: SupportedComparisonOperator,
652
+ ) -> Callable[[Any, Any], Any]:
653
+ if isinstance(ast_operator, ast.Eq):
654
+ return operator.eq
655
+ elif isinstance(ast_operator, ast.NotEq):
656
+ return operator.ne
657
+ elif isinstance(ast_operator, ast.Lt):
658
+ return sqlalchemy_operators.lt
659
+ elif isinstance(ast_operator, ast.LtE):
660
+ return sqlalchemy_operators.le
661
+ elif isinstance(ast_operator, ast.Gt):
662
+ return sqlalchemy_operators.gt
663
+ elif isinstance(ast_operator, ast.GtE):
664
+ return sqlalchemy_operators.ge
665
+ elif isinstance(ast_operator, ast.Is):
666
+ return sqlalchemy_operators.is_
667
+ elif isinstance(ast_operator, ast.IsNot):
668
+ return sqlalchemy_operators.is_not
669
+ elif isinstance(ast_operator, ast.In):
670
+ return lambda left, right: models.TextContains(right, left)
671
+ elif isinstance(ast_operator, ast.NotIn):
672
+ return lambda left, right: ~models.TextContains(right, left)
673
+ assert_never(ast_operator)
674
+
675
+
676
+ def _get_cast_type_for_comparison(
677
+ *,
678
+ operator: SupportedComparisonOperator,
679
+ left_operand: Term,
680
+ right_operand: Term,
681
+ ) -> Optional[SQLAlchemyDataType]:
682
+ """
683
+ Some column types (e.g., JSON columns) require an explicit cast before
684
+ comparing with non-null values. We don't know the true type of the value in
685
+ the JSON column, so we use heuristics to cast to a reasonable type given the
686
+ operator and operands. There are three cases:
687
+
688
+ 1. Both operands have known types.
689
+ 2. One operand has a known type and the other does not.
690
+ 3. Neither operand has a known type, e.g., both are JSON attributes.
691
+
692
+ In the first case, a cast is not needed. In the second case, we cast the
693
+ operand with the unknown type to the type of the operand being compared. In
694
+ the third case, we cast both operands to the same type using heuristics
695
+ based on the operator.
696
+ """
697
+
698
+ left_operand_data_type = left_operand.data_type
699
+ right_operand_data_type = right_operand.data_type
700
+ if left_operand_data_type is not None and right_operand_data_type is not None:
701
+ return None # Both operands have known data types, so no cast is needed.
702
+
703
+ if isinstance(operator, (ast.Gt, ast.GtE, ast.Lt, ast.LtE)):
704
+ # These operations should always cast to float, even if a comparison is
705
+ # being made to an integer.
706
+ return Float()
707
+
708
+ if isinstance(operator, (ast.In, ast.NotIn)):
709
+ # These operations are performed on strings.
710
+ return String()
711
+
712
+ # If one operand is None, don't cast.
713
+ left_operand_is_null = isinstance(left_operand, Constant) and left_operand.value is None
714
+ right_operand_is_null = isinstance(right_operand, Constant) and right_operand.value is None
715
+ if left_operand_is_null or right_operand_is_null:
716
+ return None
717
+
718
+ # If one operand has a known type and the other does not, cast to the known type.
719
+ if left_operand_data_type is not None and right_operand_data_type is None:
720
+ return left_operand_data_type
721
+ elif left_operand_data_type is None and right_operand_data_type is not None:
722
+ return right_operand_data_type
723
+
724
+ # If neither operand has a known type, we infer a cast type from the comparison operator.
725
+ if isinstance(operator, (ast.Eq, ast.NotEq, ast.Is, ast.IsNot)):
726
+ return String()
727
+ assert_never(operator)
728
+
729
+
730
+ def _is_supported_comparison_operator(
731
+ operator: ast.cmpop,
732
+ ) -> TypeGuard[SupportedComparisonOperator]:
733
+ return isinstance(operator, get_args(SupportedComparisonOperator))
734
+
735
+
736
+ def _is_supported_dataset_example_attribute(
737
+ name: str,
738
+ ) -> TypeGuard[SupportedDatasetExampleAttributeName]:
739
+ return name in get_args(SupportedDatasetExampleAttributeName)
740
+
741
+
742
+ def _is_supported_experiment_run_attribute_name(
743
+ name: str,
744
+ ) -> TypeGuard[SupportedExperimentRunAttributeName]:
745
+ return name in get_args(SupportedExperimentRunAttributeName)
746
+
747
+
748
+ def _is_supported_experiment_run_eval_attribute_name(
749
+ name: str,
750
+ ) -> TypeGuard[SupportedExperimentRunEvalAttributeName]:
751
+ return name in get_args(SupportedExperimentRunEvalAttributeName)
752
+
753
+
754
+ def _is_supported_unary_boolean_operator(
755
+ operator: ast.unaryop,
756
+ ) -> TypeGuard[SupportedUnaryBooleanOperator]:
757
+ return isinstance(operator, SupportedUnaryBooleanOperator)
758
+
759
+
760
+ def _is_supported_unary_term_operator(
761
+ operator: ast.unaryop,
762
+ ) -> TypeGuard[SupportedUnaryTermOperator]:
763
+ return isinstance(operator, SupportedUnaryTermOperator)