arize-phoenix 12.3.0__py3-none-any.whl → 12.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/METADATA +2 -1
- {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/RECORD +73 -72
- phoenix/auth.py +27 -2
- phoenix/config.py +302 -53
- phoenix/db/README.md +546 -28
- phoenix/db/models.py +3 -3
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/SpanSort.py +2 -1
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +6 -6
- phoenix/server/api/mutations/api_key_mutations.py +13 -5
- phoenix/server/api/mutations/chat_mutations.py +3 -3
- phoenix/server/api/mutations/dataset_label_mutations.py +6 -6
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/dataset_split_mutations.py +7 -7
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +4 -4
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +7 -7
- phoenix/server/api/mutations/prompt_mutations.py +7 -7
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +3 -3
- phoenix/server/api/mutations/span_annotations_mutations.py +5 -5
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +8 -5
- phoenix/server/api/routers/auth.py +23 -32
- phoenix/server/api/routers/oauth2.py +213 -24
- phoenix/server/api/routers/v1/__init__.py +18 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +21 -22
- phoenix/server/api/routers/v1/datasets.py +86 -64
- phoenix/server/api/routers/v1/documents.py +2 -3
- phoenix/server/api/routers/v1/evaluations.py +12 -24
- phoenix/server/api/routers/v1/experiment_evaluations.py +2 -3
- phoenix/server/api/routers/v1/experiment_runs.py +16 -11
- phoenix/server/api/routers/v1/experiments.py +57 -22
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +30 -31
- phoenix/server/api/routers/v1/sessions.py +2 -5
- phoenix/server/api/routers/v1/spans.py +35 -26
- phoenix/server/api/routers/v1/traces.py +11 -19
- phoenix/server/api/routers/v1/users.py +13 -29
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +3 -3
- phoenix/server/api/types/Dataset.py +95 -6
- phoenix/server/api/types/Project.py +24 -68
- phoenix/server/app.py +3 -2
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/oauth2.py +172 -5
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-Bs8eJEpU.js → components-cwdYEs7B.js} +501 -404
- phoenix/server/static/assets/{index-C6WEu5UP.js → index-Dc0vD1Rn.js} +1 -1
- phoenix/server/static/assets/{pages-D-n2pkoG.js → pages-BDkB3a_a.js} +577 -533
- phoenix/server/static/assets/{vendor-D2eEI-6h.js → vendor-Ce6GTAin.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-kfOei7nf.js → vendor-arizeai-CSF-1Kc5.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-1bq_t1Ec.js → vendor-codemirror-Bv8J_7an.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-DQ4xfrf4.js → vendor-recharts-DcLgzI7g.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-GGmcIQxA.js → vendor-shiki-BF8rh_7m.js} +1 -1
- phoenix/trace/attributes.py +80 -13
- phoenix/version.py +1 -1
- {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -27,7 +27,7 @@ from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
|
27
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
28
|
from phoenix.db import models
|
|
29
29
|
from phoenix.db.helpers import insert_experiment_with_examples_snapshot
|
|
30
|
-
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
30
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
31
31
|
from phoenix.server.api.context import Context
|
|
32
32
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
33
33
|
from phoenix.server.api.helpers.playground_clients import (
|
|
@@ -94,7 +94,7 @@ ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
|
|
|
94
94
|
|
|
95
95
|
@strawberry.type
|
|
96
96
|
class Subscription:
|
|
97
|
-
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
97
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
98
98
|
async def chat_completion(
|
|
99
99
|
self, info: Info[Context, None], input: ChatCompletionInput
|
|
100
100
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
@@ -193,7 +193,7 @@ class Subscription:
|
|
|
193
193
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
194
194
|
yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
|
|
195
195
|
|
|
196
|
-
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
196
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
197
197
|
async def chat_completion_over_dataset(
|
|
198
198
|
self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
|
|
199
199
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
@@ -3,7 +3,7 @@ from datetime import datetime
|
|
|
3
3
|
from typing import ClassVar, Optional, cast
|
|
4
4
|
|
|
5
5
|
import strawberry
|
|
6
|
-
from sqlalchemy import and_, func, or_, select
|
|
6
|
+
from sqlalchemy import Text, and_, func, or_, select
|
|
7
7
|
from sqlalchemy.sql.functions import count
|
|
8
8
|
from strawberry import UNSET
|
|
9
9
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
@@ -19,6 +19,7 @@ from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
|
|
|
19
19
|
DatasetExperimentAnnotationSummary,
|
|
20
20
|
)
|
|
21
21
|
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
22
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
22
23
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
23
24
|
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
24
25
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
@@ -87,6 +88,7 @@ class Dataset(Node):
|
|
|
87
88
|
self,
|
|
88
89
|
info: Info[Context, None],
|
|
89
90
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
91
|
+
split_ids: Optional[list[GlobalID]] = UNSET,
|
|
90
92
|
) -> int:
|
|
91
93
|
dataset_id = self.id_attr
|
|
92
94
|
version_id = (
|
|
@@ -97,6 +99,20 @@ class Dataset(Node):
|
|
|
97
99
|
if dataset_version_id
|
|
98
100
|
else None
|
|
99
101
|
)
|
|
102
|
+
|
|
103
|
+
# Parse split IDs if provided
|
|
104
|
+
split_rowids: Optional[list[int]] = None
|
|
105
|
+
if split_ids:
|
|
106
|
+
split_rowids = []
|
|
107
|
+
for split_id in split_ids:
|
|
108
|
+
try:
|
|
109
|
+
split_rowid = from_global_id_with_expected_type(
|
|
110
|
+
global_id=split_id, expected_type_name=models.DatasetSplit.__name__
|
|
111
|
+
)
|
|
112
|
+
split_rowids.append(split_rowid)
|
|
113
|
+
except Exception:
|
|
114
|
+
raise BadRequest(f"Invalid split ID: {split_id}")
|
|
115
|
+
|
|
100
116
|
revision_ids = (
|
|
101
117
|
select(func.max(models.DatasetExampleRevision.id))
|
|
102
118
|
.join(models.DatasetExample)
|
|
@@ -113,11 +129,36 @@ class Dataset(Node):
|
|
|
113
129
|
revision_ids = revision_ids.where(
|
|
114
130
|
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
115
131
|
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
132
|
+
|
|
133
|
+
# Build the count query
|
|
134
|
+
if split_rowids:
|
|
135
|
+
# When filtering by splits, count distinct examples that belong to those splits
|
|
136
|
+
stmt = (
|
|
137
|
+
select(count(models.DatasetExample.id.distinct()))
|
|
138
|
+
.join(
|
|
139
|
+
models.DatasetExampleRevision,
|
|
140
|
+
onclause=(
|
|
141
|
+
models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id
|
|
142
|
+
),
|
|
143
|
+
)
|
|
144
|
+
.join(
|
|
145
|
+
models.DatasetSplitDatasetExample,
|
|
146
|
+
onclause=(
|
|
147
|
+
models.DatasetExample.id
|
|
148
|
+
== models.DatasetSplitDatasetExample.dataset_example_id
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
.where(models.DatasetExampleRevision.id.in_(revision_ids))
|
|
152
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
153
|
+
.where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
stmt = (
|
|
157
|
+
select(count(models.DatasetExampleRevision.id))
|
|
158
|
+
.where(models.DatasetExampleRevision.id.in_(revision_ids))
|
|
159
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
160
|
+
)
|
|
161
|
+
|
|
121
162
|
async with info.context.db() as session:
|
|
122
163
|
return (await session.scalar(stmt)) or 0
|
|
123
164
|
|
|
@@ -126,10 +167,12 @@ class Dataset(Node):
|
|
|
126
167
|
self,
|
|
127
168
|
info: Info[Context, None],
|
|
128
169
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
170
|
+
split_ids: Optional[list[GlobalID]] = UNSET,
|
|
129
171
|
first: Optional[int] = 50,
|
|
130
172
|
last: Optional[int] = UNSET,
|
|
131
173
|
after: Optional[CursorString] = UNSET,
|
|
132
174
|
before: Optional[CursorString] = UNSET,
|
|
175
|
+
filter: Optional[str] = UNSET,
|
|
133
176
|
) -> Connection[DatasetExample]:
|
|
134
177
|
args = ConnectionArgs(
|
|
135
178
|
first=first,
|
|
@@ -145,6 +188,20 @@ class Dataset(Node):
|
|
|
145
188
|
if dataset_version_id
|
|
146
189
|
else None
|
|
147
190
|
)
|
|
191
|
+
|
|
192
|
+
# Parse split IDs if provided
|
|
193
|
+
split_rowids: Optional[list[int]] = None
|
|
194
|
+
if split_ids:
|
|
195
|
+
split_rowids = []
|
|
196
|
+
for split_id in split_ids:
|
|
197
|
+
try:
|
|
198
|
+
split_rowid = from_global_id_with_expected_type(
|
|
199
|
+
global_id=split_id, expected_type_name=models.DatasetSplit.__name__
|
|
200
|
+
)
|
|
201
|
+
split_rowids.append(split_rowid)
|
|
202
|
+
except Exception:
|
|
203
|
+
raise BadRequest(f"Invalid split ID: {split_id}")
|
|
204
|
+
|
|
148
205
|
revision_ids = (
|
|
149
206
|
select(func.max(models.DatasetExampleRevision.id))
|
|
150
207
|
.join(models.DatasetExample)
|
|
@@ -176,6 +233,31 @@ class Dataset(Node):
|
|
|
176
233
|
)
|
|
177
234
|
.order_by(models.DatasetExampleRevision.dataset_example_id.desc())
|
|
178
235
|
)
|
|
236
|
+
|
|
237
|
+
# Filter by split IDs if provided
|
|
238
|
+
if split_rowids:
|
|
239
|
+
query = (
|
|
240
|
+
query.join(
|
|
241
|
+
models.DatasetSplitDatasetExample,
|
|
242
|
+
onclause=(
|
|
243
|
+
models.DatasetExample.id
|
|
244
|
+
== models.DatasetSplitDatasetExample.dataset_example_id
|
|
245
|
+
),
|
|
246
|
+
)
|
|
247
|
+
.where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
|
|
248
|
+
.distinct()
|
|
249
|
+
)
|
|
250
|
+
# Apply filter if provided - search through JSON fields (input, output, metadata)
|
|
251
|
+
if filter is not UNSET and filter:
|
|
252
|
+
# Create a filter that searches for the filter string in JSON fields
|
|
253
|
+
# Using PostgreSQL's JSON operators for case-insensitive text search
|
|
254
|
+
filter_condition = or_(
|
|
255
|
+
func.cast(models.DatasetExampleRevision.input, Text).ilike(f"%{filter}%"),
|
|
256
|
+
func.cast(models.DatasetExampleRevision.output, Text).ilike(f"%{filter}%"),
|
|
257
|
+
func.cast(models.DatasetExampleRevision.metadata_, Text).ilike(f"%{filter}%"),
|
|
258
|
+
)
|
|
259
|
+
query = query.where(filter_condition)
|
|
260
|
+
|
|
179
261
|
async with info.context.db() as session:
|
|
180
262
|
dataset_examples = [
|
|
181
263
|
DatasetExample(
|
|
@@ -187,6 +269,13 @@ class Dataset(Node):
|
|
|
187
269
|
]
|
|
188
270
|
return connection_from_list(data=dataset_examples, args=args)
|
|
189
271
|
|
|
272
|
+
@strawberry.field
|
|
273
|
+
async def splits(self, info: Info[Context, None]) -> list[DatasetSplit]:
|
|
274
|
+
return [
|
|
275
|
+
to_gql_dataset_split(split)
|
|
276
|
+
for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id_attr)
|
|
277
|
+
]
|
|
278
|
+
|
|
190
279
|
@strawberry.field(
|
|
191
280
|
description="Number of experiments for a specific version if version is specified, "
|
|
192
281
|
"or for all versions if version is not specified."
|
|
@@ -7,7 +7,6 @@ from aioitertools.itertools import groupby, islice
|
|
|
7
7
|
from openinference.semconv.trace import SpanAttributes
|
|
8
8
|
from sqlalchemy import and_, case, desc, distinct, exists, func, or_, select
|
|
9
9
|
from sqlalchemy.dialects import postgresql, sqlite
|
|
10
|
-
from sqlalchemy.sql.elements import ColumnElement
|
|
11
10
|
from sqlalchemy.sql.expression import tuple_
|
|
12
11
|
from sqlalchemy.sql.functions import percentile_cont
|
|
13
12
|
from strawberry import ID, UNSET, Private, lazy
|
|
@@ -21,8 +20,8 @@ from phoenix.db.helpers import SupportedSQLDialect, date_trunc
|
|
|
21
20
|
from phoenix.server.api.context import Context
|
|
22
21
|
from phoenix.server.api.exceptions import BadRequest
|
|
23
22
|
from phoenix.server.api.input_types.ProjectSessionSort import (
|
|
24
|
-
ProjectSessionColumn,
|
|
25
23
|
ProjectSessionSort,
|
|
24
|
+
ProjectSessionSortConfig,
|
|
26
25
|
)
|
|
27
26
|
from phoenix.server.api.input_types.SpanSort import SpanColumn, SpanSort, SpanSortConfig
|
|
28
27
|
from phoenix.server.api.input_types.TimeBinConfig import TimeBinConfig, TimeBinScale
|
|
@@ -459,74 +458,31 @@ class Project(Node):
|
|
|
459
458
|
end_time=time_range.end if time_range else None,
|
|
460
459
|
)
|
|
461
460
|
stmt = stmt.where(table.id.in_(filtered_session_rowids))
|
|
461
|
+
sort_config: Optional[ProjectSessionSortConfig] = None
|
|
462
|
+
cursor_rowid_column: Any = table.id
|
|
462
463
|
if sort:
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
if
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
models.Trace.project_session_rowid.label("id"),
|
|
476
|
-
func.sum(models.Span.cumulative_llm_token_count_total).label("key"),
|
|
477
|
-
)
|
|
478
|
-
.join_from(models.Trace, models.Span)
|
|
479
|
-
.where(models.Span.parent_id.is_(None))
|
|
480
|
-
.group_by(models.Trace.project_session_rowid)
|
|
481
|
-
).subquery()
|
|
482
|
-
elif sort.col is ProjectSessionColumn.numTraces:
|
|
483
|
-
sort_subq = (
|
|
484
|
-
select(
|
|
485
|
-
models.Trace.project_session_rowid.label("id"),
|
|
486
|
-
func.count(models.Trace.id).label("key"),
|
|
487
|
-
).group_by(models.Trace.project_session_rowid)
|
|
488
|
-
).subquery()
|
|
464
|
+
sort_config = sort.update_orm_expr(stmt)
|
|
465
|
+
stmt = sort_config.stmt
|
|
466
|
+
if sort_config.dir is SortDir.desc:
|
|
467
|
+
cursor_rowid_column = desc(cursor_rowid_column)
|
|
468
|
+
if after:
|
|
469
|
+
cursor = Cursor.from_string(after)
|
|
470
|
+
if sort_config and cursor.sort_column:
|
|
471
|
+
sort_column = cursor.sort_column
|
|
472
|
+
compare = operator.lt if sort_config.dir is SortDir.desc else operator.gt
|
|
473
|
+
if sort_column.type is CursorSortColumnDataType.NULL:
|
|
474
|
+
stmt = stmt.where(sort_config.orm_expression.is_(None))
|
|
475
|
+
stmt = stmt.where(compare(table.id, cursor.rowid))
|
|
489
476
|
else:
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
select(
|
|
496
|
-
models.Trace.project_session_rowid.label("id"),
|
|
497
|
-
func.sum(models.SpanCost.total_cost).label("key"),
|
|
498
|
-
)
|
|
499
|
-
.join_from(
|
|
500
|
-
models.Trace,
|
|
501
|
-
models.SpanCost,
|
|
502
|
-
models.Trace.id == models.SpanCost.trace_rowid,
|
|
477
|
+
stmt = stmt.where(
|
|
478
|
+
compare(
|
|
479
|
+
tuple_(sort_config.orm_expression, table.id),
|
|
480
|
+
(sort_column.value, cursor.rowid),
|
|
481
|
+
)
|
|
503
482
|
)
|
|
504
|
-
.group_by(models.Trace.project_session_rowid)
|
|
505
|
-
).subquery()
|
|
506
|
-
key = sort_subq.c.key
|
|
507
|
-
stmt = stmt.join(sort_subq, table.id == sort_subq.c.id)
|
|
508
|
-
else:
|
|
509
|
-
assert_never(sort.col)
|
|
510
|
-
stmt = stmt.add_columns(key)
|
|
511
|
-
if sort.dir is SortDir.asc:
|
|
512
|
-
stmt = stmt.order_by(key.asc(), table.id.asc())
|
|
513
483
|
else:
|
|
514
|
-
stmt = stmt.order_by(key.desc(), table.id.desc())
|
|
515
|
-
if after:
|
|
516
|
-
cursor = Cursor.from_string(after)
|
|
517
|
-
assert cursor.sort_column is not None
|
|
518
|
-
compare = operator.lt if sort.dir is SortDir.desc else operator.gt
|
|
519
|
-
stmt = stmt.where(
|
|
520
|
-
compare(
|
|
521
|
-
tuple_(key, table.id),
|
|
522
|
-
(cursor.sort_column.value, cursor.rowid),
|
|
523
|
-
)
|
|
524
|
-
)
|
|
525
|
-
else:
|
|
526
|
-
stmt = stmt.order_by(table.id.desc())
|
|
527
|
-
if after:
|
|
528
|
-
cursor = Cursor.from_string(after)
|
|
529
484
|
stmt = stmt.where(table.id < cursor.rowid)
|
|
485
|
+
stmt = stmt.order_by(cursor_rowid_column)
|
|
530
486
|
if first:
|
|
531
487
|
stmt = stmt.limit(
|
|
532
488
|
first + 1 # over-fetch by one to determine whether there's a next page
|
|
@@ -537,10 +493,10 @@ class Project(Node):
|
|
|
537
493
|
async for record in islice(records, first):
|
|
538
494
|
project_session = record[0]
|
|
539
495
|
cursor = Cursor(rowid=project_session.id)
|
|
540
|
-
if
|
|
496
|
+
if sort_config:
|
|
541
497
|
assert len(record) > 1
|
|
542
498
|
cursor.sort_column = CursorSortColumn(
|
|
543
|
-
type=
|
|
499
|
+
type=sort_config.column_data_type,
|
|
544
500
|
value=record[1],
|
|
545
501
|
)
|
|
546
502
|
cursors_and_nodes.append((cursor, to_gql_project_session(project_session)))
|
|
@@ -724,7 +680,7 @@ class Project(Node):
|
|
|
724
680
|
stmt = span_filter(select(models.Span))
|
|
725
681
|
dialect = info.context.db.dialect
|
|
726
682
|
if dialect is SupportedSQLDialect.POSTGRESQL:
|
|
727
|
-
str(stmt.compile(dialect=sqlite.dialect()))
|
|
683
|
+
str(stmt.compile(dialect=sqlite.dialect()))
|
|
728
684
|
elif dialect is SupportedSQLDialect.SQLITE:
|
|
729
685
|
str(stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call]
|
|
730
686
|
else:
|
phoenix/server/app.py
CHANGED
|
@@ -45,7 +45,6 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoin
|
|
|
45
45
|
from starlette.requests import Request
|
|
46
46
|
from starlette.responses import JSONResponse, PlainTextResponse, RedirectResponse, Response
|
|
47
47
|
from starlette.staticfiles import StaticFiles
|
|
48
|
-
from starlette.status import HTTP_401_UNAUTHORIZED
|
|
49
48
|
from starlette.templating import Jinja2Templates
|
|
50
49
|
from starlette.types import Scope, StatefulLifespan
|
|
51
50
|
from strawberry.extensions import SchemaExtension
|
|
@@ -89,6 +88,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
89
88
|
AverageExperimentRepeatedRunGroupLatencyDataLoader,
|
|
90
89
|
AverageExperimentRunLatencyDataLoader,
|
|
91
90
|
CacheForDataLoaders,
|
|
91
|
+
DatasetDatasetSplitsDataLoader,
|
|
92
92
|
DatasetExampleRevisionsDataLoader,
|
|
93
93
|
DatasetExamplesAndVersionsByExperimentRunDataLoader,
|
|
94
94
|
DatasetExampleSpansDataLoader,
|
|
@@ -352,7 +352,7 @@ class RequestOriginHostnameValidator(BaseHTTPMiddleware):
|
|
|
352
352
|
if not (url := headers.get(key)):
|
|
353
353
|
continue
|
|
354
354
|
if urlparse(url).hostname not in self._trusted_hostnames:
|
|
355
|
-
return Response(f"untrusted {key}", status_code=
|
|
355
|
+
return Response(f"untrusted {key}", status_code=401)
|
|
356
356
|
return await call_next(request)
|
|
357
357
|
|
|
358
358
|
|
|
@@ -710,6 +710,7 @@ def create_graphql_router(
|
|
|
710
710
|
db
|
|
711
711
|
),
|
|
712
712
|
average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(db),
|
|
713
|
+
dataset_dataset_splits=DatasetDatasetSplitsDataLoader(db),
|
|
713
714
|
dataset_example_revisions=DatasetExampleRevisionsDataLoader(db),
|
|
714
715
|
dataset_example_spans=DatasetExampleSpansDataLoader(db),
|
|
715
716
|
dataset_examples_and_versions_by_experiment_run=DatasetExamplesAndVersionsByExperimentRunDataLoader(
|
phoenix/server/authorization.py
CHANGED
|
@@ -23,7 +23,6 @@ Usage:
|
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
25
|
from fastapi import HTTPException, Request
|
|
26
|
-
from fastapi import status as fastapi_status
|
|
27
26
|
|
|
28
27
|
from phoenix.config import get_env_support_email
|
|
29
28
|
from phoenix.server.bearer_auth import PhoenixUser
|
|
@@ -43,13 +42,15 @@ def require_admin(request: Request) -> None:
|
|
|
43
42
|
Behavior:
|
|
44
43
|
- Allows access if the authenticated user is an admin or a system user.
|
|
45
44
|
- Raises HTTP 403 Forbidden if the user is not authorized.
|
|
46
|
-
-
|
|
45
|
+
- Allows access if authentication is not enabled.
|
|
47
46
|
"""
|
|
47
|
+
if not request.app.state.authentication_enabled:
|
|
48
|
+
return
|
|
48
49
|
user = getattr(request, "user", None)
|
|
49
50
|
# System users have all privileges
|
|
50
51
|
if not (isinstance(user, PhoenixUser) and user.is_admin):
|
|
51
52
|
raise HTTPException(
|
|
52
|
-
status_code=
|
|
53
|
+
status_code=403,
|
|
53
54
|
detail="Only admin or system users can perform this action.",
|
|
54
55
|
)
|
|
55
56
|
|
|
@@ -82,6 +83,6 @@ def is_not_locked(request: Request) -> None:
|
|
|
82
83
|
if support_email := get_env_support_email():
|
|
83
84
|
detail += f" Need help? Contact us at {support_email}"
|
|
84
85
|
raise HTTPException(
|
|
85
|
-
status_code=
|
|
86
|
+
status_code=507,
|
|
86
87
|
detail=detail,
|
|
87
88
|
)
|
phoenix/server/bearer_auth.py
CHANGED
|
@@ -9,7 +9,6 @@ from fastapi import HTTPException, Request, WebSocket, WebSocketException
|
|
|
9
9
|
from grpc_interceptor import AsyncServerInterceptor
|
|
10
10
|
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
|
|
11
11
|
from starlette.requests import HTTPConnection
|
|
12
|
-
from starlette.status import HTTP_401_UNAUTHORIZED
|
|
13
12
|
from typing_extensions import override
|
|
14
13
|
|
|
15
14
|
from phoenix import config
|
|
@@ -76,11 +75,18 @@ class PhoenixUser(BaseUser):
|
|
|
76
75
|
self._is_admin = (
|
|
77
76
|
claims.status is ClaimSetStatus.VALID and claims.attributes.user_role == "ADMIN"
|
|
78
77
|
)
|
|
78
|
+
self._is_viewer = (
|
|
79
|
+
claims.status is ClaimSetStatus.VALID and claims.attributes.user_role == "VIEWER"
|
|
80
|
+
)
|
|
79
81
|
|
|
80
82
|
@cached_property
|
|
81
83
|
def is_admin(self) -> bool:
|
|
82
84
|
return self._is_admin
|
|
83
85
|
|
|
86
|
+
@cached_property
|
|
87
|
+
def is_viewer(self) -> bool:
|
|
88
|
+
return self._is_viewer
|
|
89
|
+
|
|
84
90
|
@cached_property
|
|
85
91
|
def identity(self) -> UserId:
|
|
86
92
|
return self._user_id
|
|
@@ -93,6 +99,8 @@ class PhoenixUser(BaseUser):
|
|
|
93
99
|
class PhoenixSystemUser(PhoenixUser):
|
|
94
100
|
def __init__(self, user_id: UserId) -> None:
|
|
95
101
|
self._user_id = user_id
|
|
102
|
+
self._is_admin = True # System users have admin privileges
|
|
103
|
+
self._is_viewer = False # System users are not viewers
|
|
96
104
|
|
|
97
105
|
@property
|
|
98
106
|
def is_admin(self) -> bool:
|
|
@@ -144,16 +152,16 @@ async def is_authenticated(
|
|
|
144
152
|
"""
|
|
145
153
|
assert request or websocket
|
|
146
154
|
if request and not isinstance((user := request.user), PhoenixUser):
|
|
147
|
-
raise HTTPException(status_code=
|
|
155
|
+
raise HTTPException(status_code=401, detail="Invalid token")
|
|
148
156
|
if websocket and not isinstance((user := websocket.user), PhoenixUser):
|
|
149
|
-
raise WebSocketException(code=
|
|
157
|
+
raise WebSocketException(code=401, reason="Invalid token")
|
|
150
158
|
if isinstance(user, PhoenixSystemUser):
|
|
151
159
|
return
|
|
152
160
|
claims = user.claims
|
|
153
161
|
if claims.status is ClaimSetStatus.EXPIRED:
|
|
154
|
-
raise HTTPException(status_code=
|
|
162
|
+
raise HTTPException(status_code=401, detail="Expired token")
|
|
155
163
|
if claims.status is not ClaimSetStatus.VALID:
|
|
156
|
-
raise HTTPException(status_code=
|
|
164
|
+
raise HTTPException(status_code=401, detail="Invalid token")
|
|
157
165
|
|
|
158
166
|
|
|
159
167
|
async def create_access_and_refresh_tokens(
|
phoenix/server/jwt_store.py
CHANGED
|
@@ -164,7 +164,7 @@ class JwtStore:
|
|
|
164
164
|
for token_id in token_ids:
|
|
165
165
|
if isinstance(token_id, PasswordResetTokenId):
|
|
166
166
|
password_reset_token_ids.append(token_id)
|
|
167
|
-
|
|
167
|
+
elif isinstance(token_id, AccessTokenId):
|
|
168
168
|
access_token_ids.append(token_id)
|
|
169
169
|
elif isinstance(token_id, RefreshTokenId):
|
|
170
170
|
refresh_token_ids.append(token_id)
|
|
@@ -182,10 +182,10 @@ class JwtStore:
|
|
|
182
182
|
await gather(*coroutines)
|
|
183
183
|
|
|
184
184
|
async def log_out(self, user_id: UserId) -> None:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
185
|
+
async with self._db() as session:
|
|
186
|
+
for cls in (AccessTokenId, RefreshTokenId):
|
|
187
|
+
table = cls.table
|
|
188
|
+
stmt = delete(table).where(table.user_id == int(user_id)).returning(table.id)
|
|
189
189
|
async for id_ in await session.stream_scalars(stmt):
|
|
190
190
|
await self._evict(cls(id_))
|
|
191
191
|
|
|
@@ -314,7 +314,9 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
|
|
|
314
314
|
|
|
315
315
|
async def _delete_expired_tokens(self, session: Any) -> None:
|
|
316
316
|
now = datetime.now(timezone.utc)
|
|
317
|
-
|
|
317
|
+
# Per JWT RFC 7519 Section 4.1.4, tokens expire "on or after" the expiration time.
|
|
318
|
+
# Use <= to include tokens expiring at exactly this moment.
|
|
319
|
+
await session.execute(delete(self._table).where(self._table.expires_at <= now))
|
|
318
320
|
|
|
319
321
|
async def _run(self) -> None:
|
|
320
322
|
while self._running:
|