arize-phoenix 5.5.2__py3-none-any.whl → 5.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (186) hide show
  1. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +4 -7
  2. arize_phoenix-5.7.0.dist-info/RECORD +330 -0
  3. phoenix/config.py +50 -8
  4. phoenix/core/model.py +3 -3
  5. phoenix/core/model_schema.py +41 -50
  6. phoenix/core/model_schema_adapter.py +17 -16
  7. phoenix/datetime_utils.py +2 -2
  8. phoenix/db/bulk_inserter.py +10 -20
  9. phoenix/db/engines.py +2 -1
  10. phoenix/db/enums.py +2 -2
  11. phoenix/db/helpers.py +8 -7
  12. phoenix/db/insertion/dataset.py +9 -19
  13. phoenix/db/insertion/document_annotation.py +14 -13
  14. phoenix/db/insertion/helpers.py +6 -16
  15. phoenix/db/insertion/span_annotation.py +14 -13
  16. phoenix/db/insertion/trace_annotation.py +14 -13
  17. phoenix/db/insertion/types.py +19 -30
  18. phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
  19. phoenix/db/models.py +28 -28
  20. phoenix/experiments/evaluators/base.py +2 -1
  21. phoenix/experiments/evaluators/code_evaluators.py +4 -5
  22. phoenix/experiments/evaluators/llm_evaluators.py +157 -4
  23. phoenix/experiments/evaluators/utils.py +3 -2
  24. phoenix/experiments/functions.py +10 -21
  25. phoenix/experiments/tracing.py +2 -1
  26. phoenix/experiments/types.py +20 -29
  27. phoenix/experiments/utils.py +2 -1
  28. phoenix/inferences/errors.py +6 -5
  29. phoenix/inferences/fixtures.py +6 -5
  30. phoenix/inferences/inferences.py +37 -37
  31. phoenix/inferences/schema.py +11 -10
  32. phoenix/inferences/validation.py +13 -14
  33. phoenix/logging/_formatter.py +3 -3
  34. phoenix/metrics/__init__.py +5 -4
  35. phoenix/metrics/binning.py +2 -1
  36. phoenix/metrics/metrics.py +2 -1
  37. phoenix/metrics/mixins.py +7 -6
  38. phoenix/metrics/retrieval_metrics.py +2 -1
  39. phoenix/metrics/timeseries.py +5 -4
  40. phoenix/metrics/wrappers.py +2 -2
  41. phoenix/pointcloud/clustering.py +3 -4
  42. phoenix/pointcloud/pointcloud.py +7 -5
  43. phoenix/pointcloud/umap_parameters.py +2 -1
  44. phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
  45. phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
  46. phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
  47. phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
  48. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
  49. phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
  50. phoenix/server/api/dataloaders/document_evaluations.py +3 -7
  51. phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
  52. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
  53. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
  54. phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
  55. phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
  56. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
  57. phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
  58. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
  59. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  60. phoenix/server/api/dataloaders/record_counts.py +11 -18
  61. phoenix/server/api/dataloaders/span_annotations.py +3 -7
  62. phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
  63. phoenix/server/api/dataloaders/span_descendants.py +3 -7
  64. phoenix/server/api/dataloaders/span_projects.py +2 -2
  65. phoenix/server/api/dataloaders/token_counts.py +12 -19
  66. phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
  67. phoenix/server/api/dataloaders/user_roles.py +3 -3
  68. phoenix/server/api/dataloaders/users.py +3 -3
  69. phoenix/server/api/helpers/__init__.py +4 -3
  70. phoenix/server/api/helpers/dataset_helpers.py +10 -9
  71. phoenix/server/api/helpers/playground_clients.py +671 -0
  72. phoenix/server/api/helpers/playground_registry.py +70 -0
  73. phoenix/server/api/helpers/playground_spans.py +325 -0
  74. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
  75. phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
  76. phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
  77. phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
  78. phoenix/server/api/input_types/ClusterInput.py +2 -2
  79. phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
  80. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
  81. phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
  82. phoenix/server/api/input_types/DimensionFilter.py +4 -4
  83. phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
  84. phoenix/server/api/input_types/Granularity.py +1 -1
  85. phoenix/server/api/input_types/InvocationParameters.py +156 -13
  86. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
  87. phoenix/server/api/input_types/TemplateOptions.py +10 -0
  88. phoenix/server/api/mutations/__init__.py +4 -0
  89. phoenix/server/api/mutations/chat_mutations.py +374 -0
  90. phoenix/server/api/mutations/dataset_mutations.py +4 -4
  91. phoenix/server/api/mutations/experiment_mutations.py +1 -2
  92. phoenix/server/api/mutations/export_events_mutations.py +7 -7
  93. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  94. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  95. phoenix/server/api/mutations/user_mutations.py +4 -4
  96. phoenix/server/api/openapi/schema.py +2 -2
  97. phoenix/server/api/queries.py +61 -72
  98. phoenix/server/api/routers/oauth2.py +4 -4
  99. phoenix/server/api/routers/v1/datasets.py +22 -36
  100. phoenix/server/api/routers/v1/evaluations.py +6 -5
  101. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
  102. phoenix/server/api/routers/v1/experiment_runs.py +2 -2
  103. phoenix/server/api/routers/v1/experiments.py +4 -4
  104. phoenix/server/api/routers/v1/spans.py +13 -12
  105. phoenix/server/api/routers/v1/traces.py +5 -5
  106. phoenix/server/api/routers/v1/utils.py +5 -5
  107. phoenix/server/api/schema.py +42 -10
  108. phoenix/server/api/subscriptions.py +347 -494
  109. phoenix/server/api/types/AnnotationSummary.py +3 -3
  110. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
  111. phoenix/server/api/types/Cluster.py +8 -7
  112. phoenix/server/api/types/Dataset.py +5 -4
  113. phoenix/server/api/types/Dimension.py +3 -3
  114. phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
  115. phoenix/server/api/types/EmbeddingDimension.py +6 -5
  116. phoenix/server/api/types/EvaluationSummary.py +3 -3
  117. phoenix/server/api/types/Event.py +7 -7
  118. phoenix/server/api/types/Experiment.py +3 -3
  119. phoenix/server/api/types/ExperimentComparison.py +2 -4
  120. phoenix/server/api/types/GenerativeProvider.py +27 -3
  121. phoenix/server/api/types/Inferences.py +9 -8
  122. phoenix/server/api/types/InferencesRole.py +2 -2
  123. phoenix/server/api/types/Model.py +2 -2
  124. phoenix/server/api/types/Project.py +11 -18
  125. phoenix/server/api/types/Segments.py +3 -3
  126. phoenix/server/api/types/Span.py +45 -7
  127. phoenix/server/api/types/TemplateLanguage.py +9 -0
  128. phoenix/server/api/types/TimeSeries.py +8 -7
  129. phoenix/server/api/types/Trace.py +2 -2
  130. phoenix/server/api/types/UMAPPoints.py +6 -6
  131. phoenix/server/api/types/User.py +3 -3
  132. phoenix/server/api/types/node.py +1 -3
  133. phoenix/server/api/types/pagination.py +4 -4
  134. phoenix/server/api/utils.py +2 -4
  135. phoenix/server/app.py +76 -37
  136. phoenix/server/bearer_auth.py +4 -10
  137. phoenix/server/dml_event.py +3 -3
  138. phoenix/server/dml_event_handler.py +10 -24
  139. phoenix/server/grpc_server.py +3 -2
  140. phoenix/server/jwt_store.py +22 -21
  141. phoenix/server/main.py +17 -4
  142. phoenix/server/oauth2.py +3 -2
  143. phoenix/server/rate_limiters.py +5 -8
  144. phoenix/server/static/.vite/manifest.json +31 -31
  145. phoenix/server/static/assets/components-Csu8UKOs.js +1612 -0
  146. phoenix/server/static/assets/{index-DCzakdJq.js → index-Bk5C9EA7.js} +2 -2
  147. phoenix/server/static/assets/{pages-CAL1FDMt.js → pages-UeWaKXNs.js} +337 -442
  148. phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
  149. phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
  150. phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
  151. phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
  152. phoenix/server/templates/index.html +1 -0
  153. phoenix/server/thread_server.py +1 -1
  154. phoenix/server/types.py +17 -29
  155. phoenix/services.py +8 -3
  156. phoenix/session/client.py +12 -24
  157. phoenix/session/data_extractor.py +3 -3
  158. phoenix/session/evaluation.py +1 -2
  159. phoenix/session/session.py +26 -21
  160. phoenix/trace/attributes.py +16 -28
  161. phoenix/trace/dsl/filter.py +17 -21
  162. phoenix/trace/dsl/helpers.py +3 -3
  163. phoenix/trace/dsl/query.py +13 -22
  164. phoenix/trace/fixtures.py +11 -17
  165. phoenix/trace/otel.py +5 -15
  166. phoenix/trace/projects.py +3 -2
  167. phoenix/trace/schemas.py +2 -2
  168. phoenix/trace/span_evaluations.py +9 -8
  169. phoenix/trace/span_json_decoder.py +3 -3
  170. phoenix/trace/span_json_encoder.py +2 -2
  171. phoenix/trace/trace_dataset.py +6 -5
  172. phoenix/trace/utils.py +6 -6
  173. phoenix/utilities/deprecation.py +3 -2
  174. phoenix/utilities/error_handling.py +3 -2
  175. phoenix/utilities/json.py +2 -1
  176. phoenix/utilities/logging.py +2 -2
  177. phoenix/utilities/project.py +1 -1
  178. phoenix/utilities/re.py +3 -4
  179. phoenix/utilities/template_formatters.py +16 -5
  180. phoenix/version.py +1 -1
  181. arize_phoenix-5.5.2.dist-info/RECORD +0 -321
  182. phoenix/server/static/assets/components-hX0LgYz3.js +0 -1428
  183. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
  184. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
  185. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
  186. {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,329 +1,73 @@
1
- import json
2
- from abc import ABC, abstractmethod
3
- from collections import defaultdict
4
- from datetime import datetime, timezone
5
- from enum import Enum
6
- from itertools import chain
1
+ import logging
2
+ from asyncio import FIRST_COMPLETED, Task, create_task, wait
3
+ from collections.abc import Iterator
7
4
  from typing import (
8
- TYPE_CHECKING,
9
- Annotated,
10
5
  Any,
11
6
  AsyncIterator,
12
- Callable,
13
- DefaultDict,
14
- Dict,
7
+ Collection,
15
8
  Iterable,
16
- Iterator,
17
- List,
9
+ Mapping,
18
10
  Optional,
19
- Tuple,
20
- Type,
21
- Union,
11
+ TypeVar,
22
12
  )
23
13
 
24
14
  import strawberry
25
- from openinference.instrumentation import safe_json_dumps
26
- from openinference.semconv.trace import (
27
- MessageAttributes,
28
- OpenInferenceMimeTypeValues,
29
- OpenInferenceSpanKindValues,
30
- SpanAttributes,
31
- ToolAttributes,
32
- ToolCallAttributes,
33
- )
34
- from opentelemetry.sdk.trace import TracerProvider
35
- from opentelemetry.sdk.trace.export import SimpleSpanProcessor
36
- from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
37
- from opentelemetry.trace import StatusCode
38
- from sqlalchemy import insert, select
39
- from strawberry import UNSET
40
- from strawberry.scalars import JSON as JSONScalarType
15
+ from openinference.semconv.trace import SpanAttributes
16
+ from sqlalchemy import and_, func, insert, select
17
+ from sqlalchemy.orm import load_only
18
+ from strawberry.relay.types import GlobalID
41
19
  from strawberry.types import Info
42
20
  from typing_extensions import TypeAlias, assert_never
43
21
 
44
22
  from phoenix.db import models
45
23
  from phoenix.server.api.context import Context
46
- from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
47
- from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
24
+ from phoenix.server.api.exceptions import BadRequest
25
+ from phoenix.server.api.helpers.playground_clients import (
26
+ PlaygroundStreamingClient,
27
+ initialize_playground_clients,
28
+ )
29
+ from phoenix.server.api.helpers.playground_registry import (
30
+ PLAYGROUND_CLIENT_REGISTRY,
31
+ )
32
+ from phoenix.server.api.helpers.playground_spans import streaming_llm_span
33
+ from phoenix.server.api.input_types.ChatCompletionInput import (
34
+ ChatCompletionInput,
35
+ ChatCompletionOverDatasetInput,
36
+ )
48
37
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
49
- from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
50
- from phoenix.server.api.types.Span import Span, to_gql_span
38
+ from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
39
+ ChatCompletionOverDatasetSubscriptionResult,
40
+ ChatCompletionSubscriptionError,
41
+ ChatCompletionSubscriptionPayload,
42
+ FinishedChatCompletion,
43
+ )
44
+ from phoenix.server.api.types.Dataset import Dataset
45
+ from phoenix.server.api.types.DatasetExample import DatasetExample
46
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
47
+ from phoenix.server.api.types.Experiment import to_gql_experiment
48
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
49
+ from phoenix.server.api.types.Span import to_gql_span
50
+ from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
51
51
  from phoenix.server.dml_event import SpanInsertEvent
52
- from phoenix.trace.attributes import unflatten
53
- from phoenix.utilities.json import jsonify
52
+ from phoenix.trace.attributes import get_attribute_value
54
53
  from phoenix.utilities.template_formatters import (
55
54
  FStringTemplateFormatter,
56
55
  MustacheTemplateFormatter,
57
56
  TemplateFormatter,
57
+ TemplateFormatterError,
58
58
  )
59
59
 
60
- if TYPE_CHECKING:
61
- from anthropic.types import MessageParam
62
- from openai.types import CompletionUsage
63
- from openai.types.chat import ChatCompletionMessageParam
64
-
65
- PLAYGROUND_PROJECT_NAME = "playground"
66
-
67
- ToolCallID: TypeAlias = str
68
-
69
-
70
- @strawberry.enum
71
- class TemplateLanguage(Enum):
72
- MUSTACHE = "MUSTACHE"
73
- F_STRING = "F_STRING"
60
+ GenericType = TypeVar("GenericType")
74
61
 
62
+ logger = logging.getLogger(__name__)
75
63
 
76
- @strawberry.input
77
- class TemplateOptions:
78
- variables: JSONScalarType
79
- language: TemplateLanguage
80
-
81
-
82
- @strawberry.type
83
- class TextChunk:
84
- content: str
85
-
86
-
87
- @strawberry.type
88
- class FunctionCallChunk:
89
- name: str
90
- arguments: str
91
-
92
-
93
- @strawberry.type
94
- class ToolCallChunk:
95
- id: str
96
- function: FunctionCallChunk
97
-
64
+ initialize_playground_clients()
98
65
 
99
- @strawberry.type
100
- class FinishedChatCompletion:
101
- span: Span
102
-
103
-
104
- ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
105
- Union[TextChunk, ToolCallChunk, FinishedChatCompletion],
106
- strawberry.union("ChatCompletionSubscriptionPayload"),
66
+ ChatCompletionMessage: TypeAlias = tuple[
67
+ ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
107
68
  ]
108
-
109
-
110
- @strawberry.input
111
- class GenerativeModelInput:
112
- provider_key: GenerativeProviderKey
113
- name: str
114
- """ The name of the model. Or the Deployment Name for Azure OpenAI models. """
115
- endpoint: Optional[str] = UNSET
116
- """ The endpoint to use for the model. Only required for Azure OpenAI models. """
117
- api_version: Optional[str] = UNSET
118
- """ The API version to use for the model. """
119
-
120
-
121
- @strawberry.input
122
- class ChatCompletionInput:
123
- messages: List[ChatCompletionMessageInput]
124
- model: GenerativeModelInput
125
- invocation_parameters: InvocationParameters
126
- tools: Optional[List[JSONScalarType]] = UNSET
127
- template: Optional[TemplateOptions] = UNSET
128
- api_key: Optional[str] = strawberry.field(default=None)
129
-
130
-
131
- PLAYGROUND_STREAMING_CLIENT_REGISTRY: Dict[
132
- GenerativeProviderKey, Type["PlaygroundStreamingClient"]
133
- ] = {}
134
-
135
-
136
- def register_llm_client(
137
- provider_key: GenerativeProviderKey,
138
- ) -> Callable[[Type["PlaygroundStreamingClient"]], Type["PlaygroundStreamingClient"]]:
139
- def decorator(cls: Type["PlaygroundStreamingClient"]) -> Type["PlaygroundStreamingClient"]:
140
- PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
141
- return cls
142
-
143
- return decorator
144
-
145
-
146
- class PlaygroundStreamingClient(ABC):
147
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None: ...
148
-
149
- @abstractmethod
150
- async def chat_completion_create(
151
- self,
152
- messages: List[Tuple[ChatCompletionMessageRole, str]],
153
- tools: List[JSONScalarType],
154
- **invocation_parameters: Any,
155
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
156
- # a yield statement is needed to satisfy the type-checker
157
- # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
158
- yield TextChunk(content="")
159
-
160
- @property
161
- @abstractmethod
162
- def attributes(self) -> Dict[str, Any]: ...
163
-
164
-
165
- @register_llm_client(GenerativeProviderKey.OPENAI)
166
- class OpenAIStreamingClient(PlaygroundStreamingClient):
167
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
168
- from openai import AsyncOpenAI
169
-
170
- self.client = AsyncOpenAI(api_key=api_key)
171
- self.model_name = model.name
172
- self._attributes: Dict[str, Any] = {}
173
-
174
- async def chat_completion_create(
175
- self,
176
- messages: List[Tuple[ChatCompletionMessageRole, str]],
177
- tools: List[JSONScalarType],
178
- **invocation_parameters: Any,
179
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
180
- from openai import NOT_GIVEN
181
- from openai.types.chat import ChatCompletionStreamOptionsParam
182
-
183
- # Convert standard messages to OpenAI messages
184
- openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
185
- tool_call_ids: Dict[int, str] = {}
186
- token_usage: Optional["CompletionUsage"] = None
187
- async for chunk in await self.client.chat.completions.create(
188
- messages=openai_messages,
189
- model=self.model_name,
190
- stream=True,
191
- stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
192
- tools=tools or NOT_GIVEN,
193
- **invocation_parameters,
194
- ):
195
- if (usage := chunk.usage) is not None:
196
- token_usage = usage
197
- continue
198
- choice = chunk.choices[0]
199
- delta = choice.delta
200
- if choice.finish_reason is None:
201
- if isinstance(chunk_content := delta.content, str):
202
- text_chunk = TextChunk(content=chunk_content)
203
- yield text_chunk
204
- if (tool_calls := delta.tool_calls) is not None:
205
- for tool_call_index, tool_call in enumerate(tool_calls):
206
- tool_call_id = (
207
- tool_call.id
208
- if tool_call.id is not None
209
- else tool_call_ids[tool_call_index]
210
- )
211
- tool_call_ids[tool_call_index] = tool_call_id
212
- if (function := tool_call.function) is not None:
213
- tool_call_chunk = ToolCallChunk(
214
- id=tool_call_id,
215
- function=FunctionCallChunk(
216
- name=function.name or "",
217
- arguments=function.arguments or "",
218
- ),
219
- )
220
- yield tool_call_chunk
221
- if token_usage is not None:
222
- self._attributes.update(_llm_token_counts(token_usage))
223
-
224
- def to_openai_chat_completion_param(
225
- self, role: ChatCompletionMessageRole, content: JSONScalarType
226
- ) -> "ChatCompletionMessageParam":
227
- from openai.types.chat import (
228
- ChatCompletionAssistantMessageParam,
229
- ChatCompletionSystemMessageParam,
230
- ChatCompletionUserMessageParam,
231
- )
232
-
233
- if role is ChatCompletionMessageRole.USER:
234
- return ChatCompletionUserMessageParam(
235
- {
236
- "content": content,
237
- "role": "user",
238
- }
239
- )
240
- if role is ChatCompletionMessageRole.SYSTEM:
241
- return ChatCompletionSystemMessageParam(
242
- {
243
- "content": content,
244
- "role": "system",
245
- }
246
- )
247
- if role is ChatCompletionMessageRole.AI:
248
- return ChatCompletionAssistantMessageParam(
249
- {
250
- "content": content,
251
- "role": "assistant",
252
- }
253
- )
254
- if role is ChatCompletionMessageRole.TOOL:
255
- raise NotImplementedError
256
- assert_never(role)
257
-
258
- @property
259
- def attributes(self) -> Dict[str, Any]:
260
- return self._attributes
261
-
262
-
263
- @register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
264
- class AzureOpenAIStreamingClient(OpenAIStreamingClient):
265
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None):
266
- from openai import AsyncAzureOpenAI
267
-
268
- if model.endpoint is None or model.api_version is None:
269
- raise ValueError("endpoint and api_version are required for Azure OpenAI models")
270
- self.client = AsyncAzureOpenAI(
271
- api_key=api_key,
272
- azure_endpoint=model.endpoint,
273
- api_version=model.api_version,
274
- )
275
-
276
-
277
- @register_llm_client(GenerativeProviderKey.ANTHROPIC)
278
- class AnthropicStreamingClient(PlaygroundStreamingClient):
279
- def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
280
- import anthropic
281
-
282
- self.client = anthropic.AsyncAnthropic(api_key=api_key)
283
- self.model_name = model.name
284
-
285
- async def chat_completion_create(
286
- self,
287
- messages: List[Tuple[ChatCompletionMessageRole, str]],
288
- tools: List[JSONScalarType],
289
- **invocation_parameters: Any,
290
- ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
291
- anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
292
-
293
- anthropic_params = {
294
- "messages": anthropic_messages,
295
- "model": self.model_name,
296
- "system": system_prompt,
297
- "max_tokens": 1024,
298
- **invocation_parameters,
299
- }
300
-
301
- async with self.client.messages.stream(**anthropic_params) as stream:
302
- async for text in stream.text_stream:
303
- yield TextChunk(content=text)
304
-
305
- def _build_anthropic_messages(
306
- self, messages: List[Tuple[ChatCompletionMessageRole, str]]
307
- ) -> Tuple[List["MessageParam"], str]:
308
- anthropic_messages: List["MessageParam"] = []
309
- system_prompt = ""
310
- for role, content in messages:
311
- if role == ChatCompletionMessageRole.USER:
312
- anthropic_messages.append({"role": "user", "content": content})
313
- elif role == ChatCompletionMessageRole.AI:
314
- anthropic_messages.append({"role": "assistant", "content": content})
315
- elif role == ChatCompletionMessageRole.SYSTEM:
316
- system_prompt += content + "\n"
317
- elif role == ChatCompletionMessageRole.TOOL:
318
- raise NotImplementedError
319
- else:
320
- assert_never(role)
321
-
322
- return anthropic_messages, system_prompt
323
-
324
- @property
325
- def attributes(self) -> Dict[str, Any]:
326
- return dict()
69
+ DatasetExampleID: TypeAlias = GlobalID
70
+ PLAYGROUND_PROJECT_NAME = "playground"
327
71
 
328
72
 
329
73
  @strawberry.type
@@ -332,82 +76,48 @@ class Subscription:
332
76
  async def chat_completion(
333
77
  self, info: Info[Context, None], input: ChatCompletionInput
334
78
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
335
- # Determine which LLM client to use based on provider_key
336
79
  provider_key = input.model.provider_key
337
- llm_client_class = PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)
80
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
338
81
  if llm_client_class is None:
339
- raise ValueError(f"No LLM client registered for provider '{provider_key}'")
340
-
341
- llm_client = llm_client_class(model=input.model, api_key=input.api_key)
342
-
343
- messages = [(message.role, message.content) for message in input.messages]
82
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
83
+ llm_client = llm_client_class(
84
+ model=input.model,
85
+ api_key=input.api_key,
86
+ )
344
87
 
88
+ messages = [
89
+ (
90
+ message.role,
91
+ message.content,
92
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
93
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
94
+ )
95
+ for message in input.messages
96
+ ]
345
97
  if template_options := input.template:
346
- messages = list(_formatted_messages(messages, template_options))
347
-
348
- invocation_parameters = jsonify(input.invocation_parameters)
349
-
350
- in_memory_span_exporter = InMemorySpanExporter()
351
- tracer_provider = TracerProvider()
352
- tracer_provider.add_span_processor(
353
- span_processor=SimpleSpanProcessor(span_exporter=in_memory_span_exporter)
354
- )
355
- tracer = tracer_provider.get_tracer(__name__)
356
- span_name = "ChatCompletion"
357
- with tracer.start_span(
358
- span_name,
359
- attributes=dict(
360
- chain(
361
- _llm_span_kind(),
362
- _llm_model_name(input.model.name),
363
- _llm_tools(input.tools or []),
364
- _llm_input_messages(messages),
365
- _llm_invocation_parameters(invocation_parameters),
366
- _input_value_and_mime_type(input),
98
+ messages = list(
99
+ _formatted_messages(
100
+ messages=messages,
101
+ template_language=template_options.language,
102
+ template_variables=template_options.variables,
367
103
  )
368
- ),
104
+ )
105
+ invocation_parameters = llm_client.construct_invocation_parameters(
106
+ input.invocation_parameters
107
+ )
108
+ async with streaming_llm_span(
109
+ input=input,
110
+ messages=messages,
111
+ invocation_parameters=invocation_parameters,
369
112
  ) as span:
370
- response_chunks = []
371
- text_chunks: List[TextChunk] = []
372
- tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list)
373
-
374
113
  async for chunk in llm_client.chat_completion_create(
375
- messages=messages,
376
- tools=input.tools or [],
377
- **invocation_parameters,
114
+ messages=messages, tools=input.tools or [], **invocation_parameters
378
115
  ):
379
- response_chunks.append(chunk)
380
- if isinstance(chunk, TextChunk):
381
- yield chunk
382
- text_chunks.append(chunk)
383
- elif isinstance(chunk, ToolCallChunk):
384
- yield chunk
385
- tool_call_chunks[chunk.id].append(chunk)
386
-
387
- span.set_status(StatusCode.OK)
388
- llm_client_attributes = llm_client.attributes
389
-
390
- span.set_attributes(
391
- dict(
392
- chain(
393
- _output_value_and_mime_type(response_chunks),
394
- llm_client_attributes.items(),
395
- _llm_output_messages(text_chunks, tool_call_chunks),
396
- )
397
- )
398
- )
399
- assert len(spans := in_memory_span_exporter.get_finished_spans()) == 1
400
- finished_span = spans[0]
401
- assert finished_span.start_time is not None
402
- assert finished_span.end_time is not None
403
- assert (attributes := finished_span.attributes) is not None
404
- start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
405
- end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
406
- prompt_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
407
- completion_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
408
- trace_id = _hex(finished_span.context.trace_id)
409
- span_id = _hex(finished_span.context.span_id)
410
- status = finished_span.status
116
+ span.add_response_chunk(chunk)
117
+ yield chunk
118
+ span.set_attributes(llm_client.attributes)
119
+ if span.error_message is not None:
120
+ yield ChatCompletionSubscriptionError(message=span.error_message)
411
121
  async with info.context.db() as session:
412
122
  if (
413
123
  playground_project_id := await session.scalar(
@@ -422,130 +132,273 @@ class Subscription:
422
132
  description="Traces from prompt playground",
423
133
  )
424
134
  )
425
- playground_trace = models.Trace(
426
- project_rowid=playground_project_id,
427
- trace_id=trace_id,
428
- start_time=start_time,
429
- end_time=end_time,
430
- )
431
- playground_span = models.Span(
432
- trace_rowid=playground_trace.id,
433
- span_id=span_id,
434
- parent_id=None,
435
- name=span_name,
436
- span_kind=LLM,
437
- start_time=start_time,
438
- end_time=end_time,
439
- attributes=unflatten(attributes.items()),
440
- events=finished_span.events,
441
- status_code=status.status_code.name,
442
- status_message=status.description or "",
443
- cumulative_error_count=int(not status.is_ok),
444
- cumulative_llm_token_count_prompt=prompt_tokens,
445
- cumulative_llm_token_count_completion=completion_tokens,
446
- llm_token_count_prompt=prompt_tokens,
447
- llm_token_count_completion=completion_tokens,
448
- trace=playground_trace,
449
- )
450
- session.add(playground_trace)
451
- session.add(playground_span)
135
+ db_span = span.add_to_session(session, playground_project_id)
452
136
  await session.flush()
453
- yield FinishedChatCompletion(span=to_gql_span(playground_span))
137
+ yield FinishedChatCompletion(span=to_gql_span(db_span))
454
138
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
455
139
 
140
+ @strawberry.subscription
141
+ async def chat_completion_over_dataset(
142
+ self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
143
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
144
+ provider_key = input.model.provider_key
145
+ llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
146
+ if llm_client_class is None:
147
+ raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
456
148
 
457
- def _llm_span_kind() -> Iterator[Tuple[str, Any]]:
458
- yield OPENINFERENCE_SPAN_KIND, LLM
459
-
460
-
461
- def _llm_model_name(model_name: str) -> Iterator[Tuple[str, Any]]:
462
- yield LLM_MODEL_NAME, model_name
463
-
464
-
465
- def _llm_invocation_parameters(invocation_parameters: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
466
- yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
467
-
468
-
469
- def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]:
470
- for tool_index, tool in enumerate(tools):
471
- yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
472
-
473
-
474
- def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]:
475
- yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
476
- yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
477
- yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
478
-
479
-
480
- def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
481
- assert (api_key := "api_key") in (input_data := jsonify(input))
482
- input_data = {k: v for k, v in input_data.items() if k != api_key}
483
- assert api_key not in input_data
484
- yield INPUT_MIME_TYPE, JSON
485
- yield INPUT_VALUE, safe_json_dumps(input_data)
486
-
487
-
488
- def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
489
- yield OUTPUT_MIME_TYPE, JSON
490
- yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
491
-
492
-
493
- def _llm_input_messages(
494
- messages: Iterable[Tuple[ChatCompletionMessageRole, str]],
495
- ) -> Iterator[Tuple[str, Any]]:
496
- for i, (role, content) in enumerate(messages):
497
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
498
- yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
499
-
149
+ dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
150
+ version_id = (
151
+ from_global_id_with_expected_type(
152
+ global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
153
+ )
154
+ if input.dataset_version_id
155
+ else None
156
+ )
157
+ revision_ids = (
158
+ select(func.max(models.DatasetExampleRevision.id))
159
+ .join(models.DatasetExample)
160
+ .where(models.DatasetExample.dataset_id == dataset_id)
161
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
162
+ )
163
+ if version_id:
164
+ version_id_subquery = (
165
+ select(models.DatasetVersion.id)
166
+ .where(models.DatasetVersion.dataset_id == dataset_id)
167
+ .where(models.DatasetVersion.id == version_id)
168
+ .scalar_subquery()
169
+ )
170
+ revision_ids = revision_ids.where(
171
+ models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
172
+ )
173
+ query = (
174
+ select(models.DatasetExampleRevision)
175
+ .where(
176
+ and_(
177
+ models.DatasetExampleRevision.id.in_(revision_ids),
178
+ models.DatasetExampleRevision.revision_kind != "DELETE",
179
+ )
180
+ )
181
+ .order_by(models.DatasetExampleRevision.dataset_example_id.asc())
182
+ .options(
183
+ load_only(
184
+ models.DatasetExampleRevision.dataset_example_id,
185
+ models.DatasetExampleRevision.input,
186
+ )
187
+ )
188
+ )
189
+ async with info.context.db() as session:
190
+ revisions = [revision async for revision in await session.stream_scalars(query)]
191
+ if not revisions:
192
+ raise BadRequest("No examples found for the given dataset and version")
193
+
194
+ spans: dict[DatasetExampleID, streaming_llm_span] = {}
195
+ async for payload in _merge_iterators(
196
+ [
197
+ _stream_chat_completion_over_dataset_example(
198
+ input=input,
199
+ llm_client_class=llm_client_class,
200
+ revision=revision,
201
+ spans=spans,
202
+ )
203
+ for revision in revisions
204
+ ]
205
+ ):
206
+ yield payload
500
207
 
501
- def _llm_output_messages(
502
- text_chunks: List[TextChunk],
503
- tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]],
504
- ) -> Iterator[Tuple[str, Any]]:
505
- yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
506
- if content := "".join(chunk.content for chunk in text_chunks):
507
- yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
508
- for tool_call_index, tool_call_chunks_ in tool_call_chunks.items():
509
- if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
510
- yield (
511
- f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
512
- name,
208
+ async with info.context.db() as session:
209
+ if (
210
+ playground_project_id := await session.scalar(
211
+ select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
212
+ )
213
+ ) is None:
214
+ playground_project_id = await session.scalar(
215
+ insert(models.Project)
216
+ .returning(models.Project.id)
217
+ .values(
218
+ name=PLAYGROUND_PROJECT_NAME,
219
+ description="Traces from prompt playground",
220
+ )
221
+ )
222
+ db_spans = {
223
+ example_id: span.add_to_session(session, playground_project_id)
224
+ for example_id, span in spans.items()
225
+ }
226
+ assert (
227
+ dataset_name := await session.scalar(
228
+ select(models.Dataset.name).where(models.Dataset.id == dataset_id)
229
+ )
230
+ ) is not None
231
+ if version_id is None:
232
+ resolved_version_id = await session.scalar(
233
+ select(models.DatasetVersion.id)
234
+ .where(models.DatasetVersion.dataset_id == dataset_id)
235
+ .order_by(models.DatasetVersion.id.desc())
236
+ .limit(1)
237
+ )
238
+ else:
239
+ resolved_version_id = await session.scalar(
240
+ select(models.DatasetVersion.id).where(
241
+ and_(
242
+ models.DatasetVersion.dataset_id == dataset_id,
243
+ models.DatasetVersion.id == version_id,
244
+ )
245
+ )
246
+ )
247
+ assert resolved_version_id is not None
248
+ resolved_version_node_id = GlobalID(DatasetVersion.__name__, str(resolved_version_id))
249
+ experiment = models.Experiment(
250
+ dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
251
+ dataset_version_id=resolved_version_id,
252
+ name=input.experiment_name or _DEFAULT_PLAYGROUND_EXPERIMENT_NAME,
253
+ description=input.experiment_description
254
+ or _default_playground_experiment_description(dataset_name=dataset_name),
255
+ repetitions=1,
256
+ metadata_=input.experiment_metadata
257
+ or _default_playground_experiment_metadata(
258
+ dataset_name=dataset_name,
259
+ dataset_id=input.dataset_id,
260
+ version_id=resolved_version_node_id,
261
+ ),
262
+ project_name=PLAYGROUND_PROJECT_NAME,
513
263
  )
514
- if arguments := "".join(chunk.function.arguments for chunk in tool_call_chunks_):
515
- yield (
516
- f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
517
- arguments,
264
+ session.add(experiment)
265
+ await session.flush()
266
+ runs = [
267
+ models.ExperimentRun(
268
+ experiment_id=experiment.id,
269
+ dataset_example_id=from_global_id_with_expected_type(
270
+ example_id, DatasetExample.__name__
271
+ ),
272
+ trace_id=span.trace_id,
273
+ output=models.ExperimentRunOutput(
274
+ task_output=_get_playground_experiment_task_output(span)
275
+ ),
276
+ repetition_number=1,
277
+ start_time=span.start_time,
278
+ end_time=span.end_time,
279
+ error=error_message
280
+ if (error_message := span.error_message) is not None
281
+ else None,
282
+ prompt_token_count=get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT),
283
+ completion_token_count=get_attribute_value(
284
+ span.attributes, LLM_TOKEN_COUNT_COMPLETION
285
+ ),
286
+ )
287
+ for example_id, span in spans.items()
288
+ ]
289
+ session.add_all(runs)
290
+ await session.flush()
291
+ for example_id in spans:
292
+ yield FinishedChatCompletion(
293
+ span=to_gql_span(db_spans[example_id]),
294
+ dataset_example_id=example_id,
518
295
  )
296
+ yield ChatCompletionOverDatasetSubscriptionResult(experiment=to_gql_experiment(experiment))
297
+
298
+
299
+ async def _stream_chat_completion_over_dataset_example(
300
+ *,
301
+ input: ChatCompletionOverDatasetInput,
302
+ llm_client_class: type["PlaygroundStreamingClient"],
303
+ revision: models.DatasetExampleRevision,
304
+ spans: dict[DatasetExampleID, streaming_llm_span],
305
+ ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
306
+ example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
307
+ llm_client = llm_client_class(
308
+ model=input.model,
309
+ api_key=input.api_key,
310
+ )
311
+ invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
312
+ messages = [
313
+ (
314
+ message.role,
315
+ message.content,
316
+ message.tool_call_id if isinstance(message.tool_call_id, str) else None,
317
+ message.tool_calls if isinstance(message.tool_calls, list) else None,
318
+ )
319
+ for message in input.messages
320
+ ]
321
+ try:
322
+ messages = list(
323
+ _formatted_messages(
324
+ messages=messages,
325
+ template_language=input.template_language,
326
+ template_variables=revision.input,
327
+ )
328
+ )
329
+ except TemplateFormatterError as error:
330
+ yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
331
+ return
332
+ span = streaming_llm_span(
333
+ input=input,
334
+ messages=messages,
335
+ invocation_parameters=invocation_parameters,
336
+ )
337
+ spans[example_id] = span
338
+ async with span:
339
+ async for chunk in llm_client.chat_completion_create(
340
+ messages=messages, tools=input.tools or [], **invocation_parameters
341
+ ):
342
+ span.add_response_chunk(chunk)
343
+ chunk.dataset_example_id = example_id
344
+ yield chunk
345
+ span.set_attributes(llm_client.attributes)
346
+ if span.error_message is not None:
347
+ yield ChatCompletionSubscriptionError(
348
+ message=span.error_message, dataset_example_id=example_id
349
+ )
519
350
 
520
351
 
521
- def _hex(number: int) -> str:
522
- """
523
- Converts an integer to a hexadecimal string.
524
- """
525
- return hex(number)[2:]
352
+ async def _merge_iterators(
353
+ iterators: Collection[AsyncIterator[GenericType]],
354
+ ) -> AsyncIterator[GenericType]:
355
+ tasks: dict[AsyncIterator[GenericType], Task[GenericType]] = {
356
+ iterable: _as_task(iterable) for iterable in iterators
357
+ }
358
+ while tasks:
359
+ completed_tasks, _ = await wait(tasks.values(), return_when=FIRST_COMPLETED)
360
+ for task in completed_tasks:
361
+ iterator = next(it for it, t in tasks.items() if t == task)
362
+ try:
363
+ yield task.result()
364
+ except StopAsyncIteration:
365
+ del tasks[iterator]
366
+ except Exception as error:
367
+ del tasks[iterator]
368
+ logger.exception(error)
369
+ else:
370
+ tasks[iterator] = _as_task(iterator)
526
371
 
527
372
 
528
- def _datetime(*, epoch_nanoseconds: float) -> datetime:
529
- """
530
- Converts a Unix epoch timestamp in nanoseconds to a datetime.
531
- """
532
- epoch_seconds = epoch_nanoseconds / 1e9
533
- return datetime.fromtimestamp(epoch_seconds).replace(tzinfo=timezone.utc)
373
+ def _as_task(iterable: AsyncIterator[GenericType]) -> Task[GenericType]:
374
+ return create_task(_as_coroutine(iterable))
375
+
376
+
377
+ async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
378
+ return await iterable.__anext__()
534
379
 
535
380
 
536
381
  def _formatted_messages(
537
- messages: Iterable[Tuple[ChatCompletionMessageRole, str]], template_options: TemplateOptions
538
- ) -> Iterator[Tuple[ChatCompletionMessageRole, str]]:
382
+ *,
383
+ messages: Iterable[ChatCompletionMessage],
384
+ template_language: TemplateLanguage,
385
+ template_variables: Mapping[str, Any],
386
+ ) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
539
387
  """
540
388
  Formats the messages using the given template options.
541
389
  """
542
- template_formatter = _template_formatter(template_language=template_options.language)
543
- roles, templates = zip(*messages)
390
+ template_formatter = _template_formatter(template_language=template_language)
391
+ (
392
+ roles,
393
+ templates,
394
+ tool_call_id,
395
+ tool_calls,
396
+ ) = zip(*messages)
544
397
  formatted_templates = map(
545
- lambda template: template_formatter.format(template, **template_options.variables),
398
+ lambda template: template_formatter.format(template, **template_variables),
546
399
  templates,
547
400
  )
548
- formatted_messages = zip(roles, formatted_templates)
401
+ formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
549
402
  return formatted_messages
550
403
 
551
404
 
@@ -560,29 +413,29 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
560
413
  assert_never(template_language)
561
414
 
562
415
 
563
- JSON = OpenInferenceMimeTypeValues.JSON.value
416
+ def _get_playground_experiment_task_output(
417
+ span: streaming_llm_span,
418
+ ) -> Any:
419
+ return get_attribute_value(span.attributes, LLM_OUTPUT_MESSAGES)
564
420
 
565
- LLM = OpenInferenceSpanKindValues.LLM.value
566
421
 
567
- OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
568
- INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
569
- INPUT_VALUE = SpanAttributes.INPUT_VALUE
570
- OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
571
- OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
572
- LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
573
- LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
574
- LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
575
- LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
576
- LLM_TOOLS = SpanAttributes.LLM_TOOLS
577
- LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
578
- LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
579
- LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
422
+ _DEFAULT_PLAYGROUND_EXPERIMENT_NAME = "playground-experiment"
423
+
424
+
425
+ def _default_playground_experiment_description(dataset_name: str) -> str:
426
+ return f'Playground experiment for dataset "{dataset_name}"'
580
427
 
581
- MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
582
- MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
583
- MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
584
428
 
585
- TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
586
- TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
429
+ def _default_playground_experiment_metadata(
430
+ dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
431
+ ) -> dict[str, Any]:
432
+ return {
433
+ "dataset_name": dataset_name,
434
+ "dataset_id": str(dataset_id),
435
+ "dataset_version_id": str(version_id),
436
+ }
587
437
 
588
- TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
438
+
439
+ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
440
+ LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
441
+ LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT