arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (338) hide show
  1. arize_phoenix-7.7.1.dist-info/METADATA +261 -0
  2. arize_phoenix-7.7.1.dist-info/RECORD +345 -0
  3. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
  4. arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
  5. phoenix/__init__.py +86 -14
  6. phoenix/auth.py +309 -0
  7. phoenix/config.py +675 -45
  8. phoenix/core/model.py +32 -30
  9. phoenix/core/model_schema.py +102 -109
  10. phoenix/core/model_schema_adapter.py +48 -45
  11. phoenix/datetime_utils.py +24 -3
  12. phoenix/db/README.md +54 -0
  13. phoenix/db/__init__.py +4 -0
  14. phoenix/db/alembic.ini +85 -0
  15. phoenix/db/bulk_inserter.py +294 -0
  16. phoenix/db/engines.py +208 -0
  17. phoenix/db/enums.py +20 -0
  18. phoenix/db/facilitator.py +113 -0
  19. phoenix/db/helpers.py +159 -0
  20. phoenix/db/insertion/constants.py +2 -0
  21. phoenix/db/insertion/dataset.py +227 -0
  22. phoenix/db/insertion/document_annotation.py +171 -0
  23. phoenix/db/insertion/evaluation.py +191 -0
  24. phoenix/db/insertion/helpers.py +98 -0
  25. phoenix/db/insertion/span.py +193 -0
  26. phoenix/db/insertion/span_annotation.py +158 -0
  27. phoenix/db/insertion/trace_annotation.py +158 -0
  28. phoenix/db/insertion/types.py +256 -0
  29. phoenix/db/migrate.py +86 -0
  30. phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
  31. phoenix/db/migrations/env.py +114 -0
  32. phoenix/db/migrations/script.py.mako +26 -0
  33. phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
  34. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
  35. phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
  36. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  37. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
  38. phoenix/db/models.py +807 -0
  39. phoenix/exceptions.py +5 -1
  40. phoenix/experiments/__init__.py +6 -0
  41. phoenix/experiments/evaluators/__init__.py +29 -0
  42. phoenix/experiments/evaluators/base.py +158 -0
  43. phoenix/experiments/evaluators/code_evaluators.py +184 -0
  44. phoenix/experiments/evaluators/llm_evaluators.py +473 -0
  45. phoenix/experiments/evaluators/utils.py +236 -0
  46. phoenix/experiments/functions.py +772 -0
  47. phoenix/experiments/tracing.py +86 -0
  48. phoenix/experiments/types.py +726 -0
  49. phoenix/experiments/utils.py +25 -0
  50. phoenix/inferences/__init__.py +0 -0
  51. phoenix/{datasets → inferences}/errors.py +6 -5
  52. phoenix/{datasets → inferences}/fixtures.py +49 -42
  53. phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
  54. phoenix/{datasets → inferences}/schema.py +11 -11
  55. phoenix/{datasets → inferences}/validation.py +13 -14
  56. phoenix/logging/__init__.py +3 -0
  57. phoenix/logging/_config.py +90 -0
  58. phoenix/logging/_filter.py +6 -0
  59. phoenix/logging/_formatter.py +69 -0
  60. phoenix/metrics/__init__.py +5 -4
  61. phoenix/metrics/binning.py +4 -3
  62. phoenix/metrics/metrics.py +2 -1
  63. phoenix/metrics/mixins.py +7 -6
  64. phoenix/metrics/retrieval_metrics.py +2 -1
  65. phoenix/metrics/timeseries.py +5 -4
  66. phoenix/metrics/wrappers.py +9 -3
  67. phoenix/pointcloud/clustering.py +5 -5
  68. phoenix/pointcloud/pointcloud.py +7 -5
  69. phoenix/pointcloud/projectors.py +5 -6
  70. phoenix/pointcloud/umap_parameters.py +53 -52
  71. phoenix/server/api/README.md +28 -0
  72. phoenix/server/api/auth.py +44 -0
  73. phoenix/server/api/context.py +152 -9
  74. phoenix/server/api/dataloaders/__init__.py +91 -0
  75. phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
  76. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  77. phoenix/server/api/dataloaders/cache/__init__.py +3 -0
  78. phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
  79. phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
  80. phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
  81. phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
  82. phoenix/server/api/dataloaders/document_evaluations.py +31 -0
  83. phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
  84. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
  85. phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
  86. phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
  87. phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
  88. phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
  89. phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
  90. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
  91. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  92. phoenix/server/api/dataloaders/record_counts.py +116 -0
  93. phoenix/server/api/dataloaders/session_io.py +79 -0
  94. phoenix/server/api/dataloaders/session_num_traces.py +30 -0
  95. phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
  96. phoenix/server/api/dataloaders/session_token_usages.py +41 -0
  97. phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
  98. phoenix/server/api/dataloaders/span_annotations.py +26 -0
  99. phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
  100. phoenix/server/api/dataloaders/span_descendants.py +57 -0
  101. phoenix/server/api/dataloaders/span_projects.py +33 -0
  102. phoenix/server/api/dataloaders/token_counts.py +124 -0
  103. phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
  104. phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
  105. phoenix/server/api/dataloaders/user_roles.py +30 -0
  106. phoenix/server/api/dataloaders/users.py +33 -0
  107. phoenix/server/api/exceptions.py +48 -0
  108. phoenix/server/api/helpers/__init__.py +12 -0
  109. phoenix/server/api/helpers/dataset_helpers.py +217 -0
  110. phoenix/server/api/helpers/experiment_run_filters.py +763 -0
  111. phoenix/server/api/helpers/playground_clients.py +948 -0
  112. phoenix/server/api/helpers/playground_registry.py +70 -0
  113. phoenix/server/api/helpers/playground_spans.py +455 -0
  114. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  115. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  116. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  117. phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
  118. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  119. phoenix/server/api/input_types/ClusterInput.py +2 -2
  120. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  121. phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
  122. phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
  123. phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
  124. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  125. phoenix/server/api/input_types/DatasetSort.py +17 -0
  126. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  127. phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
  128. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  129. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  130. phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
  131. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  132. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  133. phoenix/server/api/input_types/Granularity.py +1 -1
  134. phoenix/server/api/input_types/InvocationParameters.py +162 -0
  135. phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
  136. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  137. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  138. phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
  139. phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
  140. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  141. phoenix/server/api/input_types/SpanSort.py +134 -69
  142. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  143. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  144. phoenix/server/api/input_types/UserRoleInput.py +9 -0
  145. phoenix/server/api/mutations/__init__.py +28 -0
  146. phoenix/server/api/mutations/api_key_mutations.py +167 -0
  147. phoenix/server/api/mutations/chat_mutations.py +593 -0
  148. phoenix/server/api/mutations/dataset_mutations.py +591 -0
  149. phoenix/server/api/mutations/experiment_mutations.py +75 -0
  150. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
  151. phoenix/server/api/mutations/project_mutations.py +57 -0
  152. phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
  153. phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
  154. phoenix/server/api/mutations/user_mutations.py +329 -0
  155. phoenix/server/api/openapi/__init__.py +0 -0
  156. phoenix/server/api/openapi/main.py +17 -0
  157. phoenix/server/api/openapi/schema.py +16 -0
  158. phoenix/server/api/queries.py +738 -0
  159. phoenix/server/api/routers/__init__.py +11 -0
  160. phoenix/server/api/routers/auth.py +284 -0
  161. phoenix/server/api/routers/embeddings.py +26 -0
  162. phoenix/server/api/routers/oauth2.py +488 -0
  163. phoenix/server/api/routers/v1/__init__.py +64 -0
  164. phoenix/server/api/routers/v1/datasets.py +1017 -0
  165. phoenix/server/api/routers/v1/evaluations.py +362 -0
  166. phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
  167. phoenix/server/api/routers/v1/experiment_runs.py +167 -0
  168. phoenix/server/api/routers/v1/experiments.py +308 -0
  169. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  170. phoenix/server/api/routers/v1/spans.py +267 -0
  171. phoenix/server/api/routers/v1/traces.py +208 -0
  172. phoenix/server/api/routers/v1/utils.py +95 -0
  173. phoenix/server/api/schema.py +44 -241
  174. phoenix/server/api/subscriptions.py +597 -0
  175. phoenix/server/api/types/Annotation.py +21 -0
  176. phoenix/server/api/types/AnnotationSummary.py +55 -0
  177. phoenix/server/api/types/AnnotatorKind.py +16 -0
  178. phoenix/server/api/types/ApiKey.py +27 -0
  179. phoenix/server/api/types/AuthMethod.py +9 -0
  180. phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
  181. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
  182. phoenix/server/api/types/Cluster.py +25 -24
  183. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  184. phoenix/server/api/types/DataQualityMetric.py +31 -13
  185. phoenix/server/api/types/Dataset.py +288 -63
  186. phoenix/server/api/types/DatasetExample.py +85 -0
  187. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  188. phoenix/server/api/types/DatasetVersion.py +14 -0
  189. phoenix/server/api/types/Dimension.py +32 -31
  190. phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
  191. phoenix/server/api/types/EmbeddingDimension.py +56 -49
  192. phoenix/server/api/types/Evaluation.py +25 -31
  193. phoenix/server/api/types/EvaluationSummary.py +30 -50
  194. phoenix/server/api/types/Event.py +20 -20
  195. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  196. phoenix/server/api/types/Experiment.py +152 -0
  197. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  198. phoenix/server/api/types/ExperimentComparison.py +17 -0
  199. phoenix/server/api/types/ExperimentRun.py +119 -0
  200. phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
  201. phoenix/server/api/types/GenerativeModel.py +9 -0
  202. phoenix/server/api/types/GenerativeProvider.py +85 -0
  203. phoenix/server/api/types/Inferences.py +80 -0
  204. phoenix/server/api/types/InferencesRole.py +23 -0
  205. phoenix/server/api/types/LabelFraction.py +7 -0
  206. phoenix/server/api/types/MimeType.py +2 -2
  207. phoenix/server/api/types/Model.py +54 -54
  208. phoenix/server/api/types/PerformanceMetric.py +8 -5
  209. phoenix/server/api/types/Project.py +407 -142
  210. phoenix/server/api/types/ProjectSession.py +139 -0
  211. phoenix/server/api/types/Segments.py +4 -4
  212. phoenix/server/api/types/Span.py +221 -176
  213. phoenix/server/api/types/SpanAnnotation.py +43 -0
  214. phoenix/server/api/types/SpanIOValue.py +15 -0
  215. phoenix/server/api/types/SystemApiKey.py +9 -0
  216. phoenix/server/api/types/TemplateLanguage.py +10 -0
  217. phoenix/server/api/types/TimeSeries.py +19 -15
  218. phoenix/server/api/types/TokenUsage.py +11 -0
  219. phoenix/server/api/types/Trace.py +154 -0
  220. phoenix/server/api/types/TraceAnnotation.py +45 -0
  221. phoenix/server/api/types/UMAPPoints.py +7 -7
  222. phoenix/server/api/types/User.py +60 -0
  223. phoenix/server/api/types/UserApiKey.py +45 -0
  224. phoenix/server/api/types/UserRole.py +15 -0
  225. phoenix/server/api/types/node.py +4 -112
  226. phoenix/server/api/types/pagination.py +156 -57
  227. phoenix/server/api/utils.py +34 -0
  228. phoenix/server/app.py +864 -115
  229. phoenix/server/bearer_auth.py +163 -0
  230. phoenix/server/dml_event.py +136 -0
  231. phoenix/server/dml_event_handler.py +256 -0
  232. phoenix/server/email/__init__.py +0 -0
  233. phoenix/server/email/sender.py +97 -0
  234. phoenix/server/email/templates/__init__.py +0 -0
  235. phoenix/server/email/templates/password_reset.html +19 -0
  236. phoenix/server/email/types.py +11 -0
  237. phoenix/server/grpc_server.py +102 -0
  238. phoenix/server/jwt_store.py +505 -0
  239. phoenix/server/main.py +305 -116
  240. phoenix/server/oauth2.py +52 -0
  241. phoenix/server/openapi/__init__.py +0 -0
  242. phoenix/server/prometheus.py +111 -0
  243. phoenix/server/rate_limiters.py +188 -0
  244. phoenix/server/static/.vite/manifest.json +87 -0
  245. phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
  246. phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
  247. phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
  248. phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
  249. phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
  250. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
  251. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
  252. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
  253. phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
  254. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
  255. phoenix/server/telemetry.py +68 -0
  256. phoenix/server/templates/index.html +82 -23
  257. phoenix/server/thread_server.py +3 -3
  258. phoenix/server/types.py +275 -0
  259. phoenix/services.py +27 -18
  260. phoenix/session/client.py +743 -68
  261. phoenix/session/data_extractor.py +31 -7
  262. phoenix/session/evaluation.py +3 -9
  263. phoenix/session/session.py +263 -219
  264. phoenix/settings.py +22 -0
  265. phoenix/trace/__init__.py +2 -22
  266. phoenix/trace/attributes.py +338 -0
  267. phoenix/trace/dsl/README.md +116 -0
  268. phoenix/trace/dsl/filter.py +663 -213
  269. phoenix/trace/dsl/helpers.py +73 -21
  270. phoenix/trace/dsl/query.py +574 -201
  271. phoenix/trace/exporter.py +24 -19
  272. phoenix/trace/fixtures.py +368 -32
  273. phoenix/trace/otel.py +71 -219
  274. phoenix/trace/projects.py +3 -2
  275. phoenix/trace/schemas.py +33 -11
  276. phoenix/trace/span_evaluations.py +21 -16
  277. phoenix/trace/span_json_decoder.py +6 -4
  278. phoenix/trace/span_json_encoder.py +2 -2
  279. phoenix/trace/trace_dataset.py +47 -32
  280. phoenix/trace/utils.py +21 -4
  281. phoenix/utilities/__init__.py +0 -26
  282. phoenix/utilities/client.py +132 -0
  283. phoenix/utilities/deprecation.py +31 -0
  284. phoenix/utilities/error_handling.py +3 -2
  285. phoenix/utilities/json.py +109 -0
  286. phoenix/utilities/logging.py +8 -0
  287. phoenix/utilities/project.py +2 -2
  288. phoenix/utilities/re.py +49 -0
  289. phoenix/utilities/span_store.py +0 -23
  290. phoenix/utilities/template_formatters.py +99 -0
  291. phoenix/version.py +1 -1
  292. arize_phoenix-3.16.1.dist-info/METADATA +0 -495
  293. arize_phoenix-3.16.1.dist-info/RECORD +0 -178
  294. phoenix/core/project.py +0 -619
  295. phoenix/core/traces.py +0 -96
  296. phoenix/experimental/evals/__init__.py +0 -73
  297. phoenix/experimental/evals/evaluators.py +0 -413
  298. phoenix/experimental/evals/functions/__init__.py +0 -4
  299. phoenix/experimental/evals/functions/classify.py +0 -453
  300. phoenix/experimental/evals/functions/executor.py +0 -353
  301. phoenix/experimental/evals/functions/generate.py +0 -138
  302. phoenix/experimental/evals/functions/processing.py +0 -76
  303. phoenix/experimental/evals/models/__init__.py +0 -14
  304. phoenix/experimental/evals/models/anthropic.py +0 -175
  305. phoenix/experimental/evals/models/base.py +0 -170
  306. phoenix/experimental/evals/models/bedrock.py +0 -221
  307. phoenix/experimental/evals/models/litellm.py +0 -134
  308. phoenix/experimental/evals/models/openai.py +0 -448
  309. phoenix/experimental/evals/models/rate_limiters.py +0 -246
  310. phoenix/experimental/evals/models/vertex.py +0 -173
  311. phoenix/experimental/evals/models/vertexai.py +0 -186
  312. phoenix/experimental/evals/retrievals.py +0 -96
  313. phoenix/experimental/evals/templates/__init__.py +0 -50
  314. phoenix/experimental/evals/templates/default_templates.py +0 -472
  315. phoenix/experimental/evals/templates/template.py +0 -195
  316. phoenix/experimental/evals/utils/__init__.py +0 -172
  317. phoenix/experimental/evals/utils/threads.py +0 -27
  318. phoenix/server/api/helpers.py +0 -11
  319. phoenix/server/api/routers/evaluation_handler.py +0 -109
  320. phoenix/server/api/routers/span_handler.py +0 -70
  321. phoenix/server/api/routers/trace_handler.py +0 -60
  322. phoenix/server/api/types/DatasetRole.py +0 -23
  323. phoenix/server/static/index.css +0 -6
  324. phoenix/server/static/index.js +0 -7447
  325. phoenix/storage/span_store/__init__.py +0 -23
  326. phoenix/storage/span_store/text_file.py +0 -85
  327. phoenix/trace/dsl/missing.py +0 -60
  328. phoenix/trace/langchain/__init__.py +0 -3
  329. phoenix/trace/langchain/instrumentor.py +0 -35
  330. phoenix/trace/llama_index/__init__.py +0 -3
  331. phoenix/trace/llama_index/callback.py +0 -102
  332. phoenix/trace/openai/__init__.py +0 -3
  333. phoenix/trace/openai/instrumentor.py +0 -30
  334. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
  335. {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
  336. /phoenix/{datasets → db/insertion}/__init__.py +0 -0
  337. /phoenix/{experimental → db/migrations}/__init__.py +0 -0
  338. /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
@@ -0,0 +1,49 @@
1
+ from sqlalchemy import func, select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ ExperimentID: TypeAlias = int
9
+ RunCount: TypeAlias = int
10
+ Key: TypeAlias = ExperimentID
11
+ Result: TypeAlias = RunCount
12
+
13
+
14
+ class ExperimentRunCountsDataLoader(DataLoader[Key, Result]):
15
+ def __init__(
16
+ self,
17
+ db: DbSessionFactory,
18
+ ) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ experiment_ids = keys
24
+ resolved_experiment_ids = (
25
+ select(models.Experiment.id)
26
+ .where(models.Experiment.id.in_(set(experiment_ids)))
27
+ .subquery()
28
+ )
29
+ query = (
30
+ select(
31
+ resolved_experiment_ids.c.id,
32
+ func.count(models.ExperimentRun.experiment_id),
33
+ )
34
+ .outerjoin_from(
35
+ from_=resolved_experiment_ids,
36
+ target=models.ExperimentRun,
37
+ onclause=resolved_experiment_ids.c.id == models.ExperimentRun.experiment_id,
38
+ )
39
+ .group_by(resolved_experiment_ids.c.id)
40
+ )
41
+ async with self._db() as session:
42
+ run_counts = {
43
+ experiment_id: run_count
44
+ async for experiment_id, run_count in await session.stream(query)
45
+ }
46
+ return [
47
+ run_counts.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
48
+ for experiment_id in keys
49
+ ]
@@ -0,0 +1,44 @@
1
+ from typing import Optional
2
+
3
+ from sqlalchemy import distinct, func, select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ExperimentId: TypeAlias = int
11
+ Key: TypeAlias = ExperimentId
12
+ Result: TypeAlias = Optional[int]
13
+
14
+
15
+ class ExperimentSequenceNumberDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ experiment_ids = keys
22
+ dataset_ids = (
23
+ select(distinct(models.Experiment.dataset_id))
24
+ .where(models.Experiment.id.in_(experiment_ids))
25
+ .scalar_subquery()
26
+ )
27
+ row_number = (
28
+ func.row_number().over(
29
+ partition_by=models.Experiment.dataset_id,
30
+ order_by=models.Experiment.id,
31
+ )
32
+ ).label("row_number")
33
+ subq = (
34
+ select(models.Experiment.id, row_number)
35
+ .where(models.Experiment.dataset_id.in_(dataset_ids))
36
+ .subquery()
37
+ )
38
+ stmt = select(subq).where(subq.c.id.in_(experiment_ids))
39
+ async with self._db() as session:
40
+ result = {
41
+ experiment_id: sequence_number
42
+ async for experiment_id, sequence_number in await session.stream(stmt)
43
+ }
44
+ return [result.get(experiment_id) for experiment_id in keys]
@@ -0,0 +1,188 @@
1
+ from collections import defaultdict
2
+ from collections.abc import AsyncIterator, Mapping
3
+ from datetime import datetime
4
+ from typing import Any, Literal, Optional, cast
5
+
6
+ from cachetools import LFUCache, TTLCache
7
+ from sqlalchemy import (
8
+ ARRAY,
9
+ Float,
10
+ Integer,
11
+ Select,
12
+ SQLColumnExpression,
13
+ Values,
14
+ column,
15
+ func,
16
+ select,
17
+ values,
18
+ )
19
+ from sqlalchemy.ext.asyncio import AsyncSession
20
+ from sqlalchemy.sql.functions import percentile_cont
21
+ from strawberry.dataloader import AbstractCache, DataLoader
22
+ from typing_extensions import TypeAlias, assert_never
23
+
24
+ from phoenix.db import models
25
+ from phoenix.db.helpers import SupportedSQLDialect
26
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
27
+ from phoenix.server.api.input_types.TimeRange import TimeRange
28
+ from phoenix.server.types import DbSessionFactory
29
+ from phoenix.trace.dsl import SpanFilter
30
+
31
+ Kind: TypeAlias = Literal["span", "trace"]
32
+ ProjectRowId: TypeAlias = int
33
+ TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
34
+ FilterCondition: TypeAlias = Optional[str]
35
+ Probability: TypeAlias = float
36
+ QuantileValue: TypeAlias = float
37
+
38
+ Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition]
39
+ Param: TypeAlias = tuple[ProjectRowId, Probability]
40
+
41
+ Key: TypeAlias = tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, Probability]
42
+ Result: TypeAlias = Optional[QuantileValue]
43
+ ResultPosition: TypeAlias = int
44
+ DEFAULT_VALUE: Result = None
45
+
46
+ FloatCol: TypeAlias = SQLColumnExpression[Float[float]]
47
+
48
+
49
+ def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
50
+ kind, project_rowid, time_range, filter_condition, probability = key
51
+ interval = (
52
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
53
+ )
54
+ return (kind, interval, filter_condition), (project_rowid, probability)
55
+
56
+
57
+ _Section: TypeAlias = ProjectRowId
58
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind, Probability]
59
+
60
+
61
+ class LatencyMsQuantileCache(
62
+ TwoTierCache[Key, Result, _Section, _SubKey],
63
+ ):
64
+ def __init__(self) -> None:
65
+ super().__init__(
66
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
67
+ # interval endpoints are rounded down to the hour by the UI, so anything
68
+ # older than an hour most likely won't be a cache-hit anyway.
69
+ main_cache=TTLCache(maxsize=64, ttl=3600),
70
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 2 * 16),
71
+ )
72
+
73
+ def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
74
+ (kind, interval, filter_condition), (project_rowid, probability) = _cache_key_fn(key)
75
+ return project_rowid, (interval, filter_condition, kind, probability)
76
+
77
+
78
+ class LatencyMsQuantileDataLoader(DataLoader[Key, Result]):
79
+ def __init__(
80
+ self,
81
+ db: DbSessionFactory,
82
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
83
+ ) -> None:
84
+ super().__init__(
85
+ load_fn=self._load_fn,
86
+ cache_key_fn=_cache_key_fn,
87
+ cache_map=cache_map,
88
+ )
89
+ self._db = db
90
+
91
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
92
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
93
+ arguments: defaultdict[
94
+ Segment,
95
+ defaultdict[Param, list[ResultPosition]],
96
+ ] = defaultdict(lambda: defaultdict(list))
97
+ for position, key in enumerate(keys):
98
+ segment, param = _cache_key_fn(key)
99
+ arguments[segment][param].append(position)
100
+ async with self._db() as session:
101
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
102
+ for segment, params in arguments.items():
103
+ async for position, quantile_value in _get_results(
104
+ dialect, session, segment, params
105
+ ):
106
+ results[position] = quantile_value
107
+ return results
108
+
109
+
110
+ async def _get_results(
111
+ dialect: SupportedSQLDialect,
112
+ session: AsyncSession,
113
+ segment: Segment,
114
+ params: Mapping[Param, list[ResultPosition]],
115
+ ) -> AsyncIterator[tuple[ResultPosition, QuantileValue]]:
116
+ kind, (start_time, end_time), filter_condition = segment
117
+ stmt = select(models.Trace.project_rowid)
118
+ if kind == "trace":
119
+ latency_column = cast(FloatCol, models.Trace.latency_ms)
120
+ time_column = models.Trace.start_time
121
+ elif kind == "span":
122
+ latency_column = cast(FloatCol, models.Span.latency_ms)
123
+ time_column = models.Span.start_time
124
+ stmt = stmt.join(models.Span)
125
+ if filter_condition:
126
+ sf = SpanFilter(filter_condition)
127
+ stmt = sf(stmt)
128
+ else:
129
+ assert_never(kind)
130
+ if start_time:
131
+ stmt = stmt.where(start_time <= time_column)
132
+ if end_time:
133
+ stmt = stmt.where(time_column < end_time)
134
+ if dialect is SupportedSQLDialect.POSTGRESQL:
135
+ results = _get_results_postgresql(session, stmt, latency_column, params)
136
+ elif dialect is SupportedSQLDialect.SQLITE:
137
+ results = _get_results_sqlite(session, stmt, latency_column, params)
138
+ else:
139
+ assert_never(dialect)
140
+ async for position, quantile_value in results:
141
+ yield position, quantile_value
142
+
143
+
144
+ async def _get_results_sqlite(
145
+ session: AsyncSession,
146
+ base_stmt: Select[Any],
147
+ latency_column: FloatCol,
148
+ params: Mapping[Param, list[ResultPosition]],
149
+ ) -> AsyncIterator[tuple[ResultPosition, QuantileValue]]:
150
+ projects_per_prob: defaultdict[Probability, list[ProjectRowId]] = defaultdict(list)
151
+ for project_rowid, probability in params.keys():
152
+ projects_per_prob[probability].append(project_rowid)
153
+ pid = models.Trace.project_rowid
154
+ for probability, project_rowids in projects_per_prob.items():
155
+ pctl: FloatCol = func.percentile(latency_column, probability * 100)
156
+ stmt = base_stmt.add_columns(pctl)
157
+ stmt = stmt.where(pid.in_(project_rowids))
158
+ stmt = stmt.group_by(pid)
159
+ data = await session.stream(stmt)
160
+ async for project_rowid, quantile_value in data:
161
+ for position in params[(project_rowid, probability)]:
162
+ yield position, quantile_value
163
+
164
+
165
+ async def _get_results_postgresql(
166
+ session: AsyncSession,
167
+ base_stmt: Select[Any],
168
+ latency_column: FloatCol,
169
+ params: Mapping[Param, list[ResultPosition]],
170
+ ) -> AsyncIterator[tuple[ResultPosition, QuantileValue]]:
171
+ probs_per_project: defaultdict[ProjectRowId, list[Probability]] = defaultdict(list)
172
+ for project_rowid, probability in params.keys():
173
+ probs_per_project[project_rowid].append(probability)
174
+ pp: Values = values(
175
+ column("project_rowid", Integer),
176
+ column("probabilities", ARRAY(Float[float])),
177
+ name="project_probabilities",
178
+ ).data(probs_per_project.items()) # type: ignore
179
+ pid = models.Trace.project_rowid
180
+ pctl: FloatCol = percentile_cont(pp.c.probabilities).within_group(latency_column)
181
+ stmt = base_stmt.add_columns(pp.c.probabilities, pctl)
182
+ stmt = stmt.join(pp, pid == pp.c.project_rowid)
183
+ stmt = stmt.group_by(pid, pp.c.probabilities)
184
+ data = await session.stream(stmt)
185
+ async for project_rowid, probabilities, quantile_values in data:
186
+ for probability, quantile_value in zip(probabilities, quantile_values):
187
+ for position in params[(project_rowid, probability)]:
188
+ yield position, quantile_value
@@ -0,0 +1,85 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import Literal, Optional
4
+
5
+ from cachetools import LFUCache
6
+ from sqlalchemy import func, select
7
+ from strawberry.dataloader import AbstractCache, DataLoader
8
+ from typing_extensions import TypeAlias, assert_never
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
12
+ from phoenix.server.types import DbSessionFactory
13
+
14
+ Kind: TypeAlias = Literal["start", "end"]
15
+ ProjectRowId: TypeAlias = int
16
+
17
+ Segment: TypeAlias = ProjectRowId
18
+ Param: TypeAlias = Kind
19
+
20
+ Key: TypeAlias = tuple[ProjectRowId, Kind]
21
+ Result: TypeAlias = Optional[datetime]
22
+ ResultPosition: TypeAlias = int
23
+ DEFAULT_VALUE: Result = None
24
+
25
+ _Section = ProjectRowId
26
+ _SubKey = Kind
27
+
28
+
29
+ class MinStartOrMaxEndTimeCache(
30
+ TwoTierCache[Key, Result, _Section, _SubKey],
31
+ ):
32
+ def __init__(self) -> None:
33
+ super().__init__(
34
+ main_cache=LFUCache(maxsize=64),
35
+ sub_cache_factory=lambda: LFUCache(maxsize=2),
36
+ )
37
+
38
+ def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
39
+ return key
40
+
41
+
42
+ class MinStartOrMaxEndTimeDataLoader(DataLoader[Key, Result]):
43
+ def __init__(
44
+ self,
45
+ db: DbSessionFactory,
46
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
47
+ ) -> None:
48
+ super().__init__(
49
+ load_fn=self._load_fn,
50
+ cache_map=cache_map,
51
+ )
52
+ self._db = db
53
+
54
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
55
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
56
+ arguments: defaultdict[
57
+ Segment,
58
+ defaultdict[Param, list[ResultPosition]],
59
+ ] = defaultdict(lambda: defaultdict(list))
60
+ for position, key in enumerate(keys):
61
+ segment, param = key
62
+ arguments[segment][param].append(position)
63
+ pid = models.Trace.project_rowid
64
+ stmt = (
65
+ select(
66
+ pid,
67
+ func.min(models.Trace.start_time).label("min_start"),
68
+ func.max(models.Trace.end_time).label("max_end"),
69
+ )
70
+ .where(pid.in_(arguments.keys()))
71
+ .group_by(pid)
72
+ )
73
+ async with self._db() as session:
74
+ data = await session.stream(stmt)
75
+ async for project_rowid, min_start, max_end in data:
76
+ for kind, positions in arguments[project_rowid].items():
77
+ if kind == "start":
78
+ for position in positions:
79
+ results[position] = min_start
80
+ elif kind == "end":
81
+ for position in positions:
82
+ results[position] = max_end
83
+ else:
84
+ assert_never(kind)
85
+ return results
@@ -0,0 +1,31 @@
1
+ from collections import defaultdict
2
+ from typing import Optional
3
+
4
+ from sqlalchemy import select
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ ProjectName: TypeAlias = str
12
+ Key: TypeAlias = ProjectName
13
+ Result: TypeAlias = Optional[models.Project]
14
+
15
+
16
+ class ProjectByNameDataLoader(DataLoader[Key, Result]):
17
+ def __init__(self, db: DbSessionFactory) -> None:
18
+ super().__init__(load_fn=self._load_fn)
19
+ self._db = db
20
+
21
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
22
+ project_names = list(set(keys))
23
+ projects_by_name: defaultdict[Key, Result] = defaultdict(None)
24
+ async with self._db() as session:
25
+ data = await session.stream_scalars(
26
+ select(models.Project).where(models.Project.name.in_(project_names))
27
+ )
28
+ async for project in data:
29
+ projects_by_name[project.name] = project
30
+
31
+ return [projects_by_name.get(project_name) for project_name in keys]
@@ -0,0 +1,116 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import Any, Literal, Optional
4
+
5
+ from cachetools import LFUCache, TTLCache
6
+ from sqlalchemy import Select, func, select
7
+ from strawberry.dataloader import AbstractCache, DataLoader
8
+ from typing_extensions import TypeAlias, assert_never
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
12
+ from phoenix.server.api.input_types.TimeRange import TimeRange
13
+ from phoenix.server.types import DbSessionFactory
14
+ from phoenix.trace.dsl import SpanFilter
15
+
16
+ Kind: TypeAlias = Literal["span", "trace"]
17
+ ProjectRowId: TypeAlias = int
18
+ TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
19
+ FilterCondition: TypeAlias = Optional[str]
20
+ SpanCount: TypeAlias = int
21
+
22
+ Segment: TypeAlias = tuple[Kind, TimeInterval, FilterCondition]
23
+ Param: TypeAlias = ProjectRowId
24
+
25
+ Key: TypeAlias = tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition]
26
+ Result: TypeAlias = SpanCount
27
+ ResultPosition: TypeAlias = int
28
+ DEFAULT_VALUE: Result = 0
29
+
30
+
31
+ def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
32
+ kind, project_rowid, time_range, filter_condition = key
33
+ interval = (
34
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
35
+ )
36
+ return (kind, interval, filter_condition), project_rowid
37
+
38
+
39
+ _Section: TypeAlias = ProjectRowId
40
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, Kind]
41
+
42
+
43
+ class RecordCountCache(
44
+ TwoTierCache[Key, Result, _Section, _SubKey],
45
+ ):
46
+ def __init__(self) -> None:
47
+ super().__init__(
48
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
49
+ # interval endpoints are rounded down to the hour by the UI, so anything
50
+ # older than an hour most likely won't be a cache-hit anyway.
51
+ main_cache=TTLCache(maxsize=64, ttl=3600),
52
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 2),
53
+ )
54
+
55
+ def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
56
+ (kind, interval, filter_condition), project_rowid = _cache_key_fn(key)
57
+ return project_rowid, (interval, filter_condition, kind)
58
+
59
+
60
+ class RecordCountDataLoader(DataLoader[Key, Result]):
61
+ def __init__(
62
+ self,
63
+ db: DbSessionFactory,
64
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
65
+ ) -> None:
66
+ super().__init__(
67
+ load_fn=self._load_fn,
68
+ cache_key_fn=_cache_key_fn,
69
+ cache_map=cache_map,
70
+ )
71
+ self._db = db
72
+
73
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
74
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
75
+ arguments: defaultdict[
76
+ Segment,
77
+ defaultdict[Param, list[ResultPosition]],
78
+ ] = defaultdict(lambda: defaultdict(list))
79
+ for position, key in enumerate(keys):
80
+ segment, param = _cache_key_fn(key)
81
+ arguments[segment][param].append(position)
82
+ async with self._db() as session:
83
+ for segment, params in arguments.items():
84
+ stmt = _get_stmt(segment, *params.keys())
85
+ data = await session.stream(stmt)
86
+ async for project_rowid, count in data:
87
+ for position in params[project_rowid]:
88
+ results[position] = count
89
+ return results
90
+
91
+
92
+ def _get_stmt(
93
+ segment: Segment,
94
+ *project_rowids: Param,
95
+ ) -> Select[Any]:
96
+ kind, (start_time, end_time), filter_condition = segment
97
+ pid = models.Trace.project_rowid
98
+ stmt = select(pid)
99
+ if kind == "span":
100
+ time_column = models.Span.start_time
101
+ stmt = stmt.join(models.Span)
102
+ if filter_condition:
103
+ sf = SpanFilter(filter_condition)
104
+ stmt = sf(stmt)
105
+ elif kind == "trace":
106
+ time_column = models.Trace.start_time
107
+ else:
108
+ assert_never(kind)
109
+ stmt = stmt.add_columns(func.count().label("count"))
110
+ stmt = stmt.where(pid.in_(project_rowids))
111
+ stmt = stmt.group_by(pid)
112
+ if start_time:
113
+ stmt = stmt.where(start_time <= time_column)
114
+ if end_time:
115
+ stmt = stmt.where(time_column < end_time)
116
+ return stmt
@@ -0,0 +1,79 @@
1
+ from functools import cached_property
2
+ from typing import Literal, Optional, cast
3
+
4
+ from openinference.semconv.trace import SpanAttributes
5
+ from sqlalchemy import Select, func, select
6
+ from strawberry.dataloader import DataLoader
7
+ from typing_extensions import TypeAlias, assert_never
8
+
9
+ from phoenix.db import models
10
+ from phoenix.server.types import DbSessionFactory
11
+ from phoenix.trace.schemas import MimeType, SpanIOValue
12
+
13
+ Key: TypeAlias = int
14
+ Result: TypeAlias = Optional[SpanIOValue]
15
+
16
+ Kind = Literal["first_input", "last_output"]
17
+
18
+
19
+ class SessionIODataLoader(DataLoader[Key, Result]):
20
+ def __init__(self, db: DbSessionFactory, kind: Kind) -> None:
21
+ super().__init__(load_fn=self._load_fn)
22
+ self._db = db
23
+ self._kind = kind
24
+
25
+ @cached_property
26
+ def _subq(self) -> Select[tuple[Optional[int], str, str, int]]:
27
+ stmt = (
28
+ select(models.Trace.project_session_rowid.label("id_"))
29
+ .join_from(models.Span, models.Trace)
30
+ .where(models.Span.parent_id.is_(None))
31
+ )
32
+ if self._kind == "first_input":
33
+ stmt = stmt.add_columns(
34
+ models.Span.attributes[INPUT_VALUE].as_string().label("value"),
35
+ models.Span.attributes[INPUT_MIME_TYPE].as_string().label("mime_type"),
36
+ func.row_number()
37
+ .over(
38
+ partition_by=models.Trace.project_session_rowid,
39
+ order_by=[models.Trace.start_time.asc(), models.Trace.id.asc()],
40
+ )
41
+ .label("rank"),
42
+ )
43
+ elif self._kind == "last_output":
44
+ stmt = stmt.add_columns(
45
+ models.Span.attributes[OUTPUT_VALUE].as_string().label("value"),
46
+ models.Span.attributes[OUTPUT_MIME_TYPE].as_string().label("mime_type"),
47
+ func.row_number()
48
+ .over(
49
+ partition_by=models.Trace.project_session_rowid,
50
+ order_by=[models.Trace.start_time.desc(), models.Trace.id.desc()],
51
+ )
52
+ .label("rank"),
53
+ )
54
+ else:
55
+ assert_never(self._kind)
56
+ return cast(Select[tuple[Optional[int], str, str, int]], stmt)
57
+
58
+ def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]:
59
+ subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery()
60
+ return (
61
+ select(subq.c.id_, subq.c.value, subq.c.mime_type)
62
+ .filter_by(rank=1)
63
+ .where(subq.c.value.isnot(None))
64
+ )
65
+
66
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
67
+ async with self._db() as session:
68
+ result: dict[Key, SpanIOValue] = {
69
+ id_: SpanIOValue(value=value, mime_type=MimeType(mime_type))
70
+ async for id_, value, mime_type in await session.stream(self._stmt(*keys))
71
+ if id_ is not None
72
+ }
73
+ return [result.get(key) for key in keys]
74
+
75
+
76
+ INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
77
+ INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
78
+ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
79
+ OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".")
@@ -0,0 +1,30 @@
1
+ from sqlalchemy import func, select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ Key: TypeAlias = int
9
+ Result: TypeAlias = int
10
+
11
+
12
+ class SessionNumTracesDataLoader(DataLoader[Key, Result]):
13
+ def __init__(self, db: DbSessionFactory) -> None:
14
+ super().__init__(load_fn=self._load_fn)
15
+ self._db = db
16
+
17
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
18
+ stmt = (
19
+ select(
20
+ models.Trace.project_session_rowid.label("id_"),
21
+ func.count(models.Trace.id).label("value"),
22
+ )
23
+ .group_by(models.Trace.project_session_rowid)
24
+ .where(models.Trace.project_session_rowid.in_(keys))
25
+ )
26
+ async with self._db() as session:
27
+ result: dict[Key, int] = {
28
+ id_: value async for id_, value in await session.stream(stmt) if id_ is not None
29
+ }
30
+ return [result.get(key, 0) for key in keys]
@@ -0,0 +1,32 @@
1
+ from sqlalchemy import func, select
2
+ from strawberry.dataloader import DataLoader
3
+ from typing_extensions import TypeAlias
4
+
5
+ from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
+
8
+ Key: TypeAlias = int
9
+ Result: TypeAlias = int
10
+
11
+
12
+ class SessionNumTracesWithErrorDataLoader(DataLoader[Key, Result]):
13
+ def __init__(self, db: DbSessionFactory) -> None:
14
+ super().__init__(load_fn=self._load_fn)
15
+ self._db = db
16
+
17
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
18
+ stmt = (
19
+ select(
20
+ models.Trace.project_session_rowid.label("id_"),
21
+ func.count(models.Trace.id).label("value"),
22
+ )
23
+ .join(models.Span)
24
+ .group_by(models.Trace.project_session_rowid)
25
+ .where(models.Span.cumulative_error_count > 0)
26
+ .where(models.Trace.project_session_rowid.in_(keys))
27
+ )
28
+ async with self._db() as session:
29
+ result: dict[Key, int] = {
30
+ id_: value async for id_, value in await session.stream(stmt) if id_ is not None
31
+ }
32
+ return [result.get(key, 0) for key in keys]