arize-phoenix 5.5.2__py3-none-any.whl → 5.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/METADATA +3 -6
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/RECORD +171 -171
- phoenix/config.py +8 -8
- phoenix/core/model.py +3 -3
- phoenix/core/model_schema.py +41 -50
- phoenix/core/model_schema_adapter.py +17 -16
- phoenix/datetime_utils.py +2 -2
- phoenix/db/bulk_inserter.py +10 -20
- phoenix/db/engines.py +2 -1
- phoenix/db/enums.py +2 -2
- phoenix/db/helpers.py +8 -7
- phoenix/db/insertion/dataset.py +9 -19
- phoenix/db/insertion/document_annotation.py +14 -13
- phoenix/db/insertion/helpers.py +6 -16
- phoenix/db/insertion/span_annotation.py +14 -13
- phoenix/db/insertion/trace_annotation.py +14 -13
- phoenix/db/insertion/types.py +19 -30
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
- phoenix/db/models.py +28 -28
- phoenix/experiments/evaluators/base.py +2 -1
- phoenix/experiments/evaluators/code_evaluators.py +4 -5
- phoenix/experiments/evaluators/llm_evaluators.py +157 -4
- phoenix/experiments/evaluators/utils.py +3 -2
- phoenix/experiments/functions.py +10 -21
- phoenix/experiments/tracing.py +2 -1
- phoenix/experiments/types.py +20 -29
- phoenix/experiments/utils.py +2 -1
- phoenix/inferences/errors.py +6 -5
- phoenix/inferences/fixtures.py +6 -5
- phoenix/inferences/inferences.py +37 -37
- phoenix/inferences/schema.py +11 -10
- phoenix/inferences/validation.py +13 -14
- phoenix/logging/_formatter.py +3 -3
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +2 -1
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +2 -2
- phoenix/pointcloud/clustering.py +3 -4
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/umap_parameters.py +2 -1
- phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
- phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
- phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
- phoenix/server/api/dataloaders/document_evaluations.py +3 -7
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
- phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
- phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
- phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
- phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
- phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
- phoenix/server/api/dataloaders/project_by_name.py +3 -3
- phoenix/server/api/dataloaders/record_counts.py +11 -18
- phoenix/server/api/dataloaders/span_annotations.py +3 -7
- phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
- phoenix/server/api/dataloaders/span_descendants.py +3 -7
- phoenix/server/api/dataloaders/span_projects.py +2 -2
- phoenix/server/api/dataloaders/token_counts.py +12 -19
- phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
- phoenix/server/api/dataloaders/user_roles.py +3 -3
- phoenix/server/api/dataloaders/users.py +3 -3
- phoenix/server/api/helpers/__init__.py +4 -3
- phoenix/server/api/helpers/dataset_helpers.py +10 -9
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
- phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +2 -2
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
- phoenix/server/api/mutations/dataset_mutations.py +4 -4
- phoenix/server/api/mutations/experiment_mutations.py +1 -2
- phoenix/server/api/mutations/export_events_mutations.py +7 -7
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +20 -20
- phoenix/server/api/routers/oauth2.py +4 -4
- phoenix/server/api/routers/v1/datasets.py +22 -36
- phoenix/server/api/routers/v1/evaluations.py +6 -5
- phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
- phoenix/server/api/routers/v1/experiment_runs.py +2 -2
- phoenix/server/api/routers/v1/experiments.py +4 -4
- phoenix/server/api/routers/v1/spans.py +13 -12
- phoenix/server/api/routers/v1/traces.py +5 -5
- phoenix/server/api/routers/v1/utils.py +5 -5
- phoenix/server/api/subscriptions.py +284 -162
- phoenix/server/api/types/AnnotationSummary.py +3 -3
- phoenix/server/api/types/Cluster.py +8 -7
- phoenix/server/api/types/Dataset.py +5 -4
- phoenix/server/api/types/Dimension.py +3 -3
- phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
- phoenix/server/api/types/EmbeddingDimension.py +6 -5
- phoenix/server/api/types/EvaluationSummary.py +3 -3
- phoenix/server/api/types/Event.py +7 -7
- phoenix/server/api/types/Experiment.py +3 -3
- phoenix/server/api/types/ExperimentComparison.py +2 -4
- phoenix/server/api/types/Inferences.py +9 -8
- phoenix/server/api/types/InferencesRole.py +2 -2
- phoenix/server/api/types/Model.py +2 -2
- phoenix/server/api/types/Project.py +11 -18
- phoenix/server/api/types/Segments.py +3 -3
- phoenix/server/api/types/Span.py +8 -7
- phoenix/server/api/types/TimeSeries.py +8 -7
- phoenix/server/api/types/Trace.py +2 -2
- phoenix/server/api/types/UMAPPoints.py +6 -6
- phoenix/server/api/types/User.py +3 -3
- phoenix/server/api/types/node.py +1 -3
- phoenix/server/api/types/pagination.py +4 -4
- phoenix/server/api/utils.py +2 -4
- phoenix/server/app.py +16 -25
- phoenix/server/bearer_auth.py +4 -10
- phoenix/server/dml_event.py +3 -3
- phoenix/server/dml_event_handler.py +10 -24
- phoenix/server/grpc_server.py +3 -2
- phoenix/server/jwt_store.py +22 -21
- phoenix/server/main.py +3 -3
- phoenix/server/oauth2.py +3 -2
- phoenix/server/rate_limiters.py +5 -8
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/components-C70HJiXz.js +1612 -0
- phoenix/server/static/assets/{index-DCzakdJq.js → index-DLe1Oo3l.js} +2 -2
- phoenix/server/static/assets/{pages-CAL1FDMt.js → pages-C8-Sl7JI.js} +269 -434
- phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
- phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
- phoenix/server/thread_server.py +1 -1
- phoenix/server/types.py +17 -29
- phoenix/services.py +4 -3
- phoenix/session/client.py +12 -24
- phoenix/session/data_extractor.py +3 -3
- phoenix/session/evaluation.py +1 -2
- phoenix/session/session.py +11 -20
- phoenix/trace/attributes.py +16 -28
- phoenix/trace/dsl/filter.py +17 -21
- phoenix/trace/dsl/helpers.py +3 -3
- phoenix/trace/dsl/query.py +13 -22
- phoenix/trace/fixtures.py +11 -17
- phoenix/trace/otel.py +5 -15
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +2 -2
- phoenix/trace/span_evaluations.py +9 -8
- phoenix/trace/span_json_decoder.py +3 -3
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +6 -5
- phoenix/trace/utils.py +6 -6
- phoenix/utilities/deprecation.py +3 -2
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +2 -1
- phoenix/utilities/logging.py +2 -2
- phoenix/utilities/project.py +1 -1
- phoenix/utilities/re.py +3 -4
- phoenix/utilities/template_formatters.py +5 -4
- phoenix/version.py +1 -1
- phoenix/server/static/assets/components-hX0LgYz3.js +0 -1428
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,25 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections import defaultdict
|
|
4
|
+
from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Mapping
|
|
5
|
+
from dataclasses import asdict
|
|
4
6
|
from datetime import datetime, timezone
|
|
5
7
|
from enum import Enum
|
|
6
8
|
from itertools import chain
|
|
7
|
-
from
|
|
8
|
-
|
|
9
|
-
Annotated,
|
|
10
|
-
Any,
|
|
11
|
-
AsyncIterator,
|
|
12
|
-
Callable,
|
|
13
|
-
DefaultDict,
|
|
14
|
-
Dict,
|
|
15
|
-
Iterable,
|
|
16
|
-
Iterator,
|
|
17
|
-
List,
|
|
18
|
-
Optional,
|
|
19
|
-
Tuple,
|
|
20
|
-
Type,
|
|
21
|
-
Union,
|
|
22
|
-
)
|
|
9
|
+
from traceback import format_exc
|
|
10
|
+
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
|
23
11
|
|
|
24
12
|
import strawberry
|
|
25
13
|
from openinference.instrumentation import safe_json_dumps
|
|
@@ -31,9 +19,7 @@ from openinference.semconv.trace import (
|
|
|
31
19
|
ToolAttributes,
|
|
32
20
|
ToolCallAttributes,
|
|
33
21
|
)
|
|
34
|
-
from opentelemetry.sdk.trace import
|
|
35
|
-
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
|
36
|
-
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
|
22
|
+
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
|
|
37
23
|
from opentelemetry.trace import StatusCode
|
|
38
24
|
from sqlalchemy import insert, select
|
|
39
25
|
from strawberry import UNSET
|
|
@@ -41,8 +27,10 @@ from strawberry.scalars import JSON as JSONScalarType
|
|
|
41
27
|
from strawberry.types import Info
|
|
42
28
|
from typing_extensions import TypeAlias, assert_never
|
|
43
29
|
|
|
30
|
+
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
44
31
|
from phoenix.db import models
|
|
45
32
|
from phoenix.server.api.context import Context
|
|
33
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
46
34
|
from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
|
|
47
35
|
from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
|
|
48
36
|
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
@@ -50,6 +38,10 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
|
50
38
|
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
51
39
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
52
40
|
from phoenix.trace.attributes import unflatten
|
|
41
|
+
from phoenix.trace.schemas import (
|
|
42
|
+
SpanEvent,
|
|
43
|
+
SpanException,
|
|
44
|
+
)
|
|
53
45
|
from phoenix.utilities.json import jsonify
|
|
54
46
|
from phoenix.utilities.template_formatters import (
|
|
55
47
|
FStringTemplateFormatter,
|
|
@@ -60,11 +52,15 @@ from phoenix.utilities.template_formatters import (
|
|
|
60
52
|
if TYPE_CHECKING:
|
|
61
53
|
from anthropic.types import MessageParam
|
|
62
54
|
from openai.types import CompletionUsage
|
|
63
|
-
from openai.types.chat import
|
|
55
|
+
from openai.types.chat import (
|
|
56
|
+
ChatCompletionMessageParam,
|
|
57
|
+
ChatCompletionMessageToolCallParam,
|
|
58
|
+
)
|
|
64
59
|
|
|
65
60
|
PLAYGROUND_PROJECT_NAME = "playground"
|
|
66
61
|
|
|
67
62
|
ToolCallID: TypeAlias = str
|
|
63
|
+
SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
|
|
68
64
|
|
|
69
65
|
|
|
70
66
|
@strawberry.enum
|
|
@@ -96,13 +92,20 @@ class ToolCallChunk:
|
|
|
96
92
|
function: FunctionCallChunk
|
|
97
93
|
|
|
98
94
|
|
|
95
|
+
@strawberry.type
|
|
96
|
+
class ChatCompletionSubscriptionError:
|
|
97
|
+
message: str
|
|
98
|
+
|
|
99
|
+
|
|
99
100
|
@strawberry.type
|
|
100
101
|
class FinishedChatCompletion:
|
|
101
102
|
span: Span
|
|
102
103
|
|
|
103
104
|
|
|
105
|
+
ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
|
|
106
|
+
|
|
104
107
|
ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
|
|
105
|
-
Union[TextChunk, ToolCallChunk, FinishedChatCompletion],
|
|
108
|
+
Union[TextChunk, ToolCallChunk, FinishedChatCompletion, ChatCompletionSubscriptionError],
|
|
106
109
|
strawberry.union("ChatCompletionSubscriptionPayload"),
|
|
107
110
|
]
|
|
108
111
|
|
|
@@ -120,23 +123,23 @@ class GenerativeModelInput:
|
|
|
120
123
|
|
|
121
124
|
@strawberry.input
|
|
122
125
|
class ChatCompletionInput:
|
|
123
|
-
messages:
|
|
126
|
+
messages: list[ChatCompletionMessageInput]
|
|
124
127
|
model: GenerativeModelInput
|
|
125
|
-
invocation_parameters: InvocationParameters
|
|
126
|
-
tools: Optional[
|
|
128
|
+
invocation_parameters: InvocationParameters = strawberry.field(default_factory=dict)
|
|
129
|
+
tools: Optional[list[JSONScalarType]] = UNSET
|
|
127
130
|
template: Optional[TemplateOptions] = UNSET
|
|
128
131
|
api_key: Optional[str] = strawberry.field(default=None)
|
|
129
132
|
|
|
130
133
|
|
|
131
|
-
PLAYGROUND_STREAMING_CLIENT_REGISTRY:
|
|
132
|
-
GenerativeProviderKey,
|
|
134
|
+
PLAYGROUND_STREAMING_CLIENT_REGISTRY: dict[
|
|
135
|
+
GenerativeProviderKey, type["PlaygroundStreamingClient"]
|
|
133
136
|
] = {}
|
|
134
137
|
|
|
135
138
|
|
|
136
139
|
def register_llm_client(
|
|
137
140
|
provider_key: GenerativeProviderKey,
|
|
138
|
-
) -> Callable[[
|
|
139
|
-
def decorator(cls:
|
|
141
|
+
) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
|
|
142
|
+
def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
|
|
140
143
|
PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
|
|
141
144
|
return cls
|
|
142
145
|
|
|
@@ -144,45 +147,56 @@ def register_llm_client(
|
|
|
144
147
|
|
|
145
148
|
|
|
146
149
|
class PlaygroundStreamingClient(ABC):
|
|
147
|
-
def __init__(
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
model: GenerativeModelInput,
|
|
153
|
+
api_key: Optional[str] = None,
|
|
154
|
+
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
155
|
+
) -> None:
|
|
156
|
+
self._set_span_attributes = set_span_attributes
|
|
148
157
|
|
|
149
158
|
@abstractmethod
|
|
150
159
|
async def chat_completion_create(
|
|
151
160
|
self,
|
|
152
|
-
messages:
|
|
153
|
-
|
|
161
|
+
messages: list[
|
|
162
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
163
|
+
],
|
|
164
|
+
tools: list[JSONScalarType],
|
|
154
165
|
**invocation_parameters: Any,
|
|
155
|
-
) -> AsyncIterator[
|
|
166
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
156
167
|
# a yield statement is needed to satisfy the type-checker
|
|
157
168
|
# https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
|
|
158
169
|
yield TextChunk(content="")
|
|
159
170
|
|
|
160
|
-
@property
|
|
161
|
-
@abstractmethod
|
|
162
|
-
def attributes(self) -> Dict[str, Any]: ...
|
|
163
|
-
|
|
164
171
|
|
|
165
172
|
@register_llm_client(GenerativeProviderKey.OPENAI)
|
|
166
173
|
class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
167
|
-
def __init__(
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
model: GenerativeModelInput,
|
|
177
|
+
api_key: Optional[str] = None,
|
|
178
|
+
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
179
|
+
) -> None:
|
|
168
180
|
from openai import AsyncOpenAI
|
|
169
181
|
|
|
182
|
+
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
170
183
|
self.client = AsyncOpenAI(api_key=api_key)
|
|
171
184
|
self.model_name = model.name
|
|
172
|
-
self._attributes: Dict[str, Any] = {}
|
|
173
185
|
|
|
174
186
|
async def chat_completion_create(
|
|
175
187
|
self,
|
|
176
|
-
messages:
|
|
177
|
-
|
|
188
|
+
messages: list[
|
|
189
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
190
|
+
],
|
|
191
|
+
tools: list[JSONScalarType],
|
|
178
192
|
**invocation_parameters: Any,
|
|
179
|
-
) -> AsyncIterator[
|
|
193
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
180
194
|
from openai import NOT_GIVEN
|
|
181
195
|
from openai.types.chat import ChatCompletionStreamOptionsParam
|
|
182
196
|
|
|
183
197
|
# Convert standard messages to OpenAI messages
|
|
184
198
|
openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
|
|
185
|
-
tool_call_ids:
|
|
199
|
+
tool_call_ids: dict[int, str] = {}
|
|
186
200
|
token_usage: Optional["CompletionUsage"] = None
|
|
187
201
|
async for chunk in await self.client.chat.completions.create(
|
|
188
202
|
messages=openai_messages,
|
|
@@ -218,15 +232,20 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
218
232
|
),
|
|
219
233
|
)
|
|
220
234
|
yield tool_call_chunk
|
|
221
|
-
if token_usage is not None:
|
|
222
|
-
self.
|
|
235
|
+
if token_usage is not None and self._set_span_attributes:
|
|
236
|
+
self._set_span_attributes(dict(self._llm_token_counts(token_usage)))
|
|
223
237
|
|
|
224
238
|
def to_openai_chat_completion_param(
|
|
225
|
-
self,
|
|
239
|
+
self,
|
|
240
|
+
role: ChatCompletionMessageRole,
|
|
241
|
+
content: JSONScalarType,
|
|
242
|
+
tool_call_id: Optional[str] = None,
|
|
243
|
+
tool_calls: Optional[list[JSONScalarType]] = None,
|
|
226
244
|
) -> "ChatCompletionMessageParam":
|
|
227
245
|
from openai.types.chat import (
|
|
228
246
|
ChatCompletionAssistantMessageParam,
|
|
229
247
|
ChatCompletionSystemMessageParam,
|
|
248
|
+
ChatCompletionToolMessageParam,
|
|
230
249
|
ChatCompletionUserMessageParam,
|
|
231
250
|
)
|
|
232
251
|
|
|
@@ -245,26 +264,64 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
245
264
|
}
|
|
246
265
|
)
|
|
247
266
|
if role is ChatCompletionMessageRole.AI:
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
267
|
+
if tool_calls is None:
|
|
268
|
+
return ChatCompletionAssistantMessageParam(
|
|
269
|
+
{
|
|
270
|
+
"content": content,
|
|
271
|
+
"role": "assistant",
|
|
272
|
+
}
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
return ChatCompletionAssistantMessageParam(
|
|
276
|
+
{
|
|
277
|
+
"content": content,
|
|
278
|
+
"role": "assistant",
|
|
279
|
+
"tool_calls": [
|
|
280
|
+
self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
|
|
281
|
+
],
|
|
282
|
+
}
|
|
283
|
+
)
|
|
254
284
|
if role is ChatCompletionMessageRole.TOOL:
|
|
255
|
-
|
|
285
|
+
if tool_call_id is None:
|
|
286
|
+
raise ValueError("tool_call_id is required for tool messages")
|
|
287
|
+
return ChatCompletionToolMessageParam(
|
|
288
|
+
{"content": content, "role": "tool", "tool_call_id": tool_call_id}
|
|
289
|
+
)
|
|
256
290
|
assert_never(role)
|
|
257
291
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
292
|
+
def to_openai_tool_call_param(
|
|
293
|
+
self,
|
|
294
|
+
tool_call: JSONScalarType,
|
|
295
|
+
) -> "ChatCompletionMessageToolCallParam":
|
|
296
|
+
from openai.types.chat import ChatCompletionMessageToolCallParam
|
|
297
|
+
|
|
298
|
+
return ChatCompletionMessageToolCallParam(
|
|
299
|
+
id=tool_call.get("id", ""),
|
|
300
|
+
function={
|
|
301
|
+
"name": tool_call.get("function", {}).get("name", ""),
|
|
302
|
+
"arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
|
|
303
|
+
},
|
|
304
|
+
type="function",
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
@staticmethod
|
|
308
|
+
def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
|
|
309
|
+
yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
|
|
310
|
+
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
|
|
311
|
+
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
|
|
261
312
|
|
|
262
313
|
|
|
263
314
|
@register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
|
|
264
315
|
class AzureOpenAIStreamingClient(OpenAIStreamingClient):
|
|
265
|
-
def __init__(
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
model: GenerativeModelInput,
|
|
319
|
+
api_key: Optional[str] = None,
|
|
320
|
+
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
321
|
+
):
|
|
266
322
|
from openai import AsyncAzureOpenAI
|
|
267
323
|
|
|
324
|
+
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
268
325
|
if model.endpoint is None or model.api_version is None:
|
|
269
326
|
raise ValueError("endpoint and api_version are required for Azure OpenAI models")
|
|
270
327
|
self.client = AsyncAzureOpenAI(
|
|
@@ -276,18 +333,29 @@ class AzureOpenAIStreamingClient(OpenAIStreamingClient):
|
|
|
276
333
|
|
|
277
334
|
@register_llm_client(GenerativeProviderKey.ANTHROPIC)
|
|
278
335
|
class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
279
|
-
def __init__(
|
|
336
|
+
def __init__(
|
|
337
|
+
self,
|
|
338
|
+
model: GenerativeModelInput,
|
|
339
|
+
api_key: Optional[str] = None,
|
|
340
|
+
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
341
|
+
) -> None:
|
|
280
342
|
import anthropic
|
|
281
343
|
|
|
344
|
+
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
282
345
|
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
283
346
|
self.model_name = model.name
|
|
284
347
|
|
|
285
348
|
async def chat_completion_create(
|
|
286
349
|
self,
|
|
287
|
-
messages:
|
|
288
|
-
|
|
350
|
+
messages: list[
|
|
351
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
352
|
+
],
|
|
353
|
+
tools: list[JSONScalarType],
|
|
289
354
|
**invocation_parameters: Any,
|
|
290
|
-
) -> AsyncIterator[
|
|
355
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
356
|
+
import anthropic.lib.streaming as anthropic_streaming
|
|
357
|
+
import anthropic.types as anthropic_types
|
|
358
|
+
|
|
291
359
|
anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
|
|
292
360
|
|
|
293
361
|
anthropic_params = {
|
|
@@ -297,17 +365,43 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
297
365
|
"max_tokens": 1024,
|
|
298
366
|
**invocation_parameters,
|
|
299
367
|
}
|
|
300
|
-
|
|
301
368
|
async with self.client.messages.stream(**anthropic_params) as stream:
|
|
302
|
-
async for
|
|
303
|
-
|
|
369
|
+
async for event in stream:
|
|
370
|
+
if isinstance(event, anthropic_types.RawMessageStartEvent):
|
|
371
|
+
if self._set_span_attributes:
|
|
372
|
+
self._set_span_attributes(
|
|
373
|
+
{LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
|
|
374
|
+
)
|
|
375
|
+
elif isinstance(event, anthropic_streaming.TextEvent):
|
|
376
|
+
yield TextChunk(content=event.text)
|
|
377
|
+
elif isinstance(event, anthropic_streaming.MessageStopEvent):
|
|
378
|
+
if self._set_span_attributes:
|
|
379
|
+
self._set_span_attributes(
|
|
380
|
+
{LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
|
|
381
|
+
)
|
|
382
|
+
elif isinstance(
|
|
383
|
+
event,
|
|
384
|
+
(
|
|
385
|
+
anthropic_types.RawContentBlockStartEvent,
|
|
386
|
+
anthropic_types.RawContentBlockDeltaEvent,
|
|
387
|
+
anthropic_types.RawMessageDeltaEvent,
|
|
388
|
+
anthropic_streaming.ContentBlockStopEvent,
|
|
389
|
+
),
|
|
390
|
+
):
|
|
391
|
+
# event types emitted by the stream that don't contain useful information
|
|
392
|
+
pass
|
|
393
|
+
elif isinstance(event, anthropic_streaming.InputJsonEvent):
|
|
394
|
+
raise NotImplementedError
|
|
395
|
+
else:
|
|
396
|
+
assert_never(event)
|
|
304
397
|
|
|
305
398
|
def _build_anthropic_messages(
|
|
306
|
-
self,
|
|
307
|
-
|
|
308
|
-
|
|
399
|
+
self,
|
|
400
|
+
messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
401
|
+
) -> tuple[list["MessageParam"], str]:
|
|
402
|
+
anthropic_messages: list["MessageParam"] = []
|
|
309
403
|
system_prompt = ""
|
|
310
|
-
for role, content in messages:
|
|
404
|
+
for role, content, _tool_call_id, _tool_calls in messages:
|
|
311
405
|
if role == ChatCompletionMessageRole.USER:
|
|
312
406
|
anthropic_messages.append({"role": "user", "content": content})
|
|
313
407
|
elif role == ChatCompletionMessageRole.AI:
|
|
@@ -321,10 +415,6 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
321
415
|
|
|
322
416
|
return anthropic_messages, system_prompt
|
|
323
417
|
|
|
324
|
-
@property
|
|
325
|
-
def attributes(self) -> Dict[str, Any]:
|
|
326
|
-
return dict()
|
|
327
|
-
|
|
328
418
|
|
|
329
419
|
@strawberry.type
|
|
330
420
|
class Subscription:
|
|
@@ -334,43 +424,45 @@ class Subscription:
|
|
|
334
424
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
335
425
|
# Determine which LLM client to use based on provider_key
|
|
336
426
|
provider_key = input.model.provider_key
|
|
337
|
-
llm_client_class
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
427
|
+
if (llm_client_class := PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)) is None:
|
|
428
|
+
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
429
|
+
llm_client = llm_client_class(
|
|
430
|
+
model=input.model,
|
|
431
|
+
api_key=input.api_key,
|
|
432
|
+
set_span_attributes=lambda attrs: attributes.update(attrs),
|
|
433
|
+
)
|
|
434
|
+
messages = [
|
|
435
|
+
(
|
|
436
|
+
message.role,
|
|
437
|
+
message.content,
|
|
438
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
439
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
440
|
+
)
|
|
441
|
+
for message in input.messages
|
|
442
|
+
]
|
|
345
443
|
if template_options := input.template:
|
|
346
444
|
messages = list(_formatted_messages(messages, template_options))
|
|
347
|
-
|
|
348
445
|
invocation_parameters = jsonify(input.invocation_parameters)
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
446
|
+
attributes = dict(
|
|
447
|
+
chain(
|
|
448
|
+
_llm_span_kind(),
|
|
449
|
+
_llm_model_name(input.model.name),
|
|
450
|
+
_llm_tools(input.tools or []),
|
|
451
|
+
_llm_input_messages(messages),
|
|
452
|
+
_llm_invocation_parameters(invocation_parameters),
|
|
453
|
+
_input_value_and_mime_type(input),
|
|
454
|
+
)
|
|
354
455
|
)
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
_llm_invocation_parameters(invocation_parameters),
|
|
366
|
-
_input_value_and_mime_type(input),
|
|
367
|
-
)
|
|
368
|
-
),
|
|
369
|
-
) as span:
|
|
370
|
-
response_chunks = []
|
|
371
|
-
text_chunks: List[TextChunk] = []
|
|
372
|
-
tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list)
|
|
373
|
-
|
|
456
|
+
status_code: StatusCode
|
|
457
|
+
status_message = ""
|
|
458
|
+
events: list[SpanEvent] = []
|
|
459
|
+
start_time: datetime
|
|
460
|
+
end_time: datetime
|
|
461
|
+
response_chunks = []
|
|
462
|
+
text_chunks: list[TextChunk] = []
|
|
463
|
+
tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
|
|
464
|
+
try:
|
|
465
|
+
start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
374
466
|
async for chunk in llm_client.chat_completion_create(
|
|
375
467
|
messages=messages,
|
|
376
468
|
tools=input.tools or [],
|
|
@@ -383,31 +475,35 @@ class Subscription:
|
|
|
383
475
|
elif isinstance(chunk, ToolCallChunk):
|
|
384
476
|
yield chunk
|
|
385
477
|
tool_call_chunks[chunk.id].append(chunk)
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
478
|
+
else:
|
|
479
|
+
assert_never(chunk)
|
|
480
|
+
status_code = StatusCode.OK
|
|
481
|
+
except Exception as error:
|
|
482
|
+
end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
483
|
+
status_code = StatusCode.ERROR
|
|
484
|
+
status_message = str(error)
|
|
485
|
+
events.append(
|
|
486
|
+
SpanException(
|
|
487
|
+
timestamp=end_time,
|
|
488
|
+
message=status_message,
|
|
489
|
+
exception_type=type(error).__name__,
|
|
490
|
+
exception_escaped=False,
|
|
491
|
+
exception_stacktrace=format_exc(),
|
|
397
492
|
)
|
|
398
493
|
)
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
494
|
+
yield ChatCompletionSubscriptionError(message=status_message)
|
|
495
|
+
else:
|
|
496
|
+
end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
497
|
+
attributes.update(
|
|
498
|
+
chain(
|
|
499
|
+
_output_value_and_mime_type(response_chunks),
|
|
500
|
+
_llm_output_messages(text_chunks, tool_call_chunks),
|
|
501
|
+
)
|
|
502
|
+
)
|
|
503
|
+
prompt_tokens = attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
|
|
504
|
+
completion_tokens = attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
|
|
505
|
+
trace_id = _generate_trace_id()
|
|
506
|
+
span_id = _generate_span_id()
|
|
411
507
|
async with info.context.db() as session:
|
|
412
508
|
if (
|
|
413
509
|
playground_project_id := await session.scalar(
|
|
@@ -432,15 +528,15 @@ class Subscription:
|
|
|
432
528
|
trace_rowid=playground_trace.id,
|
|
433
529
|
span_id=span_id,
|
|
434
530
|
parent_id=None,
|
|
435
|
-
name=
|
|
531
|
+
name="ChatCompletion",
|
|
436
532
|
span_kind=LLM,
|
|
437
533
|
start_time=start_time,
|
|
438
534
|
end_time=end_time,
|
|
439
535
|
attributes=unflatten(attributes.items()),
|
|
440
|
-
events=
|
|
441
|
-
status_code=
|
|
442
|
-
status_message=
|
|
443
|
-
cumulative_error_count=int(
|
|
536
|
+
events=[_serialize_event(event) for event in events],
|
|
537
|
+
status_code=status_code.name,
|
|
538
|
+
status_message=status_message,
|
|
539
|
+
cumulative_error_count=int(status_code is StatusCode.ERROR),
|
|
444
540
|
cumulative_llm_token_count_prompt=prompt_tokens,
|
|
445
541
|
cumulative_llm_token_count_completion=completion_tokens,
|
|
446
542
|
llm_token_count_prompt=prompt_tokens,
|
|
@@ -454,30 +550,24 @@ class Subscription:
|
|
|
454
550
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
455
551
|
|
|
456
552
|
|
|
457
|
-
def _llm_span_kind() -> Iterator[
|
|
553
|
+
def _llm_span_kind() -> Iterator[tuple[str, Any]]:
|
|
458
554
|
yield OPENINFERENCE_SPAN_KIND, LLM
|
|
459
555
|
|
|
460
556
|
|
|
461
|
-
def _llm_model_name(model_name: str) -> Iterator[
|
|
557
|
+
def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
|
|
462
558
|
yield LLM_MODEL_NAME, model_name
|
|
463
559
|
|
|
464
560
|
|
|
465
|
-
def _llm_invocation_parameters(invocation_parameters:
|
|
561
|
+
def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
|
|
466
562
|
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
|
|
467
563
|
|
|
468
564
|
|
|
469
|
-
def _llm_tools(tools:
|
|
565
|
+
def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
|
|
470
566
|
for tool_index, tool in enumerate(tools):
|
|
471
567
|
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
|
|
472
568
|
|
|
473
569
|
|
|
474
|
-
def
|
|
475
|
-
yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
|
|
476
|
-
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
|
|
477
|
-
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
|
|
570
|
+
def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
|
|
481
571
|
assert (api_key := "api_key") in (input_data := jsonify(input))
|
|
482
572
|
input_data = {k: v for k, v in input_data.items() if k != api_key}
|
|
483
573
|
assert api_key not in input_data
|
|
@@ -485,27 +575,40 @@ def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str
|
|
|
485
575
|
yield INPUT_VALUE, safe_json_dumps(input_data)
|
|
486
576
|
|
|
487
577
|
|
|
488
|
-
def _output_value_and_mime_type(output: Any) -> Iterator[
|
|
578
|
+
def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
|
|
489
579
|
yield OUTPUT_MIME_TYPE, JSON
|
|
490
580
|
yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
|
|
491
581
|
|
|
492
582
|
|
|
493
583
|
def _llm_input_messages(
|
|
494
|
-
messages: Iterable[
|
|
495
|
-
|
|
496
|
-
|
|
584
|
+
messages: Iterable[
|
|
585
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
586
|
+
],
|
|
587
|
+
) -> Iterator[tuple[str, Any]]:
|
|
588
|
+
for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
|
|
497
589
|
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
|
|
498
590
|
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
|
|
591
|
+
if tool_calls is not None:
|
|
592
|
+
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
593
|
+
yield (
|
|
594
|
+
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
595
|
+
tool_call["function"]["name"],
|
|
596
|
+
)
|
|
597
|
+
if arguments := tool_call["function"]["arguments"]:
|
|
598
|
+
yield (
|
|
599
|
+
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
|
|
600
|
+
safe_json_dumps(jsonify(arguments)),
|
|
601
|
+
)
|
|
499
602
|
|
|
500
603
|
|
|
501
604
|
def _llm_output_messages(
|
|
502
|
-
text_chunks:
|
|
503
|
-
tool_call_chunks:
|
|
504
|
-
) -> Iterator[
|
|
605
|
+
text_chunks: list[TextChunk],
|
|
606
|
+
tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
|
|
607
|
+
) -> Iterator[tuple[str, Any]]:
|
|
505
608
|
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
|
|
506
609
|
if content := "".join(chunk.content for chunk in text_chunks):
|
|
507
610
|
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
|
|
508
|
-
for tool_call_index, tool_call_chunks_ in tool_call_chunks.items():
|
|
611
|
+
for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
|
|
509
612
|
if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
|
|
510
613
|
yield (
|
|
511
614
|
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
@@ -518,34 +621,46 @@ def _llm_output_messages(
|
|
|
518
621
|
)
|
|
519
622
|
|
|
520
623
|
|
|
521
|
-
def
|
|
624
|
+
def _generate_trace_id() -> str:
|
|
522
625
|
"""
|
|
523
|
-
|
|
626
|
+
Generates a random trace ID in hexadecimal format.
|
|
524
627
|
"""
|
|
525
|
-
return
|
|
628
|
+
return _hex(DefaultOTelIDGenerator().generate_trace_id())
|
|
526
629
|
|
|
527
630
|
|
|
528
|
-
def
|
|
631
|
+
def _generate_span_id() -> str:
|
|
529
632
|
"""
|
|
530
|
-
|
|
633
|
+
Generates a random span ID in hexadecimal format.
|
|
531
634
|
"""
|
|
532
|
-
|
|
533
|
-
|
|
635
|
+
return _hex(DefaultOTelIDGenerator().generate_span_id())
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def _hex(number: int) -> str:
|
|
639
|
+
"""
|
|
640
|
+
Converts an integer to a hexadecimal string.
|
|
641
|
+
"""
|
|
642
|
+
return hex(number)[2:]
|
|
534
643
|
|
|
535
644
|
|
|
536
645
|
def _formatted_messages(
|
|
537
|
-
messages: Iterable[
|
|
538
|
-
|
|
646
|
+
messages: Iterable[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
647
|
+
template_options: TemplateOptions,
|
|
648
|
+
) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
|
|
539
649
|
"""
|
|
540
650
|
Formats the messages using the given template options.
|
|
541
651
|
"""
|
|
542
652
|
template_formatter = _template_formatter(template_language=template_options.language)
|
|
543
|
-
|
|
653
|
+
(
|
|
654
|
+
roles,
|
|
655
|
+
templates,
|
|
656
|
+
tool_call_id,
|
|
657
|
+
tool_calls,
|
|
658
|
+
) = zip(*messages)
|
|
544
659
|
formatted_templates = map(
|
|
545
660
|
lambda template: template_formatter.format(template, **template_options.variables),
|
|
546
661
|
templates,
|
|
547
662
|
)
|
|
548
|
-
formatted_messages = zip(roles, formatted_templates)
|
|
663
|
+
formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
|
|
549
664
|
return formatted_messages
|
|
550
665
|
|
|
551
666
|
|
|
@@ -560,6 +675,13 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
|
|
|
560
675
|
assert_never(template_language)
|
|
561
676
|
|
|
562
677
|
|
|
678
|
+
def _serialize_event(event: SpanEvent) -> dict[str, Any]:
|
|
679
|
+
"""
|
|
680
|
+
Serializes a SpanEvent to a dictionary.
|
|
681
|
+
"""
|
|
682
|
+
return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
|
|
683
|
+
|
|
684
|
+
|
|
563
685
|
JSON = OpenInferenceMimeTypeValues.JSON.value
|
|
564
686
|
|
|
565
687
|
LLM = OpenInferenceSpanKindValues.LLM.value
|