arize-phoenix 5.7.0__py3-none-any.whl → 5.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/METADATA +3 -5
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/RECORD +31 -31
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/WHEEL +1 -1
- phoenix/config.py +19 -3
- phoenix/db/helpers.py +55 -1
- phoenix/server/api/helpers/playground_clients.py +283 -44
- phoenix/server/api/helpers/playground_spans.py +173 -76
- phoenix/server/api/input_types/InvocationParameters.py +7 -8
- phoenix/server/api/mutations/chat_mutations.py +244 -76
- phoenix/server/api/queries.py +5 -1
- phoenix/server/api/routers/v1/spans.py +25 -1
- phoenix/server/api/subscriptions.py +210 -158
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +5 -3
- phoenix/server/api/types/ExperimentRun.py +38 -1
- phoenix/server/api/types/GenerativeProvider.py +2 -1
- phoenix/server/app.py +21 -2
- phoenix/server/grpc_server.py +3 -1
- phoenix/server/static/.vite/manifest.json +32 -32
- phoenix/server/static/assets/{components-Csu8UKOs.js → components-DU-8CYbi.js} +370 -329
- phoenix/server/static/assets/{index-Bk5C9EA7.js → index-D9E16vvV.js} +2 -2
- phoenix/server/static/assets/pages-t09OI1rC.js +3966 -0
- phoenix/server/static/assets/{vendor-CtqfhlbC.js → vendor-D04tenE6.js} +181 -181
- phoenix/server/static/assets/{vendor-arizeai-C_3SBz56.js → vendor-arizeai-D3NxMQw0.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-wfdk9cjp.js → vendor-codemirror-XTiZSlqq.js} +5 -5
- phoenix/server/static/assets/{vendor-recharts-BiVnSv90.js → vendor-recharts-p0L0neVs.js} +1 -1
- phoenix/session/client.py +27 -7
- phoenix/utilities/json.py +31 -1
- phoenix/version.py +1 -1
- phoenix/server/static/assets/pages-UeWaKXNs.js +0 -3737
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.7.0.dist-info → arize_phoenix-5.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,14 +1,16 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import logging
|
|
2
|
-
from asyncio import FIRST_COMPLETED, Task, create_task, wait
|
|
3
|
-
from collections.abc import Iterator
|
|
3
|
+
from asyncio import FIRST_COMPLETED, Queue, QueueEmpty, Task, create_task, wait, wait_for
|
|
4
|
+
from collections.abc import AsyncIterator, Iterator
|
|
5
|
+
from datetime import datetime, timezone
|
|
4
6
|
from typing import (
|
|
5
7
|
Any,
|
|
6
|
-
AsyncIterator,
|
|
7
|
-
Collection,
|
|
8
8
|
Iterable,
|
|
9
9
|
Mapping,
|
|
10
10
|
Optional,
|
|
11
|
+
Sequence,
|
|
11
12
|
TypeVar,
|
|
13
|
+
cast,
|
|
12
14
|
)
|
|
13
15
|
|
|
14
16
|
import strawberry
|
|
@@ -19,9 +21,10 @@ from strawberry.relay.types import GlobalID
|
|
|
19
21
|
from strawberry.types import Info
|
|
20
22
|
from typing_extensions import TypeAlias, assert_never
|
|
21
23
|
|
|
24
|
+
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
22
25
|
from phoenix.db import models
|
|
23
26
|
from phoenix.server.api.context import Context
|
|
24
|
-
from phoenix.server.api.exceptions import BadRequest
|
|
27
|
+
from phoenix.server.api.exceptions import BadRequest, NotFound
|
|
25
28
|
from phoenix.server.api.helpers.playground_clients import (
|
|
26
29
|
PlaygroundStreamingClient,
|
|
27
30
|
initialize_playground_clients,
|
|
@@ -29,27 +32,33 @@ from phoenix.server.api.helpers.playground_clients import (
|
|
|
29
32
|
from phoenix.server.api.helpers.playground_registry import (
|
|
30
33
|
PLAYGROUND_CLIENT_REGISTRY,
|
|
31
34
|
)
|
|
32
|
-
from phoenix.server.api.helpers.playground_spans import
|
|
35
|
+
from phoenix.server.api.helpers.playground_spans import (
|
|
36
|
+
get_db_experiment_run,
|
|
37
|
+
get_db_span,
|
|
38
|
+
get_db_trace,
|
|
39
|
+
streaming_llm_span,
|
|
40
|
+
)
|
|
33
41
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
34
42
|
ChatCompletionInput,
|
|
35
43
|
ChatCompletionOverDatasetInput,
|
|
36
44
|
)
|
|
37
45
|
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
38
46
|
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
|
|
39
|
-
ChatCompletionOverDatasetSubscriptionResult,
|
|
40
47
|
ChatCompletionSubscriptionError,
|
|
48
|
+
ChatCompletionSubscriptionExperiment,
|
|
41
49
|
ChatCompletionSubscriptionPayload,
|
|
42
|
-
|
|
50
|
+
ChatCompletionSubscriptionResult,
|
|
43
51
|
)
|
|
44
52
|
from phoenix.server.api.types.Dataset import Dataset
|
|
45
53
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
46
54
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
47
55
|
from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
56
|
+
from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
|
|
48
57
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
49
58
|
from phoenix.server.api.types.Span import to_gql_span
|
|
50
59
|
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
51
60
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
52
|
-
from phoenix.
|
|
61
|
+
from phoenix.server.types import DbSessionFactory
|
|
53
62
|
from phoenix.utilities.template_formatters import (
|
|
54
63
|
FStringTemplateFormatter,
|
|
55
64
|
MustacheTemplateFormatter,
|
|
@@ -67,6 +76,9 @@ ChatCompletionMessage: TypeAlias = tuple[
|
|
|
67
76
|
ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
|
|
68
77
|
]
|
|
69
78
|
DatasetExampleID: TypeAlias = GlobalID
|
|
79
|
+
ChatCompletionResult: TypeAlias = tuple[
|
|
80
|
+
DatasetExampleID, Optional[models.Span], models.ExperimentRun
|
|
81
|
+
]
|
|
70
82
|
PLAYGROUND_PROJECT_NAME = "playground"
|
|
71
83
|
|
|
72
84
|
|
|
@@ -116,8 +128,8 @@ class Subscription:
|
|
|
116
128
|
span.add_response_chunk(chunk)
|
|
117
129
|
yield chunk
|
|
118
130
|
span.set_attributes(llm_client.attributes)
|
|
119
|
-
if span.
|
|
120
|
-
yield ChatCompletionSubscriptionError(message=span.
|
|
131
|
+
if span.status_message is not None:
|
|
132
|
+
yield ChatCompletionSubscriptionError(message=span.status_message)
|
|
121
133
|
async with info.context.db() as session:
|
|
122
134
|
if (
|
|
123
135
|
playground_project_id := await session.scalar(
|
|
@@ -132,10 +144,12 @@ class Subscription:
|
|
|
132
144
|
description="Traces from prompt playground",
|
|
133
145
|
)
|
|
134
146
|
)
|
|
135
|
-
|
|
147
|
+
db_trace = get_db_trace(span, playground_project_id)
|
|
148
|
+
db_span = get_db_span(span, db_trace)
|
|
149
|
+
session.add(db_span)
|
|
136
150
|
await session.flush()
|
|
137
|
-
yield FinishedChatCompletion(span=to_gql_span(db_span))
|
|
138
151
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
152
|
+
yield ChatCompletionSubscriptionResult(span=to_gql_span(db_span))
|
|
139
153
|
|
|
140
154
|
@strawberry.subscription
|
|
141
155
|
async def chat_completion_over_dataset(
|
|
@@ -154,58 +168,68 @@ class Subscription:
|
|
|
154
168
|
if input.dataset_version_id
|
|
155
169
|
else None
|
|
156
170
|
)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
162
|
-
)
|
|
163
|
-
if version_id:
|
|
164
|
-
version_id_subquery = (
|
|
165
|
-
select(models.DatasetVersion.id)
|
|
166
|
-
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
167
|
-
.where(models.DatasetVersion.id == version_id)
|
|
168
|
-
.scalar_subquery()
|
|
169
|
-
)
|
|
170
|
-
revision_ids = revision_ids.where(
|
|
171
|
-
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
172
|
-
)
|
|
173
|
-
query = (
|
|
174
|
-
select(models.DatasetExampleRevision)
|
|
175
|
-
.where(
|
|
176
|
-
and_(
|
|
177
|
-
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
178
|
-
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
171
|
+
async with info.context.db() as session:
|
|
172
|
+
if (
|
|
173
|
+
dataset := await session.scalar(
|
|
174
|
+
select(models.Dataset).where(models.Dataset.id == dataset_id)
|
|
179
175
|
)
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
176
|
+
) is None:
|
|
177
|
+
raise NotFound(f"Could not find dataset with ID {dataset_id}")
|
|
178
|
+
if version_id is None:
|
|
179
|
+
if (
|
|
180
|
+
resolved_version_id := await session.scalar(
|
|
181
|
+
select(models.DatasetVersion.id)
|
|
182
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
183
|
+
.order_by(models.DatasetVersion.id.desc())
|
|
184
|
+
.limit(1)
|
|
185
|
+
)
|
|
186
|
+
) is None:
|
|
187
|
+
raise NotFound(f"No versions found for dataset with ID {dataset_id}")
|
|
188
|
+
else:
|
|
189
|
+
if (
|
|
190
|
+
resolved_version_id := await session.scalar(
|
|
191
|
+
select(models.DatasetVersion.id).where(
|
|
192
|
+
and_(
|
|
193
|
+
models.DatasetVersion.dataset_id == dataset_id,
|
|
194
|
+
models.DatasetVersion.id == version_id,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
) is None:
|
|
199
|
+
raise NotFound(f"Could not find dataset version with ID {version_id}")
|
|
200
|
+
revision_ids = (
|
|
201
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
202
|
+
.join(models.DatasetExample)
|
|
203
|
+
.where(
|
|
204
|
+
and_(
|
|
205
|
+
models.DatasetExample.dataset_id == dataset_id,
|
|
206
|
+
models.DatasetExampleRevision.dataset_version_id <= resolved_version_id,
|
|
207
|
+
)
|
|
186
208
|
)
|
|
209
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
187
210
|
)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
211
|
+
if not (
|
|
212
|
+
revisions := [
|
|
213
|
+
rev
|
|
214
|
+
async for rev in await session.stream_scalars(
|
|
215
|
+
select(models.DatasetExampleRevision)
|
|
216
|
+
.where(
|
|
217
|
+
and_(
|
|
218
|
+
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
219
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id.asc())
|
|
223
|
+
.options(
|
|
224
|
+
load_only(
|
|
225
|
+
models.DatasetExampleRevision.dataset_example_id,
|
|
226
|
+
models.DatasetExampleRevision.input,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
)
|
|
230
|
+
]
|
|
231
|
+
):
|
|
232
|
+
raise NotFound("No examples found for the given dataset and version")
|
|
209
233
|
if (
|
|
210
234
|
playground_project_id := await session.scalar(
|
|
211
235
|
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
|
|
@@ -219,81 +243,70 @@ class Subscription:
|
|
|
219
243
|
description="Traces from prompt playground",
|
|
220
244
|
)
|
|
221
245
|
)
|
|
222
|
-
db_spans = {
|
|
223
|
-
example_id: span.add_to_session(session, playground_project_id)
|
|
224
|
-
for example_id, span in spans.items()
|
|
225
|
-
}
|
|
226
|
-
assert (
|
|
227
|
-
dataset_name := await session.scalar(
|
|
228
|
-
select(models.Dataset.name).where(models.Dataset.id == dataset_id)
|
|
229
|
-
)
|
|
230
|
-
) is not None
|
|
231
|
-
if version_id is None:
|
|
232
|
-
resolved_version_id = await session.scalar(
|
|
233
|
-
select(models.DatasetVersion.id)
|
|
234
|
-
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
235
|
-
.order_by(models.DatasetVersion.id.desc())
|
|
236
|
-
.limit(1)
|
|
237
|
-
)
|
|
238
|
-
else:
|
|
239
|
-
resolved_version_id = await session.scalar(
|
|
240
|
-
select(models.DatasetVersion.id).where(
|
|
241
|
-
and_(
|
|
242
|
-
models.DatasetVersion.dataset_id == dataset_id,
|
|
243
|
-
models.DatasetVersion.id == version_id,
|
|
244
|
-
)
|
|
245
|
-
)
|
|
246
|
-
)
|
|
247
|
-
assert resolved_version_id is not None
|
|
248
|
-
resolved_version_node_id = GlobalID(DatasetVersion.__name__, str(resolved_version_id))
|
|
249
246
|
experiment = models.Experiment(
|
|
250
247
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
251
248
|
dataset_version_id=resolved_version_id,
|
|
252
|
-
name=input.experiment_name or
|
|
249
|
+
name=input.experiment_name or _default_playground_experiment_name(),
|
|
253
250
|
description=input.experiment_description
|
|
254
|
-
or _default_playground_experiment_description(dataset_name=
|
|
251
|
+
or _default_playground_experiment_description(dataset_name=dataset.name),
|
|
255
252
|
repetitions=1,
|
|
256
253
|
metadata_=input.experiment_metadata
|
|
257
254
|
or _default_playground_experiment_metadata(
|
|
258
|
-
dataset_name=
|
|
255
|
+
dataset_name=dataset.name,
|
|
259
256
|
dataset_id=input.dataset_id,
|
|
260
|
-
version_id=
|
|
257
|
+
version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
|
|
261
258
|
),
|
|
262
259
|
project_name=PLAYGROUND_PROJECT_NAME,
|
|
263
260
|
)
|
|
264
261
|
session.add(experiment)
|
|
265
262
|
await session.flush()
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
error=error_message
|
|
280
|
-
if (error_message := span.error_message) is not None
|
|
281
|
-
else None,
|
|
282
|
-
prompt_token_count=get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT),
|
|
283
|
-
completion_token_count=get_attribute_value(
|
|
284
|
-
span.attributes, LLM_TOKEN_COUNT_COMPLETION
|
|
285
|
-
),
|
|
286
|
-
)
|
|
287
|
-
for example_id, span in spans.items()
|
|
288
|
-
]
|
|
289
|
-
session.add_all(runs)
|
|
290
|
-
await session.flush()
|
|
291
|
-
for example_id in spans:
|
|
292
|
-
yield FinishedChatCompletion(
|
|
293
|
-
span=to_gql_span(db_spans[example_id]),
|
|
294
|
-
dataset_example_id=example_id,
|
|
263
|
+
yield ChatCompletionSubscriptionExperiment(
|
|
264
|
+
experiment=to_gql_experiment(experiment)
|
|
265
|
+
) # eagerly yields experiment so it can be linked by consumers of the subscription
|
|
266
|
+
|
|
267
|
+
results_queue: Queue[ChatCompletionResult] = Queue()
|
|
268
|
+
chat_completion_streams = [
|
|
269
|
+
_stream_chat_completion_over_dataset_example(
|
|
270
|
+
input=input,
|
|
271
|
+
llm_client_class=llm_client_class,
|
|
272
|
+
revision=revision,
|
|
273
|
+
results_queue=results_queue,
|
|
274
|
+
experiment_id=experiment.id,
|
|
275
|
+
project_id=playground_project_id,
|
|
295
276
|
)
|
|
296
|
-
|
|
277
|
+
for revision in revisions
|
|
278
|
+
]
|
|
279
|
+
stream_to_async_tasks: dict[
|
|
280
|
+
AsyncIterator[ChatCompletionSubscriptionPayload],
|
|
281
|
+
Task[ChatCompletionSubscriptionPayload],
|
|
282
|
+
] = {iterator: _create_task_with_timeout(iterator) for iterator in chat_completion_streams}
|
|
283
|
+
batch_size = 10
|
|
284
|
+
while stream_to_async_tasks:
|
|
285
|
+
async_tasks_to_run = [task for task in stream_to_async_tasks.values()]
|
|
286
|
+
completed_tasks, _ = await wait(async_tasks_to_run, return_when=FIRST_COMPLETED)
|
|
287
|
+
for task in completed_tasks:
|
|
288
|
+
iterator = next(it for it, t in stream_to_async_tasks.items() if t == task)
|
|
289
|
+
try:
|
|
290
|
+
yield task.result()
|
|
291
|
+
except (StopAsyncIteration, asyncio.TimeoutError):
|
|
292
|
+
del stream_to_async_tasks[iterator] # removes exhausted iterator
|
|
293
|
+
except Exception as error:
|
|
294
|
+
del stream_to_async_tasks[iterator] # removes failed iterator
|
|
295
|
+
logger.exception(error)
|
|
296
|
+
else:
|
|
297
|
+
stream_to_async_tasks[iterator] = _create_task_with_timeout(iterator)
|
|
298
|
+
if results_queue.qsize() >= batch_size:
|
|
299
|
+
result_iterator = _chat_completion_result_payloads(
|
|
300
|
+
db=info.context.db, results=_drain_no_wait(results_queue)
|
|
301
|
+
)
|
|
302
|
+
stream_to_async_tasks[result_iterator] = _create_task_with_timeout(
|
|
303
|
+
result_iterator
|
|
304
|
+
)
|
|
305
|
+
if remaining_results := await _drain(results_queue):
|
|
306
|
+
async for result_payload in _chat_completion_result_payloads(
|
|
307
|
+
db=info.context.db, results=remaining_results
|
|
308
|
+
):
|
|
309
|
+
yield result_payload
|
|
297
310
|
|
|
298
311
|
|
|
299
312
|
async def _stream_chat_completion_over_dataset_example(
|
|
@@ -301,7 +314,9 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
301
314
|
input: ChatCompletionOverDatasetInput,
|
|
302
315
|
llm_client_class: type["PlaygroundStreamingClient"],
|
|
303
316
|
revision: models.DatasetExampleRevision,
|
|
304
|
-
|
|
317
|
+
results_queue: Queue[ChatCompletionResult],
|
|
318
|
+
experiment_id: int,
|
|
319
|
+
project_id: int,
|
|
305
320
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
306
321
|
example_id = GlobalID(DatasetExample.__name__, str(revision.dataset_example_id))
|
|
307
322
|
llm_client = llm_client_class(
|
|
@@ -319,6 +334,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
319
334
|
for message in input.messages
|
|
320
335
|
]
|
|
321
336
|
try:
|
|
337
|
+
format_start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
322
338
|
messages = list(
|
|
323
339
|
_formatted_messages(
|
|
324
340
|
messages=messages,
|
|
@@ -327,15 +343,31 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
327
343
|
)
|
|
328
344
|
)
|
|
329
345
|
except TemplateFormatterError as error:
|
|
346
|
+
format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
330
347
|
yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
|
|
348
|
+
await results_queue.put(
|
|
349
|
+
(
|
|
350
|
+
example_id,
|
|
351
|
+
None,
|
|
352
|
+
models.ExperimentRun(
|
|
353
|
+
experiment_id=experiment_id,
|
|
354
|
+
dataset_example_id=revision.dataset_example_id,
|
|
355
|
+
trace_id=None,
|
|
356
|
+
output={},
|
|
357
|
+
repetition_number=1,
|
|
358
|
+
start_time=format_start_time,
|
|
359
|
+
end_time=format_end_time,
|
|
360
|
+
error=str(error),
|
|
361
|
+
trace=None,
|
|
362
|
+
),
|
|
363
|
+
)
|
|
364
|
+
)
|
|
331
365
|
return
|
|
332
|
-
|
|
366
|
+
async with streaming_llm_span(
|
|
333
367
|
input=input,
|
|
334
368
|
messages=messages,
|
|
335
369
|
invocation_parameters=invocation_parameters,
|
|
336
|
-
)
|
|
337
|
-
spans[example_id] = span
|
|
338
|
-
async with span:
|
|
370
|
+
) as span:
|
|
339
371
|
async for chunk in llm_client.chat_completion_create(
|
|
340
372
|
messages=messages, tools=input.tools or [], **invocation_parameters
|
|
341
373
|
):
|
|
@@ -343,35 +375,60 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
343
375
|
chunk.dataset_example_id = example_id
|
|
344
376
|
yield chunk
|
|
345
377
|
span.set_attributes(llm_client.attributes)
|
|
346
|
-
|
|
378
|
+
db_trace = get_db_trace(span, project_id)
|
|
379
|
+
db_span = get_db_span(span, db_trace)
|
|
380
|
+
db_run = get_db_experiment_run(
|
|
381
|
+
db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
|
|
382
|
+
)
|
|
383
|
+
await results_queue.put((example_id, db_span, db_run))
|
|
384
|
+
if span.status_message is not None:
|
|
347
385
|
yield ChatCompletionSubscriptionError(
|
|
348
|
-
message=span.
|
|
386
|
+
message=span.status_message, dataset_example_id=example_id
|
|
349
387
|
)
|
|
350
388
|
|
|
351
389
|
|
|
352
|
-
async def
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
for
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
390
|
+
async def _chat_completion_result_payloads(
|
|
391
|
+
*,
|
|
392
|
+
db: DbSessionFactory,
|
|
393
|
+
results: Sequence[ChatCompletionResult],
|
|
394
|
+
) -> AsyncIterator[ChatCompletionSubscriptionResult]:
|
|
395
|
+
if not results:
|
|
396
|
+
return
|
|
397
|
+
async with db() as session:
|
|
398
|
+
for _, span, run in results:
|
|
399
|
+
if span:
|
|
400
|
+
session.add(span)
|
|
401
|
+
session.add(run)
|
|
402
|
+
await session.flush()
|
|
403
|
+
for example_id, span, run in results:
|
|
404
|
+
yield ChatCompletionSubscriptionResult(
|
|
405
|
+
span=to_gql_span(span) if span else None,
|
|
406
|
+
experiment_run=to_gql_experiment_run(run),
|
|
407
|
+
dataset_example_id=example_id,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _create_task_with_timeout(
|
|
412
|
+
iterable: AsyncIterator[GenericType], timeout_in_seconds: int = 60
|
|
413
|
+
) -> Task[GenericType]:
|
|
414
|
+
return create_task(wait_for(_as_coroutine(iterable), timeout=timeout_in_seconds))
|
|
415
|
+
|
|
371
416
|
|
|
417
|
+
async def _drain(queue: Queue[GenericType]) -> list[GenericType]:
|
|
418
|
+
values: list[GenericType] = []
|
|
419
|
+
while not queue.empty():
|
|
420
|
+
values.append(await queue.get())
|
|
421
|
+
return values
|
|
372
422
|
|
|
373
|
-
|
|
374
|
-
|
|
423
|
+
|
|
424
|
+
def _drain_no_wait(queue: Queue[GenericType]) -> list[GenericType]:
|
|
425
|
+
values: list[GenericType] = []
|
|
426
|
+
while True:
|
|
427
|
+
try:
|
|
428
|
+
values.append(queue.get_nowait())
|
|
429
|
+
except QueueEmpty:
|
|
430
|
+
break
|
|
431
|
+
return values
|
|
375
432
|
|
|
376
433
|
|
|
377
434
|
async def _as_coroutine(iterable: AsyncIterator[GenericType]) -> GenericType:
|
|
@@ -413,13 +470,8 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
|
|
|
413
470
|
assert_never(template_language)
|
|
414
471
|
|
|
415
472
|
|
|
416
|
-
def
|
|
417
|
-
|
|
418
|
-
) -> Any:
|
|
419
|
-
return get_attribute_value(span.attributes, LLM_OUTPUT_MESSAGES)
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
_DEFAULT_PLAYGROUND_EXPERIMENT_NAME = "playground-experiment"
|
|
473
|
+
def _default_playground_experiment_name() -> str:
|
|
474
|
+
return "playground-experiment"
|
|
423
475
|
|
|
424
476
|
|
|
425
477
|
def _default_playground_experiment_description(dataset_name: str) -> str:
|
|
@@ -4,6 +4,7 @@ import strawberry
|
|
|
4
4
|
from strawberry.relay import GlobalID
|
|
5
5
|
|
|
6
6
|
from .Experiment import Experiment
|
|
7
|
+
from .ExperimentRun import ExperimentRun
|
|
7
8
|
from .Span import Span
|
|
8
9
|
|
|
9
10
|
|
|
@@ -30,8 +31,9 @@ class ToolCallChunk(ChatCompletionSubscriptionPayload):
|
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
@strawberry.type
|
|
33
|
-
class
|
|
34
|
-
span: Span
|
|
34
|
+
class ChatCompletionSubscriptionResult(ChatCompletionSubscriptionPayload):
|
|
35
|
+
span: Optional[Span] = None
|
|
36
|
+
experiment_run: Optional[ExperimentRun] = None
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
@strawberry.type
|
|
@@ -40,5 +42,5 @@ class ChatCompletionSubscriptionError(ChatCompletionSubscriptionPayload):
|
|
|
40
42
|
|
|
41
43
|
|
|
42
44
|
@strawberry.type
|
|
43
|
-
class
|
|
45
|
+
class ChatCompletionSubscriptionExperiment(ChatCompletionSubscriptionPayload):
|
|
44
46
|
experiment: Experiment
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from sqlalchemy.orm import load_only
|
|
5
7
|
from strawberry import UNSET
|
|
6
8
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
7
9
|
from strawberry.scalars import JSON
|
|
@@ -20,6 +22,9 @@ from phoenix.server.api.types.pagination import (
|
|
|
20
22
|
)
|
|
21
23
|
from phoenix.server.api.types.Trace import Trace
|
|
22
24
|
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
@strawberry.type
|
|
25
30
|
class ExperimentRun(Node):
|
|
@@ -62,6 +67,38 @@ class ExperimentRun(Node):
|
|
|
62
67
|
trace_rowid, project_rowid = trace
|
|
63
68
|
return Trace(id_attr=trace_rowid, trace_id=self.trace_id, project_rowid=project_rowid)
|
|
64
69
|
|
|
70
|
+
@strawberry.field
|
|
71
|
+
async def example(
|
|
72
|
+
self, info: Info
|
|
73
|
+
) -> Annotated[
|
|
74
|
+
"DatasetExample", strawberry.lazy("phoenix.server.api.types.DatasetExample")
|
|
75
|
+
]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
|
|
76
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
77
|
+
|
|
78
|
+
async with info.context.db() as session:
|
|
79
|
+
assert (
|
|
80
|
+
result := await session.execute(
|
|
81
|
+
select(models.DatasetExample, models.Experiment.dataset_version_id)
|
|
82
|
+
.select_from(models.ExperimentRun)
|
|
83
|
+
.join(
|
|
84
|
+
models.DatasetExample,
|
|
85
|
+
models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
|
|
86
|
+
)
|
|
87
|
+
.join(
|
|
88
|
+
models.Experiment,
|
|
89
|
+
models.Experiment.id == models.ExperimentRun.experiment_id,
|
|
90
|
+
)
|
|
91
|
+
.where(models.ExperimentRun.id == self.id_attr)
|
|
92
|
+
.options(load_only(models.DatasetExample.id, models.DatasetExample.created_at))
|
|
93
|
+
)
|
|
94
|
+
) is not None
|
|
95
|
+
example, version_id = result.first()
|
|
96
|
+
return DatasetExample(
|
|
97
|
+
id_attr=example.id,
|
|
98
|
+
created_at=example.created_at,
|
|
99
|
+
version_id=version_id,
|
|
100
|
+
)
|
|
101
|
+
|
|
65
102
|
|
|
66
103
|
def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
|
|
67
104
|
"""
|
|
@@ -8,6 +8,7 @@ class GenerativeProviderKey(Enum):
|
|
|
8
8
|
OPENAI = "OpenAI"
|
|
9
9
|
ANTHROPIC = "Anthropic"
|
|
10
10
|
AZURE_OPENAI = "Azure OpenAI"
|
|
11
|
+
GEMINI = "Google AI Studio"
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@strawberry.type
|
|
@@ -24,7 +25,7 @@ class GenerativeProvider:
|
|
|
24
25
|
|
|
25
26
|
default_client = PLAYGROUND_CLIENT_REGISTRY.get_client(self.key, PROVIDER_DEFAULT)
|
|
26
27
|
if default_client:
|
|
27
|
-
return default_client.dependencies()
|
|
28
|
+
return [dependency.name for dependency in default_client.dependencies()]
|
|
28
29
|
return []
|
|
29
30
|
|
|
30
31
|
@strawberry.field
|
phoenix/server/app.py
CHANGED
|
@@ -26,6 +26,7 @@ import strawberry
|
|
|
26
26
|
from fastapi import APIRouter, Depends, FastAPI
|
|
27
27
|
from fastapi.middleware.gzip import GZipMiddleware
|
|
28
28
|
from fastapi.utils import is_body_allowed_for_status_code
|
|
29
|
+
from grpc.aio import ServerInterceptor
|
|
29
30
|
from sqlalchemy import select
|
|
30
31
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
31
32
|
from starlette.datastructures import State as StarletteState
|
|
@@ -54,6 +55,7 @@ from phoenix.config import (
|
|
|
54
55
|
get_env_csrf_trusted_origins,
|
|
55
56
|
get_env_fastapi_middleware_paths,
|
|
56
57
|
get_env_gql_extension_paths,
|
|
58
|
+
get_env_grpc_interceptor_paths,
|
|
57
59
|
get_env_host,
|
|
58
60
|
get_env_port,
|
|
59
61
|
server_instrumentation_is_enabled,
|
|
@@ -251,7 +253,7 @@ class Static(StaticFiles):
|
|
|
251
253
|
|
|
252
254
|
|
|
253
255
|
class RequestOriginHostnameValidator(BaseHTTPMiddleware):
|
|
254
|
-
def __init__(self, trusted_hostnames: list[str],
|
|
256
|
+
def __init__(self, *args: Any, trusted_hostnames: list[str], **kwargs: Any) -> None:
|
|
255
257
|
super().__init__(*args, **kwargs)
|
|
256
258
|
self._trusted_hostnames = trusted_hostnames
|
|
257
259
|
|
|
@@ -305,6 +307,17 @@ def user_gql_extensions() -> list[Union[type[SchemaExtension], SchemaExtension]]
|
|
|
305
307
|
return extensions
|
|
306
308
|
|
|
307
309
|
|
|
310
|
+
def user_grpc_interceptors() -> list[ServerInterceptor]:
|
|
311
|
+
paths = get_env_grpc_interceptor_paths()
|
|
312
|
+
interceptors = []
|
|
313
|
+
for file_path, object_name in paths:
|
|
314
|
+
interceptor_class = import_object_from_file(file_path, object_name)
|
|
315
|
+
if not issubclass(interceptor_class, ServerInterceptor):
|
|
316
|
+
raise TypeError(f"{interceptor_class} is not a subclass of ServerInterceptor")
|
|
317
|
+
interceptors.append(interceptor_class)
|
|
318
|
+
return interceptors
|
|
319
|
+
|
|
320
|
+
|
|
308
321
|
ProjectRowId: TypeAlias = int
|
|
309
322
|
|
|
310
323
|
|
|
@@ -479,6 +492,7 @@ def _lifespan(
|
|
|
479
492
|
tracer_provider=tracer_provider,
|
|
480
493
|
enable_prometheus=enable_prometheus,
|
|
481
494
|
token_store=token_store,
|
|
495
|
+
interceptors=user_grpc_interceptors(),
|
|
482
496
|
)
|
|
483
497
|
await stack.enter_async_context(grpc_server)
|
|
484
498
|
await stack.enter_async_context(dml_event_handler)
|
|
@@ -753,7 +767,12 @@ def create_app(
|
|
|
753
767
|
middlewares.extend(user_fastapi_middlewares())
|
|
754
768
|
if origins := get_env_csrf_trusted_origins():
|
|
755
769
|
trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
|
|
756
|
-
middlewares.append(
|
|
770
|
+
middlewares.append(
|
|
771
|
+
Middleware(
|
|
772
|
+
RequestOriginHostnameValidator,
|
|
773
|
+
trusted_hostnames=trusted_hostnames,
|
|
774
|
+
)
|
|
775
|
+
)
|
|
757
776
|
elif email_sender or oauth2_client_configs:
|
|
758
777
|
logger.warning(
|
|
759
778
|
"CSRF protection can be enabled by listing trusted origins via "
|