arize-phoenix 5.8.0__py3-none-any.whl → 5.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/METADATA +1 -1
- {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/RECORD +27 -27
- {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/WHEEL +1 -1
- phoenix/config.py +13 -1
- phoenix/db/helpers.py +55 -1
- phoenix/server/api/helpers/playground_clients.py +160 -8
- phoenix/server/api/mutations/chat_mutations.py +198 -11
- phoenix/server/api/queries.py +5 -1
- phoenix/server/api/routers/oauth2.py +55 -23
- phoenix/server/api/routers/v1/spans.py +25 -1
- phoenix/server/api/types/ExperimentRun.py +38 -1
- phoenix/server/api/types/GenerativeProvider.py +2 -1
- phoenix/server/app.py +7 -2
- phoenix/server/static/.vite/manifest.json +32 -32
- phoenix/server/static/assets/{components-MllbfxfJ.js → components-BcvRmBnN.js} +320 -297
- phoenix/server/static/assets/{index-BVO2YcT1.js → index-BF4RUiOz.js} +2 -2
- phoenix/server/static/assets/{pages-BHfC6jnL.js → pages-CM_Zho_x.js} +617 -454
- phoenix/server/static/assets/{vendor-BEuNhfwH.js → vendor-Bjm5T3cE.js} +181 -181
- phoenix/server/static/assets/{vendor-arizeai-Bskhzyjm.js → vendor-arizeai-CQhWGEdL.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-DLlXCf0x.js → vendor-codemirror-CdtiO80y.js} +5 -5
- phoenix/server/static/assets/{vendor-recharts-CRqhvLYg.js → vendor-recharts-BqWon6Py.js} +1 -1
- phoenix/session/client.py +27 -7
- phoenix/utilities/json.py +31 -1
- phoenix/version.py +1 -1
- {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-5.8.0.dist-info → arize_phoenix-5.9.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
import
|
|
2
|
-
from dataclasses import asdict
|
|
1
|
+
import asyncio
|
|
2
|
+
from dataclasses import asdict, field
|
|
3
3
|
from datetime import datetime, timezone
|
|
4
4
|
from itertools import chain
|
|
5
5
|
from traceback import format_exc
|
|
6
|
-
from typing import Any, Iterable, Iterator, List, Optional
|
|
6
|
+
from typing import Any, Iterable, Iterator, List, Optional, Union
|
|
7
7
|
|
|
8
8
|
import strawberry
|
|
9
9
|
from openinference.instrumentation import safe_json_dumps
|
|
@@ -18,14 +18,19 @@ from openinference.semconv.trace import (
|
|
|
18
18
|
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
|
|
19
19
|
from opentelemetry.trace import StatusCode
|
|
20
20
|
from sqlalchemy import insert, select
|
|
21
|
+
from strawberry.relay import GlobalID
|
|
21
22
|
from strawberry.types import Info
|
|
22
23
|
from typing_extensions import assert_never
|
|
23
24
|
|
|
24
25
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
25
26
|
from phoenix.db import models
|
|
27
|
+
from phoenix.db.helpers import get_dataset_example_revisions
|
|
26
28
|
from phoenix.server.api.context import Context
|
|
27
|
-
from phoenix.server.api.exceptions import BadRequest
|
|
28
|
-
from phoenix.server.api.helpers.playground_clients import
|
|
29
|
+
from phoenix.server.api.exceptions import BadRequest, NotFound
|
|
30
|
+
from phoenix.server.api.helpers.playground_clients import (
|
|
31
|
+
PlaygroundStreamingClient,
|
|
32
|
+
initialize_playground_clients,
|
|
33
|
+
)
|
|
29
34
|
from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
|
|
30
35
|
from phoenix.server.api.helpers.playground_spans import (
|
|
31
36
|
input_value_and_mime_type,
|
|
@@ -35,17 +40,28 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
35
40
|
llm_span_kind,
|
|
36
41
|
llm_tools,
|
|
37
42
|
)
|
|
38
|
-
from phoenix.server.api.input_types.ChatCompletionInput import
|
|
43
|
+
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
44
|
+
ChatCompletionInput,
|
|
45
|
+
ChatCompletionOverDatasetInput,
|
|
46
|
+
)
|
|
39
47
|
from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
|
|
48
|
+
from phoenix.server.api.subscriptions import (
|
|
49
|
+
_default_playground_experiment_description,
|
|
50
|
+
_default_playground_experiment_metadata,
|
|
51
|
+
_default_playground_experiment_name,
|
|
52
|
+
)
|
|
40
53
|
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
41
54
|
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
|
|
42
55
|
TextChunk,
|
|
43
56
|
ToolCallChunk,
|
|
44
57
|
)
|
|
58
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
59
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
60
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
45
61
|
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
46
62
|
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
47
63
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
48
|
-
from phoenix.trace.attributes import unflatten
|
|
64
|
+
from phoenix.trace.attributes import get_attribute_value, unflatten
|
|
49
65
|
from phoenix.trace.schemas import SpanException
|
|
50
66
|
from phoenix.utilities.json import jsonify
|
|
51
67
|
from phoenix.utilities.template_formatters import (
|
|
@@ -79,21 +95,192 @@ class ChatCompletionMutationPayload:
|
|
|
79
95
|
error_message: Optional[str]
|
|
80
96
|
|
|
81
97
|
|
|
98
|
+
@strawberry.type
|
|
99
|
+
class ChatCompletionMutationError:
|
|
100
|
+
message: str
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@strawberry.type
|
|
104
|
+
class ChatCompletionOverDatasetMutationExamplePayload:
|
|
105
|
+
dataset_example_id: GlobalID
|
|
106
|
+
experiment_run_id: GlobalID
|
|
107
|
+
result: Union[ChatCompletionMutationPayload, ChatCompletionMutationError]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@strawberry.type
|
|
111
|
+
class ChatCompletionOverDatasetMutationPayload:
|
|
112
|
+
dataset_id: GlobalID
|
|
113
|
+
dataset_version_id: GlobalID
|
|
114
|
+
experiment_id: GlobalID
|
|
115
|
+
examples: list[ChatCompletionOverDatasetMutationExamplePayload] = field(default_factory=list)
|
|
116
|
+
|
|
117
|
+
|
|
82
118
|
@strawberry.type
|
|
83
119
|
class ChatCompletionMutationMixin:
|
|
84
120
|
@strawberry.mutation
|
|
121
|
+
@classmethod
|
|
122
|
+
async def chat_completion_over_dataset(
|
|
123
|
+
cls,
|
|
124
|
+
info: Info[Context, None],
|
|
125
|
+
input: ChatCompletionOverDatasetInput,
|
|
126
|
+
) -> ChatCompletionOverDatasetMutationPayload:
|
|
127
|
+
provider_key = input.model.provider_key
|
|
128
|
+
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
129
|
+
if llm_client_class is None:
|
|
130
|
+
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
131
|
+
llm_client = llm_client_class(
|
|
132
|
+
model=input.model,
|
|
133
|
+
api_key=input.api_key,
|
|
134
|
+
)
|
|
135
|
+
dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
|
|
136
|
+
dataset_version_id = (
|
|
137
|
+
from_global_id_with_expected_type(
|
|
138
|
+
global_id=input.dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
139
|
+
)
|
|
140
|
+
if input.dataset_version_id
|
|
141
|
+
else None
|
|
142
|
+
)
|
|
143
|
+
async with info.context.db() as session:
|
|
144
|
+
dataset = await session.scalar(select(models.Dataset).filter_by(id=dataset_id))
|
|
145
|
+
if dataset is None:
|
|
146
|
+
raise NotFound("Dataset not found")
|
|
147
|
+
if dataset_version_id is None:
|
|
148
|
+
resolved_version_id = await session.scalar(
|
|
149
|
+
select(models.DatasetVersion.id)
|
|
150
|
+
.filter_by(dataset_id=dataset_id)
|
|
151
|
+
.order_by(models.DatasetVersion.id.desc())
|
|
152
|
+
.limit(1)
|
|
153
|
+
)
|
|
154
|
+
if resolved_version_id is None:
|
|
155
|
+
raise NotFound("No versions found for the given dataset")
|
|
156
|
+
else:
|
|
157
|
+
resolved_version_id = dataset_version_id
|
|
158
|
+
revisions = [
|
|
159
|
+
revision
|
|
160
|
+
async for revision in await session.stream_scalars(
|
|
161
|
+
get_dataset_example_revisions(resolved_version_id)
|
|
162
|
+
)
|
|
163
|
+
]
|
|
164
|
+
if not revisions:
|
|
165
|
+
raise NotFound("No examples found for the given dataset and version")
|
|
166
|
+
experiment = models.Experiment(
|
|
167
|
+
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
168
|
+
dataset_version_id=resolved_version_id,
|
|
169
|
+
name=input.experiment_name or _default_playground_experiment_name(),
|
|
170
|
+
description=input.experiment_description
|
|
171
|
+
or _default_playground_experiment_description(dataset_name=dataset.name),
|
|
172
|
+
repetitions=1,
|
|
173
|
+
metadata_=input.experiment_metadata
|
|
174
|
+
or _default_playground_experiment_metadata(
|
|
175
|
+
dataset_name=dataset.name,
|
|
176
|
+
dataset_id=input.dataset_id,
|
|
177
|
+
version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
|
|
178
|
+
),
|
|
179
|
+
project_name=PLAYGROUND_PROJECT_NAME,
|
|
180
|
+
)
|
|
181
|
+
session.add(experiment)
|
|
182
|
+
await session.flush()
|
|
183
|
+
|
|
184
|
+
start_time = datetime.now(timezone.utc)
|
|
185
|
+
results = await asyncio.gather(
|
|
186
|
+
*(
|
|
187
|
+
cls._chat_completion(
|
|
188
|
+
info,
|
|
189
|
+
llm_client,
|
|
190
|
+
ChatCompletionInput(
|
|
191
|
+
model=input.model,
|
|
192
|
+
api_key=input.api_key,
|
|
193
|
+
messages=input.messages,
|
|
194
|
+
tools=input.tools,
|
|
195
|
+
invocation_parameters=input.invocation_parameters,
|
|
196
|
+
template=TemplateOptions(
|
|
197
|
+
language=input.template_language,
|
|
198
|
+
variables=revision.input,
|
|
199
|
+
),
|
|
200
|
+
),
|
|
201
|
+
)
|
|
202
|
+
for revision in revisions
|
|
203
|
+
),
|
|
204
|
+
return_exceptions=True,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
payload = ChatCompletionOverDatasetMutationPayload(
|
|
208
|
+
dataset_id=GlobalID(models.Dataset.__name__, str(dataset.id)),
|
|
209
|
+
dataset_version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
|
|
210
|
+
experiment_id=GlobalID(models.Experiment.__name__, str(experiment.id)),
|
|
211
|
+
)
|
|
212
|
+
experiment_runs = []
|
|
213
|
+
for revision, result in zip(revisions, results):
|
|
214
|
+
if isinstance(result, BaseException):
|
|
215
|
+
experiment_run = models.ExperimentRun(
|
|
216
|
+
experiment_id=experiment.id,
|
|
217
|
+
dataset_example_id=revision.dataset_example_id,
|
|
218
|
+
output={},
|
|
219
|
+
repetition_number=1,
|
|
220
|
+
start_time=start_time,
|
|
221
|
+
end_time=start_time,
|
|
222
|
+
error=str(result),
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
db_span = result.span.db_span
|
|
226
|
+
experiment_run = models.ExperimentRun(
|
|
227
|
+
experiment_id=experiment.id,
|
|
228
|
+
dataset_example_id=revision.dataset_example_id,
|
|
229
|
+
trace_id=str(result.span.context.trace_id),
|
|
230
|
+
output=models.ExperimentRunOutput(
|
|
231
|
+
task_output=get_attribute_value(db_span.attributes, LLM_OUTPUT_MESSAGES),
|
|
232
|
+
),
|
|
233
|
+
prompt_token_count=db_span.cumulative_llm_token_count_prompt,
|
|
234
|
+
completion_token_count=db_span.cumulative_llm_token_count_completion,
|
|
235
|
+
repetition_number=1,
|
|
236
|
+
start_time=result.span.start_time,
|
|
237
|
+
end_time=result.span.end_time,
|
|
238
|
+
error=str(result.error_message) if result.error_message else None,
|
|
239
|
+
)
|
|
240
|
+
experiment_runs.append(experiment_run)
|
|
241
|
+
|
|
242
|
+
async with info.context.db() as session:
|
|
243
|
+
session.add_all(experiment_runs)
|
|
244
|
+
await session.flush()
|
|
245
|
+
|
|
246
|
+
for revision, experiment_run, result in zip(revisions, experiment_runs, results):
|
|
247
|
+
dataset_example_id = GlobalID(
|
|
248
|
+
models.DatasetExample.__name__, str(revision.dataset_example_id)
|
|
249
|
+
)
|
|
250
|
+
experiment_run_id = GlobalID(models.ExperimentRun.__name__, str(experiment_run.id))
|
|
251
|
+
example_payload = ChatCompletionOverDatasetMutationExamplePayload(
|
|
252
|
+
dataset_example_id=dataset_example_id,
|
|
253
|
+
experiment_run_id=experiment_run_id,
|
|
254
|
+
result=result
|
|
255
|
+
if isinstance(result, ChatCompletionMutationPayload)
|
|
256
|
+
else ChatCompletionMutationError(message=str(result)),
|
|
257
|
+
)
|
|
258
|
+
payload.examples.append(example_payload)
|
|
259
|
+
return payload
|
|
260
|
+
|
|
261
|
+
@strawberry.mutation
|
|
262
|
+
@classmethod
|
|
85
263
|
async def chat_completion(
|
|
86
|
-
|
|
264
|
+
cls, info: Info[Context, None], input: ChatCompletionInput
|
|
87
265
|
) -> ChatCompletionMutationPayload:
|
|
88
266
|
provider_key = input.model.provider_key
|
|
89
267
|
llm_client_class = PLAYGROUND_CLIENT_REGISTRY.get_client(provider_key, input.model.name)
|
|
90
268
|
if llm_client_class is None:
|
|
91
269
|
raise BadRequest(f"No LLM client registered for provider '{provider_key}'")
|
|
92
|
-
attributes: dict[str, Any] = {}
|
|
93
270
|
llm_client = llm_client_class(
|
|
94
271
|
model=input.model,
|
|
95
272
|
api_key=input.api_key,
|
|
96
273
|
)
|
|
274
|
+
return await cls._chat_completion(info, llm_client, input)
|
|
275
|
+
|
|
276
|
+
@classmethod
|
|
277
|
+
async def _chat_completion(
|
|
278
|
+
cls,
|
|
279
|
+
info: Info[Context, None],
|
|
280
|
+
llm_client: PlaygroundStreamingClient,
|
|
281
|
+
input: ChatCompletionInput,
|
|
282
|
+
) -> ChatCompletionMutationPayload:
|
|
283
|
+
attributes: dict[str, Any] = {}
|
|
97
284
|
|
|
98
285
|
messages = [
|
|
99
286
|
(
|
|
@@ -122,7 +309,6 @@ class ChatCompletionMutationMixin:
|
|
|
122
309
|
llm_input_messages(messages),
|
|
123
310
|
llm_invocation_parameters(invocation_parameters),
|
|
124
311
|
input_value_and_mime_type(input),
|
|
125
|
-
**llm_client.attributes,
|
|
126
312
|
)
|
|
127
313
|
)
|
|
128
314
|
|
|
@@ -167,6 +353,7 @@ class ChatCompletionMutationMixin:
|
|
|
167
353
|
else:
|
|
168
354
|
end_time = normalize_datetime(dt=local_now(), tz=timezone.utc)
|
|
169
355
|
|
|
356
|
+
attributes.update(llm_client.attributes)
|
|
170
357
|
if text_content or tool_calls:
|
|
171
358
|
attributes.update(
|
|
172
359
|
chain(
|
|
@@ -306,7 +493,7 @@ def _llm_output_messages(
|
|
|
306
493
|
if arguments := tool_call.function.arguments:
|
|
307
494
|
yield (
|
|
308
495
|
f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
|
|
309
|
-
|
|
496
|
+
arguments,
|
|
310
497
|
)
|
|
311
498
|
|
|
312
499
|
|
phoenix/server/api/queries.py
CHANGED
|
@@ -48,6 +48,7 @@ from phoenix.server.api.input_types.DatasetSort import DatasetSort
|
|
|
48
48
|
from phoenix.server.api.input_types.InvocationParameters import (
|
|
49
49
|
InvocationParameter,
|
|
50
50
|
)
|
|
51
|
+
from phoenix.server.api.subscriptions import PLAYGROUND_PROJECT_NAME
|
|
51
52
|
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
52
53
|
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
53
54
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
@@ -237,7 +238,10 @@ class Query:
|
|
|
237
238
|
select(models.Project)
|
|
238
239
|
.outerjoin(
|
|
239
240
|
models.Experiment,
|
|
240
|
-
|
|
241
|
+
and_(
|
|
242
|
+
models.Project.name == models.Experiment.project_name,
|
|
243
|
+
models.Experiment.project_name != PLAYGROUND_PROJECT_NAME,
|
|
244
|
+
),
|
|
241
245
|
)
|
|
242
246
|
.where(models.Experiment.project_name.is_(None))
|
|
243
247
|
.order_by(models.Project.id)
|
|
@@ -14,7 +14,7 @@ from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, upd
|
|
|
14
14
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
15
15
|
from sqlalchemy.orm import joinedload
|
|
16
16
|
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
|
|
17
|
-
from starlette.datastructures import URL
|
|
17
|
+
from starlette.datastructures import URL, URLPath
|
|
18
18
|
from starlette.responses import RedirectResponse
|
|
19
19
|
from starlette.routing import Router
|
|
20
20
|
from starlette.status import HTTP_302_FOUND
|
|
@@ -86,8 +86,16 @@ async def login(
|
|
|
86
86
|
if not isinstance(
|
|
87
87
|
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
88
88
|
):
|
|
89
|
-
return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
|
|
90
|
-
|
|
89
|
+
return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
|
|
90
|
+
if (referer := request.headers.get("referer")) is not None:
|
|
91
|
+
# if the referer header is present, use it as the origin URL
|
|
92
|
+
parsed_url = urlparse(referer)
|
|
93
|
+
origin_url = _append_root_path_if_exists(
|
|
94
|
+
request=request, base_url=f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
# fall back to the base url as the origin URL
|
|
98
|
+
origin_url = str(request.base_url)
|
|
91
99
|
authorization_url_data = await oauth2_client.create_authorization_url(
|
|
92
100
|
redirect_uri=_get_create_tokens_endpoint(
|
|
93
101
|
request=request, origin_url=origin_url, idp_name=idp_name
|
|
@@ -124,22 +132,22 @@ async def create_tokens(
|
|
|
124
132
|
) -> RedirectResponse:
|
|
125
133
|
secret = request.app.state.get_secret()
|
|
126
134
|
if state != stored_state:
|
|
127
|
-
return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
|
|
135
|
+
return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
|
|
128
136
|
try:
|
|
129
137
|
payload = _parse_state_payload(secret=secret, state=state)
|
|
130
138
|
except JoseError:
|
|
131
|
-
return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
|
|
139
|
+
return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
|
|
132
140
|
if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
|
|
133
141
|
unquote(return_url)
|
|
134
142
|
):
|
|
135
|
-
return _redirect_to_login(error="Attempting login with unsafe return URL.")
|
|
143
|
+
return _redirect_to_login(request=request, error="Attempting login with unsafe return URL.")
|
|
136
144
|
assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
|
|
137
145
|
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
|
|
138
146
|
token_store: TokenStore = request.app.state.get_token_store()
|
|
139
147
|
if not isinstance(
|
|
140
148
|
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
141
149
|
):
|
|
142
|
-
return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
|
|
150
|
+
return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
|
|
143
151
|
try:
|
|
144
152
|
token_data = await oauth2_client.fetch_access_token(
|
|
145
153
|
state=state,
|
|
@@ -149,11 +157,12 @@ async def create_tokens(
|
|
|
149
157
|
),
|
|
150
158
|
)
|
|
151
159
|
except OAuthError as error:
|
|
152
|
-
return _redirect_to_login(error=str(error))
|
|
160
|
+
return _redirect_to_login(request=request, error=str(error))
|
|
153
161
|
_validate_token_data(token_data)
|
|
154
162
|
if "id_token" not in token_data:
|
|
155
163
|
return _redirect_to_login(
|
|
156
|
-
|
|
164
|
+
request=request,
|
|
165
|
+
error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect.",
|
|
157
166
|
)
|
|
158
167
|
user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
|
|
159
168
|
user_info = _parse_user_info(user_info)
|
|
@@ -165,14 +174,18 @@ async def create_tokens(
|
|
|
165
174
|
user_info=user_info,
|
|
166
175
|
)
|
|
167
176
|
except EmailAlreadyInUse as error:
|
|
168
|
-
return _redirect_to_login(error=str(error))
|
|
177
|
+
return _redirect_to_login(request=request, error=str(error))
|
|
169
178
|
access_token, refresh_token = await create_access_and_refresh_tokens(
|
|
170
179
|
user=user,
|
|
171
180
|
token_store=token_store,
|
|
172
181
|
access_token_expiry=access_token_expiry,
|
|
173
182
|
refresh_token_expiry=refresh_token_expiry,
|
|
174
183
|
)
|
|
175
|
-
|
|
184
|
+
redirect_path = _prepend_root_path_if_exists(request=request, path=return_url or "/")
|
|
185
|
+
response = RedirectResponse(
|
|
186
|
+
url=redirect_path,
|
|
187
|
+
status_code=HTTP_302_FOUND,
|
|
188
|
+
)
|
|
176
189
|
response = set_access_token_cookie(
|
|
177
190
|
response=response, access_token=access_token, max_age=access_token_expiry
|
|
178
191
|
)
|
|
@@ -352,17 +365,46 @@ class EmailAlreadyInUse(Exception):
|
|
|
352
365
|
pass
|
|
353
366
|
|
|
354
367
|
|
|
355
|
-
def _redirect_to_login(*, error: str) -> RedirectResponse:
|
|
368
|
+
def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
|
|
356
369
|
"""
|
|
357
370
|
Creates a RedirectResponse to the login page to display an error message.
|
|
358
371
|
"""
|
|
359
|
-
|
|
372
|
+
login_path = _prepend_root_path_if_exists(request=request, path="/login")
|
|
373
|
+
url = URL(login_path).include_query_params(error=error)
|
|
360
374
|
response = RedirectResponse(url=url)
|
|
361
375
|
response = delete_oauth2_state_cookie(response)
|
|
362
376
|
response = delete_oauth2_nonce_cookie(response)
|
|
363
377
|
return response
|
|
364
378
|
|
|
365
379
|
|
|
380
|
+
def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
|
|
381
|
+
"""
|
|
382
|
+
If a root path is configured, prepends it to the input path.
|
|
383
|
+
"""
|
|
384
|
+
if not path.startswith("/"):
|
|
385
|
+
raise ValueError("path must start with a forward slash")
|
|
386
|
+
root_path = _get_root_path(request=request)
|
|
387
|
+
if root_path.endswith("/"):
|
|
388
|
+
root_path = root_path.rstrip("/")
|
|
389
|
+
return root_path + path
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
|
|
393
|
+
"""
|
|
394
|
+
If a root path is configured, appends it to the input base url.
|
|
395
|
+
"""
|
|
396
|
+
if not (root_path := _get_root_path(request=request)):
|
|
397
|
+
return base_url
|
|
398
|
+
return str(URLPath(root_path).make_absolute_url(base_url=base_url))
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _get_root_path(*, request: Request) -> str:
|
|
402
|
+
"""
|
|
403
|
+
Gets the root path from the request.
|
|
404
|
+
"""
|
|
405
|
+
return str(request.scope.get("root_path", ""))
|
|
406
|
+
|
|
407
|
+
|
|
366
408
|
def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
|
|
367
409
|
"""
|
|
368
410
|
Gets the endpoint for create tokens route.
|
|
@@ -427,16 +469,6 @@ def _with_random_suffix(string: str) -> str:
|
|
|
427
469
|
return f"{string}-{randrange(10_000, 100_000)}"
|
|
428
470
|
|
|
429
471
|
|
|
430
|
-
def _get_origin_url(request: Request) -> str:
|
|
431
|
-
"""
|
|
432
|
-
Infers the origin URL from the request.
|
|
433
|
-
"""
|
|
434
|
-
if (referer := request.headers.get("referer")) is None:
|
|
435
|
-
return str(request.base_url)
|
|
436
|
-
parsed_url = urlparse(referer)
|
|
437
|
-
return f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
438
|
-
|
|
439
|
-
|
|
440
472
|
def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2StatePayload]:
|
|
441
473
|
"""
|
|
442
474
|
Determines whether the given object is an OAuth2 state payload.
|
|
@@ -1,8 +1,11 @@
|
|
|
1
|
+
from asyncio import get_running_loop
|
|
1
2
|
from collections.abc import AsyncIterator
|
|
2
3
|
from datetime import datetime, timezone
|
|
4
|
+
from secrets import token_urlsafe
|
|
3
5
|
from typing import Any, Literal, Optional
|
|
4
6
|
|
|
5
|
-
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from fastapi import APIRouter, Header, HTTPException, Query
|
|
6
9
|
from pydantic import Field
|
|
7
10
|
from sqlalchemy import select
|
|
8
11
|
from starlette.requests import Request
|
|
@@ -19,6 +22,7 @@ from phoenix.db.insertion.types import Precursors
|
|
|
19
22
|
from phoenix.server.api.routers.utils import df_to_bytes
|
|
20
23
|
from phoenix.server.dml_event import SpanAnnotationInsertEvent
|
|
21
24
|
from phoenix.trace.dsl import SpanQuery as SpanQuery_
|
|
25
|
+
from phoenix.utilities.json import encode_df_as_json_string
|
|
22
26
|
|
|
23
27
|
from .pydantic_compat import V1RoutesBaseModel
|
|
24
28
|
from .utils import RequestBody, ResponseBody, add_errors_to_responses
|
|
@@ -72,6 +76,7 @@ class QuerySpansRequestBody(V1RoutesBaseModel):
|
|
|
72
76
|
async def query_spans_handler(
|
|
73
77
|
request: Request,
|
|
74
78
|
request_body: QuerySpansRequestBody,
|
|
79
|
+
accept: Optional[str] = Header(None),
|
|
75
80
|
project_name: Optional[str] = Query(
|
|
76
81
|
default=None, description="The project name to get evaluations from"
|
|
77
82
|
),
|
|
@@ -116,6 +121,13 @@ async def query_spans_handler(
|
|
|
116
121
|
if not results:
|
|
117
122
|
raise HTTPException(status_code=HTTP_404_NOT_FOUND)
|
|
118
123
|
|
|
124
|
+
if accept == "application/json":
|
|
125
|
+
boundary_token = token_urlsafe(64)
|
|
126
|
+
return StreamingResponse(
|
|
127
|
+
content=_json_multipart(results, boundary_token),
|
|
128
|
+
media_type=f"multipart/mixed; boundary={boundary_token}",
|
|
129
|
+
)
|
|
130
|
+
|
|
119
131
|
async def content() -> AsyncIterator[bytes]:
|
|
120
132
|
for result in results:
|
|
121
133
|
yield df_to_bytes(result)
|
|
@@ -126,6 +138,18 @@ async def query_spans_handler(
|
|
|
126
138
|
)
|
|
127
139
|
|
|
128
140
|
|
|
141
|
+
async def _json_multipart(
|
|
142
|
+
results: list[pd.DataFrame],
|
|
143
|
+
boundary_token: str,
|
|
144
|
+
) -> AsyncIterator[str]:
|
|
145
|
+
for df in results:
|
|
146
|
+
yield f"--{boundary_token}\r\n"
|
|
147
|
+
yield "Content-Type: application/json\r\n\r\n"
|
|
148
|
+
yield await get_running_loop().run_in_executor(None, encode_df_as_json_string, df)
|
|
149
|
+
yield "\r\n"
|
|
150
|
+
yield f"--{boundary_token}--\r\n"
|
|
151
|
+
|
|
152
|
+
|
|
129
153
|
@router.get("/spans", include_in_schema=False, deprecated=True)
|
|
130
154
|
async def get_spans_handler(
|
|
131
155
|
request: Request,
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from sqlalchemy.orm import load_only
|
|
5
7
|
from strawberry import UNSET
|
|
6
8
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
7
9
|
from strawberry.scalars import JSON
|
|
@@ -20,6 +22,9 @@ from phoenix.server.api.types.pagination import (
|
|
|
20
22
|
)
|
|
21
23
|
from phoenix.server.api.types.Trace import Trace
|
|
22
24
|
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
@strawberry.type
|
|
25
30
|
class ExperimentRun(Node):
|
|
@@ -62,6 +67,38 @@ class ExperimentRun(Node):
|
|
|
62
67
|
trace_rowid, project_rowid = trace
|
|
63
68
|
return Trace(id_attr=trace_rowid, trace_id=self.trace_id, project_rowid=project_rowid)
|
|
64
69
|
|
|
70
|
+
@strawberry.field
|
|
71
|
+
async def example(
|
|
72
|
+
self, info: Info
|
|
73
|
+
) -> Annotated[
|
|
74
|
+
"DatasetExample", strawberry.lazy("phoenix.server.api.types.DatasetExample")
|
|
75
|
+
]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
|
|
76
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
77
|
+
|
|
78
|
+
async with info.context.db() as session:
|
|
79
|
+
assert (
|
|
80
|
+
result := await session.execute(
|
|
81
|
+
select(models.DatasetExample, models.Experiment.dataset_version_id)
|
|
82
|
+
.select_from(models.ExperimentRun)
|
|
83
|
+
.join(
|
|
84
|
+
models.DatasetExample,
|
|
85
|
+
models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
|
|
86
|
+
)
|
|
87
|
+
.join(
|
|
88
|
+
models.Experiment,
|
|
89
|
+
models.Experiment.id == models.ExperimentRun.experiment_id,
|
|
90
|
+
)
|
|
91
|
+
.where(models.ExperimentRun.id == self.id_attr)
|
|
92
|
+
.options(load_only(models.DatasetExample.id, models.DatasetExample.created_at))
|
|
93
|
+
)
|
|
94
|
+
) is not None
|
|
95
|
+
example, version_id = result.first()
|
|
96
|
+
return DatasetExample(
|
|
97
|
+
id_attr=example.id,
|
|
98
|
+
created_at=example.created_at,
|
|
99
|
+
version_id=version_id,
|
|
100
|
+
)
|
|
101
|
+
|
|
65
102
|
|
|
66
103
|
def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
|
|
67
104
|
"""
|
|
@@ -8,6 +8,7 @@ class GenerativeProviderKey(Enum):
|
|
|
8
8
|
OPENAI = "OpenAI"
|
|
9
9
|
ANTHROPIC = "Anthropic"
|
|
10
10
|
AZURE_OPENAI = "Azure OpenAI"
|
|
11
|
+
GEMINI = "Google AI Studio"
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
@strawberry.type
|
|
@@ -24,7 +25,7 @@ class GenerativeProvider:
|
|
|
24
25
|
|
|
25
26
|
default_client = PLAYGROUND_CLIENT_REGISTRY.get_client(self.key, PROVIDER_DEFAULT)
|
|
26
27
|
if default_client:
|
|
27
|
-
return default_client.dependencies()
|
|
28
|
+
return [dependency.name for dependency in default_client.dependencies()]
|
|
28
29
|
return []
|
|
29
30
|
|
|
30
31
|
@strawberry.field
|
phoenix/server/app.py
CHANGED
|
@@ -253,7 +253,7 @@ class Static(StaticFiles):
|
|
|
253
253
|
|
|
254
254
|
|
|
255
255
|
class RequestOriginHostnameValidator(BaseHTTPMiddleware):
|
|
256
|
-
def __init__(self, trusted_hostnames: list[str],
|
|
256
|
+
def __init__(self, *args: Any, trusted_hostnames: list[str], **kwargs: Any) -> None:
|
|
257
257
|
super().__init__(*args, **kwargs)
|
|
258
258
|
self._trusted_hostnames = trusted_hostnames
|
|
259
259
|
|
|
@@ -767,7 +767,12 @@ def create_app(
|
|
|
767
767
|
middlewares.extend(user_fastapi_middlewares())
|
|
768
768
|
if origins := get_env_csrf_trusted_origins():
|
|
769
769
|
trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
|
|
770
|
-
middlewares.append(
|
|
770
|
+
middlewares.append(
|
|
771
|
+
Middleware(
|
|
772
|
+
RequestOriginHostnameValidator,
|
|
773
|
+
trusted_hostnames=trusted_hostnames,
|
|
774
|
+
)
|
|
775
|
+
)
|
|
771
776
|
elif email_sender or oauth2_client_configs:
|
|
772
777
|
logger.warning(
|
|
773
778
|
"CSRF protection can be enabled by listing trusted origins via "
|