arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,12 @@
1
1
  import asyncio
2
2
  import logging
3
+ from collections import deque
3
4
  from collections.abc import AsyncIterator, Iterator
4
5
  from datetime import datetime, timedelta, timezone
5
6
  from typing import (
6
7
  Any,
7
8
  AsyncGenerator,
9
+ Callable,
8
10
  Coroutine,
9
11
  Iterable,
10
12
  Mapping,
@@ -17,7 +19,7 @@ from typing import (
17
19
  import strawberry
18
20
  from openinference.instrumentation import safe_json_dumps
19
21
  from openinference.semconv.trace import SpanAttributes
20
- from sqlalchemy import and_, func, insert, select
22
+ from sqlalchemy import and_, insert, select
21
23
  from sqlalchemy.orm import load_only
22
24
  from strawberry.relay.types import GlobalID
23
25
  from strawberry.types import Info
@@ -26,7 +28,11 @@ from typing_extensions import TypeAlias, assert_never
26
28
  from phoenix.config import PLAYGROUND_PROJECT_NAME
27
29
  from phoenix.datetime_utils import local_now, normalize_datetime
28
30
  from phoenix.db import models
29
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
31
+ from phoenix.db.helpers import (
32
+ get_dataset_example_revisions,
33
+ insert_experiment_with_examples_snapshot,
34
+ )
35
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
30
36
  from phoenix.server.api.context import Context
31
37
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
32
38
  from phoenix.server.api.helpers.playground_clients import (
@@ -43,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
43
49
  get_db_trace,
44
50
  streaming_llm_span,
45
51
  )
52
+ from phoenix.server.api.helpers.playground_users import get_user
46
53
  from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
47
54
  from phoenix.server.api.input_types.ChatCompletionInput import (
48
55
  ChatCompletionInput,
@@ -59,7 +66,7 @@ from phoenix.server.api.types.Dataset import Dataset
59
66
  from phoenix.server.api.types.DatasetExample import DatasetExample
60
67
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
61
68
  from phoenix.server.api.types.Experiment import to_gql_experiment
62
- from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
69
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
63
70
  from phoenix.server.api.types.node import from_global_id_with_expected_type
64
71
  from phoenix.server.api.types.Span import Span
65
72
  from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
@@ -90,9 +97,109 @@ ChatCompletionResult: TypeAlias = tuple[
90
97
  ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
91
98
 
92
99
 
100
+ async def _stream_single_chat_completion(
101
+ *,
102
+ input: ChatCompletionInput,
103
+ llm_client: PlaygroundStreamingClient,
104
+ project_id: int,
105
+ repetition_number: int,
106
+ results: asyncio.Queue[tuple[Optional[models.Span], int]],
107
+ ) -> ChatStream:
108
+ messages = [
109
+ (
110
+ message.role,
111
+ message.content,
112
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
113
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
114
+ )
115
+ for message in input.messages
116
+ ]
117
+ attributes = None
118
+ if template_options := input.template:
119
+ messages = list(
120
+ _formatted_messages(
121
+ messages=messages,
122
+ template_format=template_options.format,
123
+ template_variables=template_options.variables,
124
+ )
125
+ )
126
+ attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
127
+ invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
128
+ async with streaming_llm_span(
129
+ input=input,
130
+ messages=messages,
131
+ invocation_parameters=invocation_parameters,
132
+ attributes=attributes,
133
+ ) as span:
134
+ try:
135
+ async for chunk in llm_client.chat_completion_create(
136
+ messages=messages, tools=input.tools or [], **invocation_parameters
137
+ ):
138
+ span.add_response_chunk(chunk)
139
+ chunk.repetition_number = repetition_number
140
+ yield chunk
141
+ finally:
142
+ span.set_attributes(llm_client.attributes)
143
+ if span.status_message is not None:
144
+ yield ChatCompletionSubscriptionError(
145
+ message=span.status_message,
146
+ repetition_number=repetition_number,
147
+ )
148
+
149
+ db_trace = get_db_trace(span, project_id)
150
+ db_span = get_db_span(span, db_trace)
151
+ await results.put((db_span, repetition_number))
152
+
153
+
154
+ async def _chat_completion_span_result_payloads(
155
+ *,
156
+ db: DbSessionFactory,
157
+ results: Sequence[tuple[Optional[models.Span], int]],
158
+ span_cost_calculator: SpanCostCalculator,
159
+ on_span_insertion: Callable[[], None],
160
+ ) -> ChatStream:
161
+ if not results:
162
+ return
163
+ async with db() as session:
164
+ for span, repetition_number in results:
165
+ if span:
166
+ session.add(span)
167
+ await session.flush()
168
+ try:
169
+ span_cost = span_cost_calculator.calculate_cost(
170
+ start_time=span.start_time,
171
+ attributes=span.attributes,
172
+ )
173
+ except Exception as e:
174
+ logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
175
+ span_cost = None
176
+ if span_cost:
177
+ span_cost.span_rowid = span.id
178
+ span_cost.trace_rowid = span.trace_rowid
179
+ session.add(span_cost)
180
+ await session.flush()
181
+ for span, repetition_number in results:
182
+ if span:
183
+ yield ChatCompletionSubscriptionResult(
184
+ span=Span(id=span.id, db_record=span),
185
+ repetition_number=repetition_number,
186
+ )
187
+ on_span_insertion()
188
+
189
+
190
+ def _is_span_result_payloads_stream(
191
+ stream: ChatStream,
192
+ ) -> bool:
193
+ """
194
+ Checks if the given generator was instantiated from
195
+ `_chat_completion_span_result_payloads`
196
+ """
197
+ return stream.ag_code == _chat_completion_span_result_payloads.__code__ # type: ignore
198
+
199
+
93
200
  @strawberry.type
94
201
  class Subscription:
95
- @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
202
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
96
203
  async def chat_completion(
97
204
  self, info: Info[Context, None], input: ChatCompletionInput
98
205
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
@@ -121,42 +228,6 @@ class Subscription:
121
228
  f"{str(error)}"
122
229
  )
123
230
 
124
- messages = [
125
- (
126
- message.role,
127
- message.content,
128
- message.tool_call_id if isinstance(message.tool_call_id, str) else None,
129
- message.tool_calls if isinstance(message.tool_calls, list) else None,
130
- )
131
- for message in input.messages
132
- ]
133
- attributes = None
134
- if template_options := input.template:
135
- messages = list(
136
- _formatted_messages(
137
- messages=messages,
138
- template_format=template_options.format,
139
- template_variables=template_options.variables,
140
- )
141
- )
142
- attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
143
- invocation_parameters = llm_client.construct_invocation_parameters(
144
- input.invocation_parameters
145
- )
146
- async with streaming_llm_span(
147
- input=input,
148
- messages=messages,
149
- invocation_parameters=invocation_parameters,
150
- attributes=attributes,
151
- ) as span:
152
- async for chunk in llm_client.chat_completion_create(
153
- messages=messages, tools=input.tools or [], **invocation_parameters
154
- ):
155
- span.add_response_chunk(chunk)
156
- yield chunk
157
- span.set_attributes(llm_client.attributes)
158
- if span.status_message is not None:
159
- yield ChatCompletionSubscriptionError(message=span.status_message)
160
231
  async with info.context.db() as session:
161
232
  if (
162
233
  playground_project_id := await session.scalar(
@@ -171,27 +242,100 @@ class Subscription:
171
242
  description="Traces from prompt playground",
172
243
  )
173
244
  )
174
- db_trace = get_db_trace(span, playground_project_id)
175
- db_span = get_db_span(span, db_trace)
176
- session.add(db_span)
177
- await session.flush()
178
- try:
179
- span_cost = info.context.span_cost_calculator.calculate_cost(
180
- start_time=db_span.start_time,
181
- attributes=span.attributes,
245
+
246
+ results: asyncio.Queue[tuple[Optional[models.Span], int]] = asyncio.Queue()
247
+ not_started: deque[tuple[int, ChatStream]] = deque(
248
+ (
249
+ repetition_number,
250
+ _stream_single_chat_completion(
251
+ input=input,
252
+ llm_client=llm_client,
253
+ project_id=playground_project_id,
254
+ repetition_number=repetition_number,
255
+ results=results,
256
+ ),
257
+ )
258
+ for repetition_number in range(1, input.repetitions + 1)
259
+ )
260
+ in_progress: list[
261
+ tuple[
262
+ Optional[int],
263
+ ChatStream,
264
+ asyncio.Task[ChatCompletionSubscriptionPayload],
265
+ ]
266
+ ] = []
267
+ max_in_progress = 3
268
+ write_batch_size = 10
269
+ write_interval = timedelta(seconds=10)
270
+ last_write_time = datetime.now()
271
+ while not_started or in_progress:
272
+ while not_started and len(in_progress) < max_in_progress:
273
+ rep_num, stream = not_started.popleft()
274
+ task = _create_task_with_timeout(stream)
275
+ in_progress.append((rep_num, stream, task))
276
+ async_tasks_to_run = [task for _, _, task in in_progress]
277
+ completed_tasks, _ = await asyncio.wait(
278
+ async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED
279
+ )
280
+ for completed_task in completed_tasks:
281
+ idx = [task for _, _, task in in_progress].index(completed_task)
282
+ repetition_number, stream, _ = in_progress[idx]
283
+ try:
284
+ yield completed_task.result()
285
+ except StopAsyncIteration:
286
+ del in_progress[idx] # removes exhausted stream
287
+ except asyncio.TimeoutError:
288
+ del in_progress[idx] # removes timed-out stream
289
+ if repetition_number is not None:
290
+ yield ChatCompletionSubscriptionError(
291
+ message="Playground task timed out",
292
+ repetition_number=repetition_number,
293
+ )
294
+ except Exception as error:
295
+ del in_progress[idx] # removes failed stream
296
+ if repetition_number is not None:
297
+ yield ChatCompletionSubscriptionError(
298
+ message="An unexpected error occurred",
299
+ repetition_number=repetition_number,
300
+ )
301
+ logger.exception(error)
302
+ else:
303
+ task = _create_task_with_timeout(stream)
304
+ in_progress[idx] = (repetition_number, stream, task)
305
+
306
+ exceeded_write_batch_size = results.qsize() >= write_batch_size
307
+ exceeded_write_interval = datetime.now() - last_write_time > write_interval
308
+ write_already_in_progress = any(
309
+ _is_span_result_payloads_stream(stream) for _, stream, _ in in_progress
182
310
  )
183
- except Exception as e:
184
- logger.exception(f"Failed to calculate cost for span {db_span.id}: {e}")
185
- span_cost = None
186
- if span_cost:
187
- span_cost.span_rowid = db_span.id
188
- span_cost.trace_rowid = db_span.trace_rowid
189
- session.add(span_cost)
190
-
191
- info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
192
- yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
193
-
194
- @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
311
+ if (
312
+ not results.empty()
313
+ and (exceeded_write_batch_size or exceeded_write_interval)
314
+ and not write_already_in_progress
315
+ ):
316
+ result_payloads_stream = _chat_completion_span_result_payloads(
317
+ db=info.context.db,
318
+ results=_drain_no_wait(results),
319
+ span_cost_calculator=info.context.span_cost_calculator,
320
+ on_span_insertion=lambda: info.context.event_queue.put(
321
+ SpanInsertEvent(ids=(playground_project_id,))
322
+ ),
323
+ )
324
+ task = _create_task_with_timeout(result_payloads_stream)
325
+ in_progress.append((None, result_payloads_stream, task))
326
+ last_write_time = datetime.now()
327
+ if remaining_results := await _drain(results):
328
+ async for result_payload in _chat_completion_span_result_payloads(
329
+ db=info.context.db,
330
+ results=remaining_results,
331
+ span_cost_calculator=info.context.span_cost_calculator,
332
+ on_span_insertion=lambda: info.context.event_queue.put(
333
+ SpanInsertEvent(ids=(playground_project_id,))
334
+ ),
335
+ ):
336
+ yield result_payload
337
+
338
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
195
339
  async def chat_completion_over_dataset(
196
340
  self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
197
341
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
@@ -255,27 +399,22 @@ class Subscription:
255
399
  )
256
400
  ) is None:
257
401
  raise NotFound(f"Could not find dataset version with ID {version_id}")
258
- revision_ids = (
259
- select(func.max(models.DatasetExampleRevision.id))
260
- .join(models.DatasetExample)
261
- .where(
262
- and_(
263
- models.DatasetExample.dataset_id == dataset_id,
264
- models.DatasetExampleRevision.dataset_version_id <= resolved_version_id,
265
- )
266
- )
267
- .group_by(models.DatasetExampleRevision.dataset_example_id)
268
- )
402
+
403
+ # Parse split IDs if provided
404
+ resolved_split_ids: Optional[list[int]] = None
405
+ if input.split_ids is not None and len(input.split_ids) > 0:
406
+ resolved_split_ids = [
407
+ from_global_id_with_expected_type(split_id, models.DatasetSplit.__name__)
408
+ for split_id in input.split_ids
409
+ ]
410
+
269
411
  if not (
270
412
  revisions := [
271
413
  rev
272
414
  async for rev in await session.stream_scalars(
273
- select(models.DatasetExampleRevision)
274
- .where(
275
- and_(
276
- models.DatasetExampleRevision.id.in_(revision_ids),
277
- models.DatasetExampleRevision.revision_kind != "DELETE",
278
- )
415
+ get_dataset_example_revisions(
416
+ resolved_version_id,
417
+ split_ids=resolved_split_ids,
279
418
  )
280
419
  .order_by(models.DatasetExampleRevision.dataset_example_id.asc())
281
420
  .options(
@@ -302,18 +441,24 @@ class Subscription:
302
441
  description="Traces from prompt playground",
303
442
  )
304
443
  )
444
+ user_id = get_user(info)
305
445
  experiment = models.Experiment(
306
446
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
307
447
  dataset_version_id=resolved_version_id,
308
448
  name=input.experiment_name
309
449
  or _default_playground_experiment_name(input.prompt_name),
310
450
  description=input.experiment_description,
311
- repetitions=1,
451
+ repetitions=input.repetitions,
312
452
  metadata_=input.experiment_metadata or dict(),
313
453
  project_name=project_name,
454
+ user_id=user_id,
314
455
  )
315
- session.add(experiment)
316
- await session.flush()
456
+ if resolved_split_ids:
457
+ experiment.experiment_dataset_splits = [
458
+ models.ExperimentDatasetSplit(dataset_split_id=split_id)
459
+ for split_id in resolved_split_ids
460
+ ]
461
+ await insert_experiment_with_examples_snapshot(session, experiment)
317
462
  yield ChatCompletionSubscriptionExperiment(
318
463
  experiment=to_gql_experiment(experiment)
319
464
  ) # eagerly yields experiment so it can be linked by consumers of the subscription
@@ -327,11 +472,15 @@ class Subscription:
327
472
  llm_client=llm_client,
328
473
  revision=revision,
329
474
  results=results,
475
+ repetition_number=repetition_number,
330
476
  experiment_id=experiment.id,
331
477
  project_id=playground_project_id,
332
478
  ),
333
479
  )
334
480
  for revision in revisions
481
+ for repetition_number in reversed(
482
+ range(1, input.repetitions + 1)
483
+ ) # since we pop right, this runs the repetitions in increasing order
335
484
  ]
336
485
  in_progress: list[
337
486
  tuple[
@@ -409,6 +558,7 @@ async def _stream_chat_completion_over_dataset_example(
409
558
  input: ChatCompletionOverDatasetInput,
410
559
  llm_client: PlaygroundStreamingClient,
411
560
  revision: models.DatasetExampleRevision,
561
+ repetition_number: int,
412
562
  results: asyncio.Queue[ChatCompletionResult],
413
563
  experiment_id: int,
414
564
  project_id: int,
@@ -435,7 +585,11 @@ async def _stream_chat_completion_over_dataset_example(
435
585
  )
436
586
  except TemplateFormatterError as error:
437
587
  format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
438
- yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
588
+ yield ChatCompletionSubscriptionError(
589
+ message=str(error),
590
+ dataset_example_id=example_id,
591
+ repetition_number=repetition_number,
592
+ )
439
593
  await results.put(
440
594
  (
441
595
  example_id,
@@ -445,7 +599,7 @@ async def _stream_chat_completion_over_dataset_example(
445
599
  dataset_example_id=revision.dataset_example_id,
446
600
  trace_id=None,
447
601
  output={},
448
- repetition_number=1,
602
+ repetition_number=repetition_number,
449
603
  start_time=format_start_time,
450
604
  end_time=format_end_time,
451
605
  error=str(error),
@@ -460,22 +614,31 @@ async def _stream_chat_completion_over_dataset_example(
460
614
  invocation_parameters=invocation_parameters,
461
615
  attributes={PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(revision.input)},
462
616
  ) as span:
463
- async for chunk in llm_client.chat_completion_create(
464
- messages=messages, tools=input.tools or [], **invocation_parameters
465
- ):
466
- span.add_response_chunk(chunk)
467
- chunk.dataset_example_id = example_id
468
- yield chunk
469
- span.set_attributes(llm_client.attributes)
617
+ try:
618
+ async for chunk in llm_client.chat_completion_create(
619
+ messages=messages, tools=input.tools or [], **invocation_parameters
620
+ ):
621
+ span.add_response_chunk(chunk)
622
+ chunk.dataset_example_id = example_id
623
+ chunk.repetition_number = repetition_number
624
+ yield chunk
625
+ finally:
626
+ span.set_attributes(llm_client.attributes)
470
627
  db_trace = get_db_trace(span, project_id)
471
628
  db_span = get_db_span(span, db_trace)
472
629
  db_run = get_db_experiment_run(
473
- db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
630
+ db_span,
631
+ db_trace,
632
+ experiment_id=experiment_id,
633
+ example_id=revision.dataset_example_id,
634
+ repetition_number=repetition_number,
474
635
  )
475
636
  await results.put((example_id, db_span, db_run))
476
637
  if span.status_message is not None:
477
638
  yield ChatCompletionSubscriptionError(
478
- message=span.status_message, dataset_example_id=example_id
639
+ message=span.status_message,
640
+ dataset_example_id=example_id,
641
+ repetition_number=repetition_number,
479
642
  )
480
643
 
481
644
 
@@ -508,9 +671,10 @@ async def _chat_completion_result_payloads(
508
671
  await session.flush()
509
672
  for example_id, span, run in results:
510
673
  yield ChatCompletionSubscriptionResult(
511
- span=Span(span_rowid=span.id, db_span=span) if span else None,
512
- experiment_run=to_gql_experiment_run(run),
674
+ span=Span(id=span.id, db_record=span) if span else None,
675
+ experiment_run=ExperimentRun(id=run.id, db_record=run),
513
676
  dataset_example_id=example_id,
677
+ repetition_number=run.repetition_number,
514
678
  )
515
679
 
516
680
 
@@ -1,31 +1,98 @@
1
1
  from datetime import datetime
2
- from typing import Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
+ from strawberry.scalars import JSON
6
+ from strawberry.types import Info
5
7
 
6
- from phoenix.server.api.interceptor import GqlValueMediator
8
+ from phoenix.server.api.context import Context
9
+
10
+ from .AnnotationSource import AnnotationSource
11
+ from .AnnotatorKind import AnnotatorKind
12
+
13
+ if TYPE_CHECKING:
14
+ from .User import User
7
15
 
8
16
 
9
17
  @strawberry.interface
10
18
  class Annotation:
11
- name: str = strawberry.field(
12
- description="Name of the annotation, e.g. 'helpfulness' or 'relevance'."
13
- )
14
- score: Optional[float] = strawberry.field(
15
- description="Value of the annotation in the form of a numeric score.",
16
- default=GqlValueMediator(),
17
- )
18
- label: Optional[str] = strawberry.field(
19
- description="Value of the annotation in the form of a string, e.g. "
20
- "'helpful' or 'not helpful'. Note that the label is not necessarily binary."
21
- )
22
- explanation: Optional[str] = strawberry.field(
23
- description="The annotator's explanation for the annotation result (i.e. "
24
- "score or label, or both) given to the subject."
25
- )
26
- created_at: datetime = strawberry.field(
27
- description="The date and time when the annotation was created."
28
- )
29
- updated_at: datetime = strawberry.field(
30
- description="The date and time when the annotation was last updated."
31
- )
19
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
20
+ async def name(
21
+ self,
22
+ info: Info[Context, None],
23
+ ) -> str:
24
+ raise NotImplementedError
25
+
26
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
27
+ async def annotator_kind(
28
+ self,
29
+ info: Info[Context, None],
30
+ ) -> AnnotatorKind:
31
+ raise NotImplementedError
32
+
33
+ @strawberry.field(
34
+ description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
35
+ ) # type: ignore
36
+ async def label(
37
+ self,
38
+ info: Info[Context, None],
39
+ ) -> Optional[str]:
40
+ raise NotImplementedError
41
+
42
+ @strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
43
+ async def score(
44
+ self,
45
+ info: Info[Context, None],
46
+ ) -> Optional[float]:
47
+ raise NotImplementedError
48
+
49
+ @strawberry.field(
50
+ description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
51
+ ) # type: ignore
52
+ async def explanation(
53
+ self,
54
+ info: Info[Context, None],
55
+ ) -> Optional[str]:
56
+ raise NotImplementedError
57
+
58
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
59
+ async def metadata(
60
+ self,
61
+ info: Info[Context, None],
62
+ ) -> JSON:
63
+ raise NotImplementedError
64
+
65
+ @strawberry.field(description="The source of the annotation.") # type: ignore
66
+ async def source(
67
+ self,
68
+ info: Info[Context, None],
69
+ ) -> AnnotationSource:
70
+ raise NotImplementedError
71
+
72
+ @strawberry.field(description="The identifier of the annotation.") # type: ignore
73
+ async def identifier(
74
+ self,
75
+ info: Info[Context, None],
76
+ ) -> str:
77
+ raise NotImplementedError
78
+
79
+ @strawberry.field(description="The date and time the annotation was created.") # type: ignore
80
+ async def created_at(
81
+ self,
82
+ info: Info[Context, None],
83
+ ) -> datetime:
84
+ raise NotImplementedError
85
+
86
+ @strawberry.field(description="The date and time the annotation was last updated.") # type: ignore
87
+ async def updated_at(
88
+ self,
89
+ info: Info[Context, None],
90
+ ) -> datetime:
91
+ raise NotImplementedError
92
+
93
+ @strawberry.field(description="The user that produced the annotation.") # type: ignore
94
+ async def user(
95
+ self,
96
+ info: Info[Context, None],
97
+ ) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
98
+ raise NotImplementedError
@@ -3,25 +3,21 @@ from typing import Optional
3
3
 
4
4
  import strawberry
5
5
 
6
- from phoenix.db.models import ApiKey as ORMApiKey
7
-
8
6
 
9
7
  @strawberry.interface
10
8
  class ApiKey:
11
- name: str = strawberry.field(description="Name of the API key.")
12
- description: Optional[str] = strawberry.field(description="Description of the API key.")
13
- created_at: datetime = strawberry.field(
14
- description="The date and time the API key was created."
15
- )
16
- expires_at: Optional[datetime] = strawberry.field(
17
- description="The date and time the API key will expire."
18
- )
9
+ @strawberry.field(description="Name of the API key.") # type: ignore
10
+ async def name(self) -> str:
11
+ raise NotImplementedError
12
+
13
+ @strawberry.field(description="Description of the API key.") # type: ignore
14
+ async def description(self) -> Optional[str]:
15
+ raise NotImplementedError
19
16
 
17
+ @strawberry.field(description="The date and time the API key was created.") # type: ignore
18
+ async def created_at(self) -> datetime:
19
+ raise NotImplementedError
20
20
 
21
- def to_gql_api_key(api_key: ORMApiKey) -> ApiKey:
22
- return ApiKey(
23
- name=api_key.name,
24
- description=api_key.description,
25
- created_at=api_key.created_at,
26
- expires_at=api_key.expires_at,
27
- )
21
+ @strawberry.field(description="The date and time the API key will expire.") # type: ignore
22
+ async def expires_at(self) -> Optional[datetime]:
23
+ raise NotImplementedError
@@ -7,3 +7,4 @@ import strawberry
7
7
  class AuthMethod(Enum):
8
8
  LOCAL = "LOCAL"
9
9
  OAUTH2 = "OAUTH2"
10
+ LDAP = "LDAP"
@@ -11,6 +11,7 @@ from .Span import Span
11
11
  @strawberry.interface
12
12
  class ChatCompletionSubscriptionPayload:
13
13
  dataset_example_id: Optional[GlobalID] = None
14
+ repetition_number: Optional[int] = None
14
15
 
15
16
 
16
17
  @strawberry.type