arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,31 +1,98 @@
1
1
  from datetime import datetime
2
- from typing import Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
+ from strawberry.scalars import JSON
6
+ from strawberry.types import Info
5
7
 
6
- from phoenix.server.api.interceptor import GqlValueMediator
8
+ from phoenix.server.api.context import Context
9
+
10
+ from .AnnotationSource import AnnotationSource
11
+ from .AnnotatorKind import AnnotatorKind
12
+
13
+ if TYPE_CHECKING:
14
+ from .User import User
7
15
 
8
16
 
9
17
  @strawberry.interface
10
18
  class Annotation:
11
- name: str = strawberry.field(
12
- description="Name of the annotation, e.g. 'helpfulness' or 'relevance'."
13
- )
14
- score: Optional[float] = strawberry.field(
15
- description="Value of the annotation in the form of a numeric score.",
16
- default=GqlValueMediator(),
17
- )
18
- label: Optional[str] = strawberry.field(
19
- description="Value of the annotation in the form of a string, e.g. "
20
- "'helpful' or 'not helpful'. Note that the label is not necessarily binary."
21
- )
22
- explanation: Optional[str] = strawberry.field(
23
- description="The annotator's explanation for the annotation result (i.e. "
24
- "score or label, or both) given to the subject."
25
- )
26
- created_at: datetime = strawberry.field(
27
- description="The date and time when the annotation was created."
28
- )
29
- updated_at: datetime = strawberry.field(
30
- description="The date and time when the annotation was last updated."
31
- )
19
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
20
+ async def name(
21
+ self,
22
+ info: Info[Context, None],
23
+ ) -> str:
24
+ raise NotImplementedError
25
+
26
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
27
+ async def annotator_kind(
28
+ self,
29
+ info: Info[Context, None],
30
+ ) -> AnnotatorKind:
31
+ raise NotImplementedError
32
+
33
+ @strawberry.field(
34
+ description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
35
+ ) # type: ignore
36
+ async def label(
37
+ self,
38
+ info: Info[Context, None],
39
+ ) -> Optional[str]:
40
+ raise NotImplementedError
41
+
42
+ @strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
43
+ async def score(
44
+ self,
45
+ info: Info[Context, None],
46
+ ) -> Optional[float]:
47
+ raise NotImplementedError
48
+
49
+ @strawberry.field(
50
+ description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
51
+ ) # type: ignore
52
+ async def explanation(
53
+ self,
54
+ info: Info[Context, None],
55
+ ) -> Optional[str]:
56
+ raise NotImplementedError
57
+
58
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
59
+ async def metadata(
60
+ self,
61
+ info: Info[Context, None],
62
+ ) -> JSON:
63
+ raise NotImplementedError
64
+
65
+ @strawberry.field(description="The source of the annotation.") # type: ignore
66
+ async def source(
67
+ self,
68
+ info: Info[Context, None],
69
+ ) -> AnnotationSource:
70
+ raise NotImplementedError
71
+
72
+ @strawberry.field(description="The identifier of the annotation.") # type: ignore
73
+ async def identifier(
74
+ self,
75
+ info: Info[Context, None],
76
+ ) -> str:
77
+ raise NotImplementedError
78
+
79
+ @strawberry.field(description="The date and time the annotation was created.") # type: ignore
80
+ async def created_at(
81
+ self,
82
+ info: Info[Context, None],
83
+ ) -> datetime:
84
+ raise NotImplementedError
85
+
86
+ @strawberry.field(description="The date and time the annotation was last updated.") # type: ignore
87
+ async def updated_at(
88
+ self,
89
+ info: Info[Context, None],
90
+ ) -> datetime:
91
+ raise NotImplementedError
92
+
93
+ @strawberry.field(description="The user that produced the annotation.") # type: ignore
94
+ async def user(
95
+ self,
96
+ info: Info[Context, None],
97
+ ) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
98
+ raise NotImplementedError
@@ -3,25 +3,21 @@ from typing import Optional
3
3
 
4
4
  import strawberry
5
5
 
6
- from phoenix.db.models import ApiKey as ORMApiKey
7
-
8
6
 
9
7
  @strawberry.interface
10
8
  class ApiKey:
11
- name: str = strawberry.field(description="Name of the API key.")
12
- description: Optional[str] = strawberry.field(description="Description of the API key.")
13
- created_at: datetime = strawberry.field(
14
- description="The date and time the API key was created."
15
- )
16
- expires_at: Optional[datetime] = strawberry.field(
17
- description="The date and time the API key will expire."
18
- )
9
+ @strawberry.field(description="Name of the API key.") # type: ignore
10
+ async def name(self) -> str:
11
+ raise NotImplementedError
12
+
13
+ @strawberry.field(description="Description of the API key.") # type: ignore
14
+ async def description(self) -> Optional[str]:
15
+ raise NotImplementedError
19
16
 
17
+ @strawberry.field(description="The date and time the API key was created.") # type: ignore
18
+ async def created_at(self) -> datetime:
19
+ raise NotImplementedError
20
20
 
21
- def to_gql_api_key(api_key: ORMApiKey) -> ApiKey:
22
- return ApiKey(
23
- name=api_key.name,
24
- description=api_key.description,
25
- created_at=api_key.created_at,
26
- expires_at=api_key.expires_at,
27
- )
21
+ @strawberry.field(description="The date and time the API key will expire.") # type: ignore
22
+ async def expires_at(self) -> Optional[datetime]:
23
+ raise NotImplementedError
@@ -7,3 +7,4 @@ import strawberry
7
7
  class AuthMethod(Enum):
8
8
  LOCAL = "LOCAL"
9
9
  OAUTH2 = "OAUTH2"
10
+ LDAP = "LDAP"
@@ -11,6 +11,7 @@ from .Span import Span
11
11
  @strawberry.interface
12
12
  class ChatCompletionSubscriptionPayload:
13
13
  dataset_example_id: Optional[GlobalID] = None
14
+ repetition_number: Optional[int] = None
14
15
 
15
16
 
16
17
  @strawberry.type
@@ -0,0 +1,12 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.type
7
+ class CostBreakdown:
8
+ tokens: Optional[float] = strawberry.field(
9
+ default=None,
10
+ description="Total number of tokens, including tokens for which no cost was computed.",
11
+ )
12
+ cost: Optional[float] = None
@@ -1,9 +1,9 @@
1
1
  from collections.abc import AsyncIterable
2
2
  from datetime import datetime
3
- from typing import ClassVar, Optional, cast
3
+ from typing import Optional, cast
4
4
 
5
5
  import strawberry
6
- from sqlalchemy import and_, func, select
6
+ from sqlalchemy import Text, and_, func, or_, select
7
7
  from sqlalchemy.sql.functions import count
8
8
  from strawberry import UNSET
9
9
  from strawberry.relay import Connection, GlobalID, Node, NodeID
@@ -12,11 +12,16 @@ from strawberry.types import Info
12
12
 
13
13
  from phoenix.db import models
14
14
  from phoenix.server.api.context import Context
15
+ from phoenix.server.api.exceptions import BadRequest
15
16
  from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
16
17
  from phoenix.server.api.types.DatasetExample import DatasetExample
18
+ from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
19
+ DatasetExperimentAnnotationSummary,
20
+ )
21
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
22
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
17
23
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
18
24
  from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
19
- from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
20
25
  from phoenix.server.api.types.node import from_global_id_with_expected_type
21
26
  from phoenix.server.api.types.pagination import (
22
27
  ConnectionArgs,
@@ -28,13 +33,77 @@ from phoenix.server.api.types.SortDir import SortDir
28
33
 
29
34
  @strawberry.type
30
35
  class Dataset(Node):
31
- _table: ClassVar[type[models.Base]] = models.Experiment
32
- id_attr: NodeID[int]
33
- name: str
34
- description: Optional[str]
35
- metadata: JSON
36
- created_at: datetime
37
- updated_at: datetime
36
+ id: NodeID[int]
37
+ db_record: strawberry.Private[Optional[models.Dataset]] = None
38
+
39
+ def __post_init__(self) -> None:
40
+ if self.db_record and self.id != self.db_record.id:
41
+ raise ValueError("Dataset ID mismatch")
42
+
43
+ @strawberry.field
44
+ async def name(
45
+ self,
46
+ info: Info[Context, None],
47
+ ) -> str:
48
+ if self.db_record:
49
+ val = self.db_record.name
50
+ else:
51
+ val = await info.context.data_loaders.dataset_fields.load(
52
+ (self.id, models.Dataset.name),
53
+ )
54
+ return val
55
+
56
+ @strawberry.field
57
+ async def description(
58
+ self,
59
+ info: Info[Context, None],
60
+ ) -> Optional[str]:
61
+ if self.db_record:
62
+ val = self.db_record.description
63
+ else:
64
+ val = await info.context.data_loaders.dataset_fields.load(
65
+ (self.id, models.Dataset.description),
66
+ )
67
+ return val
68
+
69
+ @strawberry.field
70
+ async def metadata(
71
+ self,
72
+ info: Info[Context, None],
73
+ ) -> JSON:
74
+ if self.db_record:
75
+ val = self.db_record.metadata_
76
+ else:
77
+ val = await info.context.data_loaders.dataset_fields.load(
78
+ (self.id, models.Dataset.metadata_),
79
+ )
80
+ return val
81
+
82
+ @strawberry.field
83
+ async def created_at(
84
+ self,
85
+ info: Info[Context, None],
86
+ ) -> datetime:
87
+ if self.db_record:
88
+ val = self.db_record.created_at
89
+ else:
90
+ val = await info.context.data_loaders.dataset_fields.load(
91
+ (self.id, models.Dataset.created_at),
92
+ )
93
+ return val
94
+
95
+ @strawberry.field
96
+ async def updated_at(
97
+ self,
98
+ info: Info[Context, None],
99
+ ) -> datetime:
100
+ if self.db_record:
101
+ val = self.db_record.updated_at
102
+ else:
103
+ val = await info.context.data_loaders.dataset_fields.load(
104
+ (self.id, models.Dataset.updated_at),
105
+ )
106
+ return val
38
107
 
39
108
  @strawberry.field
40
109
  async def versions(
@@ -53,7 +122,7 @@ class Dataset(Node):
53
122
  before=before if isinstance(before, CursorString) else None,
54
123
  )
55
124
  async with info.context.db() as session:
56
- stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id_attr)
125
+ stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id)
57
126
  if sort:
58
127
  # For now assume the the column names match 1:1 with the enum values
59
128
  sort_col = getattr(models.DatasetVersion, sort.col.value)
@@ -64,15 +133,7 @@ class Dataset(Node):
64
133
  else:
65
134
  stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
66
135
  versions = await session.scalars(stmt)
67
- data = [
68
- DatasetVersion(
69
- id_attr=version.id,
70
- description=version.description,
71
- metadata=version.metadata_,
72
- created_at=version.created_at,
73
- )
74
- for version in versions
75
- ]
136
+ data = [DatasetVersion(id=version.id, db_record=version) for version in versions]
76
137
  return connection_from_list(data=data, args=args)
77
138
 
78
139
  @strawberry.field(
@@ -83,8 +144,9 @@ class Dataset(Node):
83
144
  self,
84
145
  info: Info[Context, None],
85
146
  dataset_version_id: Optional[GlobalID] = UNSET,
147
+ split_ids: Optional[list[GlobalID]] = UNSET,
86
148
  ) -> int:
87
- dataset_id = self.id_attr
149
+ dataset_id = self.id
88
150
  version_id = (
89
151
  from_global_id_with_expected_type(
90
152
  global_id=dataset_version_id,
@@ -93,6 +155,20 @@ class Dataset(Node):
93
155
  if dataset_version_id
94
156
  else None
95
157
  )
158
+
159
+ # Parse split IDs if provided
160
+ split_rowids: Optional[list[int]] = None
161
+ if split_ids:
162
+ split_rowids = []
163
+ for split_id in split_ids:
164
+ try:
165
+ split_rowid = from_global_id_with_expected_type(
166
+ global_id=split_id, expected_type_name=models.DatasetSplit.__name__
167
+ )
168
+ split_rowids.append(split_rowid)
169
+ except Exception:
170
+ raise BadRequest(f"Invalid split ID: {split_id}")
171
+
96
172
  revision_ids = (
97
173
  select(func.max(models.DatasetExampleRevision.id))
98
174
  .join(models.DatasetExample)
@@ -109,11 +185,36 @@ class Dataset(Node):
109
185
  revision_ids = revision_ids.where(
110
186
  models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
111
187
  )
112
- stmt = (
113
- select(count(models.DatasetExampleRevision.id))
114
- .where(models.DatasetExampleRevision.id.in_(revision_ids))
115
- .where(models.DatasetExampleRevision.revision_kind != "DELETE")
116
- )
188
+
189
+ # Build the count query
190
+ if split_rowids:
191
+ # When filtering by splits, count distinct examples that belong to those splits
192
+ stmt = (
193
+ select(count(models.DatasetExample.id.distinct()))
194
+ .join(
195
+ models.DatasetExampleRevision,
196
+ onclause=(
197
+ models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id
198
+ ),
199
+ )
200
+ .join(
201
+ models.DatasetSplitDatasetExample,
202
+ onclause=(
203
+ models.DatasetExample.id
204
+ == models.DatasetSplitDatasetExample.dataset_example_id
205
+ ),
206
+ )
207
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
208
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
209
+ .where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
210
+ )
211
+ else:
212
+ stmt = (
213
+ select(count(models.DatasetExampleRevision.id))
214
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
215
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
216
+ )
217
+
117
218
  async with info.context.db() as session:
118
219
  return (await session.scalar(stmt)) or 0
119
220
 
@@ -122,10 +223,12 @@ class Dataset(Node):
122
223
  self,
123
224
  info: Info[Context, None],
124
225
  dataset_version_id: Optional[GlobalID] = UNSET,
226
+ split_ids: Optional[list[GlobalID]] = UNSET,
125
227
  first: Optional[int] = 50,
126
228
  last: Optional[int] = UNSET,
127
229
  after: Optional[CursorString] = UNSET,
128
230
  before: Optional[CursorString] = UNSET,
231
+ filter: Optional[str] = UNSET,
129
232
  ) -> Connection[DatasetExample]:
130
233
  args = ConnectionArgs(
131
234
  first=first,
@@ -133,7 +236,7 @@ class Dataset(Node):
133
236
  last=last,
134
237
  before=before if isinstance(before, CursorString) else None,
135
238
  )
136
- dataset_id = self.id_attr
239
+ dataset_id = self.id
137
240
  version_id = (
138
241
  from_global_id_with_expected_type(
139
242
  global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
@@ -141,6 +244,20 @@ class Dataset(Node):
141
244
  if dataset_version_id
142
245
  else None
143
246
  )
247
+
248
+ # Parse split IDs if provided
249
+ split_rowids: Optional[list[int]] = None
250
+ if split_ids:
251
+ split_rowids = []
252
+ for split_id in split_ids:
253
+ try:
254
+ split_rowid = from_global_id_with_expected_type(
255
+ global_id=split_id, expected_type_name=models.DatasetSplit.__name__
256
+ )
257
+ split_rowids.append(split_rowid)
258
+ except Exception:
259
+ raise BadRequest(f"Invalid split ID: {split_id}")
260
+
144
261
  revision_ids = (
145
262
  select(func.max(models.DatasetExampleRevision.id))
146
263
  .join(models.DatasetExample)
@@ -170,19 +287,51 @@ class Dataset(Node):
170
287
  models.DatasetExampleRevision.revision_kind != "DELETE",
171
288
  )
172
289
  )
173
- .order_by(models.DatasetExampleRevision.dataset_example_id.desc())
290
+ .order_by(models.DatasetExample.id.desc())
174
291
  )
292
+
293
+ # Filter by split IDs if provided
294
+ if split_rowids:
295
+ query = (
296
+ query.join(
297
+ models.DatasetSplitDatasetExample,
298
+ onclause=(
299
+ models.DatasetExample.id
300
+ == models.DatasetSplitDatasetExample.dataset_example_id
301
+ ),
302
+ )
303
+ .where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
304
+ .distinct()
305
+ )
306
+ # Apply filter if provided - search through JSON fields (input, output, metadata)
307
+ if filter is not UNSET and filter:
308
+ # Create a filter that searches for the filter string in JSON fields
309
+ # Using PostgreSQL's JSON operators for case-insensitive text search
310
+ filter_condition = or_(
311
+ func.cast(models.DatasetExampleRevision.input, Text).ilike(f"%{filter}%"),
312
+ func.cast(models.DatasetExampleRevision.output, Text).ilike(f"%{filter}%"),
313
+ func.cast(models.DatasetExampleRevision.metadata_, Text).ilike(f"%{filter}%"),
314
+ )
315
+ query = query.where(filter_condition)
316
+
175
317
  async with info.context.db() as session:
176
318
  dataset_examples = [
177
319
  DatasetExample(
178
- id_attr=example.id,
320
+ id=example.id,
321
+ db_record=example,
179
322
  version_id=version_id,
180
- created_at=example.created_at,
181
323
  )
182
324
  async for example in await session.stream_scalars(query)
183
325
  ]
184
326
  return connection_from_list(data=dataset_examples, args=args)
185
327
 
328
+ @strawberry.field
329
+ async def splits(self, info: Info[Context, None]) -> list[DatasetSplit]:
330
+ return [
331
+ DatasetSplit(id=split.id, db_record=split)
332
+ for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id)
333
+ ]
334
+
186
335
  @strawberry.field(
187
336
  description="Number of experiments for a specific version if version is specified, "
188
337
  "or for all versions if version is not specified."
@@ -192,9 +341,7 @@ class Dataset(Node):
192
341
  info: Info[Context, None],
193
342
  dataset_version_id: Optional[GlobalID] = UNSET,
194
343
  ) -> int:
195
- stmt = select(count(models.Experiment.id)).where(
196
- models.Experiment.dataset_id == self.id_attr
197
- )
344
+ stmt = select(count(models.Experiment.id)).where(models.Experiment.dataset_id == self.id)
198
345
  version_id = (
199
346
  from_global_id_with_expected_type(
200
347
  global_id=dataset_version_id,
@@ -216,6 +363,10 @@ class Dataset(Node):
216
363
  last: Optional[int] = UNSET,
217
364
  after: Optional[CursorString] = UNSET,
218
365
  before: Optional[CursorString] = UNSET,
366
+ filter_condition: Optional[str] = UNSET,
367
+ filter_ids: Optional[
368
+ list[GlobalID]
369
+ ] = UNSET, # this is a stopgap until a query DSL is implemented
219
370
  ) -> Connection[Experiment]:
220
371
  args = ConnectionArgs(
221
372
  first=first,
@@ -223,13 +374,35 @@ class Dataset(Node):
223
374
  last=last,
224
375
  before=before if isinstance(before, CursorString) else None,
225
376
  )
226
- dataset_id = self.id_attr
377
+ dataset_id = self.id
227
378
  row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
228
379
  query = (
229
380
  select(models.Experiment, row_number)
230
381
  .where(models.Experiment.dataset_id == dataset_id)
231
382
  .order_by(models.Experiment.id.desc())
232
383
  )
384
+ if filter_condition is not UNSET and filter_condition:
385
+ # Search both name and description columns with case-insensitive partial matching
386
+ search_filter = or_(
387
+ models.Experiment.name.ilike(f"%{filter_condition}%"),
388
+ models.Experiment.description.ilike(f"%{filter_condition}%"),
389
+ )
390
+ query = query.where(search_filter)
391
+
392
+ if filter_ids:
393
+ filter_rowids = []
394
+ for filter_id in filter_ids:
395
+ try:
396
+ filter_rowids.append(
397
+ from_global_id_with_expected_type(
398
+ global_id=filter_id,
399
+ expected_type_name=Experiment.__name__,
400
+ )
401
+ )
402
+ except ValueError:
403
+ raise BadRequest(f"Invalid filter ID: {filter_id}")
404
+ query = query.where(models.Experiment.id.in_(filter_rowids))
405
+
233
406
  async with info.context.db() as session:
234
407
  experiments = [
235
408
  to_gql_experiment(experiment, sequence_number)
@@ -243,17 +416,15 @@ class Dataset(Node):
243
416
  @strawberry.field
244
417
  async def experiment_annotation_summaries(
245
418
  self, info: Info[Context, None]
246
- ) -> list[ExperimentAnnotationSummary]:
247
- dataset_id = self.id_attr
419
+ ) -> list[DatasetExperimentAnnotationSummary]:
420
+ dataset_id = self.id
248
421
  query = (
249
422
  select(
250
- models.ExperimentRunAnnotation.name,
251
- func.min(models.ExperimentRunAnnotation.score),
252
- func.max(models.ExperimentRunAnnotation.score),
253
- func.avg(models.ExperimentRunAnnotation.score),
254
- func.count(),
255
- func.count(models.ExperimentRunAnnotation.error),
423
+ models.ExperimentRunAnnotation.name.label("annotation_name"),
424
+ func.min(models.ExperimentRunAnnotation.score).label("min_score"),
425
+ func.max(models.ExperimentRunAnnotation.score).label("max_score"),
256
426
  )
427
+ .select_from(models.ExperimentRunAnnotation)
257
428
  .join(
258
429
  models.ExperimentRun,
259
430
  models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
@@ -268,38 +439,21 @@ class Dataset(Node):
268
439
  )
269
440
  async with info.context.db() as session:
270
441
  return [
271
- ExperimentAnnotationSummary(
272
- annotation_name=annotation_name,
273
- min_score=min_score,
274
- max_score=max_score,
275
- mean_score=mean_score,
276
- count=count_,
277
- error_count=error_count,
442
+ DatasetExperimentAnnotationSummary(
443
+ annotation_name=scores_tuple.annotation_name,
444
+ min_score=scores_tuple.min_score,
445
+ max_score=scores_tuple.max_score,
278
446
  )
279
- async for (
280
- annotation_name,
281
- min_score,
282
- max_score,
283
- mean_score,
284
- count_,
285
- error_count,
286
- ) in await session.stream(query)
447
+ async for scores_tuple in await session.stream(query)
287
448
  ]
288
449
 
289
450
  @strawberry.field
290
- def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
291
- return info.context.last_updated_at.get(self._table, self.id_attr)
292
-
451
+ async def labels(self, info: Info[Context, None]) -> list[DatasetLabel]:
452
+ return [
453
+ DatasetLabel(id=label.id, db_record=label)
454
+ for label in await info.context.data_loaders.dataset_labels.load(self.id)
455
+ ]
293
456
 
294
- def to_gql_dataset(dataset: models.Dataset) -> Dataset:
295
- """
296
- Converts an ORM dataset to a GraphQL dataset.
297
- """
298
- return Dataset(
299
- id_attr=dataset.id,
300
- name=dataset.name,
301
- description=dataset.description,
302
- metadata=dataset.metadata_,
303
- created_at=dataset.created_at,
304
- updated_at=dataset.updated_at,
305
- )
457
+ @strawberry.field
458
+ def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
459
+ return info.context.last_updated_at.get(models.Dataset, self.id)