arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,10 @@
1
1
  import asyncio
2
+ import logging
2
3
  from dataclasses import asdict, field
3
4
  from datetime import datetime, timezone
4
5
  from itertools import chain, islice
5
6
  from traceback import format_exc
6
- from typing import Any, Iterable, Iterator, List, Optional, TypeVar, Union
7
+ from typing import Any, Iterable, Iterator, Optional, TypeVar, Union
7
8
 
8
9
  import strawberry
9
10
  from openinference.instrumentation import safe_json_dumps
@@ -22,14 +23,19 @@ from strawberry.relay import GlobalID
22
23
  from strawberry.types import Info
23
24
  from typing_extensions import assert_never
24
25
 
26
+ from phoenix.config import PLAYGROUND_PROJECT_NAME
25
27
  from phoenix.datetime_utils import local_now, normalize_datetime
26
28
  from phoenix.db import models
27
- from phoenix.db.helpers import get_dataset_example_revisions
28
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
29
+ from phoenix.db.helpers import (
30
+ get_dataset_example_revisions,
31
+ insert_experiment_with_examples_snapshot,
32
+ )
33
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
29
34
  from phoenix.server.api.context import Context
30
35
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
31
36
  from phoenix.server.api.helpers.dataset_helpers import get_dataset_example_output
32
37
  from phoenix.server.api.helpers.playground_clients import (
38
+ PlaygroundClientCredential,
33
39
  PlaygroundStreamingClient,
34
40
  initialize_playground_clients,
35
41
  )
@@ -43,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
43
49
  llm_tools,
44
50
  prompt_metadata,
45
51
  )
52
+ from phoenix.server.api.helpers.playground_users import get_user
46
53
  from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
47
54
  from phoenix.server.api.input_types.ChatCompletionInput import (
48
55
  ChatCompletionInput,
@@ -62,6 +69,7 @@ from phoenix.server.api.types.DatasetVersion import DatasetVersion
62
69
  from phoenix.server.api.types.node import from_global_id_with_expected_type
63
70
  from phoenix.server.api.types.Span import Span
64
71
  from phoenix.server.dml_event import SpanInsertEvent
72
+ from phoenix.server.experiments.utils import generate_experiment_project_name
65
73
  from phoenix.trace.attributes import unflatten
66
74
  from phoenix.trace.schemas import SpanException
67
75
  from phoenix.utilities.json import jsonify
@@ -72,9 +80,11 @@ from phoenix.utilities.template_formatters import (
72
80
  TemplateFormatter,
73
81
  )
74
82
 
83
+ logger = logging.getLogger(__name__)
84
+
75
85
  initialize_playground_clients()
76
86
 
77
- ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[List[Any]]]
87
+ ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[Any]]]
78
88
 
79
89
 
80
90
  @strawberry.type
@@ -90,24 +100,25 @@ class ChatCompletionToolCall:
90
100
 
91
101
 
92
102
  @strawberry.type
93
- class ChatCompletionMutationPayload:
94
- db_span: strawberry.Private[models.Span]
103
+ class ChatCompletionRepetition:
104
+ repetition_number: int
95
105
  content: Optional[str]
96
- tool_calls: List[ChatCompletionToolCall]
97
- span: Span
106
+ tool_calls: list[ChatCompletionToolCall]
107
+ span: Optional[Span]
98
108
  error_message: Optional[str]
99
109
 
100
110
 
101
111
  @strawberry.type
102
- class ChatCompletionMutationError:
103
- message: str
112
+ class ChatCompletionMutationPayload:
113
+ repetitions: list[ChatCompletionRepetition]
104
114
 
105
115
 
106
116
  @strawberry.type
107
117
  class ChatCompletionOverDatasetMutationExamplePayload:
108
118
  dataset_example_id: GlobalID
119
+ repetition_number: int
109
120
  experiment_run_id: GlobalID
110
- result: Union[ChatCompletionMutationPayload, ChatCompletionMutationError]
121
+ repetition: ChatCompletionRepetition
111
122
 
112
123
 
113
124
  @strawberry.type
@@ -120,7 +131,7 @@ class ChatCompletionOverDatasetMutationPayload:
120
131
 
121
132
  @strawberry.type
122
133
  class ChatCompletionMutationMixin:
123
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
134
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
124
135
  @classmethod
125
136
  async def chat_completion_over_dataset(
126
137
  cls,
@@ -132,9 +143,17 @@ class ChatCompletionMutationMixin:
132
143
  if llm_client_class is None:
133
144
  raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
134
145
  try:
146
+ # Convert GraphQL credentials to PlaygroundCredential objects
147
+ credentials = None
148
+ if input.credentials:
149
+ credentials = [
150
+ PlaygroundClientCredential(env_var_name=cred.env_var_name, value=cred.value)
151
+ for cred in input.credentials
152
+ ]
153
+
135
154
  llm_client = llm_client_class(
136
155
  model=input.model,
137
- api_key=input.api_key,
156
+ credentials=credentials,
138
157
  )
139
158
  except CustomGraphQLError:
140
159
  raise
@@ -151,6 +170,7 @@ class ChatCompletionMutationMixin:
151
170
  if input.dataset_version_id
152
171
  else None
153
172
  )
173
+ project_name = generate_experiment_project_name()
154
174
  async with info.context.db() as session:
155
175
  dataset = await session.scalar(select(models.Dataset).filter_by(id=dataset_id))
156
176
  if dataset is None:
@@ -166,16 +186,26 @@ class ChatCompletionMutationMixin:
166
186
  raise NotFound("No versions found for the given dataset")
167
187
  else:
168
188
  resolved_version_id = dataset_version_id
189
+ # Parse split IDs if provided
190
+ resolved_split_ids: Optional[list[int]] = None
191
+ if input.split_ids is not None and len(input.split_ids) > 0:
192
+ resolved_split_ids = [
193
+ from_global_id_with_expected_type(split_id, models.DatasetSplit.__name__)
194
+ for split_id in input.split_ids
195
+ ]
196
+
169
197
  revisions = [
170
198
  revision
171
199
  async for revision in await session.stream_scalars(
172
- get_dataset_example_revisions(resolved_version_id).order_by(
173
- models.DatasetExampleRevision.id
174
- )
200
+ get_dataset_example_revisions(
201
+ resolved_version_id,
202
+ split_ids=resolved_split_ids,
203
+ ).order_by(models.DatasetExampleRevision.id)
175
204
  )
176
205
  ]
177
206
  if not revisions:
178
207
  raise NotFound("No examples found for the given dataset and version")
208
+ user_id = get_user(info)
179
209
  experiment = models.Experiment(
180
210
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
181
211
  dataset_version_id=resolved_version_id,
@@ -184,15 +214,25 @@ class ChatCompletionMutationMixin:
184
214
  description=input.experiment_description,
185
215
  repetitions=1,
186
216
  metadata_=input.experiment_metadata or dict(),
187
- project_name=PLAYGROUND_PROJECT_NAME,
217
+ project_name=project_name,
218
+ user_id=user_id,
188
219
  )
189
- session.add(experiment)
190
- await session.flush()
191
-
192
- results: list[Union[ChatCompletionMutationPayload, BaseException]] = []
220
+ if resolved_split_ids:
221
+ experiment.experiment_dataset_splits = [
222
+ models.ExperimentDatasetSplit(dataset_split_id=split_id)
223
+ for split_id in resolved_split_ids
224
+ ]
225
+ await insert_experiment_with_examples_snapshot(session, experiment)
226
+
227
+ results: list[Union[tuple[ChatCompletionRepetition, models.Span], BaseException]] = []
193
228
  batch_size = 3
194
229
  start_time = datetime.now(timezone.utc)
195
- for batch in _get_batches(revisions, batch_size):
230
+ unbatched_items = [
231
+ (revision, repetition_number)
232
+ for revision in revisions
233
+ for repetition_number in range(1, input.repetitions + 1)
234
+ ]
235
+ for batch in _get_batches(unbatched_items, batch_size):
196
236
  batch_results = await asyncio.gather(
197
237
  *(
198
238
  cls._chat_completion(
@@ -200,7 +240,7 @@ class ChatCompletionMutationMixin:
200
240
  llm_client,
201
241
  ChatCompletionInput(
202
242
  model=input.model,
203
- api_key=input.api_key,
243
+ credentials=input.credentials,
204
244
  messages=input.messages,
205
245
  tools=input.tools,
206
246
  invocation_parameters=input.invocation_parameters,
@@ -209,9 +249,12 @@ class ChatCompletionMutationMixin:
209
249
  variables=revision.input,
210
250
  ),
211
251
  prompt_name=input.prompt_name,
252
+ repetitions=repetition_number,
212
253
  ),
254
+ repetition_number=repetition_number,
255
+ project_name=project_name,
213
256
  )
214
- for revision in batch
257
+ for revision, repetition_number in batch
215
258
  ),
216
259
  return_exceptions=True,
217
260
  )
@@ -223,19 +266,19 @@ class ChatCompletionMutationMixin:
223
266
  experiment_id=GlobalID(models.Experiment.__name__, str(experiment.id)),
224
267
  )
225
268
  experiment_runs = []
226
- for revision, result in zip(revisions, results):
269
+ for (revision, repetition_number), result in zip(unbatched_items, results):
227
270
  if isinstance(result, BaseException):
228
271
  experiment_run = models.ExperimentRun(
229
272
  experiment_id=experiment.id,
230
273
  dataset_example_id=revision.dataset_example_id,
231
274
  output={},
232
- repetition_number=1,
275
+ repetition_number=repetition_number,
233
276
  start_time=start_time,
234
277
  end_time=start_time,
235
278
  error=str(result),
236
279
  )
237
280
  else:
238
- db_span: models.Span = result.db_span
281
+ repetition, db_span = result
239
282
  experiment_run = models.ExperimentRun(
240
283
  experiment_id=experiment.id,
241
284
  dataset_example_id=revision.dataset_example_id,
@@ -245,10 +288,10 @@ class ChatCompletionMutationMixin:
245
288
  ),
246
289
  prompt_token_count=db_span.cumulative_llm_token_count_prompt,
247
290
  completion_token_count=db_span.cumulative_llm_token_count_completion,
248
- repetition_number=1,
291
+ repetition_number=repetition_number,
249
292
  start_time=db_span.start_time,
250
293
  end_time=db_span.end_time,
251
- error=str(result.error_message) if result.error_message else None,
294
+ error=str(repetition.error_message) if repetition.error_message else None,
252
295
  )
253
296
  experiment_runs.append(experiment_run)
254
297
 
@@ -256,22 +299,31 @@ class ChatCompletionMutationMixin:
256
299
  session.add_all(experiment_runs)
257
300
  await session.flush()
258
301
 
259
- for revision, experiment_run, result in zip(revisions, experiment_runs, results):
302
+ for (revision, repetition_number), experiment_run, result in zip(
303
+ unbatched_items, experiment_runs, results
304
+ ):
260
305
  dataset_example_id = GlobalID(
261
306
  models.DatasetExample.__name__, str(revision.dataset_example_id)
262
307
  )
263
308
  experiment_run_id = GlobalID(models.ExperimentRun.__name__, str(experiment_run.id))
264
309
  example_payload = ChatCompletionOverDatasetMutationExamplePayload(
265
310
  dataset_example_id=dataset_example_id,
311
+ repetition_number=repetition_number,
266
312
  experiment_run_id=experiment_run_id,
267
- result=result
268
- if isinstance(result, ChatCompletionMutationPayload)
269
- else ChatCompletionMutationError(message=str(result)),
313
+ repetition=ChatCompletionRepetition(
314
+ repetition_number=repetition_number,
315
+ content=None,
316
+ tool_calls=[],
317
+ span=None,
318
+ error_message=str(result),
319
+ )
320
+ if isinstance(result, BaseException)
321
+ else result[0],
270
322
  )
271
323
  payload.examples.append(example_payload)
272
324
  return payload
273
325
 
274
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
326
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
275
327
  @classmethod
276
328
  async def chat_completion(
277
329
  cls, info: Info[Context, None], input: ChatCompletionInput
@@ -281,9 +333,17 @@ class ChatCompletionMutationMixin:
281
333
  if llm_client_class is None:
282
334
  raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
283
335
  try:
336
+ # Convert GraphQL credentials to PlaygroundCredential objects
337
+ credentials = None
338
+ if input.credentials:
339
+ credentials = [
340
+ PlaygroundClientCredential(env_var_name=cred.env_var_name, value=cred.value)
341
+ for cred in input.credentials
342
+ ]
343
+
284
344
  llm_client = llm_client_class(
285
345
  model=input.model,
286
- api_key=input.api_key,
346
+ credentials=credentials,
287
347
  )
288
348
  except CustomGraphQLError:
289
349
  raise
@@ -292,7 +352,38 @@ class ChatCompletionMutationMixin:
292
352
  f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
293
353
  f"{str(error)}"
294
354
  )
295
- return await cls._chat_completion(info, llm_client, input)
355
+
356
+ results: list[Union[tuple[ChatCompletionRepetition, models.Span], BaseException]] = []
357
+ batch_size = 3
358
+ for batch in _get_batches(range(1, input.repetitions + 1), batch_size):
359
+ batch_results = await asyncio.gather(
360
+ *(
361
+ cls._chat_completion(
362
+ info, llm_client, input, repetition_number=repetition_number
363
+ )
364
+ for repetition_number in batch
365
+ ),
366
+ return_exceptions=True,
367
+ )
368
+ results.extend(batch_results)
369
+
370
+ repetitions: list[ChatCompletionRepetition] = []
371
+ for repetition_number, result in enumerate(results, start=1):
372
+ if isinstance(result, BaseException):
373
+ repetitions.append(
374
+ ChatCompletionRepetition(
375
+ repetition_number=repetition_number,
376
+ content=None,
377
+ tool_calls=[],
378
+ span=None,
379
+ error_message=str(result),
380
+ )
381
+ )
382
+ else:
383
+ repetition, _ = result
384
+ repetitions.append(repetition)
385
+
386
+ return ChatCompletionMutationPayload(repetitions=repetitions)
296
387
 
297
388
  @classmethod
298
389
  async def _chat_completion(
@@ -300,7 +391,10 @@ class ChatCompletionMutationMixin:
300
391
  info: Info[Context, None],
301
392
  llm_client: PlaygroundStreamingClient,
302
393
  input: ChatCompletionInput,
303
- ) -> ChatCompletionMutationPayload:
394
+ repetition_number: int,
395
+ project_name: str = PLAYGROUND_PROJECT_NAME,
396
+ project_description: str = "Traces from prompt playground",
397
+ ) -> tuple[ChatCompletionRepetition, models.Span]:
304
398
  attributes: dict[str, Any] = {}
305
399
  attributes.update(dict(prompt_metadata(input.prompt_name)))
306
400
 
@@ -394,15 +488,15 @@ class ChatCompletionMutationMixin:
394
488
  # Get or create the project ID
395
489
  if (
396
490
  project_id := await session.scalar(
397
- select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
491
+ select(models.Project.id).where(models.Project.name == project_name)
398
492
  )
399
493
  ) is None:
400
494
  project_id = await session.scalar(
401
495
  insert(models.Project)
402
496
  .returning(models.Project.id)
403
497
  .values(
404
- name=PLAYGROUND_PROJECT_NAME,
405
- description="Traces from prompt playground",
498
+ name=project_name,
499
+ description=project_description,
406
500
  )
407
501
  )
408
502
  trace = models.Trace(
@@ -433,27 +527,41 @@ class ChatCompletionMutationMixin:
433
527
  session.add(trace)
434
528
  session.add(span)
435
529
  await session.flush()
530
+ try:
531
+ span_cost = info.context.span_cost_calculator.calculate_cost(
532
+ start_time=span.start_time,
533
+ attributes=span.attributes,
534
+ )
535
+ except Exception as e:
536
+ logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
537
+ span_cost = None
538
+ if span_cost:
539
+ span_cost.span_rowid = span.id
540
+ span_cost.trace_rowid = trace.id
541
+ session.add(span_cost)
542
+ await session.flush()
436
543
 
437
- gql_span = Span(span_rowid=span.id, db_span=span)
544
+ gql_span = Span(id=span.id, db_record=span)
438
545
 
439
546
  info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
440
547
 
441
548
  if status_code is StatusCode.ERROR:
442
- return ChatCompletionMutationPayload(
443
- db_span=span,
549
+ repetition = ChatCompletionRepetition(
550
+ repetition_number=repetition_number,
444
551
  content=None,
445
552
  tool_calls=[],
446
553
  span=gql_span,
447
554
  error_message=status_message,
448
555
  )
449
556
  else:
450
- return ChatCompletionMutationPayload(
451
- db_span=span,
557
+ repetition = ChatCompletionRepetition(
558
+ repetition_number=repetition_number,
452
559
  content=text_content if text_content else None,
453
560
  tool_calls=list(tool_calls.values()),
454
561
  span=gql_span,
455
562
  error_message=None,
456
563
  )
564
+ return repetition, span
457
565
 
458
566
 
459
567
  def _formatted_messages(
@@ -588,5 +696,4 @@ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUME
588
696
  TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
589
697
  PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
590
698
 
591
-
592
- PLAYGROUND_PROJECT_NAME = "playground"
699
+ LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
@@ -0,0 +1,243 @@
1
+ from typing import Optional
2
+
3
+ import sqlalchemy
4
+ import strawberry
5
+ from sqlalchemy import delete, select
6
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
7
+ from sqlalchemy.orm import joinedload
8
+ from sqlalchemy.sql import tuple_
9
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
10
+ from strawberry import UNSET
11
+ from strawberry.relay.types import GlobalID
12
+ from strawberry.types import Info
13
+
14
+ from phoenix.db import models
15
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
16
+ from phoenix.server.api.context import Context
17
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
18
+ from phoenix.server.api.queries import Query
19
+ from phoenix.server.api.types.Dataset import Dataset
20
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
21
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
22
+
23
+
24
+ @strawberry.input
25
+ class CreateDatasetLabelInput:
26
+ name: str
27
+ description: Optional[str] = UNSET
28
+ color: str
29
+ dataset_ids: Optional[list[GlobalID]] = UNSET
30
+
31
+
32
+ @strawberry.type
33
+ class CreateDatasetLabelMutationPayload:
34
+ dataset_label: DatasetLabel
35
+ datasets: list[Dataset]
36
+
37
+
38
+ @strawberry.input
39
+ class DeleteDatasetLabelsInput:
40
+ dataset_label_ids: list[GlobalID]
41
+
42
+
43
+ @strawberry.type
44
+ class DeleteDatasetLabelsMutationPayload:
45
+ dataset_labels: list[DatasetLabel]
46
+
47
+
48
+ @strawberry.input
49
+ class SetDatasetLabelsInput:
50
+ dataset_id: GlobalID
51
+ dataset_label_ids: list[GlobalID]
52
+
53
+
54
+ @strawberry.type
55
+ class SetDatasetLabelsMutationPayload:
56
+ query: Query
57
+ dataset: Dataset
58
+
59
+
60
+ @strawberry.type
61
+ class DatasetLabelMutationMixin:
62
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
63
+ async def create_dataset_label(
64
+ self,
65
+ info: Info[Context, None],
66
+ input: CreateDatasetLabelInput,
67
+ ) -> CreateDatasetLabelMutationPayload:
68
+ name = input.name
69
+ description = input.description
70
+ color = input.color
71
+ dataset_rowids: dict[
72
+ int, None
73
+ ] = {} # use dictionary to de-duplicate while preserving order
74
+ if input.dataset_ids:
75
+ for dataset_id in input.dataset_ids:
76
+ try:
77
+ dataset_rowid = from_global_id_with_expected_type(dataset_id, Dataset.__name__)
78
+ except ValueError:
79
+ raise BadRequest(f"Invalid dataset ID: {dataset_id}")
80
+ dataset_rowids[dataset_rowid] = None
81
+
82
+ async with info.context.db() as session:
83
+ dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
84
+ session.add(dataset_label_orm)
85
+ try:
86
+ await session.flush()
87
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
88
+ raise Conflict(f"A dataset label named '{name}' already exists")
89
+ except sqlalchemy.exc.StatementError as error:
90
+ raise BadRequest(str(error.orig))
91
+
92
+ datasets_by_id: dict[int, models.Dataset] = {}
93
+ if dataset_rowids:
94
+ datasets_by_id = {
95
+ dataset.id: dataset
96
+ for dataset in await session.scalars(
97
+ select(models.Dataset).where(models.Dataset.id.in_(dataset_rowids.keys()))
98
+ )
99
+ }
100
+ if len(datasets_by_id) < len(dataset_rowids):
101
+ raise NotFound("One or more datasets not found")
102
+ session.add_all(
103
+ [
104
+ models.DatasetsDatasetLabel(
105
+ dataset_id=dataset_rowid,
106
+ dataset_label_id=dataset_label_orm.id,
107
+ )
108
+ for dataset_rowid in dataset_rowids
109
+ ]
110
+ )
111
+ await session.commit()
112
+
113
+ return CreateDatasetLabelMutationPayload(
114
+ dataset_label=DatasetLabel(id=dataset_label_orm.id, db_record=dataset_label_orm),
115
+ datasets=[
116
+ Dataset(
117
+ id=datasets_by_id[dataset_rowid].id, db_record=datasets_by_id[dataset_rowid]
118
+ )
119
+ for dataset_rowid in dataset_rowids
120
+ ],
121
+ )
122
+
123
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
124
+ async def delete_dataset_labels(
125
+ self, info: Info[Context, None], input: DeleteDatasetLabelsInput
126
+ ) -> DeleteDatasetLabelsMutationPayload:
127
+ dataset_label_row_ids: dict[int, None] = {}
128
+ for dataset_label_node_id in input.dataset_label_ids:
129
+ try:
130
+ dataset_label_row_id = from_global_id_with_expected_type(
131
+ dataset_label_node_id, DatasetLabel.__name__
132
+ )
133
+ except ValueError:
134
+ raise BadRequest(f"Unknown dataset label: {dataset_label_node_id}")
135
+ dataset_label_row_ids[dataset_label_row_id] = None
136
+ async with info.context.db() as session:
137
+ stmt = (
138
+ delete(models.DatasetLabel)
139
+ .where(models.DatasetLabel.id.in_(dataset_label_row_ids.keys()))
140
+ .returning(models.DatasetLabel)
141
+ )
142
+ deleted_dataset_labels = (await session.scalars(stmt)).all()
143
+ if len(deleted_dataset_labels) < len(dataset_label_row_ids):
144
+ await session.rollback()
145
+ raise NotFound("Could not find one or more dataset labels with given IDs")
146
+ deleted_dataset_labels_by_id = {
147
+ dataset_label.id: dataset_label for dataset_label in deleted_dataset_labels
148
+ }
149
+ return DeleteDatasetLabelsMutationPayload(
150
+ dataset_labels=[
151
+ DatasetLabel(
152
+ id=deleted_dataset_labels_by_id[dataset_label_row_id].id,
153
+ db_record=deleted_dataset_labels_by_id[dataset_label_row_id],
154
+ )
155
+ for dataset_label_row_id in dataset_label_row_ids
156
+ ]
157
+ )
158
+
159
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
160
+ async def set_dataset_labels(
161
+ self, info: Info[Context, None], input: SetDatasetLabelsInput
162
+ ) -> SetDatasetLabelsMutationPayload:
163
+ try:
164
+ dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
165
+ except ValueError:
166
+ raise BadRequest(f"Invalid dataset ID: {input.dataset_id}")
167
+
168
+ dataset_label_ids: dict[
169
+ int, None
170
+ ] = {} # use dictionary to de-duplicate while preserving order
171
+ for dataset_label_gid in input.dataset_label_ids:
172
+ try:
173
+ dataset_label_id = from_global_id_with_expected_type(
174
+ dataset_label_gid, DatasetLabel.__name__
175
+ )
176
+ except ValueError:
177
+ raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
178
+ dataset_label_ids[dataset_label_id] = None
179
+
180
+ async with info.context.db() as session:
181
+ dataset = await session.scalar(
182
+ select(models.Dataset)
183
+ .where(models.Dataset.id == dataset_id)
184
+ .options(joinedload(models.Dataset.datasets_dataset_labels))
185
+ )
186
+
187
+ if not dataset:
188
+ raise NotFound(f"Dataset with ID {input.dataset_id} not found")
189
+
190
+ existing_label_ids = (
191
+ await session.scalars(
192
+ select(models.DatasetLabel.id).where(
193
+ models.DatasetLabel.id.in_(dataset_label_ids.keys())
194
+ )
195
+ )
196
+ ).all()
197
+ if len(existing_label_ids) != len(dataset_label_ids):
198
+ raise NotFound("One or more dataset labels not found")
199
+
200
+ previously_applied_dataset_label_ids = {
201
+ dataset_dataset_label.dataset_label_id
202
+ for dataset_dataset_label in dataset.datasets_dataset_labels
203
+ }
204
+
205
+ datasets_dataset_labels_to_add = [
206
+ models.DatasetsDatasetLabel(
207
+ dataset_id=dataset_id,
208
+ dataset_label_id=dataset_label_id,
209
+ )
210
+ for dataset_label_id in dataset_label_ids
211
+ if dataset_label_id not in previously_applied_dataset_label_ids
212
+ ]
213
+ if datasets_dataset_labels_to_add:
214
+ session.add_all(datasets_dataset_labels_to_add)
215
+ await session.flush()
216
+
217
+ datasets_dataset_labels_to_delete = [
218
+ dataset_dataset_label
219
+ for dataset_dataset_label in dataset.datasets_dataset_labels
220
+ if dataset_dataset_label.dataset_label_id not in dataset_label_ids
221
+ ]
222
+ if datasets_dataset_labels_to_delete:
223
+ await session.execute(
224
+ delete(models.DatasetsDatasetLabel).where(
225
+ tuple_(
226
+ models.DatasetsDatasetLabel.dataset_id,
227
+ models.DatasetsDatasetLabel.dataset_label_id,
228
+ ).in_(
229
+ [
230
+ (
231
+ datasets_dataset_labels.dataset_id,
232
+ datasets_dataset_labels.dataset_label_id,
233
+ )
234
+ for datasets_dataset_labels in datasets_dataset_labels_to_delete
235
+ ]
236
+ )
237
+ )
238
+ )
239
+
240
+ return SetDatasetLabelsMutationPayload(
241
+ dataset=Dataset(id=dataset.id, db_record=dataset),
242
+ query=Query(),
243
+ )