arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,56 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import (
10
+ CostBreakdown,
11
+ SpanCostDetailSummaryEntry,
12
+ )
13
+ from phoenix.server.types import DbSessionFactory
14
+
15
+ TraceRowId: TypeAlias = int
16
+ Key: TypeAlias = TraceRowId
17
+ Result: TypeAlias = list[SpanCostDetailSummaryEntry]
18
+
19
+
20
+ class SpanCostDetailSummaryEntriesByTraceDataLoader(DataLoader[Key, Result]):
21
+ def __init__(self, db: DbSessionFactory) -> None:
22
+ super().__init__(load_fn=self._load_fn)
23
+ self._db = db
24
+
25
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
26
+ pk = models.SpanCost.trace_rowid
27
+ stmt = (
28
+ select(
29
+ pk,
30
+ models.SpanCostDetail.token_type,
31
+ models.SpanCostDetail.is_prompt,
32
+ coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
33
+ coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
34
+ )
35
+ .select_from(models.SpanCostDetail)
36
+ .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
37
+ .where(pk.in_(keys))
38
+ .group_by(pk, models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
39
+ )
40
+ results: defaultdict[Key, Result] = defaultdict(list)
41
+ async with self._db() as session:
42
+ data = await session.stream(stmt)
43
+ async for (
44
+ id_,
45
+ token_type,
46
+ is_prompt,
47
+ cost,
48
+ tokens,
49
+ ) in data:
50
+ entry = SpanCostDetailSummaryEntry(
51
+ token_type=token_type,
52
+ is_prompt=is_prompt,
53
+ value=CostBreakdown(tokens=tokens, cost=cost),
54
+ )
55
+ results[id_].append(entry)
56
+ return list(map(list, map(results.__getitem__, keys)))
@@ -0,0 +1,27 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ SpanCostId: TypeAlias = int
11
+ Key: TypeAlias = SpanCostId
12
+ Result: TypeAlias = list[models.SpanCostDetail]
13
+
14
+
15
+ class SpanCostDetailsBySpanCostDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ entity = models.SpanCostDetail
22
+ stmt = select(entity).where(entity.span_cost_id.in_(keys))
23
+ result: defaultdict[Key, Result] = defaultdict(list)
24
+ async with self._db() as session:
25
+ async for obj in await session.stream_scalars(stmt):
26
+ result[obj.span_cost_id].append(obj)
27
+ return list(map(result.__getitem__, keys))
@@ -0,0 +1,57 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ ExperimentId: TypeAlias = int
12
+ Key: TypeAlias = ExperimentId
13
+ Result: TypeAlias = SpanCostSummary
14
+
15
+
16
+ class SpanCostSummaryByExperimentDataLoader(DataLoader[Key, Result]):
17
+ def __init__(self, db: DbSessionFactory) -> None:
18
+ super().__init__(load_fn=self._load_fn)
19
+ self._db = db
20
+
21
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
22
+ stmt = (
23
+ select(
24
+ models.ExperimentRun.experiment_id,
25
+ func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
26
+ func.sum(models.SpanCost.completion_cost).label("completion_cost"),
27
+ func.sum(models.SpanCost.total_cost).label("total_cost"),
28
+ func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
29
+ func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
30
+ func.sum(models.SpanCost.total_tokens).label("total_tokens"),
31
+ )
32
+ .select_from(models.ExperimentRun)
33
+ .join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
34
+ .join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
35
+ .where(models.ExperimentRun.experiment_id.in_(keys))
36
+ .group_by(models.ExperimentRun.experiment_id)
37
+ )
38
+
39
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
40
+ async with self._db() as session:
41
+ data = await session.stream(stmt)
42
+ async for (
43
+ experiment_id,
44
+ prompt_cost,
45
+ completion_cost,
46
+ total_cost,
47
+ prompt_tokens,
48
+ completion_tokens,
49
+ total_tokens,
50
+ ) in data:
51
+ summary = SpanCostSummary(
52
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
53
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
54
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
55
+ )
56
+ results[experiment_id] = summary
57
+ return list(map(results.__getitem__, keys))
@@ -0,0 +1,64 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select, tuple_
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
9
+ from phoenix.server.types import DbSessionFactory
10
+
11
+ ExperimentId: TypeAlias = int
12
+ DatasetExampleId: TypeAlias = int
13
+ Key: TypeAlias = tuple[ExperimentId, DatasetExampleId]
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByExperimentRepeatedRunGroupDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ stmt = (
24
+ select(
25
+ models.ExperimentRun.experiment_id,
26
+ models.ExperimentRun.dataset_example_id,
27
+ func.sum(models.SpanCost.prompt_cost).label("prompt_cost"),
28
+ func.sum(models.SpanCost.completion_cost).label("completion_cost"),
29
+ func.sum(models.SpanCost.total_cost).label("total_cost"),
30
+ func.sum(models.SpanCost.prompt_tokens).label("prompt_tokens"),
31
+ func.sum(models.SpanCost.completion_tokens).label("completion_tokens"),
32
+ func.sum(models.SpanCost.total_tokens).label("total_tokens"),
33
+ )
34
+ .select_from(models.ExperimentRun)
35
+ .join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
36
+ .join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
37
+ .where(
38
+ tuple_(
39
+ models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id
40
+ ).in_(set(keys))
41
+ )
42
+ .group_by(models.ExperimentRun.experiment_id, models.ExperimentRun.dataset_example_id)
43
+ )
44
+
45
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
46
+ async with self._db() as session:
47
+ data = await session.stream(stmt)
48
+ async for (
49
+ experiment_id,
50
+ dataset_example_id,
51
+ prompt_cost,
52
+ completion_cost,
53
+ total_cost,
54
+ prompt_tokens,
55
+ completion_tokens,
56
+ total_tokens,
57
+ ) in data:
58
+ summary = SpanCostSummary(
59
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
60
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
61
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
62
+ )
63
+ results[(experiment_id, dataset_example_id)] = summary
64
+ return [results.get(key, SpanCostSummary()) for key in keys]
@@ -0,0 +1,58 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+ ExperimentRunId: TypeAlias = int
13
+ Key: TypeAlias = ExperimentRunId
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByExperimentRunDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ stmt = (
24
+ select(
25
+ models.ExperimentRun.id,
26
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
27
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
28
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
29
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
30
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
31
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
32
+ )
33
+ .select_from(models.ExperimentRun)
34
+ .join(models.Trace, models.ExperimentRun.trace_id == models.Trace.trace_id)
35
+ .join(models.SpanCost, models.SpanCost.trace_rowid == models.Trace.id)
36
+ .where(models.ExperimentRun.id.in_(keys))
37
+ .group_by(models.ExperimentRun.id)
38
+ )
39
+
40
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
41
+ async with self._db() as session:
42
+ data = await session.stream(stmt)
43
+ async for (
44
+ run_id,
45
+ prompt_cost,
46
+ completion_cost,
47
+ total_cost,
48
+ prompt_tokens,
49
+ completion_tokens,
50
+ total_tokens,
51
+ ) in data:
52
+ summary = SpanCostSummary(
53
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
54
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
55
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
56
+ )
57
+ results[run_id] = summary
58
+ return list(map(results.__getitem__, keys))
@@ -0,0 +1,55 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+ GenerativeModelId: TypeAlias = int
13
+ Key: TypeAlias = GenerativeModelId
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByGenerativeModelDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ pk = models.SpanCost.model_id
24
+ stmt = (
25
+ select(
26
+ pk,
27
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
28
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
29
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
30
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
31
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
32
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
33
+ )
34
+ .where(pk.in_(keys))
35
+ .group_by(pk)
36
+ )
37
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
38
+ async with self._db() as session:
39
+ data = await session.stream(stmt)
40
+ async for (
41
+ id_,
42
+ prompt_cost,
43
+ completion_cost,
44
+ total_cost,
45
+ prompt_tokens,
46
+ completion_tokens,
47
+ total_tokens,
48
+ ) in data:
49
+ summary = SpanCostSummary(
50
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
51
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
52
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
53
+ )
54
+ results[id_] = summary
55
+ return list(map(results.__getitem__, keys))
@@ -0,0 +1,152 @@
1
+ from collections import defaultdict
2
+ from datetime import datetime
3
+ from typing import Any, Optional
4
+
5
+ from cachetools import LFUCache, TTLCache
6
+ from sqlalchemy import Select, func, select
7
+ from sqlalchemy.sql.functions import coalesce
8
+ from strawberry.dataloader import AbstractCache, DataLoader
9
+ from typing_extensions import TypeAlias
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.api.dataloaders.cache import TwoTierCache
13
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
14
+ from phoenix.server.api.input_types.TimeRange import TimeRange
15
+ from phoenix.server.session_filters import get_filtered_session_rowids_subquery
16
+ from phoenix.server.types import DbSessionFactory
17
+ from phoenix.trace.dsl import SpanFilter
18
+
19
+ ProjectRowId: TypeAlias = int
20
+ TimeInterval: TypeAlias = tuple[Optional[datetime], Optional[datetime]]
21
+ FilterCondition: TypeAlias = Optional[str]
22
+ SessionFilterCondition: TypeAlias = Optional[str]
23
+
24
+ Segment: TypeAlias = tuple[
25
+ TimeInterval,
26
+ FilterCondition,
27
+ SessionFilterCondition,
28
+ ]
29
+ Param: TypeAlias = ProjectRowId
30
+
31
+ Key: TypeAlias = tuple[ProjectRowId, Optional[TimeRange], FilterCondition, SessionFilterCondition]
32
+ Result: TypeAlias = SpanCostSummary
33
+ ResultPosition: TypeAlias = int
34
+ DEFAULT_VALUE: Result = SpanCostSummary()
35
+
36
+
37
+ def _cache_key_fn(key: Key) -> tuple[Segment, Param]:
38
+ project_rowid, time_range, filter_condition, session_filter_condition = key
39
+ interval = (
40
+ (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
41
+ )
42
+ return (interval, filter_condition, session_filter_condition), project_rowid
43
+
44
+
45
+ _Section: TypeAlias = ProjectRowId
46
+ _SubKey: TypeAlias = tuple[TimeInterval, FilterCondition, SessionFilterCondition]
47
+
48
+
49
+ class SpanCostSummaryCache(
50
+ TwoTierCache[Key, Result, _Section, _SubKey],
51
+ ):
52
+ def __init__(self) -> None:
53
+ super().__init__(
54
+ # TTL=3600 (1-hour) because time intervals are always moving forward, but
55
+ # interval endpoints are rounded down to the hour by the UI, so anything
56
+ # older than an hour most likely won't be a cache-hit anyway.
57
+ main_cache=TTLCache(maxsize=64, ttl=3600),
58
+ sub_cache_factory=lambda: LFUCache(maxsize=2 * 2 * 3),
59
+ )
60
+
61
+ def _cache_key(self, key: Key) -> tuple[_Section, _SubKey]:
62
+ (interval, filter_condition, session_filter_condition), project_rowid = _cache_key_fn(key)
63
+ return project_rowid, (interval, filter_condition, session_filter_condition)
64
+
65
+
66
+ class SpanCostSummaryByProjectDataLoader(DataLoader[Key, Result]):
67
+ def __init__(
68
+ self,
69
+ db: DbSessionFactory,
70
+ cache_map: Optional[AbstractCache[Key, Result]] = None,
71
+ ) -> None:
72
+ super().__init__(
73
+ load_fn=self._load_fn,
74
+ cache_key_fn=_cache_key_fn,
75
+ cache_map=cache_map,
76
+ )
77
+ self._db = db
78
+
79
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
80
+ results: list[Result] = [DEFAULT_VALUE] * len(keys)
81
+ arguments: defaultdict[
82
+ Segment,
83
+ defaultdict[Param, list[ResultPosition]],
84
+ ] = defaultdict(lambda: defaultdict(list))
85
+ for position, key in enumerate(keys):
86
+ segment, param = _cache_key_fn(key)
87
+ arguments[segment][param].append(position)
88
+ async with self._db() as session:
89
+ for segment, params in arguments.items():
90
+ stmt = _get_stmt(segment, *params.keys())
91
+ data = await session.stream(stmt)
92
+ async for (
93
+ id_,
94
+ prompt_cost,
95
+ completion_cost,
96
+ total_cost,
97
+ prompt_tokens,
98
+ completion_tokens,
99
+ total_tokens,
100
+ ) in data:
101
+ summary = SpanCostSummary(
102
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
103
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
104
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
105
+ )
106
+ for position in params.get(id_, []):
107
+ results[position] = summary
108
+ return results
109
+
110
+
111
+ def _get_stmt(
112
+ segment: Segment,
113
+ *params: Param,
114
+ ) -> Select[Any]:
115
+ project_rowids = params
116
+ (start_time, end_time), filter_condition, session_filter_condition = segment
117
+
118
+ stmt: Select[Any] = (
119
+ select(
120
+ models.Trace.project_rowid,
121
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
122
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
123
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
124
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
125
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
126
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
127
+ )
128
+ .select_from(models.Trace)
129
+ .join(models.SpanCost, models.Trace.id == models.SpanCost.trace_rowid)
130
+ .where(models.Trace.project_rowid.in_(project_rowids))
131
+ .group_by(models.Trace.project_rowid)
132
+ )
133
+
134
+ if start_time:
135
+ stmt = stmt.where(start_time <= models.Trace.start_time)
136
+ if end_time:
137
+ stmt = stmt.where(models.Trace.start_time < end_time)
138
+
139
+ if filter_condition:
140
+ sf = SpanFilter(filter_condition)
141
+ stmt = sf(stmt.join_from(models.SpanCost, models.Span))
142
+
143
+ if session_filter_condition:
144
+ filtered_session_rowids = get_filtered_session_rowids_subquery(
145
+ session_filter_condition=session_filter_condition,
146
+ project_rowids=project_rowids,
147
+ start_time=start_time,
148
+ end_time=end_time,
149
+ )
150
+ stmt = stmt.where(models.Trace.project_session_rowid.in_(filtered_session_rowids))
151
+
152
+ return stmt
@@ -0,0 +1,56 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+ ProjectSessionRowId: TypeAlias = int
13
+ Key: TypeAlias = ProjectSessionRowId
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByProjectSessionDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ pk = models.Trace.project_session_rowid
24
+ stmt = (
25
+ select(
26
+ pk,
27
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
28
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
29
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
30
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
31
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
32
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
33
+ )
34
+ .join_from(models.SpanCost, models.Trace)
35
+ .where(pk.in_(keys))
36
+ .group_by(pk)
37
+ )
38
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
39
+ async with self._db() as session:
40
+ data = await session.stream(stmt)
41
+ async for (
42
+ id_,
43
+ prompt_cost,
44
+ completion_cost,
45
+ total_cost,
46
+ prompt_tokens,
47
+ completion_tokens,
48
+ total_tokens,
49
+ ) in data:
50
+ summary = SpanCostSummary(
51
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
52
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
53
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
54
+ )
55
+ results[id_] = summary
56
+ return list(map(results.__getitem__, keys))
@@ -0,0 +1,55 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import func, select
4
+ from sqlalchemy.sql.functions import coalesce
5
+ from strawberry.dataloader import DataLoader
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.dataloaders.types import CostBreakdown, SpanCostSummary
10
+ from phoenix.server.types import DbSessionFactory
11
+
12
+ TraceRowId: TypeAlias = int
13
+ Key: TypeAlias = TraceRowId
14
+ Result: TypeAlias = SpanCostSummary
15
+
16
+
17
+ class SpanCostSummaryByTraceDataLoader(DataLoader[Key, Result]):
18
+ def __init__(self, db: DbSessionFactory) -> None:
19
+ super().__init__(load_fn=self._load_fn)
20
+ self._db = db
21
+
22
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
23
+ pk = models.SpanCost.trace_rowid
24
+ stmt = (
25
+ select(
26
+ pk,
27
+ coalesce(func.sum(models.SpanCost.prompt_cost), 0).label("prompt_cost"),
28
+ coalesce(func.sum(models.SpanCost.completion_cost), 0).label("completion_cost"),
29
+ coalesce(func.sum(models.SpanCost.total_cost), 0).label("total_cost"),
30
+ coalesce(func.sum(models.SpanCost.prompt_tokens), 0).label("prompt_tokens"),
31
+ coalesce(func.sum(models.SpanCost.completion_tokens), 0).label("completion_tokens"),
32
+ coalesce(func.sum(models.SpanCost.total_tokens), 0).label("total_tokens"),
33
+ )
34
+ .where(pk.in_(keys))
35
+ .group_by(pk)
36
+ )
37
+ results: defaultdict[Key, Result] = defaultdict(SpanCostSummary)
38
+ async with self._db() as session:
39
+ data = await session.stream(stmt)
40
+ async for (
41
+ id_,
42
+ prompt_cost,
43
+ completion_cost,
44
+ total_cost,
45
+ prompt_tokens,
46
+ completion_tokens,
47
+ total_tokens,
48
+ ) in data:
49
+ summary = SpanCostSummary(
50
+ prompt=CostBreakdown(tokens=prompt_tokens, cost=prompt_cost),
51
+ completion=CostBreakdown(tokens=completion_tokens, cost=completion_cost),
52
+ total=CostBreakdown(tokens=total_tokens, cost=total_cost),
53
+ )
54
+ results[id_] = summary
55
+ return list(map(results.__getitem__, keys))
@@ -0,0 +1,29 @@
1
+ from typing import Optional
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ SpanID: TypeAlias = int
11
+ Key: TypeAlias = SpanID
12
+ Result: TypeAlias = Optional[models.SpanCost]
13
+
14
+
15
+ class SpanCostsDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ span_ids = list(set(keys))
22
+ async with self._db() as session:
23
+ costs = {
24
+ span_cost.span_rowid: span_cost
25
+ async for span_cost in await session.stream_scalars(
26
+ select(models.SpanCost).where(models.SpanCost.span_rowid.in_(span_ids))
27
+ )
28
+ }
29
+ return [costs.get(span_id) for span_id in keys]
@@ -18,7 +18,7 @@ _AttrStrIdentifier: TypeAlias = str
18
18
 
19
19
 
20
20
  class TableFieldsDataLoader(DataLoader[Key, Result]):
21
- def __init__(self, db: DbSessionFactory, table: type[models.Base]) -> None:
21
+ def __init__(self, db: DbSessionFactory, table: type[models.HasId]) -> None:
22
22
  super().__init__(load_fn=self._load_fn)
23
23
  self._db = db
24
24
  self._table = table
@@ -37,7 +37,7 @@ class TableFieldsDataLoader(DataLoader[Key, Result]):
37
37
 
38
38
  def _get_stmt(
39
39
  keys: Iterable[tuple[RowId, QueryableAttribute[Any]]],
40
- table: type[models.Base],
40
+ table: type[models.HasId],
41
41
  ) -> tuple[
42
42
  Select[Any],
43
43
  dict[_ResultColumnPosition, _AttrStrIdentifier],
@@ -0,0 +1,30 @@
1
+ from collections import defaultdict
2
+
3
+ from sqlalchemy import select
4
+ from strawberry.dataloader import DataLoader
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.types import DbSessionFactory
9
+
10
+ ModelId: TypeAlias = int
11
+ Key: TypeAlias = ModelId
12
+ Result: TypeAlias = list[models.TokenPrice]
13
+
14
+
15
+ class TokenPricesByModelDataLoader(DataLoader[Key, Result]):
16
+ def __init__(self, db: DbSessionFactory) -> None:
17
+ super().__init__(load_fn=self._load_fn)
18
+ self._db = db
19
+
20
+ async def _load_fn(self, keys: list[Key]) -> list[Result]:
21
+ model_ids = keys
22
+ token_prices: defaultdict[Key, Result] = defaultdict(list)
23
+
24
+ async with self._db() as session:
25
+ async for token_price in await session.stream_scalars(
26
+ select(models.TokenPrice).where(models.TokenPrice.model_id.in_(model_ids))
27
+ ):
28
+ token_prices[token_price.model_id].append(token_price)
29
+
30
+ return [token_prices[model_id] for model_id in keys]