arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime, timezone
2
- from typing import Optional
2
+ from typing import Literal, Optional
3
3
 
4
4
  import strawberry
5
5
  from sqlalchemy import select
@@ -9,7 +9,7 @@ from strawberry.types import Info
9
9
 
10
10
  from phoenix.db import models
11
11
  from phoenix.db.models import UserRoleName
12
- from phoenix.server.api.auth import IsAdmin, IsLocked, IsNotReadOnly
12
+ from phoenix.server.api.auth import IsAdmin, IsLocked, IsNotReadOnly, IsNotViewer
13
13
  from phoenix.server.api.context import Context
14
14
  from phoenix.server.api.exceptions import Unauthorized
15
15
  from phoenix.server.api.queries import Query
@@ -61,7 +61,7 @@ class DeleteApiKeyMutationPayload:
61
61
 
62
62
  @strawberry.type
63
63
  class ApiKeyMutationMixin:
64
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin, IsLocked]) # type: ignore
64
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsAdmin, IsLocked]) # type: ignore
65
65
  async def create_system_api_key(
66
66
  self, info: Info[Context, None], input: CreateApiKeyInput
67
67
  ) -> CreateSystemApiKeyMutationPayload:
@@ -92,13 +92,7 @@ class ApiKeyMutationMixin:
92
92
  token, token_id = await token_store.create_api_key(claims)
93
93
  return CreateSystemApiKeyMutationPayload(
94
94
  jwt=token,
95
- api_key=SystemApiKey(
96
- id_attr=int(token_id),
97
- name=input.name,
98
- description=input.description or None,
99
- created_at=issued_at,
100
- expires_at=input.expires_at or None,
101
- ),
95
+ api_key=SystemApiKey(id=int(token_id)),
102
96
  query=Query(),
103
97
  )
104
98
 
@@ -113,12 +107,20 @@ class ApiKeyMutationMixin:
113
107
  except AttributeError:
114
108
  raise ValueError("User not found")
115
109
  issued_at = datetime.now(timezone.utc)
110
+ # Determine user role for API key
111
+ user_role: Literal["ADMIN", "MEMBER", "VIEWER"]
112
+ if user.is_admin:
113
+ user_role = "ADMIN"
114
+ elif user.is_viewer:
115
+ user_role = "VIEWER"
116
+ else:
117
+ user_role = "MEMBER"
116
118
  claims = ApiKeyClaims(
117
119
  subject=user.identity,
118
120
  issued_at=issued_at,
119
121
  expiration_time=input.expires_at or None,
120
122
  attributes=ApiKeyAttributes(
121
- user_role="ADMIN" if user.is_admin else "MEMBER",
123
+ user_role=user_role,
122
124
  name=input.name,
123
125
  description=input.description,
124
126
  ),
@@ -126,18 +128,11 @@ class ApiKeyMutationMixin:
126
128
  token, token_id = await token_store.create_api_key(claims)
127
129
  return CreateUserApiKeyMutationPayload(
128
130
  jwt=token,
129
- api_key=UserApiKey(
130
- id_attr=int(token_id),
131
- name=input.name,
132
- description=input.description or None,
133
- created_at=issued_at,
134
- expires_at=input.expires_at or None,
135
- user_id=int(user.identity),
136
- ),
131
+ api_key=UserApiKey(id=int(token_id)),
137
132
  query=Query(),
138
133
  )
139
134
 
140
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
135
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsAdmin]) # type: ignore
141
136
  async def delete_system_api_key(
142
137
  self, info: Info[Context, None], input: DeleteApiKeyInput
143
138
  ) -> DeleteApiKeyMutationPayload:
@@ -4,7 +4,7 @@ from dataclasses import asdict, field
4
4
  from datetime import datetime, timezone
5
5
  from itertools import chain, islice
6
6
  from traceback import format_exc
7
- from typing import Any, Iterable, Iterator, List, Optional, TypeVar, Union
7
+ from typing import Any, Iterable, Iterator, Optional, TypeVar, Union
8
8
 
9
9
  import strawberry
10
10
  from openinference.instrumentation import safe_json_dumps
@@ -26,8 +26,11 @@ from typing_extensions import assert_never
26
26
  from phoenix.config import PLAYGROUND_PROJECT_NAME
27
27
  from phoenix.datetime_utils import local_now, normalize_datetime
28
28
  from phoenix.db import models
29
- from phoenix.db.helpers import get_dataset_example_revisions
30
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
29
+ from phoenix.db.helpers import (
30
+ get_dataset_example_revisions,
31
+ insert_experiment_with_examples_snapshot,
32
+ )
33
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
31
34
  from phoenix.server.api.context import Context
32
35
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
33
36
  from phoenix.server.api.helpers.dataset_helpers import get_dataset_example_output
@@ -46,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
46
49
  llm_tools,
47
50
  prompt_metadata,
48
51
  )
52
+ from phoenix.server.api.helpers.playground_users import get_user
49
53
  from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
50
54
  from phoenix.server.api.input_types.ChatCompletionInput import (
51
55
  ChatCompletionInput,
@@ -80,7 +84,7 @@ logger = logging.getLogger(__name__)
80
84
 
81
85
  initialize_playground_clients()
82
86
 
83
- ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[List[Any]]]
87
+ ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[Any]]]
84
88
 
85
89
 
86
90
  @strawberry.type
@@ -96,24 +100,25 @@ class ChatCompletionToolCall:
96
100
 
97
101
 
98
102
  @strawberry.type
99
- class ChatCompletionMutationPayload:
100
- db_span: strawberry.Private[models.Span]
103
+ class ChatCompletionRepetition:
104
+ repetition_number: int
101
105
  content: Optional[str]
102
- tool_calls: List[ChatCompletionToolCall]
103
- span: Span
106
+ tool_calls: list[ChatCompletionToolCall]
107
+ span: Optional[Span]
104
108
  error_message: Optional[str]
105
109
 
106
110
 
107
111
  @strawberry.type
108
- class ChatCompletionMutationError:
109
- message: str
112
+ class ChatCompletionMutationPayload:
113
+ repetitions: list[ChatCompletionRepetition]
110
114
 
111
115
 
112
116
  @strawberry.type
113
117
  class ChatCompletionOverDatasetMutationExamplePayload:
114
118
  dataset_example_id: GlobalID
119
+ repetition_number: int
115
120
  experiment_run_id: GlobalID
116
- result: Union[ChatCompletionMutationPayload, ChatCompletionMutationError]
121
+ repetition: ChatCompletionRepetition
117
122
 
118
123
 
119
124
  @strawberry.type
@@ -126,7 +131,7 @@ class ChatCompletionOverDatasetMutationPayload:
126
131
 
127
132
  @strawberry.type
128
133
  class ChatCompletionMutationMixin:
129
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
134
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
130
135
  @classmethod
131
136
  async def chat_completion_over_dataset(
132
137
  cls,
@@ -181,16 +186,26 @@ class ChatCompletionMutationMixin:
181
186
  raise NotFound("No versions found for the given dataset")
182
187
  else:
183
188
  resolved_version_id = dataset_version_id
189
+ # Parse split IDs if provided
190
+ resolved_split_ids: Optional[list[int]] = None
191
+ if input.split_ids is not None and len(input.split_ids) > 0:
192
+ resolved_split_ids = [
193
+ from_global_id_with_expected_type(split_id, models.DatasetSplit.__name__)
194
+ for split_id in input.split_ids
195
+ ]
196
+
184
197
  revisions = [
185
198
  revision
186
199
  async for revision in await session.stream_scalars(
187
- get_dataset_example_revisions(resolved_version_id).order_by(
188
- models.DatasetExampleRevision.id
189
- )
200
+ get_dataset_example_revisions(
201
+ resolved_version_id,
202
+ split_ids=resolved_split_ids,
203
+ ).order_by(models.DatasetExampleRevision.id)
190
204
  )
191
205
  ]
192
206
  if not revisions:
193
207
  raise NotFound("No examples found for the given dataset and version")
208
+ user_id = get_user(info)
194
209
  experiment = models.Experiment(
195
210
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
196
211
  dataset_version_id=resolved_version_id,
@@ -200,14 +215,24 @@ class ChatCompletionMutationMixin:
200
215
  repetitions=1,
201
216
  metadata_=input.experiment_metadata or dict(),
202
217
  project_name=project_name,
218
+ user_id=user_id,
203
219
  )
204
- session.add(experiment)
205
- await session.flush()
220
+ if resolved_split_ids:
221
+ experiment.experiment_dataset_splits = [
222
+ models.ExperimentDatasetSplit(dataset_split_id=split_id)
223
+ for split_id in resolved_split_ids
224
+ ]
225
+ await insert_experiment_with_examples_snapshot(session, experiment)
206
226
 
207
- results: list[Union[ChatCompletionMutationPayload, BaseException]] = []
227
+ results: list[Union[tuple[ChatCompletionRepetition, models.Span], BaseException]] = []
208
228
  batch_size = 3
209
229
  start_time = datetime.now(timezone.utc)
210
- for batch in _get_batches(revisions, batch_size):
230
+ unbatched_items = [
231
+ (revision, repetition_number)
232
+ for revision in revisions
233
+ for repetition_number in range(1, input.repetitions + 1)
234
+ ]
235
+ for batch in _get_batches(unbatched_items, batch_size):
211
236
  batch_results = await asyncio.gather(
212
237
  *(
213
238
  cls._chat_completion(
@@ -224,10 +249,12 @@ class ChatCompletionMutationMixin:
224
249
  variables=revision.input,
225
250
  ),
226
251
  prompt_name=input.prompt_name,
252
+ repetitions=repetition_number,
227
253
  ),
254
+ repetition_number=repetition_number,
228
255
  project_name=project_name,
229
256
  )
230
- for revision in batch
257
+ for revision, repetition_number in batch
231
258
  ),
232
259
  return_exceptions=True,
233
260
  )
@@ -239,19 +266,19 @@ class ChatCompletionMutationMixin:
239
266
  experiment_id=GlobalID(models.Experiment.__name__, str(experiment.id)),
240
267
  )
241
268
  experiment_runs = []
242
- for revision, result in zip(revisions, results):
269
+ for (revision, repetition_number), result in zip(unbatched_items, results):
243
270
  if isinstance(result, BaseException):
244
271
  experiment_run = models.ExperimentRun(
245
272
  experiment_id=experiment.id,
246
273
  dataset_example_id=revision.dataset_example_id,
247
274
  output={},
248
- repetition_number=1,
275
+ repetition_number=repetition_number,
249
276
  start_time=start_time,
250
277
  end_time=start_time,
251
278
  error=str(result),
252
279
  )
253
280
  else:
254
- db_span: models.Span = result.db_span
281
+ repetition, db_span = result
255
282
  experiment_run = models.ExperimentRun(
256
283
  experiment_id=experiment.id,
257
284
  dataset_example_id=revision.dataset_example_id,
@@ -261,10 +288,10 @@ class ChatCompletionMutationMixin:
261
288
  ),
262
289
  prompt_token_count=db_span.cumulative_llm_token_count_prompt,
263
290
  completion_token_count=db_span.cumulative_llm_token_count_completion,
264
- repetition_number=1,
291
+ repetition_number=repetition_number,
265
292
  start_time=db_span.start_time,
266
293
  end_time=db_span.end_time,
267
- error=str(result.error_message) if result.error_message else None,
294
+ error=str(repetition.error_message) if repetition.error_message else None,
268
295
  )
269
296
  experiment_runs.append(experiment_run)
270
297
 
@@ -272,22 +299,31 @@ class ChatCompletionMutationMixin:
272
299
  session.add_all(experiment_runs)
273
300
  await session.flush()
274
301
 
275
- for revision, experiment_run, result in zip(revisions, experiment_runs, results):
302
+ for (revision, repetition_number), experiment_run, result in zip(
303
+ unbatched_items, experiment_runs, results
304
+ ):
276
305
  dataset_example_id = GlobalID(
277
306
  models.DatasetExample.__name__, str(revision.dataset_example_id)
278
307
  )
279
308
  experiment_run_id = GlobalID(models.ExperimentRun.__name__, str(experiment_run.id))
280
309
  example_payload = ChatCompletionOverDatasetMutationExamplePayload(
281
310
  dataset_example_id=dataset_example_id,
311
+ repetition_number=repetition_number,
282
312
  experiment_run_id=experiment_run_id,
283
- result=result
284
- if isinstance(result, ChatCompletionMutationPayload)
285
- else ChatCompletionMutationError(message=str(result)),
313
+ repetition=ChatCompletionRepetition(
314
+ repetition_number=repetition_number,
315
+ content=None,
316
+ tool_calls=[],
317
+ span=None,
318
+ error_message=str(result),
319
+ )
320
+ if isinstance(result, BaseException)
321
+ else result[0],
286
322
  )
287
323
  payload.examples.append(example_payload)
288
324
  return payload
289
325
 
290
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
326
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
291
327
  @classmethod
292
328
  async def chat_completion(
293
329
  cls, info: Info[Context, None], input: ChatCompletionInput
@@ -316,7 +352,38 @@ class ChatCompletionMutationMixin:
316
352
  f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
317
353
  f"{str(error)}"
318
354
  )
319
- return await cls._chat_completion(info, llm_client, input)
355
+
356
+ results: list[Union[tuple[ChatCompletionRepetition, models.Span], BaseException]] = []
357
+ batch_size = 3
358
+ for batch in _get_batches(range(1, input.repetitions + 1), batch_size):
359
+ batch_results = await asyncio.gather(
360
+ *(
361
+ cls._chat_completion(
362
+ info, llm_client, input, repetition_number=repetition_number
363
+ )
364
+ for repetition_number in batch
365
+ ),
366
+ return_exceptions=True,
367
+ )
368
+ results.extend(batch_results)
369
+
370
+ repetitions: list[ChatCompletionRepetition] = []
371
+ for repetition_number, result in enumerate(results, start=1):
372
+ if isinstance(result, BaseException):
373
+ repetitions.append(
374
+ ChatCompletionRepetition(
375
+ repetition_number=repetition_number,
376
+ content=None,
377
+ tool_calls=[],
378
+ span=None,
379
+ error_message=str(result),
380
+ )
381
+ )
382
+ else:
383
+ repetition, _ = result
384
+ repetitions.append(repetition)
385
+
386
+ return ChatCompletionMutationPayload(repetitions=repetitions)
320
387
 
321
388
  @classmethod
322
389
  async def _chat_completion(
@@ -324,9 +391,10 @@ class ChatCompletionMutationMixin:
324
391
  info: Info[Context, None],
325
392
  llm_client: PlaygroundStreamingClient,
326
393
  input: ChatCompletionInput,
394
+ repetition_number: int,
327
395
  project_name: str = PLAYGROUND_PROJECT_NAME,
328
396
  project_description: str = "Traces from prompt playground",
329
- ) -> ChatCompletionMutationPayload:
397
+ ) -> tuple[ChatCompletionRepetition, models.Span]:
330
398
  attributes: dict[str, Any] = {}
331
399
  attributes.update(dict(prompt_metadata(input.prompt_name)))
332
400
 
@@ -473,26 +541,27 @@ class ChatCompletionMutationMixin:
473
541
  session.add(span_cost)
474
542
  await session.flush()
475
543
 
476
- gql_span = Span(span_rowid=span.id, db_span=span)
544
+ gql_span = Span(id=span.id, db_record=span)
477
545
 
478
546
  info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
479
547
 
480
548
  if status_code is StatusCode.ERROR:
481
- return ChatCompletionMutationPayload(
482
- db_span=span,
549
+ repetition = ChatCompletionRepetition(
550
+ repetition_number=repetition_number,
483
551
  content=None,
484
552
  tool_calls=[],
485
553
  span=gql_span,
486
554
  error_message=status_message,
487
555
  )
488
556
  else:
489
- return ChatCompletionMutationPayload(
490
- db_span=span,
557
+ repetition = ChatCompletionRepetition(
558
+ repetition_number=repetition_number,
491
559
  content=text_content if text_content else None,
492
560
  tool_calls=list(tool_calls.values()),
493
561
  span=gql_span,
494
562
  error_message=None,
495
563
  )
564
+ return repetition, span
496
565
 
497
566
 
498
567
  def _formatted_messages(
@@ -0,0 +1,243 @@
1
+ from typing import Optional
2
+
3
+ import sqlalchemy
4
+ import strawberry
5
+ from sqlalchemy import delete, select
6
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
7
+ from sqlalchemy.orm import joinedload
8
+ from sqlalchemy.sql import tuple_
9
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
10
+ from strawberry import UNSET
11
+ from strawberry.relay.types import GlobalID
12
+ from strawberry.types import Info
13
+
14
+ from phoenix.db import models
15
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
16
+ from phoenix.server.api.context import Context
17
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
18
+ from phoenix.server.api.queries import Query
19
+ from phoenix.server.api.types.Dataset import Dataset
20
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
21
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
22
+
23
+
24
+ @strawberry.input
25
+ class CreateDatasetLabelInput:
26
+ name: str
27
+ description: Optional[str] = UNSET
28
+ color: str
29
+ dataset_ids: Optional[list[GlobalID]] = UNSET
30
+
31
+
32
+ @strawberry.type
33
+ class CreateDatasetLabelMutationPayload:
34
+ dataset_label: DatasetLabel
35
+ datasets: list[Dataset]
36
+
37
+
38
+ @strawberry.input
39
+ class DeleteDatasetLabelsInput:
40
+ dataset_label_ids: list[GlobalID]
41
+
42
+
43
+ @strawberry.type
44
+ class DeleteDatasetLabelsMutationPayload:
45
+ dataset_labels: list[DatasetLabel]
46
+
47
+
48
+ @strawberry.input
49
+ class SetDatasetLabelsInput:
50
+ dataset_id: GlobalID
51
+ dataset_label_ids: list[GlobalID]
52
+
53
+
54
+ @strawberry.type
55
+ class SetDatasetLabelsMutationPayload:
56
+ query: Query
57
+ dataset: Dataset
58
+
59
+
60
+ @strawberry.type
61
+ class DatasetLabelMutationMixin:
62
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
63
+ async def create_dataset_label(
64
+ self,
65
+ info: Info[Context, None],
66
+ input: CreateDatasetLabelInput,
67
+ ) -> CreateDatasetLabelMutationPayload:
68
+ name = input.name
69
+ description = input.description
70
+ color = input.color
71
+ dataset_rowids: dict[
72
+ int, None
73
+ ] = {} # use dictionary to de-duplicate while preserving order
74
+ if input.dataset_ids:
75
+ for dataset_id in input.dataset_ids:
76
+ try:
77
+ dataset_rowid = from_global_id_with_expected_type(dataset_id, Dataset.__name__)
78
+ except ValueError:
79
+ raise BadRequest(f"Invalid dataset ID: {dataset_id}")
80
+ dataset_rowids[dataset_rowid] = None
81
+
82
+ async with info.context.db() as session:
83
+ dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
84
+ session.add(dataset_label_orm)
85
+ try:
86
+ await session.flush()
87
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
88
+ raise Conflict(f"A dataset label named '{name}' already exists")
89
+ except sqlalchemy.exc.StatementError as error:
90
+ raise BadRequest(str(error.orig))
91
+
92
+ datasets_by_id: dict[int, models.Dataset] = {}
93
+ if dataset_rowids:
94
+ datasets_by_id = {
95
+ dataset.id: dataset
96
+ for dataset in await session.scalars(
97
+ select(models.Dataset).where(models.Dataset.id.in_(dataset_rowids.keys()))
98
+ )
99
+ }
100
+ if len(datasets_by_id) < len(dataset_rowids):
101
+ raise NotFound("One or more datasets not found")
102
+ session.add_all(
103
+ [
104
+ models.DatasetsDatasetLabel(
105
+ dataset_id=dataset_rowid,
106
+ dataset_label_id=dataset_label_orm.id,
107
+ )
108
+ for dataset_rowid in dataset_rowids
109
+ ]
110
+ )
111
+ await session.commit()
112
+
113
+ return CreateDatasetLabelMutationPayload(
114
+ dataset_label=DatasetLabel(id=dataset_label_orm.id, db_record=dataset_label_orm),
115
+ datasets=[
116
+ Dataset(
117
+ id=datasets_by_id[dataset_rowid].id, db_record=datasets_by_id[dataset_rowid]
118
+ )
119
+ for dataset_rowid in dataset_rowids
120
+ ],
121
+ )
122
+
123
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
124
+ async def delete_dataset_labels(
125
+ self, info: Info[Context, None], input: DeleteDatasetLabelsInput
126
+ ) -> DeleteDatasetLabelsMutationPayload:
127
+ dataset_label_row_ids: dict[int, None] = {}
128
+ for dataset_label_node_id in input.dataset_label_ids:
129
+ try:
130
+ dataset_label_row_id = from_global_id_with_expected_type(
131
+ dataset_label_node_id, DatasetLabel.__name__
132
+ )
133
+ except ValueError:
134
+ raise BadRequest(f"Unknown dataset label: {dataset_label_node_id}")
135
+ dataset_label_row_ids[dataset_label_row_id] = None
136
+ async with info.context.db() as session:
137
+ stmt = (
138
+ delete(models.DatasetLabel)
139
+ .where(models.DatasetLabel.id.in_(dataset_label_row_ids.keys()))
140
+ .returning(models.DatasetLabel)
141
+ )
142
+ deleted_dataset_labels = (await session.scalars(stmt)).all()
143
+ if len(deleted_dataset_labels) < len(dataset_label_row_ids):
144
+ await session.rollback()
145
+ raise NotFound("Could not find one or more dataset labels with given IDs")
146
+ deleted_dataset_labels_by_id = {
147
+ dataset_label.id: dataset_label for dataset_label in deleted_dataset_labels
148
+ }
149
+ return DeleteDatasetLabelsMutationPayload(
150
+ dataset_labels=[
151
+ DatasetLabel(
152
+ id=deleted_dataset_labels_by_id[dataset_label_row_id].id,
153
+ db_record=deleted_dataset_labels_by_id[dataset_label_row_id],
154
+ )
155
+ for dataset_label_row_id in dataset_label_row_ids
156
+ ]
157
+ )
158
+
159
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
160
+ async def set_dataset_labels(
161
+ self, info: Info[Context, None], input: SetDatasetLabelsInput
162
+ ) -> SetDatasetLabelsMutationPayload:
163
+ try:
164
+ dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
165
+ except ValueError:
166
+ raise BadRequest(f"Invalid dataset ID: {input.dataset_id}")
167
+
168
+ dataset_label_ids: dict[
169
+ int, None
170
+ ] = {} # use dictionary to de-duplicate while preserving order
171
+ for dataset_label_gid in input.dataset_label_ids:
172
+ try:
173
+ dataset_label_id = from_global_id_with_expected_type(
174
+ dataset_label_gid, DatasetLabel.__name__
175
+ )
176
+ except ValueError:
177
+ raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
178
+ dataset_label_ids[dataset_label_id] = None
179
+
180
+ async with info.context.db() as session:
181
+ dataset = await session.scalar(
182
+ select(models.Dataset)
183
+ .where(models.Dataset.id == dataset_id)
184
+ .options(joinedload(models.Dataset.datasets_dataset_labels))
185
+ )
186
+
187
+ if not dataset:
188
+ raise NotFound(f"Dataset with ID {input.dataset_id} not found")
189
+
190
+ existing_label_ids = (
191
+ await session.scalars(
192
+ select(models.DatasetLabel.id).where(
193
+ models.DatasetLabel.id.in_(dataset_label_ids.keys())
194
+ )
195
+ )
196
+ ).all()
197
+ if len(existing_label_ids) != len(dataset_label_ids):
198
+ raise NotFound("One or more dataset labels not found")
199
+
200
+ previously_applied_dataset_label_ids = {
201
+ dataset_dataset_label.dataset_label_id
202
+ for dataset_dataset_label in dataset.datasets_dataset_labels
203
+ }
204
+
205
+ datasets_dataset_labels_to_add = [
206
+ models.DatasetsDatasetLabel(
207
+ dataset_id=dataset_id,
208
+ dataset_label_id=dataset_label_id,
209
+ )
210
+ for dataset_label_id in dataset_label_ids
211
+ if dataset_label_id not in previously_applied_dataset_label_ids
212
+ ]
213
+ if datasets_dataset_labels_to_add:
214
+ session.add_all(datasets_dataset_labels_to_add)
215
+ await session.flush()
216
+
217
+ datasets_dataset_labels_to_delete = [
218
+ dataset_dataset_label
219
+ for dataset_dataset_label in dataset.datasets_dataset_labels
220
+ if dataset_dataset_label.dataset_label_id not in dataset_label_ids
221
+ ]
222
+ if datasets_dataset_labels_to_delete:
223
+ await session.execute(
224
+ delete(models.DatasetsDatasetLabel).where(
225
+ tuple_(
226
+ models.DatasetsDatasetLabel.dataset_id,
227
+ models.DatasetsDatasetLabel.dataset_label_id,
228
+ ).in_(
229
+ [
230
+ (
231
+ datasets_dataset_labels.dataset_id,
232
+ datasets_dataset_labels.dataset_label_id,
233
+ )
234
+ for datasets_dataset_labels in datasets_dataset_labels_to_delete
235
+ ]
236
+ )
237
+ )
238
+ )
239
+
240
+ return SetDatasetLabelsMutationPayload(
241
+ dataset=Dataset(id=dataset.id, db_record=dataset),
242
+ query=Query(),
243
+ )