arize-phoenix 3.16.1__py3-none-any.whl → 7.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.0.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.0.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.0.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -241
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +4 -112
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.1.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.1.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -619
  295. phoenix/core/traces.py +0 -96
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.0.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
phoenix/db/models.py ADDED
@@ -0,0 +1,807 @@
1
+ from datetime import datetime, timezone
2
+ from enum import Enum
3
+ from typing import Any, Optional, TypedDict
4
+
5
+ from sqlalchemy import (
6
+ JSON,
7
+ NUMERIC,
8
+ TIMESTAMP,
9
+ CheckConstraint,
10
+ ColumnElement,
11
+ Dialect,
12
+ Float,
13
+ ForeignKey,
14
+ Index,
15
+ MetaData,
16
+ String,
17
+ TypeDecorator,
18
+ UniqueConstraint,
19
+ case,
20
+ func,
21
+ insert,
22
+ not_,
23
+ select,
24
+ text,
25
+ )
26
+ from sqlalchemy.dialects import postgresql
27
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
28
+ from sqlalchemy.ext.compiler import compiles
29
+ from sqlalchemy.ext.hybrid import hybrid_property
30
+ from sqlalchemy.orm import (
31
+ DeclarativeBase,
32
+ Mapped,
33
+ WriteOnlyMapped,
34
+ mapped_column,
35
+ relationship,
36
+ )
37
+ from sqlalchemy.sql import expression
38
+
39
+ from phoenix.config import get_env_database_schema
40
+ from phoenix.datetime_utils import normalize_datetime
41
+
42
+
43
+ class AuthMethod(Enum):
44
+ LOCAL = "LOCAL"
45
+ OAUTH2 = "OAUTH2"
46
+
47
+
48
+ class JSONB(JSON):
49
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
50
+ __visit_name__ = "JSONB"
51
+
52
+
53
+ @compiles(JSONB, "sqlite")
54
+ def _(*args: Any, **kwargs: Any) -> str:
55
+ # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
56
+ return "JSONB"
57
+
58
+
59
+ JSON_ = (
60
+ JSON()
61
+ .with_variant(
62
+ postgresql.JSONB(), # type: ignore
63
+ "postgresql",
64
+ )
65
+ .with_variant(
66
+ JSONB(),
67
+ "sqlite",
68
+ )
69
+ )
70
+
71
+
72
+ class JsonDict(TypeDecorator[dict[str, Any]]):
73
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
74
+ cache_ok = True
75
+ impl = JSON_
76
+
77
+ def process_bind_param(self, value: Optional[dict[str, Any]], _: Dialect) -> dict[str, Any]:
78
+ return value if isinstance(value, dict) else {}
79
+
80
+
81
+ class JsonList(TypeDecorator[list[Any]]):
82
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
83
+ cache_ok = True
84
+ impl = JSON_
85
+
86
+ def process_bind_param(self, value: Optional[list[Any]], _: Dialect) -> list[Any]:
87
+ return value if isinstance(value, list) else []
88
+
89
+
90
+ class UtcTimeStamp(TypeDecorator[datetime]):
91
+ # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
92
+ cache_ok = True
93
+ impl = TIMESTAMP(timezone=True)
94
+
95
+ def process_bind_param(self, value: Optional[datetime], _: Dialect) -> Optional[datetime]:
96
+ return normalize_datetime(value)
97
+
98
+ def process_result_value(self, value: Optional[Any], _: Dialect) -> Optional[datetime]:
99
+ return normalize_datetime(value, timezone.utc)
100
+
101
+
102
+ class ExperimentRunOutput(TypedDict, total=False):
103
+ task_output: Any
104
+
105
+
106
+ class Base(DeclarativeBase):
107
+ # Enforce best practices for naming constraints
108
+ # https://alembic.sqlalchemy.org/en/latest/naming.html#integration-of-naming-conventions-into-operations-autogenerate
109
+ metadata = MetaData(
110
+ schema=get_env_database_schema(),
111
+ naming_convention={
112
+ "ix": "ix_%(table_name)s_%(column_0_N_name)s",
113
+ "uq": "uq_%(table_name)s_%(column_0_N_name)s",
114
+ "ck": "ck_%(table_name)s_`%(constraint_name)s`",
115
+ "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
116
+ "pk": "pk_%(table_name)s",
117
+ },
118
+ )
119
+ type_annotation_map = {
120
+ dict[str, Any]: JsonDict,
121
+ list[dict[str, Any]]: JsonList,
122
+ ExperimentRunOutput: JsonDict,
123
+ }
124
+
125
+
126
+ class Project(Base):
127
+ __tablename__ = "projects"
128
+ id: Mapped[int] = mapped_column(primary_key=True)
129
+ name: Mapped[str]
130
+ description: Mapped[Optional[str]]
131
+ gradient_start_color: Mapped[str] = mapped_column(
132
+ String,
133
+ server_default=text("'#5bdbff'"),
134
+ )
135
+
136
+ gradient_end_color: Mapped[str] = mapped_column(
137
+ String,
138
+ server_default=text("'#1c76fc'"),
139
+ )
140
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
141
+ updated_at: Mapped[datetime] = mapped_column(
142
+ UtcTimeStamp, server_default=func.now(), onupdate=func.now()
143
+ )
144
+
145
+ traces: WriteOnlyMapped[list["Trace"]] = relationship(
146
+ "Trace",
147
+ back_populates="project",
148
+ cascade="all, delete-orphan",
149
+ passive_deletes=True,
150
+ uselist=True,
151
+ )
152
+ __table_args__ = (
153
+ UniqueConstraint(
154
+ "name",
155
+ ),
156
+ )
157
+
158
+
159
+ class ProjectSession(Base):
160
+ __tablename__ = "project_sessions"
161
+ id: Mapped[int] = mapped_column(primary_key=True)
162
+ session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True)
163
+ project_id: Mapped[int] = mapped_column(
164
+ ForeignKey("projects.id", ondelete="CASCADE"),
165
+ nullable=False,
166
+ index=True,
167
+ )
168
+ start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False)
169
+ end_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False)
170
+ traces: Mapped[list["Trace"]] = relationship(
171
+ "Trace",
172
+ back_populates="project_session",
173
+ uselist=True,
174
+ )
175
+
176
+
177
+ class Trace(Base):
178
+ __tablename__ = "traces"
179
+ id: Mapped[int] = mapped_column(primary_key=True)
180
+ project_rowid: Mapped[int] = mapped_column(
181
+ ForeignKey("projects.id", ondelete="CASCADE"),
182
+ nullable=False,
183
+ index=True,
184
+ )
185
+ trace_id: Mapped[str]
186
+ project_session_rowid: Mapped[Optional[int]] = mapped_column(
187
+ ForeignKey("project_sessions.id", ondelete="CASCADE"),
188
+ index=True,
189
+ )
190
+ start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
191
+ end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
192
+
193
+ @hybrid_property
194
+ def latency_ms(self) -> float:
195
+ # See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
196
+ return (self.end_time - self.start_time).total_seconds() * 1000
197
+
198
+ @latency_ms.inplace.expression
199
+ @classmethod
200
+ def _latency_ms_expression(cls) -> ColumnElement[float]:
201
+ # See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
202
+ return LatencyMs(cls.start_time, cls.end_time)
203
+
204
+ project: Mapped["Project"] = relationship(
205
+ "Project",
206
+ back_populates="traces",
207
+ )
208
+ spans: Mapped[list["Span"]] = relationship(
209
+ "Span",
210
+ back_populates="trace",
211
+ cascade="all, delete-orphan",
212
+ uselist=True,
213
+ )
214
+ project_session: Mapped[ProjectSession] = relationship(
215
+ "ProjectSession",
216
+ back_populates="traces",
217
+ )
218
+ experiment_runs: Mapped[list["ExperimentRun"]] = relationship(
219
+ primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
220
+ back_populates="trace",
221
+ )
222
+ __table_args__ = (
223
+ UniqueConstraint(
224
+ "trace_id",
225
+ ),
226
+ )
227
+
228
+
229
+ class Span(Base):
230
+ __tablename__ = "spans"
231
+ id: Mapped[int] = mapped_column(primary_key=True)
232
+ trace_rowid: Mapped[int] = mapped_column(
233
+ ForeignKey("traces.id", ondelete="CASCADE"),
234
+ index=True,
235
+ )
236
+ span_id: Mapped[str] = mapped_column(index=True)
237
+ parent_id: Mapped[Optional[str]] = mapped_column(index=True)
238
+ name: Mapped[str]
239
+ span_kind: Mapped[str]
240
+ start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
241
+ end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
242
+ attributes: Mapped[dict[str, Any]]
243
+ events: Mapped[list[dict[str, Any]]]
244
+ status_code: Mapped[str] = mapped_column(
245
+ CheckConstraint("status_code IN ('OK', 'ERROR', 'UNSET')", name="valid_status")
246
+ )
247
+ status_message: Mapped[str]
248
+
249
+ # TODO(mikeldking): is computed columns possible here
250
+ cumulative_error_count: Mapped[int]
251
+ cumulative_llm_token_count_prompt: Mapped[int]
252
+ cumulative_llm_token_count_completion: Mapped[int]
253
+ llm_token_count_prompt: Mapped[Optional[int]]
254
+ llm_token_count_completion: Mapped[Optional[int]]
255
+
256
+ @hybrid_property
257
+ def latency_ms(self) -> float:
258
+ # See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
259
+ return (self.end_time - self.start_time).total_seconds() * 1000
260
+
261
+ @latency_ms.inplace.expression
262
+ @classmethod
263
+ def _latency_ms_expression(cls) -> ColumnElement[float]:
264
+ # See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
265
+ return LatencyMs(cls.start_time, cls.end_time)
266
+
267
+ @hybrid_property
268
+ def cumulative_llm_token_count_total(self) -> int:
269
+ return self.cumulative_llm_token_count_prompt + self.cumulative_llm_token_count_completion
270
+
271
+ @hybrid_property
272
+ def llm_token_count_total(self) -> Optional[int]:
273
+ if self.llm_token_count_prompt is None and self.llm_token_count_completion is None:
274
+ return None
275
+ return (self.llm_token_count_prompt or 0) + (self.llm_token_count_completion or 0)
276
+
277
+ trace: Mapped["Trace"] = relationship("Trace", back_populates="spans")
278
+ document_annotations: Mapped[list["DocumentAnnotation"]] = relationship(back_populates="span")
279
+ dataset_examples: Mapped[list["DatasetExample"]] = relationship(back_populates="span")
280
+
281
+ __table_args__ = (
282
+ UniqueConstraint(
283
+ "span_id",
284
+ sqlite_on_conflict="IGNORE",
285
+ ),
286
+ Index("ix_latency", text("(end_time - start_time)")),
287
+ Index(
288
+ "ix_cumulative_llm_token_count_total",
289
+ text("(cumulative_llm_token_count_prompt + cumulative_llm_token_count_completion)"),
290
+ ),
291
+ )
292
+
293
+
294
+ class LatencyMs(expression.FunctionElement[float]):
295
+ # See https://docs.sqlalchemy.org/en/20/core/compiler.html
296
+ inherit_cache = True
297
+ type = Float()
298
+ name = "latency_ms"
299
+
300
+
301
+ @compiles(LatencyMs)
302
+ def _(element: Any, compiler: Any, **kw: Any) -> Any:
303
+ # See https://docs.sqlalchemy.org/en/20/core/compiler.html
304
+ start_time, end_time = list(element.clauses)
305
+ return compiler.process(
306
+ func.round(
307
+ func.cast(
308
+ (func.extract("EPOCH", end_time) - func.extract("EPOCH", start_time)) * 1000,
309
+ NUMERIC,
310
+ ),
311
+ 1,
312
+ ),
313
+ **kw,
314
+ )
315
+
316
+
317
+ @compiles(LatencyMs, "sqlite")
318
+ def _(element: Any, compiler: Any, **kw: Any) -> Any:
319
+ # See https://docs.sqlalchemy.org/en/20/core/compiler.html
320
+ start_time, end_time = list(element.clauses)
321
+ return compiler.process(
322
+ # We don't know why sqlite returns a slightly different value.
323
+ # postgresql is correct because it matches the value computed by Python.
324
+ func.round(
325
+ (func.unixepoch(end_time, "subsec") - func.unixepoch(start_time, "subsec")) * 1000, 1
326
+ ),
327
+ **kw,
328
+ )
329
+
330
+
331
+ class TextContains(expression.FunctionElement[str]):
332
+ # See https://docs.sqlalchemy.org/en/20/core/compiler.html
333
+ inherit_cache = True
334
+ type = String()
335
+ name = "text_contains"
336
+
337
+
338
+ @compiles(TextContains)
339
+ def _(element: Any, compiler: Any, **kw: Any) -> Any:
340
+ # See https://docs.sqlalchemy.org/en/20/core/compiler.html
341
+ string, substring = list(element.clauses)
342
+ return compiler.process(string.contains(substring), **kw)
343
+
344
+
345
+ @compiles(TextContains, "postgresql")
346
+ def _(element: Any, compiler: Any, **kw: Any) -> Any:
347
+ # See https://docs.sqlalchemy.org/en/20/core/compiler.html
348
+ string, substring = list(element.clauses)
349
+ return compiler.process(func.strpos(string, substring) > 0, **kw)
350
+
351
+
352
+ @compiles(TextContains, "sqlite")
353
+ def _(element: Any, compiler: Any, **kw: Any) -> Any:
354
+ # See https://docs.sqlalchemy.org/en/20/core/compiler.html
355
+ string, substring = list(element.clauses)
356
+ return compiler.process(func.text_contains(string, substring) > 0, **kw)
357
+
358
+
359
+ async def init_models(engine: AsyncEngine) -> None:
360
+ async with engine.begin() as conn:
361
+ await conn.run_sync(Base.metadata.create_all)
362
+ await conn.execute(
363
+ insert(Project).values(
364
+ name="default",
365
+ description="default project",
366
+ )
367
+ )
368
+
369
+
370
+ class SpanAnnotation(Base):
371
+ __tablename__ = "span_annotations"
372
+ id: Mapped[int] = mapped_column(primary_key=True)
373
+ span_rowid: Mapped[int] = mapped_column(
374
+ ForeignKey("spans.id", ondelete="CASCADE"),
375
+ index=True,
376
+ )
377
+ name: Mapped[str]
378
+ label: Mapped[Optional[str]] = mapped_column(String, index=True)
379
+ score: Mapped[Optional[float]] = mapped_column(Float, index=True)
380
+ explanation: Mapped[Optional[str]]
381
+ metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
382
+ annotator_kind: Mapped[str] = mapped_column(
383
+ CheckConstraint("annotator_kind IN ('LLM', 'HUMAN')", name="valid_annotator_kind"),
384
+ )
385
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
386
+ updated_at: Mapped[datetime] = mapped_column(
387
+ UtcTimeStamp, server_default=func.now(), onupdate=func.now()
388
+ )
389
+ __table_args__ = (
390
+ UniqueConstraint(
391
+ "name",
392
+ "span_rowid",
393
+ ),
394
+ )
395
+
396
+
397
+ class TraceAnnotation(Base):
398
+ __tablename__ = "trace_annotations"
399
+ id: Mapped[int] = mapped_column(primary_key=True)
400
+ trace_rowid: Mapped[int] = mapped_column(
401
+ ForeignKey("traces.id", ondelete="CASCADE"),
402
+ index=True,
403
+ )
404
+ name: Mapped[str]
405
+ label: Mapped[Optional[str]] = mapped_column(String, index=True)
406
+ score: Mapped[Optional[float]] = mapped_column(Float, index=True)
407
+ explanation: Mapped[Optional[str]]
408
+ metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
409
+ annotator_kind: Mapped[str] = mapped_column(
410
+ CheckConstraint("annotator_kind IN ('LLM', 'HUMAN')", name="valid_annotator_kind"),
411
+ )
412
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
413
+ updated_at: Mapped[datetime] = mapped_column(
414
+ UtcTimeStamp, server_default=func.now(), onupdate=func.now()
415
+ )
416
+ __table_args__ = (
417
+ UniqueConstraint(
418
+ "name",
419
+ "trace_rowid",
420
+ ),
421
+ )
422
+
423
+
424
+ class DocumentAnnotation(Base):
425
+ __tablename__ = "document_annotations"
426
+ id: Mapped[int] = mapped_column(primary_key=True)
427
+ span_rowid: Mapped[int] = mapped_column(
428
+ ForeignKey("spans.id", ondelete="CASCADE"),
429
+ index=True,
430
+ )
431
+ document_position: Mapped[int]
432
+ name: Mapped[str]
433
+ label: Mapped[Optional[str]] = mapped_column(String, index=True)
434
+ score: Mapped[Optional[float]] = mapped_column(Float, index=True)
435
+ explanation: Mapped[Optional[str]]
436
+ metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
437
+ annotator_kind: Mapped[str] = mapped_column(
438
+ CheckConstraint("annotator_kind IN ('LLM', 'HUMAN')", name="valid_annotator_kind"),
439
+ )
440
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
441
+ updated_at: Mapped[datetime] = mapped_column(
442
+ UtcTimeStamp, server_default=func.now(), onupdate=func.now()
443
+ )
444
+ span: Mapped["Span"] = relationship(back_populates="document_annotations")
445
+
446
+ __table_args__ = (
447
+ UniqueConstraint(
448
+ "name",
449
+ "span_rowid",
450
+ "document_position",
451
+ ),
452
+ )
453
+
454
+
455
+ class Dataset(Base):
456
+ __tablename__ = "datasets"
457
+ id: Mapped[int] = mapped_column(primary_key=True)
458
+ name: Mapped[str] = mapped_column(unique=True)
459
+ description: Mapped[Optional[str]]
460
+ metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
461
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
462
+ updated_at: Mapped[datetime] = mapped_column(
463
+ UtcTimeStamp, server_default=func.now(), onupdate=func.now()
464
+ )
465
+
466
+ @hybrid_property
467
+ def example_count(self) -> Optional[int]:
468
+ if hasattr(self, "_example_count_value"):
469
+ assert isinstance(self._example_count_value, int)
470
+ return self._example_count_value
471
+ return None
472
+
473
+ @example_count.inplace.expression
474
+ def _example_count(cls) -> ColumnElement[int]:
475
+ return (
476
+ select(
477
+ func.sum(
478
+ case(
479
+ (DatasetExampleRevision.revision_kind == "CREATE", 1),
480
+ (DatasetExampleRevision.revision_kind == "DELETE", -1),
481
+ else_=0,
482
+ )
483
+ )
484
+ )
485
+ .select_from(DatasetExampleRevision)
486
+ .join(
487
+ DatasetExample,
488
+ onclause=DatasetExample.id == DatasetExampleRevision.dataset_example_id,
489
+ )
490
+ .filter(DatasetExample.dataset_id == cls.id)
491
+ .label("example_count")
492
+ )
493
+
494
+ async def load_example_count(self, session: AsyncSession) -> None:
495
+ if not hasattr(self, "_example_count_value"):
496
+ self._example_count_value = await session.scalar(
497
+ select(
498
+ func.sum(
499
+ case(
500
+ (DatasetExampleRevision.revision_kind == "CREATE", 1),
501
+ (DatasetExampleRevision.revision_kind == "DELETE", -1),
502
+ else_=0,
503
+ )
504
+ )
505
+ )
506
+ .select_from(DatasetExampleRevision)
507
+ .join(
508
+ DatasetExample,
509
+ onclause=DatasetExample.id == DatasetExampleRevision.dataset_example_id,
510
+ )
511
+ .filter(DatasetExample.dataset_id == self.id)
512
+ )
513
+
514
+
515
+ class DatasetVersion(Base):
516
+ __tablename__ = "dataset_versions"
517
+ id: Mapped[int] = mapped_column(primary_key=True)
518
+ dataset_id: Mapped[int] = mapped_column(
519
+ ForeignKey("datasets.id", ondelete="CASCADE"),
520
+ index=True,
521
+ )
522
+ description: Mapped[Optional[str]]
523
+ metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
524
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
525
+
526
+
527
+ class DatasetExample(Base):
528
+ __tablename__ = "dataset_examples"
529
+ id: Mapped[int] = mapped_column(primary_key=True)
530
+ dataset_id: Mapped[int] = mapped_column(
531
+ ForeignKey("datasets.id", ondelete="CASCADE"),
532
+ index=True,
533
+ )
534
+ span_rowid: Mapped[Optional[int]] = mapped_column(
535
+ ForeignKey("spans.id", ondelete="SET NULL"),
536
+ index=True,
537
+ nullable=True,
538
+ )
539
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
540
+
541
+ span: Mapped[Optional[Span]] = relationship(back_populates="dataset_examples")
542
+
543
+
544
+ class DatasetExampleRevision(Base):
545
+ __tablename__ = "dataset_example_revisions"
546
+ id: Mapped[int] = mapped_column(primary_key=True)
547
+ dataset_example_id: Mapped[int] = mapped_column(
548
+ ForeignKey("dataset_examples.id", ondelete="CASCADE"),
549
+ index=True,
550
+ )
551
+ dataset_version_id: Mapped[int] = mapped_column(
552
+ ForeignKey("dataset_versions.id", ondelete="CASCADE"),
553
+ index=True,
554
+ )
555
+ input: Mapped[dict[str, Any]]
556
+ output: Mapped[dict[str, Any]]
557
+ metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
558
+ revision_kind: Mapped[str] = mapped_column(
559
+ CheckConstraint(
560
+ "revision_kind IN ('CREATE', 'PATCH', 'DELETE')", name="valid_revision_kind"
561
+ ),
562
+ )
563
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
564
+
565
+ __table_args__ = (
566
+ UniqueConstraint(
567
+ "dataset_example_id",
568
+ "dataset_version_id",
569
+ ),
570
+ )
571
+
572
+
573
+ class Experiment(Base):
574
+ __tablename__ = "experiments"
575
+ id: Mapped[int] = mapped_column(primary_key=True)
576
+ dataset_id: Mapped[int] = mapped_column(
577
+ ForeignKey("datasets.id", ondelete="CASCADE"),
578
+ index=True,
579
+ )
580
+ dataset_version_id: Mapped[int] = mapped_column(
581
+ ForeignKey("dataset_versions.id", ondelete="CASCADE"),
582
+ index=True,
583
+ )
584
+ name: Mapped[str]
585
+ description: Mapped[Optional[str]]
586
+ repetitions: Mapped[int]
587
+ metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
588
+ project_name: Mapped[Optional[str]] = mapped_column(index=True)
589
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
590
+ updated_at: Mapped[datetime] = mapped_column(
591
+ UtcTimeStamp, server_default=func.now(), onupdate=func.now()
592
+ )
593
+
594
+
595
+ class ExperimentRun(Base):
596
+ __tablename__ = "experiment_runs"
597
+ id: Mapped[int] = mapped_column(primary_key=True)
598
+ experiment_id: Mapped[int] = mapped_column(
599
+ ForeignKey("experiments.id", ondelete="CASCADE"),
600
+ index=True,
601
+ )
602
+ dataset_example_id: Mapped[int] = mapped_column(
603
+ ForeignKey("dataset_examples.id", ondelete="CASCADE"),
604
+ index=True,
605
+ )
606
+ repetition_number: Mapped[int]
607
+ trace_id: Mapped[Optional[str]]
608
+ output: Mapped[ExperimentRunOutput]
609
+ start_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
610
+ end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
611
+ prompt_token_count: Mapped[Optional[int]]
612
+ completion_token_count: Mapped[Optional[int]]
613
+ error: Mapped[Optional[str]]
614
+
615
+ @hybrid_property
616
+ def latency_ms(self) -> float:
617
+ # See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
618
+ return (self.end_time - self.start_time).total_seconds() * 1000
619
+
620
+ @latency_ms.inplace.expression
621
+ @classmethod
622
+ def _latency_expression(cls) -> ColumnElement[float]:
623
+ # See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
624
+ return LatencyMs(cls.start_time, cls.end_time)
625
+
626
+ trace: Mapped["Trace"] = relationship(
627
+ primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
628
+ back_populates="experiment_runs",
629
+ )
630
+
631
+ __table_args__ = (
632
+ UniqueConstraint(
633
+ "experiment_id",
634
+ "dataset_example_id",
635
+ "repetition_number",
636
+ ),
637
+ )
638
+
639
+
640
+ class ExperimentRunAnnotation(Base):
641
+ __tablename__ = "experiment_run_annotations"
642
+ id: Mapped[int] = mapped_column(primary_key=True)
643
+ experiment_run_id: Mapped[int] = mapped_column(
644
+ ForeignKey("experiment_runs.id", ondelete="CASCADE"),
645
+ index=True,
646
+ )
647
+ name: Mapped[str]
648
+ annotator_kind: Mapped[str] = mapped_column(
649
+ CheckConstraint("annotator_kind IN ('LLM', 'CODE', 'HUMAN')", name="valid_annotator_kind"),
650
+ )
651
+ label: Mapped[Optional[str]]
652
+ score: Mapped[Optional[float]]
653
+ explanation: Mapped[Optional[str]]
654
+ trace_id: Mapped[Optional[str]]
655
+ error: Mapped[Optional[str]]
656
+ metadata_: Mapped[dict[str, Any]] = mapped_column("metadata")
657
+ start_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
658
+ end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
659
+
660
+ __table_args__ = (
661
+ UniqueConstraint(
662
+ "experiment_run_id",
663
+ "name",
664
+ ),
665
+ )
666
+
667
+
668
+ class UserRole(Base):
669
+ __tablename__ = "user_roles"
670
+ id: Mapped[int] = mapped_column(primary_key=True)
671
+ name: Mapped[str] = mapped_column(unique=True, index=True)
672
+ users: Mapped[list["User"]] = relationship("User", back_populates="role")
673
+
674
+
675
+ class User(Base):
676
+ __tablename__ = "users"
677
+ id: Mapped[int] = mapped_column(primary_key=True)
678
+ user_role_id: Mapped[int] = mapped_column(
679
+ ForeignKey("user_roles.id", ondelete="CASCADE"),
680
+ index=True,
681
+ )
682
+ role: Mapped["UserRole"] = relationship("UserRole", back_populates="users")
683
+ username: Mapped[str] = mapped_column(nullable=False, unique=True, index=True)
684
+ email: Mapped[str] = mapped_column(nullable=False, unique=True, index=True)
685
+ profile_picture_url: Mapped[Optional[str]]
686
+ password_hash: Mapped[Optional[bytes]]
687
+ password_salt: Mapped[Optional[bytes]]
688
+ reset_password: Mapped[bool]
689
+ oauth2_client_id: Mapped[Optional[str]] = mapped_column(index=True, nullable=True)
690
+ oauth2_user_id: Mapped[Optional[str]] = mapped_column(index=True, nullable=True)
691
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
692
+ updated_at: Mapped[datetime] = mapped_column(
693
+ UtcTimeStamp, server_default=func.now(), onupdate=func.now()
694
+ )
695
+ password_reset_token: Mapped[Optional["PasswordResetToken"]] = relationship(
696
+ "PasswordResetToken",
697
+ back_populates="user",
698
+ uselist=False,
699
+ )
700
+ access_tokens: Mapped[list["AccessToken"]] = relationship("AccessToken", back_populates="user")
701
+ refresh_tokens: Mapped[list["RefreshToken"]] = relationship(
702
+ "RefreshToken", back_populates="user"
703
+ )
704
+ api_keys: Mapped[list["ApiKey"]] = relationship("ApiKey", back_populates="user")
705
+
706
+ @hybrid_property
707
+ def auth_method(self) -> Optional[str]:
708
+ if self.password_hash is not None:
709
+ return AuthMethod.LOCAL.value
710
+ elif self.oauth2_client_id is not None:
711
+ return AuthMethod.OAUTH2.value
712
+ return None
713
+
714
+ @auth_method.inplace.expression
715
+ @classmethod
716
+ def _auth_method_expression(cls) -> ColumnElement[Optional[str]]:
717
+ return case(
718
+ (
719
+ not_(cls.password_hash.is_(None)),
720
+ AuthMethod.LOCAL.value,
721
+ ),
722
+ (
723
+ not_(cls.oauth2_client_id.is_(None)),
724
+ AuthMethod.OAUTH2.value,
725
+ ),
726
+ else_=None,
727
+ )
728
+
729
+ __table_args__ = (
730
+ CheckConstraint(
731
+ "(password_hash IS NULL) = (password_salt IS NULL)",
732
+ name="password_hash_and_salt",
733
+ ),
734
+ CheckConstraint(
735
+ "(oauth2_client_id IS NULL) = (oauth2_user_id IS NULL)",
736
+ name="oauth2_client_id_and_user_id",
737
+ ),
738
+ CheckConstraint(
739
+ "(password_hash IS NULL) != (oauth2_client_id IS NULL)",
740
+ name="exactly_one_auth_method",
741
+ ),
742
+ UniqueConstraint(
743
+ "oauth2_client_id",
744
+ "oauth2_user_id",
745
+ ),
746
+ dict(sqlite_autoincrement=True),
747
+ )
748
+
749
+
750
+ class PasswordResetToken(Base):
751
+ __tablename__ = "password_reset_tokens"
752
+ id: Mapped[int] = mapped_column(primary_key=True)
753
+ user_id: Mapped[int] = mapped_column(
754
+ ForeignKey("users.id", ondelete="CASCADE"),
755
+ unique=True,
756
+ index=True,
757
+ )
758
+ user: Mapped["User"] = relationship("User", back_populates="password_reset_token")
759
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
760
+ expires_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp, nullable=False, index=True)
761
+ __table_args__ = (dict(sqlite_autoincrement=True),)
762
+
763
+
764
+ class RefreshToken(Base):
765
+ __tablename__ = "refresh_tokens"
766
+ id: Mapped[int] = mapped_column(primary_key=True)
767
+ user_id: Mapped[int] = mapped_column(
768
+ ForeignKey("users.id", ondelete="CASCADE"),
769
+ index=True,
770
+ )
771
+ user: Mapped["User"] = relationship("User", back_populates="refresh_tokens")
772
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
773
+ expires_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp, nullable=False, index=True)
774
+ __table_args__ = (dict(sqlite_autoincrement=True),)
775
+
776
+
777
+ class AccessToken(Base):
778
+ __tablename__ = "access_tokens"
779
+ id: Mapped[int] = mapped_column(primary_key=True)
780
+ user_id: Mapped[int] = mapped_column(
781
+ ForeignKey("users.id", ondelete="CASCADE"),
782
+ index=True,
783
+ )
784
+ user: Mapped["User"] = relationship("User", back_populates="access_tokens")
785
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
786
+ expires_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp, nullable=False, index=True)
787
+ refresh_token_id: Mapped[int] = mapped_column(
788
+ ForeignKey("refresh_tokens.id", ondelete="CASCADE"),
789
+ index=True,
790
+ unique=True,
791
+ )
792
+ __table_args__ = (dict(sqlite_autoincrement=True),)
793
+
794
+
795
+ class ApiKey(Base):
796
+ __tablename__ = "api_keys"
797
+ id: Mapped[int] = mapped_column(primary_key=True)
798
+ user_id: Mapped[int] = mapped_column(
799
+ ForeignKey("users.id", ondelete="CASCADE"),
800
+ index=True,
801
+ )
802
+ user: Mapped["User"] = relationship("User", back_populates="api_keys")
803
+ name: Mapped[str]
804
+ description: Mapped[Optional[str]]
805
+ created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
806
+ expires_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp, nullable=True, index=True)
807
+ __table_args__ = (dict(sqlite_autoincrement=True),)