arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.1.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.1.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -241
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +4 -112
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.1.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.1.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -619
  295. phoenix/core/traces.py +0 -96
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -0,0 +1,68 @@
1
+ """
2
+ The primary intent of a two-tier system is to make cache invalidation more efficient,
3
+ because the cache keys are typically tuples such as (project_id, time_interval, ...),
4
+ but we need to invalidate subsets of keys, e.g. all those associated with a
5
+ specific project, very frequently (i.e. essentially at each span insertion). In a
6
+ single-tier system we would need to check all the keys to see if they are in the
7
+ subset that we want to invalidate.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ from asyncio import Future
12
+ from collections.abc import Callable
13
+ from typing import Any, Generic, Optional, TypeVar
14
+
15
+ from cachetools import Cache
16
+ from strawberry.dataloader import AbstractCache
17
+
18
+ _Key = TypeVar("_Key")
19
+ _Result = TypeVar("_Result")
20
+
21
+ _Section = TypeVar("_Section")
22
+ _SubKey = TypeVar("_SubKey")
23
+
24
+
25
+ class TwoTierCache(
26
+ AbstractCache[_Key, _Result],
27
+ Generic[_Key, _Result, _Section, _SubKey],
28
+ ABC,
29
+ ):
30
+ def __init__(
31
+ self,
32
+ main_cache: "Cache[_Section, Cache[_SubKey, Future[_Result]]]",
33
+ sub_cache_factory: Callable[[], "Cache[_SubKey, Future[_Result]]"],
34
+ *args: Any,
35
+ **kwargs: Any,
36
+ ) -> None:
37
+ super().__init__(*args, **kwargs)
38
+ self._cache = main_cache
39
+ self._sub_cache_factory = sub_cache_factory
40
+
41
+ @abstractmethod
42
+ def _cache_key(self, key: _Key) -> tuple[_Section, _SubKey]: ...
43
+
44
+ def invalidate(self, section: _Section) -> None:
45
+ if sub_cache := self._cache.get(section):
46
+ sub_cache.clear()
47
+
48
+ def get(self, key: _Key) -> Optional["Future[_Result]"]:
49
+ section, sub_key = self._cache_key(key)
50
+ if not (sub_cache := self._cache.get(section)):
51
+ return None
52
+ return sub_cache.get(sub_key)
53
+
54
+ def set(self, key: _Key, value: "Future[_Result]") -> None:
55
+ section, sub_key = self._cache_key(key)
56
+ if (sub_cache := self._cache.get(section)) is None:
57
+ self._cache[section] = sub_cache = self._sub_cache_factory()
58
+ sub_cache[sub_key] = value
59
+
60
+ def delete(self, key: _Key) -> None:
61
+ section, sub_key = self._cache_key(key)
62
+ if sub_cache := self._cache.get(section):
63
+ del sub_cache[sub_key]
64
+ if not sub_cache:
65
+ del self._cache[section]
66
+
67
+ def clear(self) -> None:
68
+ self._cache.clear()
@@ -0,0 +1,131 @@
1
+ from typing import Optional, Union
2
+
3
+ from sqlalchemy import and_, case, func, null, or_, select
4
+ from sqlalchemy.sql.expression import literal
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.exceptions import NotFound
10
+ from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
11
+ from phoenix.server.types import DbSessionFactory
12
+
13
+ ExampleID: TypeAlias = int
14
+ VersionID: TypeAlias = Optional[int]
15
+ Key: TypeAlias = tuple[ExampleID, Optional[VersionID]]
16
+ Result: TypeAlias = DatasetExampleRevision
17
+
18
+
19
+ class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
20
+ def __init__(self, db: DbSessionFactory) -> None:
21
+ super().__init__(
22
+ load_fn=self._load_fn,
23
+ max_batch_size=200, # needed to prevent the size of the query from getting too large
24
+ )
25
+ self._db = db
26
+
27
+ async def _load_fn(self, keys: list[Key]) -> list[Union[Result, NotFound]]:
28
+ example_and_version_ids = tuple(
29
+ set(
30
+ (example_id, version_id)
31
+ for example_id, version_id in keys
32
+ if version_id is not None
33
+ )
34
+ )
35
+ versionless_example_ids = tuple(
36
+ set(example_id for example_id, version_id in keys if version_id is None)
37
+ )
38
+ resolved_example_and_version_ids = (
39
+ (
40
+ select(
41
+ models.DatasetExample.id.label("example_id"),
42
+ models.DatasetVersion.id.label("version_id"),
43
+ )
44
+ .select_from(models.DatasetExample)
45
+ .join(
46
+ models.DatasetVersion,
47
+ onclause=literal(True), # cross join
48
+ )
49
+ .where(
50
+ or_(
51
+ *(
52
+ and_(
53
+ models.DatasetExample.id == example_id,
54
+ models.DatasetVersion.id == version_id,
55
+ )
56
+ for example_id, version_id in example_and_version_ids
57
+ )
58
+ )
59
+ )
60
+ )
61
+ .union(
62
+ select(
63
+ models.DatasetExample.id.label("example_id"), null().label("version_id")
64
+ ).where(models.DatasetExample.id.in_(versionless_example_ids))
65
+ )
66
+ .subquery()
67
+ )
68
+ revision_ids = (
69
+ select(
70
+ resolved_example_and_version_ids.c.example_id,
71
+ resolved_example_and_version_ids.c.version_id,
72
+ func.max(models.DatasetExampleRevision.id).label("revision_id"),
73
+ )
74
+ .select_from(resolved_example_and_version_ids)
75
+ .join(
76
+ models.DatasetExampleRevision,
77
+ onclause=resolved_example_and_version_ids.c.example_id
78
+ == models.DatasetExampleRevision.dataset_example_id,
79
+ )
80
+ .where(
81
+ or_(
82
+ resolved_example_and_version_ids.c.version_id.is_(None),
83
+ models.DatasetExampleRevision.dataset_version_id
84
+ <= resolved_example_and_version_ids.c.version_id,
85
+ )
86
+ )
87
+ .group_by(
88
+ resolved_example_and_version_ids.c.example_id,
89
+ resolved_example_and_version_ids.c.version_id,
90
+ )
91
+ ).subquery()
92
+ query = (
93
+ select(
94
+ revision_ids.c.example_id,
95
+ revision_ids.c.version_id,
96
+ case(
97
+ (
98
+ or_(
99
+ revision_ids.c.version_id.is_(None),
100
+ models.DatasetVersion.id.is_not(None),
101
+ ),
102
+ True,
103
+ ),
104
+ else_=False,
105
+ ).label("is_valid_version"), # check that non-null versions exist
106
+ models.DatasetExampleRevision,
107
+ )
108
+ .select_from(revision_ids)
109
+ .join(
110
+ models.DatasetExampleRevision,
111
+ onclause=revision_ids.c.revision_id == models.DatasetExampleRevision.id,
112
+ )
113
+ .join(
114
+ models.DatasetVersion,
115
+ onclause=revision_ids.c.version_id == models.DatasetVersion.id,
116
+ isouter=True, # keep rows where the version id is null
117
+ )
118
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
119
+ )
120
+ async with self._db() as session:
121
+ results = {
122
+ (example_id, version_id): DatasetExampleRevision.from_orm_revision(revision)
123
+ async for (
124
+ example_id,
125
+ version_id,
126
+ is_valid_version,
127
+ revision,
128
+ ) in await session.stream(query)
129
+ if is_valid_version
130
+ }
131
+ return [results.get(key, NotFound("Could not find revision.")) for key in keys]
@@ -0,0 +1,38 @@
1
+ from typing import Optional
2
+
3
+ from sqlalchemy import select
4
+ from sqlalchemy.orm import joinedload
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ ExampleID: TypeAlias = int
12
+ Key: TypeAlias = ExampleID
13
+ Result: TypeAlias = Optional[models.Span]
14
+
15
+
16
+ class DatasetExampleSpansDataLoader(DataLoader[Key, Result]):
17
+ def __init__(self, db: DbSessionFactory) -> None:
18
+ super().__init__(load_fn=self._load_fn)
19
+ self._db = db
20
+
21
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
22
+ example_ids = keys
23
+ async with self._db() as session:
24
+ spans = {
25
+ example_id: span
26
+ async for example_id, span in await session.stream(
27
+ select(models.DatasetExample.id, models.Span)
28
+ .select_from(models.DatasetExample)
29
+ .join(models.Span, models.DatasetExample.span_rowid == models.Span.id)
30
+ .where(models.DatasetExample.id.in_(example_ids))
31
+ .options(
32
+ joinedload(models.Span.trace, innerjoin=True).load_only(
33
+ models.Trace.trace_id
34
+ )
35
+ )
36
+ )
37
+ }
38
+ return [spans.get(example_id) for example_id in keys]
@@ -0,0 +1,144 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import Any, Optional
4
+
5
+ import numpy as np
6
+ from aioitertools.itertools import groupby
7
+ from cachetools import LFUCache, TTLCache
8
+ from sqlalchemy import Select, select
9
+ from strawberry.dataloader import AbstractCache, DataLoader
10
+ from typing_extensions import TypeAlias
11
+
12
+ from phoenix.db import models
13
+ from phoenix.db.helpers import SupportedSQLDialect, num_docs_col
14
+ from phoenix.metrics.retrieval_metrics import RetrievalMetrics
15
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
16
+ from phoenix.server.api.input_types.TimeRange import TimeRange
17
+ from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary
18
+ from phoenix.server.types import DbSessionFactory
19
+ from phoenix.trace.dsl import SpanFilter
20
+
21
+ ProjectRowId: TypeAlias = int
22
+ TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
23
+ FilterCondition: TypeAlias = Optional[str]
24
+ EvalName: TypeAlias = str
25
+
26
+ Segment: TypeAlias = tuple[ProjectRowId, TimeInterval, FilterCondition]
27
+ Param: TypeAlias = EvalName
28
+
29
+ Key: TypeAlias = tuple[ProjectRowId, Optional[TimeRange], FilterCondition, EvalName]
30
+ Result: TypeAlias = Optional[DocumentEvaluationSummary]
31
+ ResultPosition: TypeAlias = int
32
+ DEFAULT_VALUE: Result = None
33
+
34
+
35
+ def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
36
+ project_rowid, time_range, filter_condition, eval_name = key
37
+ interval = (
38
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
39
+ )
40
+ return (project_rowid, interval, filter_condition), eval_name
41
+
42
+
43
+ _Section: TypeAlias = tuple[ProjectRowId, EvalName]
44
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition]
45
+
46
+
47
+ class DocumentEvaluationSummaryCache(
48
+ TwoTierCache[Key, Result, _Section, _SubKey],
49
+ ):
50
+ def __init__(self) -> None:
51
+ super().__init__(
52
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
53
+ # interval endpoints are rounded down to the hour by the UI, so anything
54
+ # older than an hour most likely won't be a cache-hit anyway.
55
+ main_cache=TTLCache(maxsize=64 * 32, ttl=3600),
56
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2),
57
+ )
58
+
59
+ def invalidate_project(self, project_rowid: ProjectRowId) -> None:
60
+ for section in self._cache.keys():
61
+ if section[0] == project_rowid:
62
+ del self._cache[section]
63
+
64
+ def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
65
+ (project_rowid, interval, filter_condition), eval_name = _cache_key_fn(key)
66
+ return (project_rowid, eval_name), (interval, filter_condition)
67
+
68
+
69
+ class DocumentEvaluationSummaryDataLoader(DataLoader[Key, Result]):
70
+ def __init__(
71
+ self,
72
+ db: DbSessionFactory,
73
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
74
+ ) -> None:
75
+ super().__init__(
76
+ load_fn=self._load_fn,
77
+ cache_key_fn=_cache_key_fn,
78
+ cache_map=cache_map,
79
+ )
80
+ self._db = db
81
+
82
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
83
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
84
+ arguments: defaultdict[
85
+ Segment,
86
+ defaultdict[Param, list[ResultPosition]],
87
+ ] = defaultdict(lambda: defaultdict(list))
88
+ for position, key in enumerate(keys):
89
+ segment, param = _cache_key_fn(key)
90
+ arguments[segment][param].append(position)
91
+ for segment, params in arguments.items():
92
+ async with self._db() as session:
93
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
94
+ stmt = _get_stmt(dialect, segment, *params.keys())
95
+ data = await session.stream(stmt)
96
+ async for eval_name, group in groupby(data, lambda d: d.name):
97
+ metrics_collection = []
98
+ async for (_, num_docs), subgroup in groupby(
99
+ group, lambda g: (g.id, g.num_docs)
100
+ ):
101
+ scores = [np.nan] * num_docs
102
+ for row in subgroup:
103
+ scores[row.document_position] = row.score
104
+ metrics_collection.append(RetrievalMetrics(scores))
105
+ summary = DocumentEvaluationSummary(
106
+ evaluation_name=eval_name,
107
+ metrics_collection=metrics_collection,
108
+ )
109
+ for position in params[eval_name]:
110
+ results[position] = summary
111
+ return results
112
+
113
+
114
+ def _get_stmt(
115
+ dialect: SupportedSQLDialect,
116
+ segment: Segment,
117
+ *eval_names: Param,
118
+ ) -> Select[Any]:
119
+ project_rowid, (start_time, end_time), filter_condition = segment
120
+ mda = models.DocumentAnnotation
121
+ stmt = (
122
+ select(
123
+ mda.name,
124
+ models.Span.id,
125
+ num_docs_col(dialect),
126
+ mda.score,
127
+ mda.document_position,
128
+ )
129
+ .join(models.Trace)
130
+ .where(models.Trace.project_rowid == project_rowid)
131
+ .join(mda)
132
+ .where(mda.name.in_(eval_names))
133
+ .where(mda.annotator_kind == "LLM")
134
+ .where(mda.score.is_not(None))
135
+ .order_by(mda.name, models.Span.id)
136
+ )
137
+ if start_time:
138
+ stmt = stmt.where(start_time <= models.Span.start_time)
139
+ if end_time:
140
+ stmt = stmt.where(models.Span.start_time < end_time)
141
+ if filter_condition:
142
+ span_filter = SpanFilter(condition=filter_condition)
143
+ stmt = span_filter(stmt)
144
+ return stmt
@@ -0,0 +1,31 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.api.types.Evaluation import DocumentEvaluation
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ Key: TypeAlias = int
12
+ Result: TypeAlias = list[DocumentEvaluation]
13
+
14
+
15
+ class DocumentEvaluationsDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ document_evaluations_by_id: defaultdict[Key, Result] = defaultdict(list)
22
+ mda = models.DocumentAnnotation
23
+ async with self._db() as session:
24
+ data = await session.stream_scalars(
25
+ select(mda).where(mda.span_rowid.in_(keys)).where(mda.annotator_kind == "LLM")
26
+ )
27
+ async for document_evaluation in data:
28
+ document_evaluations_by_id[document_evaluation.span_rowid].append(
29
+ DocumentEvaluation.from_sql_document_annotation(document_evaluation)
30
+ )
31
+ return [document_evaluations_by_id[key] for key in keys]
@@ -0,0 +1,89 @@
1
+ from collections import defaultdict
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ from aioitertools.itertools import groupby
6
+ from sqlalchemy import select
7
+ from strawberry.dataloader import DataLoader
8
+ from typing_extensions import TypeAlias
9
+
10
+ from phoenix.db import models
11
+ from phoenix.metrics.retrieval_metrics import RetrievalMetrics
12
+ from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
13
+ from phoenix.server.types import DbSessionFactory
14
+
15
+ RowId: TypeAlias = int
16
+ NumDocs: TypeAlias = int
17
+ EvalName: TypeAlias = Optional[str]
18
+
19
+ Key: TypeAlias = tuple[RowId, EvalName, NumDocs]
20
+ Result: TypeAlias = list[DocumentRetrievalMetrics]
21
+
22
+
23
+ class DocumentRetrievalMetricsDataLoader(DataLoader[Key, Result]):
24
+ def __init__(self, db: DbSessionFactory) -> None:
25
+ super().__init__(load_fn=self._load_fn)
26
+ self._db = db
27
+
28
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
29
+ mda = models.DocumentAnnotation
30
+ stmt = (
31
+ select(
32
+ mda.span_rowid,
33
+ mda.name,
34
+ mda.score,
35
+ mda.document_position,
36
+ )
37
+ .where(mda.score != None) # noqa: E711
38
+ .where(mda.annotator_kind == "LLM")
39
+ .where(mda.document_position >= 0)
40
+ .order_by(mda.span_rowid, mda.name)
41
+ )
42
+ # Using CTE with VALUES clause is possible in SQLite, but not in
43
+ # SQLAlchemy v2.0.29, hence the workaround below with over-fetching.
44
+ # We could use CTE with VALUES for postgresql, but for now we'll keep
45
+ # it simple and just use one approach for all backends.
46
+ all_row_ids = {row_id for row_id, _, _ in keys}
47
+ stmt = stmt.where(mda.span_rowid.in_(all_row_ids))
48
+ all_eval_names = {eval_name for _, eval_name, _ in keys}
49
+ if None not in all_eval_names:
50
+ stmt = stmt.where(mda.name.in_(all_eval_names))
51
+ max_position = max(num_docs for _, _, num_docs in keys)
52
+ stmt = stmt.where(mda.document_position < max_position)
53
+ results: dict[Key, Result] = {key: [] for key in keys}
54
+ requested_num_docs: defaultdict[tuple[RowId, EvalName], set[NumDocs]] = defaultdict(set)
55
+ for row_id, eval_name, num_docs in results.keys():
56
+ requested_num_docs[(row_id, eval_name)].add(num_docs)
57
+ async with self._db() as session:
58
+ data = await session.stream(stmt)
59
+ async for (span_rowid, name), group in groupby(data, lambda r: (r.span_rowid, r.name)):
60
+ # We need to fulfill two types of potential requests: 1. when it
61
+ # specifies an evaluation name, and 2. when it doesn't care about
62
+ # the evaluation name by specifying None.
63
+ max_requested_num_docs = max(
64
+ (
65
+ num_docs
66
+ for eval_name in (name, None)
67
+ for num_docs in (requested_num_docs.get((span_rowid, eval_name)) or ())
68
+ ),
69
+ default=0,
70
+ )
71
+ if max_requested_num_docs <= 0:
72
+ # We have over-fetched. Skip this group.
73
+ continue
74
+ scores = [np.nan] * max_requested_num_docs
75
+ for row in group:
76
+ # Length check is necessary due to over-fetching.
77
+ if row.document_position < len(scores):
78
+ scores[row.document_position] = row.score
79
+ for eval_name in (name, None):
80
+ for num_docs in requested_num_docs.get((span_rowid, eval_name)) or ():
81
+ metrics = RetrievalMetrics(scores[:num_docs])
82
+ doc_metrics = DocumentRetrievalMetrics(
83
+ evaluation_name=name, metrics=metrics
84
+ )
85
+ key = (span_rowid, eval_name, num_docs)
86
+ results[key].append(doc_metrics)
87
+ # Make sure to copy the result, so we don't return the same list
88
+ # object to two different requesters.
89
+ return [results[key].copy() for key in keys]
@@ -0,0 +1,79 @@
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ from sqlalchemy import func, select
6
+ from strawberry.dataloader import AbstractCache, DataLoader
7
+ from typing_extensions import TypeAlias
8
+
9
+ from phoenix.db import models
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+
13
+ @dataclass
14
+ class ExperimentAnnotationSummary:
15
+ annotation_name: str
16
+ min_score: float
17
+ max_score: float
18
+ mean_score: float
19
+ count: int
20
+ error_count: int
21
+
22
+
23
+ ExperimentID: TypeAlias = int
24
+ Key: TypeAlias = ExperimentID
25
+ Result: TypeAlias = list[ExperimentAnnotationSummary]
26
+
27
+
28
+ class ExperimentAnnotationSummaryDataLoader(DataLoader[Key, Result]):
29
+ def __init__(
30
+ self,
31
+ db: DbSessionFactory,
32
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
33
+ ) -> None:
34
+ super().__init__(load_fn=self._load_fn)
35
+ self._db = db
36
+
37
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
38
+ experiment_ids = keys
39
+ summaries: defaultdict[ExperimentID, Result] = defaultdict(list)
40
+ async with self._db() as session:
41
+ async for (
42
+ experiment_id,
43
+ annotation_name,
44
+ min_score,
45
+ max_score,
46
+ mean_score,
47
+ count,
48
+ error_count,
49
+ ) in await session.stream(
50
+ select(
51
+ models.ExperimentRun.experiment_id,
52
+ models.ExperimentRunAnnotation.name,
53
+ func.min(models.ExperimentRunAnnotation.score),
54
+ func.max(models.ExperimentRunAnnotation.score),
55
+ func.avg(models.ExperimentRunAnnotation.score),
56
+ func.count(),
57
+ func.count(models.ExperimentRunAnnotation.error),
58
+ )
59
+ .join(
60
+ models.ExperimentRun,
61
+ models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
62
+ )
63
+ .where(models.ExperimentRun.experiment_id.in_(experiment_ids))
64
+ .group_by(models.ExperimentRun.experiment_id, models.ExperimentRunAnnotation.name)
65
+ ):
66
+ summaries[experiment_id].append(
67
+ ExperimentAnnotationSummary(
68
+ annotation_name=annotation_name,
69
+ min_score=min_score,
70
+ max_score=max_score,
71
+ mean_score=mean_score,
72
+ count=count,
73
+ error_count=error_count,
74
+ )
75
+ )
76
+ return [
77
+ sorted(summaries[experiment_id], key=lambda summary: summary.annotation_name)
78
+ for experiment_id in keys
79
+ ]
@@ -0,0 +1,58 @@
1
+ from typing import Optional
2
+
3
+ from sqlalchemy import case, func, select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ExperimentID: TypeAlias = int
11
+ ErrorRate: TypeAlias = float
12
+ Key: TypeAlias = ExperimentID
13
+ Result: TypeAlias = Optional[ErrorRate]
14
+
15
+
16
+ class ExperimentErrorRatesDataLoader(DataLoader[Key, Result]):
17
+ def __init__(
18
+ self,
19
+ db: DbSessionFactory,
20
+ ) -> None:
21
+ super().__init__(load_fn=self._load_fn)
22
+ self._db = db
23
+
24
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
25
+ experiment_ids = keys
26
+ resolved_experiment_ids = (
27
+ select(models.Experiment.id)
28
+ .where(models.Experiment.id.in_(set(experiment_ids)))
29
+ .subquery()
30
+ )
31
+ query = (
32
+ select(
33
+ resolved_experiment_ids.c.id,
34
+ case(
35
+ (
36
+ func.count(models.ExperimentRun.id) != 0,
37
+ func.count(models.ExperimentRun.error)
38
+ / func.count(models.ExperimentRun.id),
39
+ ),
40
+ else_=None,
41
+ ),
42
+ )
43
+ .outerjoin_from(
44
+ from_=resolved_experiment_ids,
45
+ target=models.ExperimentRun,
46
+ onclause=resolved_experiment_ids.c.id == models.ExperimentRun.experiment_id,
47
+ )
48
+ .group_by(resolved_experiment_ids.c.id)
49
+ )
50
+ async with self._db() as session:
51
+ error_rates = {
52
+ experiment_id: error_rate
53
+ async for experiment_id, error_rate in await session.stream(query)
54
+ }
55
+ return [
56
+ error_rates.get(experiment_id, ValueError(f"Unknown experiment ID: {experiment_id}"))
57
+ for experiment_id in keys
58
+ ]
@@ -0,0 +1,36 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db.models import ExperimentRunAnnotation as OrmExperimentRunAnnotation
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ExperimentRunID: TypeAlias = int
11
+ Key: TypeAlias = ExperimentRunID
12
+ Result: TypeAlias = list[OrmExperimentRunAnnotation]
13
+
14
+
15
+ class ExperimentRunAnnotations(DataLoader[Key, Result]):
16
+ def __init__(
17
+ self,
18
+ db: DbSessionFactory,
19
+ ) -> None:
20
+ super().__init__(load_fn=self._load_fn)
21
+ self._db = db
22
+
23
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
24
+ run_ids = keys
25
+ annotations: defaultdict[Key, Result] = defaultdict(list)
26
+ async with self._db() as session:
27
+ async for run_id, annotation in await session.stream(
28
+ select(
29
+ OrmExperimentRunAnnotation.experiment_run_id, OrmExperimentRunAnnotation
30
+ ).where(OrmExperimentRunAnnotation.experiment_run_id.in_(run_ids))
31
+ ):
32
+ annotations[run_id].append(annotation)
33
+ return [
34
+ sorted(annotations[run_id], key=lambda annotation: annotation.name, reverse=True)
35
+ for run_id in keys
36
+ ]