arize-phoenix 5.5.1__py3-none-any.whl → 5.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/METADATA +8 -11
- {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/RECORD +171 -171
- phoenix/config.py +8 -8
- phoenix/core/model.py +3 -3
- phoenix/core/model_schema.py +41 -50
- phoenix/core/model_schema_adapter.py +17 -16
- phoenix/datetime_utils.py +2 -2
- phoenix/db/bulk_inserter.py +10 -20
- phoenix/db/engines.py +2 -1
- phoenix/db/enums.py +2 -2
- phoenix/db/helpers.py +8 -7
- phoenix/db/insertion/dataset.py +9 -19
- phoenix/db/insertion/document_annotation.py +14 -13
- phoenix/db/insertion/helpers.py +6 -16
- phoenix/db/insertion/span_annotation.py +14 -13
- phoenix/db/insertion/trace_annotation.py +14 -13
- phoenix/db/insertion/types.py +19 -30
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
- phoenix/db/models.py +28 -28
- phoenix/experiments/evaluators/base.py +2 -1
- phoenix/experiments/evaluators/code_evaluators.py +4 -5
- phoenix/experiments/evaluators/llm_evaluators.py +157 -4
- phoenix/experiments/evaluators/utils.py +3 -2
- phoenix/experiments/functions.py +10 -21
- phoenix/experiments/tracing.py +2 -1
- phoenix/experiments/types.py +20 -29
- phoenix/experiments/utils.py +2 -1
- phoenix/inferences/errors.py +6 -5
- phoenix/inferences/fixtures.py +6 -5
- phoenix/inferences/inferences.py +37 -37
- phoenix/inferences/schema.py +11 -10
- phoenix/inferences/validation.py +13 -14
- phoenix/logging/_formatter.py +3 -3
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +2 -1
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +2 -2
- phoenix/pointcloud/clustering.py +3 -4
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/umap_parameters.py +2 -1
- phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
- phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
- phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
- phoenix/server/api/dataloaders/document_evaluations.py +3 -7
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
- phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
- phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
- phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
- phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
- phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
- phoenix/server/api/dataloaders/project_by_name.py +3 -3
- phoenix/server/api/dataloaders/record_counts.py +11 -18
- phoenix/server/api/dataloaders/span_annotations.py +3 -7
- phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
- phoenix/server/api/dataloaders/span_descendants.py +3 -7
- phoenix/server/api/dataloaders/span_projects.py +2 -2
- phoenix/server/api/dataloaders/token_counts.py +12 -19
- phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
- phoenix/server/api/dataloaders/user_roles.py +3 -3
- phoenix/server/api/dataloaders/users.py +3 -3
- phoenix/server/api/helpers/__init__.py +4 -3
- phoenix/server/api/helpers/dataset_helpers.py +10 -9
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
- phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +2 -2
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
- phoenix/server/api/mutations/dataset_mutations.py +4 -4
- phoenix/server/api/mutations/experiment_mutations.py +1 -2
- phoenix/server/api/mutations/export_events_mutations.py +7 -7
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +20 -20
- phoenix/server/api/routers/oauth2.py +4 -4
- phoenix/server/api/routers/v1/datasets.py +22 -36
- phoenix/server/api/routers/v1/evaluations.py +6 -5
- phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
- phoenix/server/api/routers/v1/experiment_runs.py +2 -2
- phoenix/server/api/routers/v1/experiments.py +4 -4
- phoenix/server/api/routers/v1/spans.py +13 -12
- phoenix/server/api/routers/v1/traces.py +5 -5
- phoenix/server/api/routers/v1/utils.py +5 -5
- phoenix/server/api/subscriptions.py +289 -167
- phoenix/server/api/types/AnnotationSummary.py +3 -3
- phoenix/server/api/types/Cluster.py +8 -7
- phoenix/server/api/types/Dataset.py +5 -4
- phoenix/server/api/types/Dimension.py +3 -3
- phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
- phoenix/server/api/types/EmbeddingDimension.py +6 -5
- phoenix/server/api/types/EvaluationSummary.py +3 -3
- phoenix/server/api/types/Event.py +7 -7
- phoenix/server/api/types/Experiment.py +3 -3
- phoenix/server/api/types/ExperimentComparison.py +2 -4
- phoenix/server/api/types/Inferences.py +9 -8
- phoenix/server/api/types/InferencesRole.py +2 -2
- phoenix/server/api/types/Model.py +2 -2
- phoenix/server/api/types/Project.py +11 -18
- phoenix/server/api/types/Segments.py +3 -3
- phoenix/server/api/types/Span.py +8 -7
- phoenix/server/api/types/TimeSeries.py +8 -7
- phoenix/server/api/types/Trace.py +2 -2
- phoenix/server/api/types/UMAPPoints.py +6 -6
- phoenix/server/api/types/User.py +3 -3
- phoenix/server/api/types/node.py +1 -3
- phoenix/server/api/types/pagination.py +4 -4
- phoenix/server/api/utils.py +2 -4
- phoenix/server/app.py +16 -25
- phoenix/server/bearer_auth.py +4 -10
- phoenix/server/dml_event.py +3 -3
- phoenix/server/dml_event_handler.py +10 -24
- phoenix/server/grpc_server.py +3 -2
- phoenix/server/jwt_store.py +22 -21
- phoenix/server/main.py +3 -3
- phoenix/server/oauth2.py +3 -2
- phoenix/server/rate_limiters.py +5 -8
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/components-C70HJiXz.js +1612 -0
- phoenix/server/static/assets/{index-BHfTZ6x_.js → index-DLe1Oo3l.js} +2 -2
- phoenix/server/static/assets/{pages-aAez_Ntk.js → pages-C8-Sl7JI.js} +269 -434
- phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
- phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
- phoenix/server/thread_server.py +1 -1
- phoenix/server/types.py +17 -29
- phoenix/services.py +4 -3
- phoenix/session/client.py +12 -24
- phoenix/session/data_extractor.py +3 -3
- phoenix/session/evaluation.py +1 -2
- phoenix/session/session.py +11 -20
- phoenix/trace/attributes.py +16 -28
- phoenix/trace/dsl/filter.py +17 -21
- phoenix/trace/dsl/helpers.py +3 -3
- phoenix/trace/dsl/query.py +13 -22
- phoenix/trace/fixtures.py +11 -17
- phoenix/trace/otel.py +5 -15
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +2 -2
- phoenix/trace/span_evaluations.py +9 -8
- phoenix/trace/span_json_decoder.py +3 -3
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +6 -5
- phoenix/trace/utils.py +6 -6
- phoenix/utilities/deprecation.py +3 -2
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +2 -1
- phoenix/utilities/logging.py +2 -2
- phoenix/utilities/project.py +1 -1
- phoenix/utilities/re.py +3 -4
- phoenix/utilities/template_formatters.py +5 -4
- phoenix/version.py +1 -1
- phoenix/server/static/assets/components-mVBxvljU.js +0 -1428
- {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.5.1.dist-info → arize_phoenix-5.6.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,26 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
from
|
|
5
|
-
from
|
|
4
|
+
from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Mapping
|
|
5
|
+
from dataclasses import asdict
|
|
6
|
+
from datetime import datetime, timezone
|
|
6
7
|
from enum import Enum
|
|
7
8
|
from itertools import chain
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
Annotated,
|
|
11
|
-
Any,
|
|
12
|
-
AsyncIterator,
|
|
13
|
-
Callable,
|
|
14
|
-
DefaultDict,
|
|
15
|
-
Dict,
|
|
16
|
-
Iterable,
|
|
17
|
-
Iterator,
|
|
18
|
-
List,
|
|
19
|
-
Optional,
|
|
20
|
-
Tuple,
|
|
21
|
-
Type,
|
|
22
|
-
Union,
|
|
23
|
-
)
|
|
9
|
+
from traceback import format_exc
|
|
10
|
+
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
|
24
11
|
|
|
25
12
|
import strawberry
|
|
26
13
|
from openinference.instrumentation import safe_json_dumps
|
|
@@ -32,9 +19,7 @@ from openinference.semconv.trace import (
|
|
|
32
19
|
ToolAttributes,
|
|
33
20
|
ToolCallAttributes,
|
|
34
21
|
)
|
|
35
|
-
from opentelemetry.sdk.trace import
|
|
36
|
-
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
|
37
|
-
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
|
22
|
+
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
|
|
38
23
|
from opentelemetry.trace import StatusCode
|
|
39
24
|
from sqlalchemy import insert, select
|
|
40
25
|
from strawberry import UNSET
|
|
@@ -42,8 +27,10 @@ from strawberry.scalars import JSON as JSONScalarType
|
|
|
42
27
|
from strawberry.types import Info
|
|
43
28
|
from typing_extensions import TypeAlias, assert_never
|
|
44
29
|
|
|
30
|
+
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
45
31
|
from phoenix.db import models
|
|
46
32
|
from phoenix.server.api.context import Context
|
|
33
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
47
34
|
from phoenix.server.api.input_types.ChatCompletionMessageInput import ChatCompletionMessageInput
|
|
48
35
|
from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
|
|
49
36
|
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
@@ -51,6 +38,10 @@ from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
|
|
|
51
38
|
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
52
39
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
53
40
|
from phoenix.trace.attributes import unflatten
|
|
41
|
+
from phoenix.trace.schemas import (
|
|
42
|
+
SpanEvent,
|
|
43
|
+
SpanException,
|
|
44
|
+
)
|
|
54
45
|
from phoenix.utilities.json import jsonify
|
|
55
46
|
from phoenix.utilities.template_formatters import (
|
|
56
47
|
FStringTemplateFormatter,
|
|
@@ -61,11 +52,15 @@ from phoenix.utilities.template_formatters import (
|
|
|
61
52
|
if TYPE_CHECKING:
|
|
62
53
|
from anthropic.types import MessageParam
|
|
63
54
|
from openai.types import CompletionUsage
|
|
64
|
-
from openai.types.chat import
|
|
55
|
+
from openai.types.chat import (
|
|
56
|
+
ChatCompletionMessageParam,
|
|
57
|
+
ChatCompletionMessageToolCallParam,
|
|
58
|
+
)
|
|
65
59
|
|
|
66
60
|
PLAYGROUND_PROJECT_NAME = "playground"
|
|
67
61
|
|
|
68
62
|
ToolCallID: TypeAlias = str
|
|
63
|
+
SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
|
|
69
64
|
|
|
70
65
|
|
|
71
66
|
@strawberry.enum
|
|
@@ -97,13 +92,20 @@ class ToolCallChunk:
|
|
|
97
92
|
function: FunctionCallChunk
|
|
98
93
|
|
|
99
94
|
|
|
95
|
+
@strawberry.type
|
|
96
|
+
class ChatCompletionSubscriptionError:
|
|
97
|
+
message: str
|
|
98
|
+
|
|
99
|
+
|
|
100
100
|
@strawberry.type
|
|
101
101
|
class FinishedChatCompletion:
|
|
102
102
|
span: Span
|
|
103
103
|
|
|
104
104
|
|
|
105
|
+
ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]
|
|
106
|
+
|
|
105
107
|
ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
|
|
106
|
-
Union[TextChunk, ToolCallChunk, FinishedChatCompletion],
|
|
108
|
+
Union[TextChunk, ToolCallChunk, FinishedChatCompletion, ChatCompletionSubscriptionError],
|
|
107
109
|
strawberry.union("ChatCompletionSubscriptionPayload"),
|
|
108
110
|
]
|
|
109
111
|
|
|
@@ -121,23 +123,23 @@ class GenerativeModelInput:
|
|
|
121
123
|
|
|
122
124
|
@strawberry.input
|
|
123
125
|
class ChatCompletionInput:
|
|
124
|
-
messages:
|
|
126
|
+
messages: list[ChatCompletionMessageInput]
|
|
125
127
|
model: GenerativeModelInput
|
|
126
|
-
invocation_parameters: InvocationParameters
|
|
127
|
-
tools: Optional[
|
|
128
|
+
invocation_parameters: InvocationParameters = strawberry.field(default_factory=dict)
|
|
129
|
+
tools: Optional[list[JSONScalarType]] = UNSET
|
|
128
130
|
template: Optional[TemplateOptions] = UNSET
|
|
129
131
|
api_key: Optional[str] = strawberry.field(default=None)
|
|
130
132
|
|
|
131
133
|
|
|
132
|
-
PLAYGROUND_STREAMING_CLIENT_REGISTRY:
|
|
133
|
-
GenerativeProviderKey,
|
|
134
|
+
PLAYGROUND_STREAMING_CLIENT_REGISTRY: dict[
|
|
135
|
+
GenerativeProviderKey, type["PlaygroundStreamingClient"]
|
|
134
136
|
] = {}
|
|
135
137
|
|
|
136
138
|
|
|
137
139
|
def register_llm_client(
|
|
138
140
|
provider_key: GenerativeProviderKey,
|
|
139
|
-
) -> Callable[[
|
|
140
|
-
def decorator(cls:
|
|
141
|
+
) -> Callable[[type["PlaygroundStreamingClient"]], type["PlaygroundStreamingClient"]]:
|
|
142
|
+
def decorator(cls: type["PlaygroundStreamingClient"]) -> type["PlaygroundStreamingClient"]:
|
|
141
143
|
PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
|
|
142
144
|
return cls
|
|
143
145
|
|
|
@@ -145,45 +147,56 @@ def register_llm_client(
|
|
|
145
147
|
|
|
146
148
|
|
|
147
149
|
class PlaygroundStreamingClient(ABC):
|
|
148
|
-
def __init__(
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
model: GenerativeModelInput,
|
|
153
|
+
api_key: Optional[str] = None,
|
|
154
|
+
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
155
|
+
) -> None:
|
|
156
|
+
self._set_span_attributes = set_span_attributes
|
|
149
157
|
|
|
150
158
|
@abstractmethod
|
|
151
159
|
async def chat_completion_create(
|
|
152
160
|
self,
|
|
153
|
-
messages:
|
|
154
|
-
|
|
161
|
+
messages: list[
|
|
162
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
163
|
+
],
|
|
164
|
+
tools: list[JSONScalarType],
|
|
155
165
|
**invocation_parameters: Any,
|
|
156
|
-
) -> AsyncIterator[
|
|
166
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
157
167
|
# a yield statement is needed to satisfy the type-checker
|
|
158
168
|
# https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
|
|
159
169
|
yield TextChunk(content="")
|
|
160
170
|
|
|
161
|
-
@property
|
|
162
|
-
@abstractmethod
|
|
163
|
-
def attributes(self) -> Dict[str, Any]: ...
|
|
164
|
-
|
|
165
171
|
|
|
166
172
|
@register_llm_client(GenerativeProviderKey.OPENAI)
|
|
167
173
|
class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
168
|
-
def __init__(
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
model: GenerativeModelInput,
|
|
177
|
+
api_key: Optional[str] = None,
|
|
178
|
+
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
179
|
+
) -> None:
|
|
169
180
|
from openai import AsyncOpenAI
|
|
170
181
|
|
|
182
|
+
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
171
183
|
self.client = AsyncOpenAI(api_key=api_key)
|
|
172
184
|
self.model_name = model.name
|
|
173
|
-
self._attributes: Dict[str, Any] = {}
|
|
174
185
|
|
|
175
186
|
async def chat_completion_create(
|
|
176
187
|
self,
|
|
177
|
-
messages:
|
|
178
|
-
|
|
188
|
+
messages: list[
|
|
189
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
190
|
+
],
|
|
191
|
+
tools: list[JSONScalarType],
|
|
179
192
|
**invocation_parameters: Any,
|
|
180
|
-
) -> AsyncIterator[
|
|
193
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
181
194
|
from openai import NOT_GIVEN
|
|
182
195
|
from openai.types.chat import ChatCompletionStreamOptionsParam
|
|
183
196
|
|
|
184
197
|
# Convert standard messages to OpenAI messages
|
|
185
198
|
openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
|
|
186
|
-
tool_call_ids:
|
|
199
|
+
tool_call_ids: dict[int, str] = {}
|
|
187
200
|
token_usage: Optional["CompletionUsage"] = None
|
|
188
201
|
async for chunk in await self.client.chat.completions.create(
|
|
189
202
|
messages=openai_messages,
|
|
@@ -219,15 +232,20 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
219
232
|
),
|
|
220
233
|
)
|
|
221
234
|
yield tool_call_chunk
|
|
222
|
-
if token_usage is not None:
|
|
223
|
-
self.
|
|
235
|
+
if token_usage is not None and self._set_span_attributes:
|
|
236
|
+
self._set_span_attributes(dict(self._llm_token_counts(token_usage)))
|
|
224
237
|
|
|
225
238
|
def to_openai_chat_completion_param(
|
|
226
|
-
self,
|
|
239
|
+
self,
|
|
240
|
+
role: ChatCompletionMessageRole,
|
|
241
|
+
content: JSONScalarType,
|
|
242
|
+
tool_call_id: Optional[str] = None,
|
|
243
|
+
tool_calls: Optional[list[JSONScalarType]] = None,
|
|
227
244
|
) -> "ChatCompletionMessageParam":
|
|
228
245
|
from openai.types.chat import (
|
|
229
246
|
ChatCompletionAssistantMessageParam,
|
|
230
247
|
ChatCompletionSystemMessageParam,
|
|
248
|
+
ChatCompletionToolMessageParam,
|
|
231
249
|
ChatCompletionUserMessageParam,
|
|
232
250
|
)
|
|
233
251
|
|
|
@@ -246,26 +264,64 @@ class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
|
246
264
|
}
|
|
247
265
|
)
|
|
248
266
|
if role is ChatCompletionMessageRole.AI:
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
267
|
+
if tool_calls is None:
|
|
268
|
+
return ChatCompletionAssistantMessageParam(
|
|
269
|
+
{
|
|
270
|
+
"content": content,
|
|
271
|
+
"role": "assistant",
|
|
272
|
+
}
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
return ChatCompletionAssistantMessageParam(
|
|
276
|
+
{
|
|
277
|
+
"content": content,
|
|
278
|
+
"role": "assistant",
|
|
279
|
+
"tool_calls": [
|
|
280
|
+
self.to_openai_tool_call_param(tool_call) for tool_call in tool_calls
|
|
281
|
+
],
|
|
282
|
+
}
|
|
283
|
+
)
|
|
255
284
|
if role is ChatCompletionMessageRole.TOOL:
|
|
256
|
-
|
|
285
|
+
if tool_call_id is None:
|
|
286
|
+
raise ValueError("tool_call_id is required for tool messages")
|
|
287
|
+
return ChatCompletionToolMessageParam(
|
|
288
|
+
{"content": content, "role": "tool", "tool_call_id": tool_call_id}
|
|
289
|
+
)
|
|
257
290
|
assert_never(role)
|
|
258
291
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
292
|
+
def to_openai_tool_call_param(
|
|
293
|
+
self,
|
|
294
|
+
tool_call: JSONScalarType,
|
|
295
|
+
) -> "ChatCompletionMessageToolCallParam":
|
|
296
|
+
from openai.types.chat import ChatCompletionMessageToolCallParam
|
|
297
|
+
|
|
298
|
+
return ChatCompletionMessageToolCallParam(
|
|
299
|
+
id=tool_call.get("id", ""),
|
|
300
|
+
function={
|
|
301
|
+
"name": tool_call.get("function", {}).get("name", ""),
|
|
302
|
+
"arguments": safe_json_dumps(tool_call.get("function", {}).get("arguments", "")),
|
|
303
|
+
},
|
|
304
|
+
type="function",
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
@staticmethod
|
|
308
|
+
def _llm_token_counts(usage: "CompletionUsage") -> Iterator[tuple[str, Any]]:
|
|
309
|
+
yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
|
|
310
|
+
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
|
|
311
|
+
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens
|
|
262
312
|
|
|
263
313
|
|
|
264
314
|
@register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
|
|
265
315
|
class AzureOpenAIStreamingClient(OpenAIStreamingClient):
|
|
266
|
-
def __init__(
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
model: GenerativeModelInput,
|
|
319
|
+
api_key: Optional[str] = None,
|
|
320
|
+
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
321
|
+
):
|
|
267
322
|
from openai import AsyncAzureOpenAI
|
|
268
323
|
|
|
324
|
+
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
269
325
|
if model.endpoint is None or model.api_version is None:
|
|
270
326
|
raise ValueError("endpoint and api_version are required for Azure OpenAI models")
|
|
271
327
|
self.client = AsyncAzureOpenAI(
|
|
@@ -277,18 +333,29 @@ class AzureOpenAIStreamingClient(OpenAIStreamingClient):
|
|
|
277
333
|
|
|
278
334
|
@register_llm_client(GenerativeProviderKey.ANTHROPIC)
|
|
279
335
|
class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
280
|
-
def __init__(
|
|
336
|
+
def __init__(
|
|
337
|
+
self,
|
|
338
|
+
model: GenerativeModelInput,
|
|
339
|
+
api_key: Optional[str] = None,
|
|
340
|
+
set_span_attributes: Optional[SetSpanAttributesFn] = None,
|
|
341
|
+
) -> None:
|
|
281
342
|
import anthropic
|
|
282
343
|
|
|
344
|
+
super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
|
|
283
345
|
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
284
346
|
self.model_name = model.name
|
|
285
347
|
|
|
286
348
|
async def chat_completion_create(
|
|
287
349
|
self,
|
|
288
|
-
messages:
|
|
289
|
-
|
|
350
|
+
messages: list[
|
|
351
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
352
|
+
],
|
|
353
|
+
tools: list[JSONScalarType],
|
|
290
354
|
**invocation_parameters: Any,
|
|
291
|
-
) -> AsyncIterator[
|
|
355
|
+
) -> AsyncIterator[ChatCompletionChunk]:
|
|
356
|
+
import anthropic.lib.streaming as anthropic_streaming
|
|
357
|
+
import anthropic.types as anthropic_types
|
|
358
|
+
|
|
292
359
|
anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
|
|
293
360
|
|
|
294
361
|
anthropic_params = {
|
|
@@ -298,17 +365,43 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
298
365
|
"max_tokens": 1024,
|
|
299
366
|
**invocation_parameters,
|
|
300
367
|
}
|
|
301
|
-
|
|
302
368
|
async with self.client.messages.stream(**anthropic_params) as stream:
|
|
303
|
-
async for
|
|
304
|
-
|
|
369
|
+
async for event in stream:
|
|
370
|
+
if isinstance(event, anthropic_types.RawMessageStartEvent):
|
|
371
|
+
if self._set_span_attributes:
|
|
372
|
+
self._set_span_attributes(
|
|
373
|
+
{LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
|
|
374
|
+
)
|
|
375
|
+
elif isinstance(event, anthropic_streaming.TextEvent):
|
|
376
|
+
yield TextChunk(content=event.text)
|
|
377
|
+
elif isinstance(event, anthropic_streaming.MessageStopEvent):
|
|
378
|
+
if self._set_span_attributes:
|
|
379
|
+
self._set_span_attributes(
|
|
380
|
+
{LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
|
|
381
|
+
)
|
|
382
|
+
elif isinstance(
|
|
383
|
+
event,
|
|
384
|
+
(
|
|
385
|
+
anthropic_types.RawContentBlockStartEvent,
|
|
386
|
+
anthropic_types.RawContentBlockDeltaEvent,
|
|
387
|
+
anthropic_types.RawMessageDeltaEvent,
|
|
388
|
+
anthropic_streaming.ContentBlockStopEvent,
|
|
389
|
+
),
|
|
390
|
+
):
|
|
391
|
+
# event types emitted by the stream that don't contain useful information
|
|
392
|
+
pass
|
|
393
|
+
elif isinstance(event, anthropic_streaming.InputJsonEvent):
|
|
394
|
+
raise NotImplementedError
|
|
395
|
+
else:
|
|
396
|
+
assert_never(event)
|
|
305
397
|
|
|
306
398
|
def _build_anthropic_messages(
|
|
307
|
-
self,
|
|
308
|
-
|
|
309
|
-
|
|
399
|
+
self,
|
|
400
|
+
messages: list[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
401
|
+
) -> tuple[list["MessageParam"], str]:
|
|
402
|
+
anthropic_messages: list["MessageParam"] = []
|
|
310
403
|
system_prompt = ""
|
|
311
|
-
for role, content in messages:
|
|
404
|
+
for role, content, _tool_call_id, _tool_calls in messages:
|
|
312
405
|
if role == ChatCompletionMessageRole.USER:
|
|
313
406
|
anthropic_messages.append({"role": "user", "content": content})
|
|
314
407
|
elif role == ChatCompletionMessageRole.AI:
|
|
@@ -322,10 +415,6 @@ class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
|
322
415
|
|
|
323
416
|
return anthropic_messages, system_prompt
|
|
324
417
|
|
|
325
|
-
@property
|
|
326
|
-
def attributes(self) -> Dict[str, Any]:
|
|
327
|
-
return dict()
|
|
328
|
-
|
|
329
418
|
|
|
330
419
|
@strawberry.type
|
|
331
420
|
class Subscription:
|
|
@@ -335,44 +424,45 @@ class Subscription:
|
|
|
335
424
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
336
425
|
# Determine which LLM client to use based on provider_key
|
|
337
426
|
provider_key = input.model.provider_key
|
|
338
|
-
llm_client_class
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
427
|
+
if (llm_client_class := PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)) is None:
|
|
428
|
+
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
429
|
+
llm_client = llm_client_class(
|
|
430
|
+
model=input.model,
|
|
431
|
+
api_key=input.api_key,
|
|
432
|
+
set_span_attributes=lambda attrs: attributes.update(attrs),
|
|
433
|
+
)
|
|
434
|
+
messages = [
|
|
435
|
+
(
|
|
436
|
+
message.role,
|
|
437
|
+
message.content,
|
|
438
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
439
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
440
|
+
)
|
|
441
|
+
for message in input.messages
|
|
442
|
+
]
|
|
346
443
|
if template_options := input.template:
|
|
347
444
|
messages = list(_formatted_messages(messages, template_options))
|
|
348
|
-
|
|
349
445
|
invocation_parameters = jsonify(input.invocation_parameters)
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
446
|
+
attributes = dict(
|
|
447
|
+
chain(
|
|
448
|
+
_llm_span_kind(),
|
|
449
|
+
_llm_model_name(input.model.name),
|
|
450
|
+
_llm_tools(input.tools or []),
|
|
451
|
+
_llm_input_messages(messages),
|
|
452
|
+
_llm_invocation_parameters(invocation_parameters),
|
|
453
|
+
_input_value_and_mime_type(input),
|
|
454
|
+
)
|
|
355
455
|
)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
_llm_input_messages(messages),
|
|
367
|
-
_llm_invocation_parameters(invocation_parameters),
|
|
368
|
-
_input_value_and_mime_type(input),
|
|
369
|
-
)
|
|
370
|
-
),
|
|
371
|
-
) as span:
|
|
372
|
-
response_chunks = []
|
|
373
|
-
text_chunks: List[TextChunk] = []
|
|
374
|
-
tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list)
|
|
375
|
-
|
|
456
|
+
status_code: StatusCode
|
|
457
|
+
status_message = ""
|
|
458
|
+
events: list[SpanEvent] = []
|
|
459
|
+
start_time: datetime
|
|
460
|
+
end_time: datetime
|
|
461
|
+
response_chunks = []
|
|
462
|
+
text_chunks: list[TextChunk] = []
|
|
463
|
+
tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
|
|
464
|
+
try:
|
|
465
|
+
start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
376
466
|
async for chunk in llm_client.chat_completion_create(
|
|
377
467
|
messages=messages,
|
|
378
468
|
tools=input.tools or [],
|
|
@@ -385,31 +475,35 @@ class Subscription:
|
|
|
385
475
|
elif isinstance(chunk, ToolCallChunk):
|
|
386
476
|
yield chunk
|
|
387
477
|
tool_call_chunks[chunk.id].append(chunk)
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
478
|
+
else:
|
|
479
|
+
assert_never(chunk)
|
|
480
|
+
status_code = StatusCode.OK
|
|
481
|
+
except Exception as error:
|
|
482
|
+
end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
483
|
+
status_code = StatusCode.ERROR
|
|
484
|
+
status_message = str(error)
|
|
485
|
+
events.append(
|
|
486
|
+
SpanException(
|
|
487
|
+
timestamp=end_time,
|
|
488
|
+
message=status_message,
|
|
489
|
+
exception_type=type(error).__name__,
|
|
490
|
+
exception_escaped=False,
|
|
491
|
+
exception_stacktrace=format_exc(),
|
|
399
492
|
)
|
|
400
493
|
)
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
494
|
+
yield ChatCompletionSubscriptionError(message=status_message)
|
|
495
|
+
else:
|
|
496
|
+
end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
497
|
+
attributes.update(
|
|
498
|
+
chain(
|
|
499
|
+
_output_value_and_mime_type(response_chunks),
|
|
500
|
+
_llm_output_messages(text_chunks, tool_call_chunks),
|
|
501
|
+
)
|
|
502
|
+
)
|
|
503
|
+
prompt_tokens = attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
|
|
504
|
+
completion_tokens = attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
|
|
505
|
+
trace_id = _generate_trace_id()
|
|
506
|
+
span_id = _generate_span_id()
|
|
413
507
|
async with info.context.db() as session:
|
|
414
508
|
if (
|
|
415
509
|
playground_project_id := await session.scalar(
|
|
@@ -434,15 +528,15 @@ class Subscription:
|
|
|
434
528
|
trace_rowid=playground_trace.id,
|
|
435
529
|
span_id=span_id,
|
|
436
530
|
parent_id=None,
|
|
437
|
-
name=
|
|
531
|
+
name="ChatCompletion",
|
|
438
532
|
span_kind=LLM,
|
|
439
533
|
start_time=start_time,
|
|
440
534
|
end_time=end_time,
|
|
441
535
|
attributes=unflatten(attributes.items()),
|
|
442
|
-
events=
|
|
443
|
-
status_code=
|
|
444
|
-
status_message=
|
|
445
|
-
cumulative_error_count=int(
|
|
536
|
+
events=[_serialize_event(event) for event in events],
|
|
537
|
+
status_code=status_code.name,
|
|
538
|
+
status_message=status_message,
|
|
539
|
+
cumulative_error_count=int(status_code is StatusCode.ERROR),
|
|
446
540
|
cumulative_llm_token_count_prompt=prompt_tokens,
|
|
447
541
|
cumulative_llm_token_count_completion=completion_tokens,
|
|
448
542
|
llm_token_count_prompt=prompt_tokens,
|
|
@@ -456,56 +550,65 @@ class Subscription:
|
|
|
456
550
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
457
551
|
|
|
458
552
|
|
|
459
|
-
def _llm_span_kind() -> Iterator[
|
|
553
|
+
def _llm_span_kind() -> Iterator[tuple[str, Any]]:
|
|
460
554
|
yield OPENINFERENCE_SPAN_KIND, LLM
|
|
461
555
|
|
|
462
556
|
|
|
463
|
-
def _llm_model_name(model_name: str) -> Iterator[
|
|
557
|
+
def _llm_model_name(model_name: str) -> Iterator[tuple[str, Any]]:
|
|
464
558
|
yield LLM_MODEL_NAME, model_name
|
|
465
559
|
|
|
466
560
|
|
|
467
|
-
def _llm_invocation_parameters(invocation_parameters:
|
|
561
|
+
def _llm_invocation_parameters(invocation_parameters: dict[str, Any]) -> Iterator[tuple[str, Any]]:
|
|
468
562
|
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_parameters)
|
|
469
563
|
|
|
470
564
|
|
|
471
|
-
def _llm_tools(tools:
|
|
565
|
+
def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
|
|
472
566
|
for tool_index, tool in enumerate(tools):
|
|
473
567
|
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)
|
|
474
568
|
|
|
475
569
|
|
|
476
|
-
def
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
|
|
483
|
-
assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput))
|
|
570
|
+
def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
|
|
571
|
+
assert (api_key := "api_key") in (input_data := jsonify(input))
|
|
572
|
+
input_data = {k: v for k, v in input_data.items() if k != api_key}
|
|
573
|
+
assert api_key not in input_data
|
|
484
574
|
yield INPUT_MIME_TYPE, JSON
|
|
485
|
-
yield INPUT_VALUE, safe_json_dumps(
|
|
575
|
+
yield INPUT_VALUE, safe_json_dumps(input_data)
|
|
486
576
|
|
|
487
577
|
|
|
488
|
-
def _output_value_and_mime_type(output: Any) -> Iterator[
|
|
578
|
+
def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:
|
|
489
579
|
yield OUTPUT_MIME_TYPE, JSON
|
|
490
580
|
yield OUTPUT_VALUE, safe_json_dumps(jsonify(output))
|
|
491
581
|
|
|
492
582
|
|
|
493
583
|
def _llm_input_messages(
|
|
494
|
-
messages: Iterable[
|
|
495
|
-
|
|
496
|
-
|
|
584
|
+
messages: Iterable[
|
|
585
|
+
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
|
|
586
|
+
],
|
|
587
|
+
) -> Iterator[tuple[str, Any]]:
|
|
588
|
+
for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
|
|
497
589
|
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_ROLE}", role.value.lower()
|
|
498
590
|
yield f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_CONTENT}", content
|
|
591
|
+
if tool_calls is not None:
|
|
592
|
+
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
593
|
+
yield (
|
|
594
|
+
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
595
|
+
tool_call["function"]["name"],
|
|
596
|
+
)
|
|
597
|
+
if arguments := tool_call["function"]["arguments"]:
|
|
598
|
+
yield (
|
|
599
|
+
f"{LLM_INPUT_MESSAGES}.{i}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
|
|
600
|
+
safe_json_dumps(jsonify(arguments)),
|
|
601
|
+
)
|
|
499
602
|
|
|
500
603
|
|
|
501
604
|
def _llm_output_messages(
|
|
502
|
-
text_chunks:
|
|
503
|
-
tool_call_chunks:
|
|
504
|
-
) -> Iterator[
|
|
605
|
+
text_chunks: list[TextChunk],
|
|
606
|
+
tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
|
|
607
|
+
) -> Iterator[tuple[str, Any]]:
|
|
505
608
|
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
|
|
506
609
|
if content := "".join(chunk.content for chunk in text_chunks):
|
|
507
610
|
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}", content
|
|
508
|
-
for tool_call_index, tool_call_chunks_ in tool_call_chunks.items():
|
|
611
|
+
for tool_call_index, (_tool_call_id, tool_call_chunks_) in enumerate(tool_call_chunks.items()):
|
|
509
612
|
if tool_call_chunks_ and (name := tool_call_chunks_[0].function.name):
|
|
510
613
|
yield (
|
|
511
614
|
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}",
|
|
@@ -518,34 +621,46 @@ def _llm_output_messages(
|
|
|
518
621
|
)
|
|
519
622
|
|
|
520
623
|
|
|
521
|
-
def
|
|
624
|
+
def _generate_trace_id() -> str:
|
|
522
625
|
"""
|
|
523
|
-
|
|
626
|
+
Generates a random trace ID in hexadecimal format.
|
|
524
627
|
"""
|
|
525
|
-
return
|
|
628
|
+
return _hex(DefaultOTelIDGenerator().generate_trace_id())
|
|
526
629
|
|
|
527
630
|
|
|
528
|
-
def
|
|
631
|
+
def _generate_span_id() -> str:
|
|
529
632
|
"""
|
|
530
|
-
|
|
633
|
+
Generates a random span ID in hexadecimal format.
|
|
531
634
|
"""
|
|
532
|
-
|
|
533
|
-
|
|
635
|
+
return _hex(DefaultOTelIDGenerator().generate_span_id())
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def _hex(number: int) -> str:
|
|
639
|
+
"""
|
|
640
|
+
Converts an integer to a hexadecimal string.
|
|
641
|
+
"""
|
|
642
|
+
return hex(number)[2:]
|
|
534
643
|
|
|
535
644
|
|
|
536
645
|
def _formatted_messages(
|
|
537
|
-
messages: Iterable[
|
|
538
|
-
|
|
646
|
+
messages: Iterable[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]],
|
|
647
|
+
template_options: TemplateOptions,
|
|
648
|
+
) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
|
|
539
649
|
"""
|
|
540
650
|
Formats the messages using the given template options.
|
|
541
651
|
"""
|
|
542
652
|
template_formatter = _template_formatter(template_language=template_options.language)
|
|
543
|
-
|
|
653
|
+
(
|
|
654
|
+
roles,
|
|
655
|
+
templates,
|
|
656
|
+
tool_call_id,
|
|
657
|
+
tool_calls,
|
|
658
|
+
) = zip(*messages)
|
|
544
659
|
formatted_templates = map(
|
|
545
660
|
lambda template: template_formatter.format(template, **template_options.variables),
|
|
546
661
|
templates,
|
|
547
662
|
)
|
|
548
|
-
formatted_messages = zip(roles, formatted_templates)
|
|
663
|
+
formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
|
|
549
664
|
return formatted_messages
|
|
550
665
|
|
|
551
666
|
|
|
@@ -560,6 +675,13 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
|
|
|
560
675
|
assert_never(template_language)
|
|
561
676
|
|
|
562
677
|
|
|
678
|
+
def _serialize_event(event: SpanEvent) -> dict[str, Any]:
|
|
679
|
+
"""
|
|
680
|
+
Serializes a SpanEvent to a dictionary.
|
|
681
|
+
"""
|
|
682
|
+
return {k: (v.isoformat() if isinstance(v, datetime) else v) for k, v in asdict(event).items()}
|
|
683
|
+
|
|
684
|
+
|
|
563
685
|
JSON = OpenInferenceMimeTypeValues.JSON.value
|
|
564
686
|
|
|
565
687
|
LLM = OpenInferenceSpanKindValues.LLM.value
|