arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,12 @@
1
1
  import asyncio
2
2
  import logging
3
+ from collections import deque
3
4
  from collections.abc import AsyncIterator, Iterator
4
5
  from datetime import datetime, timedelta, timezone
5
6
  from typing import (
6
7
  Any,
7
8
  AsyncGenerator,
9
+ Callable,
8
10
  Coroutine,
9
11
  Iterable,
10
12
  Mapping,
@@ -17,7 +19,7 @@ from typing import (
17
19
  import strawberry
18
20
  from openinference.instrumentation import safe_json_dumps
19
21
  from openinference.semconv.trace import SpanAttributes
20
- from sqlalchemy import and_, func, insert, select
22
+ from sqlalchemy import and_, insert, select
21
23
  from sqlalchemy.orm import load_only
22
24
  from strawberry.relay.types import GlobalID
23
25
  from strawberry.types import Info
@@ -26,10 +28,15 @@ from typing_extensions import TypeAlias, assert_never
26
28
  from phoenix.config import PLAYGROUND_PROJECT_NAME
27
29
  from phoenix.datetime_utils import local_now, normalize_datetime
28
30
  from phoenix.db import models
29
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
31
+ from phoenix.db.helpers import (
32
+ get_dataset_example_revisions,
33
+ insert_experiment_with_examples_snapshot,
34
+ )
35
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
30
36
  from phoenix.server.api.context import Context
31
37
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
32
38
  from phoenix.server.api.helpers.playground_clients import (
39
+ PlaygroundClientCredential,
33
40
  PlaygroundStreamingClient,
34
41
  initialize_playground_clients,
35
42
  )
@@ -42,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
42
49
  get_db_trace,
43
50
  streaming_llm_span,
44
51
  )
52
+ from phoenix.server.api.helpers.playground_users import get_user
45
53
  from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
46
54
  from phoenix.server.api.input_types.ChatCompletionInput import (
47
55
  ChatCompletionInput,
@@ -58,10 +66,12 @@ from phoenix.server.api.types.Dataset import Dataset
58
66
  from phoenix.server.api.types.DatasetExample import DatasetExample
59
67
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
60
68
  from phoenix.server.api.types.Experiment import to_gql_experiment
61
- from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
69
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
62
70
  from phoenix.server.api.types.node import from_global_id_with_expected_type
63
71
  from phoenix.server.api.types.Span import Span
72
+ from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
64
73
  from phoenix.server.dml_event import SpanInsertEvent
74
+ from phoenix.server.experiments.utils import generate_experiment_project_name
65
75
  from phoenix.server.types import DbSessionFactory
66
76
  from phoenix.utilities.template_formatters import (
67
77
  FStringTemplateFormatter,
@@ -87,9 +97,109 @@ ChatCompletionResult: TypeAlias = tuple[
87
97
  ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
88
98
 
89
99
 
100
+ async def _stream_single_chat_completion(
101
+ *,
102
+ input: ChatCompletionInput,
103
+ llm_client: PlaygroundStreamingClient,
104
+ project_id: int,
105
+ repetition_number: int,
106
+ results: asyncio.Queue[tuple[Optional[models.Span], int]],
107
+ ) -> ChatStream:
108
+ messages = [
109
+ (
110
+ message.role,
111
+ message.content,
112
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
113
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
114
+ )
115
+ for message in input.messages
116
+ ]
117
+ attributes = None
118
+ if template_options := input.template:
119
+ messages = list(
120
+ _formatted_messages(
121
+ messages=messages,
122
+ template_format=template_options.format,
123
+ template_variables=template_options.variables,
124
+ )
125
+ )
126
+ attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
127
+ invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
128
+ async with streaming_llm_span(
129
+ input=input,
130
+ messages=messages,
131
+ invocation_parameters=invocation_parameters,
132
+ attributes=attributes,
133
+ ) as span:
134
+ try:
135
+ async for chunk in llm_client.chat_completion_create(
136
+ messages=messages, tools=input.tools or [], **invocation_parameters
137
+ ):
138
+ span.add_response_chunk(chunk)
139
+ chunk.repetition_number = repetition_number
140
+ yield chunk
141
+ finally:
142
+ span.set_attributes(llm_client.attributes)
143
+ if span.status_message is not None:
144
+ yield ChatCompletionSubscriptionError(
145
+ message=span.status_message,
146
+ repetition_number=repetition_number,
147
+ )
148
+
149
+ db_trace = get_db_trace(span, project_id)
150
+ db_span = get_db_span(span, db_trace)
151
+ await results.put((db_span, repetition_number))
152
+
153
+
154
+ async def _chat_completion_span_result_payloads(
155
+ *,
156
+ db: DbSessionFactory,
157
+ results: Sequence[tuple[Optional[models.Span], int]],
158
+ span_cost_calculator: SpanCostCalculator,
159
+ on_span_insertion: Callable[[], None],
160
+ ) -> ChatStream:
161
+ if not results:
162
+ return
163
+ async with db() as session:
164
+ for span, repetition_number in results:
165
+ if span:
166
+ session.add(span)
167
+ await session.flush()
168
+ try:
169
+ span_cost = span_cost_calculator.calculate_cost(
170
+ start_time=span.start_time,
171
+ attributes=span.attributes,
172
+ )
173
+ except Exception as e:
174
+ logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
175
+ span_cost = None
176
+ if span_cost:
177
+ span_cost.span_rowid = span.id
178
+ span_cost.trace_rowid = span.trace_rowid
179
+ session.add(span_cost)
180
+ await session.flush()
181
+ for span, repetition_number in results:
182
+ if span:
183
+ yield ChatCompletionSubscriptionResult(
184
+ span=Span(id=span.id, db_record=span),
185
+ repetition_number=repetition_number,
186
+ )
187
+ on_span_insertion()
188
+
189
+
190
+ def _is_span_result_payloads_stream(
191
+ stream: ChatStream,
192
+ ) -> bool:
193
+ """
194
+ Checks if the given generator was instantiated from
195
+ `_chat_completion_span_result_payloads`
196
+ """
197
+ return stream.ag_code == _chat_completion_span_result_payloads.__code__ # type: ignore
198
+
199
+
90
200
  @strawberry.type
91
201
  class Subscription:
92
- @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
202
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
93
203
  async def chat_completion(
94
204
  self, info: Info[Context, None], input: ChatCompletionInput
95
205
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
@@ -98,9 +208,17 @@ class Subscription:
98
208
  if llm_client_class is None:
99
209
  raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
100
210
  try:
211
+ # Convert GraphQL credentials to PlaygroundCredential objects
212
+ playground_credentials = None
213
+ if input.credentials:
214
+ playground_credentials = [
215
+ PlaygroundClientCredential(env_var_name=cred.env_var_name, value=cred.value)
216
+ for cred in input.credentials
217
+ ]
218
+
101
219
  llm_client = llm_client_class(
102
220
  model=input.model,
103
- api_key=input.api_key,
221
+ credentials=playground_credentials,
104
222
  )
105
223
  except CustomGraphQLError:
106
224
  raise
@@ -110,42 +228,6 @@ class Subscription:
110
228
  f"{str(error)}"
111
229
  )
112
230
 
113
- messages = [
114
- (
115
- message.role,
116
- message.content,
117
- message.tool_call_id if isinstance(message.tool_call_id, str) else None,
118
- message.tool_calls if isinstance(message.tool_calls, list) else None,
119
- )
120
- for message in input.messages
121
- ]
122
- attributes = None
123
- if template_options := input.template:
124
- messages = list(
125
- _formatted_messages(
126
- messages=messages,
127
- template_format=template_options.format,
128
- template_variables=template_options.variables,
129
- )
130
- )
131
- attributes = {PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(template_options.variables)}
132
- invocation_parameters = llm_client.construct_invocation_parameters(
133
- input.invocation_parameters
134
- )
135
- async with streaming_llm_span(
136
- input=input,
137
- messages=messages,
138
- invocation_parameters=invocation_parameters,
139
- attributes=attributes,
140
- ) as span:
141
- async for chunk in llm_client.chat_completion_create(
142
- messages=messages, tools=input.tools or [], **invocation_parameters
143
- ):
144
- span.add_response_chunk(chunk)
145
- yield chunk
146
- span.set_attributes(llm_client.attributes)
147
- if span.status_message is not None:
148
- yield ChatCompletionSubscriptionError(message=span.status_message)
149
231
  async with info.context.db() as session:
150
232
  if (
151
233
  playground_project_id := await session.scalar(
@@ -160,14 +242,100 @@ class Subscription:
160
242
  description="Traces from prompt playground",
161
243
  )
162
244
  )
163
- db_trace = get_db_trace(span, playground_project_id)
164
- db_span = get_db_span(span, db_trace)
165
- session.add(db_span)
166
- await session.flush()
167
- info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
168
- yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
169
-
170
- @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
245
+
246
+ results: asyncio.Queue[tuple[Optional[models.Span], int]] = asyncio.Queue()
247
+ not_started: deque[tuple[int, ChatStream]] = deque(
248
+ (
249
+ repetition_number,
250
+ _stream_single_chat_completion(
251
+ input=input,
252
+ llm_client=llm_client,
253
+ project_id=playground_project_id,
254
+ repetition_number=repetition_number,
255
+ results=results,
256
+ ),
257
+ )
258
+ for repetition_number in range(1, input.repetitions + 1)
259
+ )
260
+ in_progress: list[
261
+ tuple[
262
+ Optional[int],
263
+ ChatStream,
264
+ asyncio.Task[ChatCompletionSubscriptionPayload],
265
+ ]
266
+ ] = []
267
+ max_in_progress = 3
268
+ write_batch_size = 10
269
+ write_interval = timedelta(seconds=10)
270
+ last_write_time = datetime.now()
271
+ while not_started or in_progress:
272
+ while not_started and len(in_progress) < max_in_progress:
273
+ rep_num, stream = not_started.popleft()
274
+ task = _create_task_with_timeout(stream)
275
+ in_progress.append((rep_num, stream, task))
276
+ async_tasks_to_run = [task for _, _, task in in_progress]
277
+ completed_tasks, _ = await asyncio.wait(
278
+ async_tasks_to_run, return_when=asyncio.FIRST_COMPLETED
279
+ )
280
+ for completed_task in completed_tasks:
281
+ idx = [task for _, _, task in in_progress].index(completed_task)
282
+ repetition_number, stream, _ = in_progress[idx]
283
+ try:
284
+ yield completed_task.result()
285
+ except StopAsyncIteration:
286
+ del in_progress[idx] # removes exhausted stream
287
+ except asyncio.TimeoutError:
288
+ del in_progress[idx] # removes timed-out stream
289
+ if repetition_number is not None:
290
+ yield ChatCompletionSubscriptionError(
291
+ message="Playground task timed out",
292
+ repetition_number=repetition_number,
293
+ )
294
+ except Exception as error:
295
+ del in_progress[idx] # removes failed stream
296
+ if repetition_number is not None:
297
+ yield ChatCompletionSubscriptionError(
298
+ message="An unexpected error occurred",
299
+ repetition_number=repetition_number,
300
+ )
301
+ logger.exception(error)
302
+ else:
303
+ task = _create_task_with_timeout(stream)
304
+ in_progress[idx] = (repetition_number, stream, task)
305
+
306
+ exceeded_write_batch_size = results.qsize() >= write_batch_size
307
+ exceeded_write_interval = datetime.now() - last_write_time > write_interval
308
+ write_already_in_progress = any(
309
+ _is_span_result_payloads_stream(stream) for _, stream, _ in in_progress
310
+ )
311
+ if (
312
+ not results.empty()
313
+ and (exceeded_write_batch_size or exceeded_write_interval)
314
+ and not write_already_in_progress
315
+ ):
316
+ result_payloads_stream = _chat_completion_span_result_payloads(
317
+ db=info.context.db,
318
+ results=_drain_no_wait(results),
319
+ span_cost_calculator=info.context.span_cost_calculator,
320
+ on_span_insertion=lambda: info.context.event_queue.put(
321
+ SpanInsertEvent(ids=(playground_project_id,))
322
+ ),
323
+ )
324
+ task = _create_task_with_timeout(result_payloads_stream)
325
+ in_progress.append((None, result_payloads_stream, task))
326
+ last_write_time = datetime.now()
327
+ if remaining_results := await _drain(results):
328
+ async for result_payload in _chat_completion_span_result_payloads(
329
+ db=info.context.db,
330
+ results=remaining_results,
331
+ span_cost_calculator=info.context.span_cost_calculator,
332
+ on_span_insertion=lambda: info.context.event_queue.put(
333
+ SpanInsertEvent(ids=(playground_project_id,))
334
+ ),
335
+ ):
336
+ yield result_payload
337
+
338
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
171
339
  async def chat_completion_over_dataset(
172
340
  self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
173
341
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
@@ -176,9 +344,17 @@ class Subscription:
176
344
  if llm_client_class is None:
177
345
  raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
178
346
  try:
347
+ # Convert GraphQL credentials to PlaygroundCredential objects
348
+ playground_credentials = None
349
+ if input.credentials:
350
+ playground_credentials = [
351
+ PlaygroundClientCredential(env_var_name=cred.env_var_name, value=cred.value)
352
+ for cred in input.credentials
353
+ ]
354
+
179
355
  llm_client = llm_client_class(
180
356
  model=input.model,
181
- api_key=input.api_key,
357
+ credentials=playground_credentials,
182
358
  )
183
359
  except CustomGraphQLError:
184
360
  raise
@@ -223,27 +399,22 @@ class Subscription:
223
399
  )
224
400
  ) is None:
225
401
  raise NotFound(f"Could not find dataset version with ID {version_id}")
226
- revision_ids = (
227
- select(func.max(models.DatasetExampleRevision.id))
228
- .join(models.DatasetExample)
229
- .where(
230
- and_(
231
- models.DatasetExample.dataset_id == dataset_id,
232
- models.DatasetExampleRevision.dataset_version_id <= resolved_version_id,
233
- )
234
- )
235
- .group_by(models.DatasetExampleRevision.dataset_example_id)
236
- )
402
+
403
+ # Parse split IDs if provided
404
+ resolved_split_ids: Optional[list[int]] = None
405
+ if input.split_ids is not None and len(input.split_ids) > 0:
406
+ resolved_split_ids = [
407
+ from_global_id_with_expected_type(split_id, models.DatasetSplit.__name__)
408
+ for split_id in input.split_ids
409
+ ]
410
+
237
411
  if not (
238
412
  revisions := [
239
413
  rev
240
414
  async for rev in await session.stream_scalars(
241
- select(models.DatasetExampleRevision)
242
- .where(
243
- and_(
244
- models.DatasetExampleRevision.id.in_(revision_ids),
245
- models.DatasetExampleRevision.revision_kind != "DELETE",
246
- )
415
+ get_dataset_example_revisions(
416
+ resolved_version_id,
417
+ split_ids=resolved_split_ids,
247
418
  )
248
419
  .order_by(models.DatasetExampleRevision.dataset_example_id.asc())
249
420
  .options(
@@ -256,31 +427,38 @@ class Subscription:
256
427
  ]
257
428
  ):
258
429
  raise NotFound("No examples found for the given dataset and version")
430
+ project_name = generate_experiment_project_name()
259
431
  if (
260
432
  playground_project_id := await session.scalar(
261
- select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
433
+ select(models.Project.id).where(models.Project.name == project_name)
262
434
  )
263
435
  ) is None:
264
436
  playground_project_id = await session.scalar(
265
437
  insert(models.Project)
266
438
  .returning(models.Project.id)
267
439
  .values(
268
- name=PLAYGROUND_PROJECT_NAME,
440
+ name=project_name,
269
441
  description="Traces from prompt playground",
270
442
  )
271
443
  )
444
+ user_id = get_user(info)
272
445
  experiment = models.Experiment(
273
446
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
274
447
  dataset_version_id=resolved_version_id,
275
448
  name=input.experiment_name
276
449
  or _default_playground_experiment_name(input.prompt_name),
277
450
  description=input.experiment_description,
278
- repetitions=1,
451
+ repetitions=input.repetitions,
279
452
  metadata_=input.experiment_metadata or dict(),
280
- project_name=PLAYGROUND_PROJECT_NAME,
453
+ project_name=project_name,
454
+ user_id=user_id,
281
455
  )
282
- session.add(experiment)
283
- await session.flush()
456
+ if resolved_split_ids:
457
+ experiment.experiment_dataset_splits = [
458
+ models.ExperimentDatasetSplit(dataset_split_id=split_id)
459
+ for split_id in resolved_split_ids
460
+ ]
461
+ await insert_experiment_with_examples_snapshot(session, experiment)
284
462
  yield ChatCompletionSubscriptionExperiment(
285
463
  experiment=to_gql_experiment(experiment)
286
464
  ) # eagerly yields experiment so it can be linked by consumers of the subscription
@@ -294,11 +472,15 @@ class Subscription:
294
472
  llm_client=llm_client,
295
473
  revision=revision,
296
474
  results=results,
475
+ repetition_number=repetition_number,
297
476
  experiment_id=experiment.id,
298
477
  project_id=playground_project_id,
299
478
  ),
300
479
  )
301
480
  for revision in revisions
481
+ for repetition_number in reversed(
482
+ range(1, input.repetitions + 1)
483
+ ) # since we pop right, this runs the repetitions in increasing order
302
484
  ]
303
485
  in_progress: list[
304
486
  tuple[
@@ -355,14 +537,18 @@ class Subscription:
355
537
  and not write_already_in_progress
356
538
  ):
357
539
  result_payloads_stream = _chat_completion_result_payloads(
358
- db=info.context.db, results=_drain_no_wait(results)
540
+ db=info.context.db,
541
+ results=_drain_no_wait(results),
542
+ span_cost_calculator=info.context.span_cost_calculator,
359
543
  )
360
544
  task = _create_task_with_timeout(result_payloads_stream)
361
545
  in_progress.append((None, result_payloads_stream, task))
362
546
  last_write_time = datetime.now()
363
547
  if remaining_results := await _drain(results):
364
548
  async for result_payload in _chat_completion_result_payloads(
365
- db=info.context.db, results=remaining_results
549
+ db=info.context.db,
550
+ results=remaining_results,
551
+ span_cost_calculator=info.context.span_cost_calculator,
366
552
  ):
367
553
  yield result_payload
368
554
 
@@ -372,6 +558,7 @@ async def _stream_chat_completion_over_dataset_example(
372
558
  input: ChatCompletionOverDatasetInput,
373
559
  llm_client: PlaygroundStreamingClient,
374
560
  revision: models.DatasetExampleRevision,
561
+ repetition_number: int,
375
562
  results: asyncio.Queue[ChatCompletionResult],
376
563
  experiment_id: int,
377
564
  project_id: int,
@@ -398,7 +585,11 @@ async def _stream_chat_completion_over_dataset_example(
398
585
  )
399
586
  except TemplateFormatterError as error:
400
587
  format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
401
- yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
588
+ yield ChatCompletionSubscriptionError(
589
+ message=str(error),
590
+ dataset_example_id=example_id,
591
+ repetition_number=repetition_number,
592
+ )
402
593
  await results.put(
403
594
  (
404
595
  example_id,
@@ -408,7 +599,7 @@ async def _stream_chat_completion_over_dataset_example(
408
599
  dataset_example_id=revision.dataset_example_id,
409
600
  trace_id=None,
410
601
  output={},
411
- repetition_number=1,
602
+ repetition_number=repetition_number,
412
603
  start_time=format_start_time,
413
604
  end_time=format_end_time,
414
605
  error=str(error),
@@ -423,22 +614,31 @@ async def _stream_chat_completion_over_dataset_example(
423
614
  invocation_parameters=invocation_parameters,
424
615
  attributes={PROMPT_TEMPLATE_VARIABLES: safe_json_dumps(revision.input)},
425
616
  ) as span:
426
- async for chunk in llm_client.chat_completion_create(
427
- messages=messages, tools=input.tools or [], **invocation_parameters
428
- ):
429
- span.add_response_chunk(chunk)
430
- chunk.dataset_example_id = example_id
431
- yield chunk
432
- span.set_attributes(llm_client.attributes)
617
+ try:
618
+ async for chunk in llm_client.chat_completion_create(
619
+ messages=messages, tools=input.tools or [], **invocation_parameters
620
+ ):
621
+ span.add_response_chunk(chunk)
622
+ chunk.dataset_example_id = example_id
623
+ chunk.repetition_number = repetition_number
624
+ yield chunk
625
+ finally:
626
+ span.set_attributes(llm_client.attributes)
433
627
  db_trace = get_db_trace(span, project_id)
434
628
  db_span = get_db_span(span, db_trace)
435
629
  db_run = get_db_experiment_run(
436
- db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
630
+ db_span,
631
+ db_trace,
632
+ experiment_id=experiment_id,
633
+ example_id=revision.dataset_example_id,
634
+ repetition_number=repetition_number,
437
635
  )
438
636
  await results.put((example_id, db_span, db_run))
439
637
  if span.status_message is not None:
440
638
  yield ChatCompletionSubscriptionError(
441
- message=span.status_message, dataset_example_id=example_id
639
+ message=span.status_message,
640
+ dataset_example_id=example_id,
641
+ repetition_number=repetition_number,
442
642
  )
443
643
 
444
644
 
@@ -446,6 +646,7 @@ async def _chat_completion_result_payloads(
446
646
  *,
447
647
  db: DbSessionFactory,
448
648
  results: Sequence[ChatCompletionResult],
649
+ span_cost_calculator: SpanCostCalculator,
449
650
  ) -> ChatStream:
450
651
  if not results:
451
652
  return
@@ -453,13 +654,27 @@ async def _chat_completion_result_payloads(
453
654
  for _, span, run in results:
454
655
  if span:
455
656
  session.add(span)
657
+ await session.flush()
658
+ try:
659
+ span_cost = span_cost_calculator.calculate_cost(
660
+ start_time=span.start_time,
661
+ attributes=span.attributes,
662
+ )
663
+ except Exception as e:
664
+ logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
665
+ span_cost = None
666
+ if span_cost:
667
+ span_cost.span_rowid = span.id
668
+ span_cost.trace_rowid = span.trace_rowid
669
+ session.add(span_cost)
456
670
  session.add(run)
457
671
  await session.flush()
458
672
  for example_id, span, run in results:
459
673
  yield ChatCompletionSubscriptionResult(
460
- span=Span(span_rowid=span.id, db_span=span) if span else None,
461
- experiment_run=to_gql_experiment_run(run),
674
+ span=Span(id=span.id, db_record=span) if span else None,
675
+ experiment_run=ExperimentRun(id=run.id, db_record=run),
462
676
  dataset_example_id=example_id,
677
+ repetition_number=run.repetition_number,
463
678
  )
464
679
 
465
680
 
@@ -577,3 +792,5 @@ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
577
792
  LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
578
793
  LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
579
794
  PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
795
+ LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
796
+ LLM_PROVIDER = SpanAttributes.LLM_PROVIDER