arize-phoenix 5.5.2__py3-none-any.whl → 5.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/METADATA +4 -7
- arize_phoenix-5.7.0.dist-info/RECORD +330 -0
- phoenix/config.py +50 -8
- phoenix/core/model.py +3 -3
- phoenix/core/model_schema.py +41 -50
- phoenix/core/model_schema_adapter.py +17 -16
- phoenix/datetime_utils.py +2 -2
- phoenix/db/bulk_inserter.py +10 -20
- phoenix/db/engines.py +2 -1
- phoenix/db/enums.py +2 -2
- phoenix/db/helpers.py +8 -7
- phoenix/db/insertion/dataset.py +9 -19
- phoenix/db/insertion/document_annotation.py +14 -13
- phoenix/db/insertion/helpers.py +6 -16
- phoenix/db/insertion/span_annotation.py +14 -13
- phoenix/db/insertion/trace_annotation.py +14 -13
- phoenix/db/insertion/types.py +19 -30
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +8 -8
- phoenix/db/models.py +28 -28
- phoenix/experiments/evaluators/base.py +2 -1
- phoenix/experiments/evaluators/code_evaluators.py +4 -5
- phoenix/experiments/evaluators/llm_evaluators.py +157 -4
- phoenix/experiments/evaluators/utils.py +3 -2
- phoenix/experiments/functions.py +10 -21
- phoenix/experiments/tracing.py +2 -1
- phoenix/experiments/types.py +20 -29
- phoenix/experiments/utils.py +2 -1
- phoenix/inferences/errors.py +6 -5
- phoenix/inferences/fixtures.py +6 -5
- phoenix/inferences/inferences.py +37 -37
- phoenix/inferences/schema.py +11 -10
- phoenix/inferences/validation.py +13 -14
- phoenix/logging/_formatter.py +3 -3
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +2 -1
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +2 -2
- phoenix/pointcloud/clustering.py +3 -4
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/umap_parameters.py +2 -1
- phoenix/server/api/dataloaders/annotation_summaries.py +12 -19
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +2 -2
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +3 -2
- phoenix/server/api/dataloaders/dataset_example_revisions.py +3 -8
- phoenix/server/api/dataloaders/dataset_example_spans.py +2 -5
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +12 -18
- phoenix/server/api/dataloaders/document_evaluations.py +3 -7
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +6 -13
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +4 -8
- phoenix/server/api/dataloaders/experiment_error_rates.py +2 -5
- phoenix/server/api/dataloaders/experiment_run_annotations.py +3 -7
- phoenix/server/api/dataloaders/experiment_run_counts.py +1 -5
- phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -5
- phoenix/server/api/dataloaders/latency_ms_quantile.py +21 -30
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +7 -13
- phoenix/server/api/dataloaders/project_by_name.py +3 -3
- phoenix/server/api/dataloaders/record_counts.py +11 -18
- phoenix/server/api/dataloaders/span_annotations.py +3 -7
- phoenix/server/api/dataloaders/span_dataset_examples.py +3 -8
- phoenix/server/api/dataloaders/span_descendants.py +3 -7
- phoenix/server/api/dataloaders/span_projects.py +2 -2
- phoenix/server/api/dataloaders/token_counts.py +12 -19
- phoenix/server/api/dataloaders/trace_row_ids.py +3 -7
- phoenix/server/api/dataloaders/user_roles.py +3 -3
- phoenix/server/api/dataloaders/users.py +3 -3
- phoenix/server/api/helpers/__init__.py +4 -3
- phoenix/server/api/helpers/dataset_helpers.py +10 -9
- phoenix/server/api/helpers/playground_clients.py +671 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +325 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +2 -2
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +2 -2
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +13 -1
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +1 -3
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +2 -2
- phoenix/server/api/input_types/DeleteExperimentsInput.py +1 -3
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +156 -13
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +2 -2
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/mutations/__init__.py +4 -0
- phoenix/server/api/mutations/chat_mutations.py +374 -0
- phoenix/server/api/mutations/dataset_mutations.py +4 -4
- phoenix/server/api/mutations/experiment_mutations.py +1 -2
- phoenix/server/api/mutations/export_events_mutations.py +7 -7
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +4 -4
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +61 -72
- phoenix/server/api/routers/oauth2.py +4 -4
- phoenix/server/api/routers/v1/datasets.py +22 -36
- phoenix/server/api/routers/v1/evaluations.py +6 -5
- phoenix/server/api/routers/v1/experiment_evaluations.py +2 -2
- phoenix/server/api/routers/v1/experiment_runs.py +2 -2
- phoenix/server/api/routers/v1/experiments.py +4 -4
- phoenix/server/api/routers/v1/spans.py +13 -12
- phoenix/server/api/routers/v1/traces.py +5 -5
- phoenix/server/api/routers/v1/utils.py +5 -5
- phoenix/server/api/schema.py +42 -10
- phoenix/server/api/subscriptions.py +347 -494
- phoenix/server/api/types/AnnotationSummary.py +3 -3
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +44 -0
- phoenix/server/api/types/Cluster.py +8 -7
- phoenix/server/api/types/Dataset.py +5 -4
- phoenix/server/api/types/Dimension.py +3 -3
- phoenix/server/api/types/DocumentEvaluationSummary.py +8 -7
- phoenix/server/api/types/EmbeddingDimension.py +6 -5
- phoenix/server/api/types/EvaluationSummary.py +3 -3
- phoenix/server/api/types/Event.py +7 -7
- phoenix/server/api/types/Experiment.py +3 -3
- phoenix/server/api/types/ExperimentComparison.py +2 -4
- phoenix/server/api/types/GenerativeProvider.py +27 -3
- phoenix/server/api/types/Inferences.py +9 -8
- phoenix/server/api/types/InferencesRole.py +2 -2
- phoenix/server/api/types/Model.py +2 -2
- phoenix/server/api/types/Project.py +11 -18
- phoenix/server/api/types/Segments.py +3 -3
- phoenix/server/api/types/Span.py +45 -7
- phoenix/server/api/types/TemplateLanguage.py +9 -0
- phoenix/server/api/types/TimeSeries.py +8 -7
- phoenix/server/api/types/Trace.py +2 -2
- phoenix/server/api/types/UMAPPoints.py +6 -6
- phoenix/server/api/types/User.py +3 -3
- phoenix/server/api/types/node.py +1 -3
- phoenix/server/api/types/pagination.py +4 -4
- phoenix/server/api/utils.py +2 -4
- phoenix/server/app.py +76 -37
- phoenix/server/bearer_auth.py +4 -10
- phoenix/server/dml_event.py +3 -3
- phoenix/server/dml_event_handler.py +10 -24
- phoenix/server/grpc_server.py +3 -2
- phoenix/server/jwt_store.py +22 -21
- phoenix/server/main.py +17 -4
- phoenix/server/oauth2.py +3 -2
- phoenix/server/rate_limiters.py +5 -8
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/components-Csu8UKOs.js +1612 -0
- phoenix/server/static/assets/{index-DCzakdJq.js → index-Bk5C9EA7.js} +2 -2
- phoenix/server/static/assets/{pages-CAL1FDMt.js → pages-UeWaKXNs.js} +337 -442
- phoenix/server/static/assets/{vendor-6IcPAw_j.js → vendor-CtqfhlbC.js} +6 -6
- phoenix/server/static/assets/{vendor-arizeai-DRZuoyuF.js → vendor-arizeai-C_3SBz56.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-DVE2_WBr.js → vendor-codemirror-wfdk9cjp.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-DwrexFA4.js → vendor-recharts-BiVnSv90.js} +1 -1
- phoenix/server/templates/index.html +1 -0
- phoenix/server/thread_server.py +1 -1
- phoenix/server/types.py +17 -29
- phoenix/services.py +8 -3
- phoenix/session/client.py +12 -24
- phoenix/session/data_extractor.py +3 -3
- phoenix/session/evaluation.py +1 -2
- phoenix/session/session.py +26 -21
- phoenix/trace/attributes.py +16 -28
- phoenix/trace/dsl/filter.py +17 -21
- phoenix/trace/dsl/helpers.py +3 -3
- phoenix/trace/dsl/query.py +13 -22
- phoenix/trace/fixtures.py +11 -17
- phoenix/trace/otel.py +5 -15
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +2 -2
- phoenix/trace/span_evaluations.py +9 -8
- phoenix/trace/span_json_decoder.py +3 -3
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +6 -5
- phoenix/trace/utils.py +6 -6
- phoenix/utilities/deprecation.py +3 -2
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +2 -1
- phoenix/utilities/logging.py +2 -2
- phoenix/utilities/project.py +1 -1
- phoenix/utilities/re.py +3 -4
- phoenix/utilities/template_formatters.py +16 -5
- phoenix/version.py +1 -1
- arize_phoenix-5.5.2.dist-info/RECORD +0 -321
- phoenix/server/static/assets/components-hX0LgYz3.js +0 -1428
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.5.2.dist-info → arize_phoenix-5.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,329 +1,73 @@
|
|
|
1
|
-
import
|
|
2
|
-
from
|
|
3
|
-
from collections import
|
|
4
|
-
from datetime import datetime, timezone
|
|
5
|
-
from enum import Enum
|
|
6
|
-
from itertools import chain
|
|
1
|
+
import logging
|
|
2
|
+
from asyncio import FIRST_COMPLETED, Task, create_task, wait
|
|
3
|
+
from collections.abc import Iterator
|
|
7
4
|
from typing import (
|
|
8
|
-
TYPE_CHECKING,
|
|
9
|
-
Annotated,
|
|
10
5
|
Any,
|
|
11
6
|
AsyncIterator,
|
|
12
|
-
|
|
13
|
-
DefaultDict,
|
|
14
|
-
Dict,
|
|
7
|
+
Collection,
|
|
15
8
|
Iterable,
|
|
16
|
-
|
|
17
|
-
List,
|
|
9
|
+
Mapping,
|
|
18
10
|
Optional,
|
|
19
|
-
|
|
20
|
-
Type,
|
|
21
|
-
Union,
|
|
11
|
+
TypeVar,
|
|
22
12
|
)
|
|
23
13
|
|
|
24
14
|
import strawberry
|
|
25
|
-
from openinference.
|
|
26
|
-
from
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
OpenInferenceSpanKindValues,
|
|
30
|
-
SpanAttributes,
|
|
31
|
-
ToolAttributes,
|
|
32
|
-
ToolCallAttributes,
|
|
33
|
-
)
|
|
34
|
-
from opentelemetry.sdk.trace import TracerProvider
|
|
35
|
-
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
|
36
|
-
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
|
37
|
-
from opentelemetry.trace import StatusCode
|
|
38
|
-
from sqlalchemy import insert, select
|
|
39
|
-
from strawberry import UNSET
|
|
40
|
-
from strawberry.scalars import JSON as JSONScalarType
|
|
15
|
+
from openinference.semconv.trace import SpanAttributes
|
|
16
|
+
from sqlalchemy import and_, func, insert, select
|
|
17
|
+
from sqlalchemy.orm import load_only
|
|
18
|
+
from strawberry.relay.types import GlobalID
|
|
41
19
|
from strawberry.types import Info
|
|
42
20
|
from typing_extensions import TypeAlias, assert_never
|
|
43
21
|
|
|
44
22
|
from phoenix.db import models
|
|
45
23
|
from phoenix.server.api.context import Context
|
|
46
|
-
from phoenix.server.api.
|
|
47
|
-
from phoenix.server.api.
|
|
24
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
25
|
+
from phoenix.server.api.helpers.playground_clients import (
|
|
26
|
+
PlaygroundStreamingClient,
|
|
27
|
+
initialize_playground_clients,
|
|
28
|
+
)
|
|
29
|
+
from phoenix.server.api.helpers.playground_registry import (
|
|
30
|
+
PLAYGROUND_CLIENT_REGISTRY,
|
|
31
|
+
)
|
|
32
|
+
from phoenix.server.api.helpers.playground_spans import streaming_llm_span
|
|
33
|
+
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
34
|
+
ChatCompletionInput,
|
|
35
|
+
ChatCompletionOverDatasetInput,
|
|
36
|
+
)
|
|
48
37
|
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
49
|
-
from phoenix.server.api.types.
|
|
50
|
-
|
|
38
|
+
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
|
|
39
|
+
ChatCompletionOverDatasetSubscriptionResult,
|
|
40
|
+
ChatCompletionSubscriptionError,
|
|
41
|
+
ChatCompletionSubscriptionPayload,
|
|
42
|
+
FinishedChatCompletion,
|
|
43
|
+
)
|
|
44
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
45
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
46
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
47
|
+
from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
48
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
49
|
+
from phoenix.server.api.types.Span import to_gql_span
|
|
50
|
+
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
51
51
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
52
|
-
from phoenix.trace.attributes import
|
|
53
|
-
from phoenix.utilities.json import jsonify
|
|
52
|
+
from phoenix.trace.attributes import get_attribute_value
|
|
54
53
|
from phoenix.utilities.template_formatters import (
|
|
55
54
|
FStringTemplateFormatter,
|
|
56
55
|
MustacheTemplateFormatter,
|
|
57
56
|
TemplateFormatter,
|
|
57
|
+
TemplateFormatterError,
|
|
58
58
|
)
|
|
59
59
|
|
|
60
|
-
|
|
61
|
-
from anthropic.types import MessageParam
|
|
62
|
-
from openai.types import CompletionUsage
|
|
63
|
-
from openai.types.chat import ChatCompletionMessageParam
|
|
64
|
-
|
|
65
|
-
PLAYGROUND_PROJECT_NAME = "playground"
|
|
66
|
-
|
|
67
|
-
ToolCallID: TypeAlias = str
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
@strawberry.enum
|
|
71
|
-
class TemplateLanguage(Enum):
|
|
72
|
-
MUSTACHE = "MUSTACHE"
|
|
73
|
-
F_STRING = "F_STRING"
|
|
60
|
+
GenericType = TypeVar("GenericType")
|
|
74
61
|
|
|
62
|
+
logger = logging.getLogger(__name__)
|
|
75
63
|
|
|
76
|
-
|
|
77
|
-
class TemplateOptions:
|
|
78
|
-
variables: JSONScalarType
|
|
79
|
-
language: TemplateLanguage
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
@strawberry.type
|
|
83
|
-
class TextChunk:
|
|
84
|
-
content: str
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@strawberry.type
|
|
88
|
-
class FunctionCallChunk:
|
|
89
|
-
name: str
|
|
90
|
-
arguments: str
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
@strawberry.type
|
|
94
|
-
class ToolCallChunk:
|
|
95
|
-
id: str
|
|
96
|
-
function: FunctionCallChunk
|
|
97
|
-
|
|
64
|
+
initialize_playground_clients()
|
|
98
65
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
span: Span
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
|
|
105
|
-
Union[TextChunk, ToolCallChunk, FinishedChatCompletion],
|
|
106
|
-
strawberry.union("ChatCompletionSubscriptionPayload"),
|
|
66
|
+
ChatCompletionMessage: TypeAlias = tuple[
|
|
67
|
+
ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
|
|
107
68
|
]
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
@strawberry.input
|
|
111
|
-
class GenerativeModelInput:
|
|
112
|
-
provider_key: GenerativeProviderKey
|
|
113
|
-
name: str
|
|
114
|
-
""" The name of the model. Or the Deployment Name for Azure OpenAI models. """
|
|
115
|
-
endpoint: Optional[str] = UNSET
|
|
116
|
-
""" The endpoint to use for the model. Only required for Azure OpenAI models. """
|
|
117
|
-
api_version: Optional[str] = UNSET
|
|
118
|
-
""" The API version to use for the model. """
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
@strawberry.input
|
|
122
|
-
class ChatCompletionInput:
|
|
123
|
-
messages: List[ChatCompletionMessageInput]
|
|
124
|
-
model: GenerativeModelInput
|
|
125
|
-
invocation_parameters: InvocationParameters
|
|
126
|
-
tools: Optional[List[JSONScalarType]] = UNSET
|
|
127
|
-
template: Optional[TemplateOptions] = UNSET
|
|
128
|
-
api_key: Optional[str] = strawberry.field(default=None)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
PLAYGROUND_STREAMING_CLIENT_REGISTRY: Dict[
|
|
132
|
-
GenerativeProviderKey, Type["PlaygroundStreamingClient"]
|
|
133
|
-
] = {}
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def register_llm_client(
|
|
137
|
-
provider_key: GenerativeProviderKey,
|
|
138
|
-
) -> Callable[[Type["PlaygroundStreamingClient"]], Type["PlaygroundStreamingClient"]]:
|
|
139
|
-
def decorator(cls: Type["PlaygroundStreamingClient"]) -> Type["PlaygroundStreamingClient"]:
|
|
140
|
-
PLAYGROUND_STREAMING_CLIENT_REGISTRY[provider_key] = cls
|
|
141
|
-
return cls
|
|
142
|
-
|
|
143
|
-
return decorator
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
class PlaygroundStreamingClient(ABC):
|
|
147
|
-
def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None: ...
|
|
148
|
-
|
|
149
|
-
@abstractmethod
|
|
150
|
-
async def chat_completion_create(
|
|
151
|
-
self,
|
|
152
|
-
messages: List[Tuple[ChatCompletionMessageRole, str]],
|
|
153
|
-
tools: List[JSONScalarType],
|
|
154
|
-
**invocation_parameters: Any,
|
|
155
|
-
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
156
|
-
# a yield statement is needed to satisfy the type-checker
|
|
157
|
-
# https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
|
|
158
|
-
yield TextChunk(content="")
|
|
159
|
-
|
|
160
|
-
@property
|
|
161
|
-
@abstractmethod
|
|
162
|
-
def attributes(self) -> Dict[str, Any]: ...
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
@register_llm_client(GenerativeProviderKey.OPENAI)
|
|
166
|
-
class OpenAIStreamingClient(PlaygroundStreamingClient):
|
|
167
|
-
def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
|
|
168
|
-
from openai import AsyncOpenAI
|
|
169
|
-
|
|
170
|
-
self.client = AsyncOpenAI(api_key=api_key)
|
|
171
|
-
self.model_name = model.name
|
|
172
|
-
self._attributes: Dict[str, Any] = {}
|
|
173
|
-
|
|
174
|
-
async def chat_completion_create(
|
|
175
|
-
self,
|
|
176
|
-
messages: List[Tuple[ChatCompletionMessageRole, str]],
|
|
177
|
-
tools: List[JSONScalarType],
|
|
178
|
-
**invocation_parameters: Any,
|
|
179
|
-
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
180
|
-
from openai import NOT_GIVEN
|
|
181
|
-
from openai.types.chat import ChatCompletionStreamOptionsParam
|
|
182
|
-
|
|
183
|
-
# Convert standard messages to OpenAI messages
|
|
184
|
-
openai_messages = [self.to_openai_chat_completion_param(*message) for message in messages]
|
|
185
|
-
tool_call_ids: Dict[int, str] = {}
|
|
186
|
-
token_usage: Optional["CompletionUsage"] = None
|
|
187
|
-
async for chunk in await self.client.chat.completions.create(
|
|
188
|
-
messages=openai_messages,
|
|
189
|
-
model=self.model_name,
|
|
190
|
-
stream=True,
|
|
191
|
-
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
|
192
|
-
tools=tools or NOT_GIVEN,
|
|
193
|
-
**invocation_parameters,
|
|
194
|
-
):
|
|
195
|
-
if (usage := chunk.usage) is not None:
|
|
196
|
-
token_usage = usage
|
|
197
|
-
continue
|
|
198
|
-
choice = chunk.choices[0]
|
|
199
|
-
delta = choice.delta
|
|
200
|
-
if choice.finish_reason is None:
|
|
201
|
-
if isinstance(chunk_content := delta.content, str):
|
|
202
|
-
text_chunk = TextChunk(content=chunk_content)
|
|
203
|
-
yield text_chunk
|
|
204
|
-
if (tool_calls := delta.tool_calls) is not None:
|
|
205
|
-
for tool_call_index, tool_call in enumerate(tool_calls):
|
|
206
|
-
tool_call_id = (
|
|
207
|
-
tool_call.id
|
|
208
|
-
if tool_call.id is not None
|
|
209
|
-
else tool_call_ids[tool_call_index]
|
|
210
|
-
)
|
|
211
|
-
tool_call_ids[tool_call_index] = tool_call_id
|
|
212
|
-
if (function := tool_call.function) is not None:
|
|
213
|
-
tool_call_chunk = ToolCallChunk(
|
|
214
|
-
id=tool_call_id,
|
|
215
|
-
function=FunctionCallChunk(
|
|
216
|
-
name=function.name or "",
|
|
217
|
-
arguments=function.arguments or "",
|
|
218
|
-
),
|
|
219
|
-
)
|
|
220
|
-
yield tool_call_chunk
|
|
221
|
-
if token_usage is not None:
|
|
222
|
-
self._attributes.update(_llm_token_counts(token_usage))
|
|
223
|
-
|
|
224
|
-
def to_openai_chat_completion_param(
|
|
225
|
-
self, role: ChatCompletionMessageRole, content: JSONScalarType
|
|
226
|
-
) -> "ChatCompletionMessageParam":
|
|
227
|
-
from openai.types.chat import (
|
|
228
|
-
ChatCompletionAssistantMessageParam,
|
|
229
|
-
ChatCompletionSystemMessageParam,
|
|
230
|
-
ChatCompletionUserMessageParam,
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
if role is ChatCompletionMessageRole.USER:
|
|
234
|
-
return ChatCompletionUserMessageParam(
|
|
235
|
-
{
|
|
236
|
-
"content": content,
|
|
237
|
-
"role": "user",
|
|
238
|
-
}
|
|
239
|
-
)
|
|
240
|
-
if role is ChatCompletionMessageRole.SYSTEM:
|
|
241
|
-
return ChatCompletionSystemMessageParam(
|
|
242
|
-
{
|
|
243
|
-
"content": content,
|
|
244
|
-
"role": "system",
|
|
245
|
-
}
|
|
246
|
-
)
|
|
247
|
-
if role is ChatCompletionMessageRole.AI:
|
|
248
|
-
return ChatCompletionAssistantMessageParam(
|
|
249
|
-
{
|
|
250
|
-
"content": content,
|
|
251
|
-
"role": "assistant",
|
|
252
|
-
}
|
|
253
|
-
)
|
|
254
|
-
if role is ChatCompletionMessageRole.TOOL:
|
|
255
|
-
raise NotImplementedError
|
|
256
|
-
assert_never(role)
|
|
257
|
-
|
|
258
|
-
@property
|
|
259
|
-
def attributes(self) -> Dict[str, Any]:
|
|
260
|
-
return self._attributes
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
@register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
|
|
264
|
-
class AzureOpenAIStreamingClient(OpenAIStreamingClient):
|
|
265
|
-
def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None):
|
|
266
|
-
from openai import AsyncAzureOpenAI
|
|
267
|
-
|
|
268
|
-
if model.endpoint is None or model.api_version is None:
|
|
269
|
-
raise ValueError("endpoint and api_version are required for Azure OpenAI models")
|
|
270
|
-
self.client = AsyncAzureOpenAI(
|
|
271
|
-
api_key=api_key,
|
|
272
|
-
azure_endpoint=model.endpoint,
|
|
273
|
-
api_version=model.api_version,
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
@register_llm_client(GenerativeProviderKey.ANTHROPIC)
|
|
278
|
-
class AnthropicStreamingClient(PlaygroundStreamingClient):
|
|
279
|
-
def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
|
|
280
|
-
import anthropic
|
|
281
|
-
|
|
282
|
-
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
283
|
-
self.model_name = model.name
|
|
284
|
-
|
|
285
|
-
async def chat_completion_create(
|
|
286
|
-
self,
|
|
287
|
-
messages: List[Tuple[ChatCompletionMessageRole, str]],
|
|
288
|
-
tools: List[JSONScalarType],
|
|
289
|
-
**invocation_parameters: Any,
|
|
290
|
-
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
291
|
-
anthropic_messages, system_prompt = self._build_anthropic_messages(messages)
|
|
292
|
-
|
|
293
|
-
anthropic_params = {
|
|
294
|
-
"messages": anthropic_messages,
|
|
295
|
-
"model": self.model_name,
|
|
296
|
-
"system": system_prompt,
|
|
297
|
-
"max_tokens": 1024,
|
|
298
|
-
**invocation_parameters,
|
|
299
|
-
}
|
|
300
|
-
|
|
301
|
-
async with self.client.messages.stream(**anthropic_params) as stream:
|
|
302
|
-
async for text in stream.text_stream:
|
|
303
|
-
yield TextChunk(content=text)
|
|
304
|
-
|
|
305
|
-
def _build_anthropic_messages(
|
|
306
|
-
self, messages: List[Tuple[ChatCompletionMessageRole, str]]
|
|
307
|
-
) -> Tuple[List["MessageParam"], str]:
|
|
308
|
-
anthropic_messages: List["MessageParam"] = []
|
|
309
|
-
system_prompt = ""
|
|
310
|
-
for role, content in messages:
|
|
311
|
-
if role == ChatCompletionMessageRole.USER:
|
|
312
|
-
anthropic_messages.append({"role": "user", "content": content})
|
|
313
|
-
elif role == ChatCompletionMessageRole.AI:
|
|
314
|
-
anthropic_messages.append({"role": "assistant", "content": content})
|
|
315
|
-
elif role == ChatCompletionMessageRole.SYSTEM:
|
|
316
|
-
system_prompt += content + "\n"
|
|
317
|
-
elif role == ChatCompletionMessageRole.TOOL:
|
|
318
|
-
raise NotImplementedError
|
|
319
|
-
else:
|
|
320
|
-
assert_never(role)
|
|
321
|
-
|
|
322
|
-
return anthropic_messages, system_prompt
|
|
323
|
-
|
|
324
|
-
@property
|
|
325
|
-
def attributes(self) -> Dict[str, Any]:
|
|
326
|
-
return dict()
|
|
69
|
+
DatasetExampleID: TypeAlias = GlobalID
|
|
70
|
+
PLAYGROUND_PROJECT_NAME = "playground"
|
|
327
71
|
|
|
328
72
|
|
|
329
73
|
@strawberry.type
|
|
@@ -332,82 +76,48 @@ class Subscription:
|
|
|
332
76
|
async def chat_completion(
|
|
333
77
|
self, info: Info[Context, None], input: ChatCompletionInput
|
|
334
78
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
335
|
-
# Determine which LLM client to use based on provider_key
|
|
336
79
|
provider_key = input.model.provider_key
|
|
337
|
-
llm_client_class =
|
|
80
|
+
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
338
81
|
if llm_client_class is None:
|
|
339
|
-
raise
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
82
|
+
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
83
|
+
llm_client = llm_client_class(
|
|
84
|
+
model=input.model,
|
|
85
|
+
api_key=input.api_key,
|
|
86
|
+
)
|
|
344
87
|
|
|
88
|
+
messages = [
|
|
89
|
+
(
|
|
90
|
+
message.role,
|
|
91
|
+
message.content,
|
|
92
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
93
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
94
|
+
)
|
|
95
|
+
for message in input.messages
|
|
96
|
+
]
|
|
345
97
|
if template_options := input.template:
|
|
346
|
-
messages = list(
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
tracer_provider = TracerProvider()
|
|
352
|
-
tracer_provider.add_span_processor(
|
|
353
|
-
span_processor=SimpleSpanProcessor(span_exporter=in_memory_span_exporter)
|
|
354
|
-
)
|
|
355
|
-
tracer = tracer_provider.get_tracer(__name__)
|
|
356
|
-
span_name = "ChatCompletion"
|
|
357
|
-
with tracer.start_span(
|
|
358
|
-
span_name,
|
|
359
|
-
attributes=dict(
|
|
360
|
-
chain(
|
|
361
|
-
_llm_span_kind(),
|
|
362
|
-
_llm_model_name(input.model.name),
|
|
363
|
-
_llm_tools(input.tools or []),
|
|
364
|
-
_llm_input_messages(messages),
|
|
365
|
-
_llm_invocation_parameters(invocation_parameters),
|
|
366
|
-
_input_value_and_mime_type(input),
|
|
98
|
+
messages = list(
|
|
99
|
+
_formatted_messages(
|
|
100
|
+
messages=messages,
|
|
101
|
+
template_language=template_options.language,
|
|
102
|
+
template_variables=template_options.variables,
|
|
367
103
|
)
|
|
368
|
-
)
|
|
104
|
+
)
|
|
105
|
+
invocation_parameters = llm_client.construct_invocation_parameters(
|
|
106
|
+
input.invocation_parameters
|
|
107
|
+
)
|
|
108
|
+
async with streaming_llm_span(
|
|
109
|
+
input=input,
|
|
110
|
+
messages=messages,
|
|
111
|
+
invocation_parameters=invocation_parameters,
|
|
369
112
|
) as span:
|
|
370
|
-
response_chunks = []
|
|
371
|
-
text_chunks: List[TextChunk] = []
|
|
372
|
-
tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list)
|
|
373
|
-
|
|
374
113
|
async for chunk in llm_client.chat_completion_create(
|
|
375
|
-
messages=messages,
|
|
376
|
-
tools=input.tools or [],
|
|
377
|
-
**invocation_parameters,
|
|
114
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
378
115
|
):
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
yield chunk
|
|
385
|
-
tool_call_chunks[chunk.id].append(chunk)
|
|
386
|
-
|
|
387
|
-
span.set_status(StatusCode.OK)
|
|
388
|
-
llm_client_attributes = llm_client.attributes
|
|
389
|
-
|
|
390
|
-
span.set_attributes(
|
|
391
|
-
dict(
|
|
392
|
-
chain(
|
|
393
|
-
_output_value_and_mime_type(response_chunks),
|
|
394
|
-
llm_client_attributes.items(),
|
|
395
|
-
_llm_output_messages(text_chunks, tool_call_chunks),
|
|
396
|
-
)
|
|
397
|
-
)
|
|
398
|
-
)
|
|
399
|
-
assert len(spans := in_memory_span_exporter.get_finished_spans()) == 1
|
|
400
|
-
finished_span = spans[0]
|
|
401
|
-
assert finished_span.start_time is not None
|
|
402
|
-
assert finished_span.end_time is not None
|
|
403
|
-
assert (attributes := finished_span.attributes) is not None
|
|
404
|
-
start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
|
|
405
|
-
end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
|
|
406
|
-
prompt_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
|
|
407
|
-
completion_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
|
|
408
|
-
trace_id = _hex(finished_span.context.trace_id)
|
|
409
|
-
span_id = _hex(finished_span.context.span_id)
|
|
410
|
-
status = finished_span.status
|
|
116
|
+
span.add_response_chunk(chunk)
|
|
117
|
+
yield chunk
|
|
118
|
+
span.set_attributes(llm_client.attributes)
|
|
119
|
+
if span.error_message is not None:
|
|
120
|
+
yield ChatCompletionSubscriptionError(message=span.error_message)
|
|
411
121
|
async with info.context.db() as session:
|
|
412
122
|
if (
|
|
413
123
|
playground_project_id := await session.scalar(
|
|
@@ -422,130 +132,273 @@ class Subscription:
|
|
|
422
132
|
description="Traces from prompt playground",
|
|
423
133
|
)
|
|
424
134
|
)
|
|
425
|
-
|
|
426
|
-
project_rowid=playground_project_id,
|
|
427
|
-
trace_id=trace_id,
|
|
428
|
-
start_time=start_time,
|
|
429
|
-
end_time=end_time,
|
|
430
|
-
)
|
|
431
|
-
playground_span = models.Span(
|
|
432
|
-
trace_rowid=playground_trace.id,
|
|
433
|
-
span_id=span_id,
|
|
434
|
-
parent_id=None,
|
|
435
|
-
name=span_name,
|
|
436
|
-
span_kind=LLM,
|
|
437
|
-
start_time=start_time,
|
|
438
|
-
end_time=end_time,
|
|
439
|
-
attributes=unflatten(attributes.items()),
|
|
440
|
-
events=finished_span.events,
|
|
441
|
-
status_code=status.status_code.name,
|
|
442
|
-
status_message=status.description or "",
|
|
443
|
-
cumulative_error_count=int(not status.is_ok),
|
|
444
|
-
cumulative_llm_token_count_prompt=prompt_tokens,
|
|
445
|
-
cumulative_llm_token_count_completion=completion_tokens,
|
|
446
|
-
llm_token_count_prompt=prompt_tokens,
|
|
447
|
-
llm_token_count_completion=completion_tokens,
|
|
448
|
-
trace=playground_trace,
|
|
449
|
-
)
|
|
450
|
-
session.add(playground_trace)
|
|
451
|
-
session.add(playground_span)
|
|
135
|
+
db_span = span.add_to_session(session, playground_project_id)
|
|
452
136
|
await session.flush()
|
|
453
|
-
yield FinishedChatCompletion(span=to_gql_span(
|
|
137
|
+
yield FinishedChatCompletion(span=to_gql_span(db_span))
|
|
454
138
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
455
139
|
|
|
140
|
+
@strawberry.subscription
|
|
141
|
+
async def chat_completion_over_dataset(
|
|
142
|
+
self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
|
|
143
|
+
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
144
|
+
provider_key = input.model.provider_key
|
|
145
|
+
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
146
|
+
if llm_client_class is None:
|
|
147
|
+
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
456
148
|
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
)
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
149
|
+
dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
|
|
150
|
+
version_id = (
|
|
151
|
+
from_global_id_with_expected_type(
|
|
152
|
+
global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
153
|
+
)
|
|
154
|
+
if input.dataset_version_id
|
|
155
|
+
else None
|
|
156
|
+
)
|
|
157
|
+
revision_ids = (
|
|
158
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
159
|
+
.join(models.DatasetExample)
|
|
160
|
+
.where(models.DatasetExample.dataset_id == dataset_id)
|
|
161
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
162
|
+
)
|
|
163
|
+
if version_id:
|
|
164
|
+
version_id_subquery = (
|
|
165
|
+
select(models.DatasetVersion.id)
|
|
166
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
167
|
+
.where(models.DatasetVersion.id == version_id)
|
|
168
|
+
.scalar_subquery()
|
|
169
|
+
)
|
|
170
|
+
revision_ids = revision_ids.where(
|
|
171
|
+
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
172
|
+
)
|
|
173
|
+
query = (
|
|
174
|
+
select(models.DatasetExampleRevision)
|
|
175
|
+
.where(
|
|
176
|
+
and_(
|
|
177
|
+
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
178
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id.asc())
|
|
182
|
+
.options(
|
|
183
|
+
load_only(
|
|
184
|
+
models.DatasetExampleRevision.dataset_example_id,
|
|
185
|
+
models.DatasetExampleRevision.input,
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
async with info.context.db() as session:
|
|
190
|
+
revisions = [revision async for revision in await session.stream_scalars(query)]
|
|
191
|
+
if not revisions:
|
|
192
|
+
raise BadRequest("No examples found for the given dataset and version")
|
|
193
|
+
|
|
194
|
+
spans: dict[DatasetExampleID, streaming_llm_span] = {}
|
|
195
|
+
async for payload in _merge_iterators(
|
|
196
|
+
[
|
|
197
|
+
_stream_chat_completion_over_dataset_example(
|
|
198
|
+
input=input,
|
|
199
|
+
llm_client_class=llm_client_class,
|
|
200
|
+
revision=revision,
|
|
201
|
+
spans=spans,
|
|
202
|
+
)
|
|
203
|
+
for revision in revisions
|
|
204
|
+
]
|
|
205
|
+
):
|
|
206
|
+
yield payload
|
|
500
207
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
208
|
+
async with info.context.db() as session:
|
|
209
|
+
if (
|
|
210
|
+
playground_project_id := await session.scalar(
|
|
211
|
+
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
|
|
212
|
+
)
|
|
213
|
+
) is None:
|
|
214
|
+
playground_project_id = await session.scalar(
|
|
215
|
+
insert(models.Project)
|
|
216
|
+
.returning(models.Project.id)
|
|
217
|
+
.values(
|
|
218
|
+
name=PLAYGROUND_PROJECT_NAME,
|
|
219
|
+
description="Traces from prompt playground",
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
db_spans = {
|
|
223
|
+
example_id: span.add_to_session(session, playground_project_id)
|
|
224
|
+
for example_id, span in spans.items()
|
|
225
|
+
}
|
|
226
|
+
assert (
|
|
227
|
+
dataset_name := await session.scalar(
|
|
228
|
+
select(models.Dataset.name).where(models.Dataset.id == dataset_id)
|
|
229
|
+
)
|
|
230
|
+
) is not None
|
|
231
|
+
if version_id is None:
|
|
232
|
+
resolved_version_id = await session.scalar(
|
|
233
|
+
select(models.DatasetVersion.id)
|
|
234
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
235
|
+
.order_by(models.DatasetVersion.id.desc())
|
|
236
|
+
.limit(1)
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
resolved_version_id = await session.scalar(
|
|
240
|
+
select(models.DatasetVersion.id).where(
|
|
241
|
+
and_(
|
|
242
|
+
models.DatasetVersion.dataset_id == dataset_id,
|
|
243
|
+
models.DatasetVersion.id == version_id,
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
)
|
|
247
|
+
assert resolved_version_id is not None
|
|
248
|
+
resolved_version_node_id = GlobalID(DatasetVersion.__name__, str(resolved_version_id))
|
|
249
|
+
experiment = models.Experiment(
|
|
250
|
+
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
251
|
+
dataset_version_id=resolved_version_id,
|
|
252
|
+
name=input.experiment_name or _DEFAULT_PLAYGROUND_EXPERIMENT_NAME,
|
|
253
|
+
description=input.experiment_description
|
|
254
|
+
or _default_playground_experiment_description(dataset_name=dataset_name),
|
|
255
|
+
repetitions=1,
|
|
256
|
+
metadata_=input.experiment_metadata
|
|
257
|
+
or _default_playground_experiment_metadata(
|
|
258
|
+
dataset_name=dataset_name,
|
|
259
|
+
dataset_id=input.dataset_id,
|
|
260
|
+
version_id=resolved_version_node_id,
|
|
261
|
+
),
|
|
262
|
+
project_name=PLAYGROUND_PROJECT_NAME,
|
|
513
263
|
)
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
264
|
+
session.add(experiment)
|
|
265
|
+
await session.flush()
|
|
266
|
+
runs = [
|
|
267
|
+
models.ExperimentRun(
|
|
268
|
+
experiment_id=experiment.id,
|
|
269
|
+
dataset_example_id=from_global_id_with_expected_type(
|
|
270
|
+
example_id, DatasetExample.__name__
|
|
271
|
+
),
|
|
272
|
+
trace_id=span.trace_id,
|
|
273
|
+
output=models.ExperimentRunOutput(
|
|
274
|
+
task_output=_get_playground_experiment_task_output(span)
|
|
275
|
+
),
|
|
276
|
+
repetition_number=1,
|
|
277
|
+
start_time=span.start_time,
|
|
278
|
+
end_time=span.end_time,
|
|
279
|
+
error=error_message
|
|
280
|
+
if (error_message := span.error_message) is not None
|
|
281
|
+
else None,
|
|
282
|
+
prompt_token_count=get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT),
|
|
283
|
+
completion_token_count=get_attribute_value(
|
|
284
|
+
span.attributes, LLM_TOKEN_COUNT_COMPLETION
|
|
285
|
+
),
|
|
286
|
+
)
|
|
287
|
+
for example_id, span in spans.items()
|
|
288
|
+
]
|
|
289
|
+
session.add_all(runs)
|
|
290
|
+
await session.flush()
|
|
291
|
+
for example_id in spans:
|
|
292
|
+
yield FinishedChatCompletion(
|
|
293
|
+
span=to_gql_span(db_spans[example_id]),
|
|
294
|
+
dataset_example_id=example_id,
|
|
518
295
|
)
|
|
296
|
+
yield ChatCompletionOverDatasetSubscriptionResult(experiment=to_gql_experiment(experiment))
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
async def _stream_chat_completion_over_dataset_example(
|
|
300
|
+
*,
|
|
301
|
+
input: ChatCompletionOverDatasetInput,
|
|
302
|
+
llm_client_class: type["PlaygroundStreamingClient"],
|
|
303
|
+
revision: models.DatasetExampleRevision,
|
|
304
|
+
spans: dict[DatasetExampleID, streaming_llm_span],
|
|
305
|
+
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
306
|
+
example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
|
|
307
|
+
llm_client = llm_client_class(
|
|
308
|
+
model=input.model,
|
|
309
|
+
api_key=input.api_key,
|
|
310
|
+
)
|
|
311
|
+
invocation_parameters = llm_client.construct_invocation_parameters(input.invocation_parameters)
|
|
312
|
+
messages = [
|
|
313
|
+
(
|
|
314
|
+
message.role,
|
|
315
|
+
message.content,
|
|
316
|
+
message.tool_call_id if isinstance(message.tool_call_id, str) else None,
|
|
317
|
+
message.tool_calls if isinstance(message.tool_calls, list) else None,
|
|
318
|
+
)
|
|
319
|
+
for message in input.messages
|
|
320
|
+
]
|
|
321
|
+
try:
|
|
322
|
+
messages = list(
|
|
323
|
+
_formatted_messages(
|
|
324
|
+
messages=messages,
|
|
325
|
+
template_language=input.template_language,
|
|
326
|
+
template_variables=revision.input,
|
|
327
|
+
)
|
|
328
|
+
)
|
|
329
|
+
except TemplateFormatterError as error:
|
|
330
|
+
yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
|
|
331
|
+
return
|
|
332
|
+
span = streaming_llm_span(
|
|
333
|
+
input=input,
|
|
334
|
+
messages=messages,
|
|
335
|
+
invocation_parameters=invocation_parameters,
|
|
336
|
+
)
|
|
337
|
+
spans[example_id] = span
|
|
338
|
+
async with span:
|
|
339
|
+
async for chunk in llm_client.chat_completion_create(
|
|
340
|
+
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
341
|
+
):
|
|
342
|
+
span.add_response_chunk(chunk)
|
|
343
|
+
chunk.dataset_example_id = example_id
|
|
344
|
+
yield chunk
|
|
345
|
+
span.set_attributes(llm_client.attributes)
|
|
346
|
+
if span.error_message is not None:
|
|
347
|
+
yield ChatCompletionSubscriptionError(
|
|
348
|
+
message=span.error_message, dataset_example_id=example_id
|
|
349
|
+
)
|
|
519
350
|
|
|
520
351
|
|
|
521
|
-
def
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
352
|
+
async def _merge_iterators(
|
|
353
|
+
iterators: Collection[AsyncIterator[GenericType]],
|
|
354
|
+
) -> AsyncIterator[GenericType]:
|
|
355
|
+
tasks: dict[AsyncIterator[GenericType], Task[GenericType]] = {
|
|
356
|
+
iterable: _as_task(iterable) for iterable in iterators
|
|
357
|
+
}
|
|
358
|
+
while tasks:
|
|
359
|
+
completed_tasks, _ = await wait(tasks.values(), return_when=FIRST_COMPLETED)
|
|
360
|
+
for task in completed_tasks:
|
|
361
|
+
iterator = next(it for it, t in tasks.items() if t == task)
|
|
362
|
+
try:
|
|
363
|
+
yield task.result()
|
|
364
|
+
except StopAsyncIteration:
|
|
365
|
+
del tasks[iterator]
|
|
366
|
+
except Exception as error:
|
|
367
|
+
del tasks[iterator]
|
|
368
|
+
logger.exception(error)
|
|
369
|
+
else:
|
|
370
|
+
tasks[iterator] = _as_task(iterator)
|
|
526
371
|
|
|
527
372
|
|
|
528
|
-
def
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
return
|
|
373
|
+
def _as_task(iterable: AsyncIterator[GenericType]) -> Task[GenericType]:
|
|
374
|
+
return create_task(_as_coroutine(iterable))
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
|
|
378
|
+
return await iterable.__anext__()
|
|
534
379
|
|
|
535
380
|
|
|
536
381
|
def _formatted_messages(
|
|
537
|
-
|
|
538
|
-
|
|
382
|
+
*,
|
|
383
|
+
messages: Iterable[ChatCompletionMessage],
|
|
384
|
+
template_language: TemplateLanguage,
|
|
385
|
+
template_variables: Mapping[str, Any],
|
|
386
|
+
) -> Iterator[tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]]]:
|
|
539
387
|
"""
|
|
540
388
|
Formats the messages using the given template options.
|
|
541
389
|
"""
|
|
542
|
-
template_formatter = _template_formatter(template_language=
|
|
543
|
-
|
|
390
|
+
template_formatter = _template_formatter(template_language=template_language)
|
|
391
|
+
(
|
|
392
|
+
roles,
|
|
393
|
+
templates,
|
|
394
|
+
tool_call_id,
|
|
395
|
+
tool_calls,
|
|
396
|
+
) = zip(*messages)
|
|
544
397
|
formatted_templates = map(
|
|
545
|
-
lambda template: template_formatter.format(template, **
|
|
398
|
+
lambda template: template_formatter.format(template, **template_variables),
|
|
546
399
|
templates,
|
|
547
400
|
)
|
|
548
|
-
formatted_messages = zip(roles, formatted_templates)
|
|
401
|
+
formatted_messages = zip(roles, formatted_templates, tool_call_id, tool_calls)
|
|
549
402
|
return formatted_messages
|
|
550
403
|
|
|
551
404
|
|
|
@@ -560,29 +413,29 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
|
|
|
560
413
|
assert_never(template_language)
|
|
561
414
|
|
|
562
415
|
|
|
563
|
-
|
|
416
|
+
def _get_playground_experiment_task_output(
|
|
417
|
+
span: streaming_llm_span,
|
|
418
|
+
) -> Any:
|
|
419
|
+
return get_attribute_value(span.attributes, LLM_OUTPUT_MESSAGES)
|
|
564
420
|
|
|
565
|
-
LLM = OpenInferenceSpanKindValues.LLM.value
|
|
566
421
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
|
|
573
|
-
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
574
|
-
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
|
|
575
|
-
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
|
|
576
|
-
LLM_TOOLS = SpanAttributes.LLM_TOOLS
|
|
577
|
-
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
578
|
-
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
579
|
-
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
|
|
422
|
+
_DEFAULT_PLAYGROUND_EXPERIMENT_NAME = "playground-experiment"
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _default_playground_experiment_description(dataset_name: str) -> str:
|
|
426
|
+
return f'Playground experiment for dataset "{dataset_name}"'
|
|
580
427
|
|
|
581
|
-
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
|
|
582
|
-
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
|
|
583
|
-
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
|
|
584
428
|
|
|
585
|
-
|
|
586
|
-
|
|
429
|
+
def _default_playground_experiment_metadata(
|
|
430
|
+
dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
|
|
431
|
+
) -> dict[str, Any]:
|
|
432
|
+
return {
|
|
433
|
+
"dataset_name": dataset_name,
|
|
434
|
+
"dataset_id": str(dataset_id),
|
|
435
|
+
"dataset_version_id": str(version_id),
|
|
436
|
+
}
|
|
587
437
|
|
|
588
|
-
|
|
438
|
+
|
|
439
|
+
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
440
|
+
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
441
|
+
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|