arize-phoenix 12.3.0__py3-none-any.whl → 12.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (73) hide show
  1. {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/METADATA +2 -1
  2. {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/RECORD +73 -72
  3. phoenix/auth.py +27 -2
  4. phoenix/config.py +302 -53
  5. phoenix/db/README.md +546 -28
  6. phoenix/db/models.py +3 -3
  7. phoenix/server/api/auth.py +9 -0
  8. phoenix/server/api/context.py +2 -0
  9. phoenix/server/api/dataloaders/__init__.py +2 -0
  10. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  11. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  12. phoenix/server/api/input_types/SpanSort.py +2 -1
  13. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  14. phoenix/server/api/mutations/annotation_config_mutations.py +6 -6
  15. phoenix/server/api/mutations/api_key_mutations.py +13 -5
  16. phoenix/server/api/mutations/chat_mutations.py +3 -3
  17. phoenix/server/api/mutations/dataset_label_mutations.py +6 -6
  18. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  19. phoenix/server/api/mutations/dataset_split_mutations.py +7 -7
  20. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  21. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  22. phoenix/server/api/mutations/model_mutations.py +4 -4
  23. phoenix/server/api/mutations/project_mutations.py +4 -4
  24. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -4
  25. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  26. phoenix/server/api/mutations/prompt_label_mutations.py +7 -7
  27. phoenix/server/api/mutations/prompt_mutations.py +7 -7
  28. phoenix/server/api/mutations/prompt_version_tag_mutations.py +3 -3
  29. phoenix/server/api/mutations/span_annotations_mutations.py +5 -5
  30. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  31. phoenix/server/api/mutations/trace_mutations.py +3 -3
  32. phoenix/server/api/mutations/user_mutations.py +8 -5
  33. phoenix/server/api/routers/auth.py +23 -32
  34. phoenix/server/api/routers/oauth2.py +213 -24
  35. phoenix/server/api/routers/v1/__init__.py +18 -4
  36. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  37. phoenix/server/api/routers/v1/annotations.py +21 -22
  38. phoenix/server/api/routers/v1/datasets.py +86 -64
  39. phoenix/server/api/routers/v1/documents.py +2 -3
  40. phoenix/server/api/routers/v1/evaluations.py +12 -24
  41. phoenix/server/api/routers/v1/experiment_evaluations.py +2 -3
  42. phoenix/server/api/routers/v1/experiment_runs.py +16 -11
  43. phoenix/server/api/routers/v1/experiments.py +57 -22
  44. phoenix/server/api/routers/v1/projects.py +16 -50
  45. phoenix/server/api/routers/v1/prompts.py +30 -31
  46. phoenix/server/api/routers/v1/sessions.py +2 -5
  47. phoenix/server/api/routers/v1/spans.py +35 -26
  48. phoenix/server/api/routers/v1/traces.py +11 -19
  49. phoenix/server/api/routers/v1/users.py +13 -29
  50. phoenix/server/api/routers/v1/utils.py +3 -7
  51. phoenix/server/api/subscriptions.py +3 -3
  52. phoenix/server/api/types/Dataset.py +95 -6
  53. phoenix/server/api/types/Project.py +24 -68
  54. phoenix/server/app.py +3 -2
  55. phoenix/server/authorization.py +5 -4
  56. phoenix/server/bearer_auth.py +13 -5
  57. phoenix/server/jwt_store.py +8 -6
  58. phoenix/server/oauth2.py +172 -5
  59. phoenix/server/static/.vite/manifest.json +39 -39
  60. phoenix/server/static/assets/{components-Bs8eJEpU.js → components-cwdYEs7B.js} +501 -404
  61. phoenix/server/static/assets/{index-C6WEu5UP.js → index-Dc0vD1Rn.js} +1 -1
  62. phoenix/server/static/assets/{pages-D-n2pkoG.js → pages-BDkB3a_a.js} +577 -533
  63. phoenix/server/static/assets/{vendor-D2eEI-6h.js → vendor-Ce6GTAin.js} +1 -1
  64. phoenix/server/static/assets/{vendor-arizeai-kfOei7nf.js → vendor-arizeai-CSF-1Kc5.js} +1 -1
  65. phoenix/server/static/assets/{vendor-codemirror-1bq_t1Ec.js → vendor-codemirror-Bv8J_7an.js} +3 -3
  66. phoenix/server/static/assets/{vendor-recharts-DQ4xfrf4.js → vendor-recharts-DcLgzI7g.js} +1 -1
  67. phoenix/server/static/assets/{vendor-shiki-GGmcIQxA.js → vendor-shiki-BF8rh_7m.js} +1 -1
  68. phoenix/trace/attributes.py +80 -13
  69. phoenix/version.py +1 -1
  70. {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/WHEEL +0 -0
  71. {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/entry_points.txt +0 -0
  72. {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/licenses/IP_NOTICE +0 -0
  73. {arize_phoenix-12.3.0.dist-info → arize_phoenix-12.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -27,7 +27,7 @@ from phoenix.config import PLAYGROUND_PROJECT_NAME
27
27
  from phoenix.datetime_utils import local_now, normalize_datetime
28
28
  from phoenix.db import models
29
29
  from phoenix.db.helpers import insert_experiment_with_examples_snapshot
30
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
30
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
31
31
  from phoenix.server.api.context import Context
32
32
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
33
33
  from phoenix.server.api.helpers.playground_clients import (
@@ -94,7 +94,7 @@ ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
94
94
 
95
95
  @strawberry.type
96
96
  class Subscription:
97
- @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
97
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
98
98
  async def chat_completion(
99
99
  self, info: Info[Context, None], input: ChatCompletionInput
100
100
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
@@ -193,7 +193,7 @@ class Subscription:
193
193
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
194
194
  yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
195
195
 
196
- @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
196
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
197
197
  async def chat_completion_over_dataset(
198
198
  self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
199
199
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
@@ -3,7 +3,7 @@ from datetime import datetime
3
3
  from typing import ClassVar, Optional, cast
4
4
 
5
5
  import strawberry
6
- from sqlalchemy import and_, func, or_, select
6
+ from sqlalchemy import Text, and_, func, or_, select
7
7
  from sqlalchemy.sql.functions import count
8
8
  from strawberry import UNSET
9
9
  from strawberry.relay import Connection, GlobalID, Node, NodeID
@@ -19,6 +19,7 @@ from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
19
19
  DatasetExperimentAnnotationSummary,
20
20
  )
21
21
  from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
22
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
22
23
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
23
24
  from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
24
25
  from phoenix.server.api.types.node import from_global_id_with_expected_type
@@ -87,6 +88,7 @@ class Dataset(Node):
87
88
  self,
88
89
  info: Info[Context, None],
89
90
  dataset_version_id: Optional[GlobalID] = UNSET,
91
+ split_ids: Optional[list[GlobalID]] = UNSET,
90
92
  ) -> int:
91
93
  dataset_id = self.id_attr
92
94
  version_id = (
@@ -97,6 +99,20 @@ class Dataset(Node):
97
99
  if dataset_version_id
98
100
  else None
99
101
  )
102
+
103
+ # Parse split IDs if provided
104
+ split_rowids: Optional[list[int]] = None
105
+ if split_ids:
106
+ split_rowids = []
107
+ for split_id in split_ids:
108
+ try:
109
+ split_rowid = from_global_id_with_expected_type(
110
+ global_id=split_id, expected_type_name=models.DatasetSplit.__name__
111
+ )
112
+ split_rowids.append(split_rowid)
113
+ except Exception:
114
+ raise BadRequest(f"Invalid split ID: {split_id}")
115
+
100
116
  revision_ids = (
101
117
  select(func.max(models.DatasetExampleRevision.id))
102
118
  .join(models.DatasetExample)
@@ -113,11 +129,36 @@ class Dataset(Node):
113
129
  revision_ids = revision_ids.where(
114
130
  models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
115
131
  )
116
- stmt = (
117
- select(count(models.DatasetExampleRevision.id))
118
- .where(models.DatasetExampleRevision.id.in_(revision_ids))
119
- .where(models.DatasetExampleRevision.revision_kind != "DELETE")
120
- )
132
+
133
+ # Build the count query
134
+ if split_rowids:
135
+ # When filtering by splits, count distinct examples that belong to those splits
136
+ stmt = (
137
+ select(count(models.DatasetExample.id.distinct()))
138
+ .join(
139
+ models.DatasetExampleRevision,
140
+ onclause=(
141
+ models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id
142
+ ),
143
+ )
144
+ .join(
145
+ models.DatasetSplitDatasetExample,
146
+ onclause=(
147
+ models.DatasetExample.id
148
+ == models.DatasetSplitDatasetExample.dataset_example_id
149
+ ),
150
+ )
151
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
152
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
153
+ .where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
154
+ )
155
+ else:
156
+ stmt = (
157
+ select(count(models.DatasetExampleRevision.id))
158
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
159
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
160
+ )
161
+
121
162
  async with info.context.db() as session:
122
163
  return (await session.scalar(stmt)) or 0
123
164
 
@@ -126,10 +167,12 @@ class Dataset(Node):
126
167
  self,
127
168
  info: Info[Context, None],
128
169
  dataset_version_id: Optional[GlobalID] = UNSET,
170
+ split_ids: Optional[list[GlobalID]] = UNSET,
129
171
  first: Optional[int] = 50,
130
172
  last: Optional[int] = UNSET,
131
173
  after: Optional[CursorString] = UNSET,
132
174
  before: Optional[CursorString] = UNSET,
175
+ filter: Optional[str] = UNSET,
133
176
  ) -> Connection[DatasetExample]:
134
177
  args = ConnectionArgs(
135
178
  first=first,
@@ -145,6 +188,20 @@ class Dataset(Node):
145
188
  if dataset_version_id
146
189
  else None
147
190
  )
191
+
192
+ # Parse split IDs if provided
193
+ split_rowids: Optional[list[int]] = None
194
+ if split_ids:
195
+ split_rowids = []
196
+ for split_id in split_ids:
197
+ try:
198
+ split_rowid = from_global_id_with_expected_type(
199
+ global_id=split_id, expected_type_name=models.DatasetSplit.__name__
200
+ )
201
+ split_rowids.append(split_rowid)
202
+ except Exception:
203
+ raise BadRequest(f"Invalid split ID: {split_id}")
204
+
148
205
  revision_ids = (
149
206
  select(func.max(models.DatasetExampleRevision.id))
150
207
  .join(models.DatasetExample)
@@ -176,6 +233,31 @@ class Dataset(Node):
176
233
  )
177
234
  .order_by(models.DatasetExampleRevision.dataset_example_id.desc())
178
235
  )
236
+
237
+ # Filter by split IDs if provided
238
+ if split_rowids:
239
+ query = (
240
+ query.join(
241
+ models.DatasetSplitDatasetExample,
242
+ onclause=(
243
+ models.DatasetExample.id
244
+ == models.DatasetSplitDatasetExample.dataset_example_id
245
+ ),
246
+ )
247
+ .where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
248
+ .distinct()
249
+ )
250
+ # Apply filter if provided - search through JSON fields (input, output, metadata)
251
+ if filter is not UNSET and filter:
252
+ # Create a filter that searches for the filter string in JSON fields
253
+ # Using PostgreSQL's JSON operators for case-insensitive text search
254
+ filter_condition = or_(
255
+ func.cast(models.DatasetExampleRevision.input, Text).ilike(f"%{filter}%"),
256
+ func.cast(models.DatasetExampleRevision.output, Text).ilike(f"%{filter}%"),
257
+ func.cast(models.DatasetExampleRevision.metadata_, Text).ilike(f"%{filter}%"),
258
+ )
259
+ query = query.where(filter_condition)
260
+
179
261
  async with info.context.db() as session:
180
262
  dataset_examples = [
181
263
  DatasetExample(
@@ -187,6 +269,13 @@ class Dataset(Node):
187
269
  ]
188
270
  return connection_from_list(data=dataset_examples, args=args)
189
271
 
272
+ @strawberry.field
273
+ async def splits(self, info: Info[Context, None]) -> list[DatasetSplit]:
274
+ return [
275
+ to_gql_dataset_split(split)
276
+ for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id_attr)
277
+ ]
278
+
190
279
  @strawberry.field(
191
280
  description="Number of experiments for a specific version if version is specified, "
192
281
  "or for all versions if version is not specified."
@@ -7,7 +7,6 @@ from aioitertools.itertools import groupby, islice
7
7
  from openinference.semconv.trace import SpanAttributes
8
8
  from sqlalchemy import and_, case, desc, distinct, exists, func, or_, select
9
9
  from sqlalchemy.dialects import postgresql, sqlite
10
- from sqlalchemy.sql.elements import ColumnElement
11
10
  from sqlalchemy.sql.expression import tuple_
12
11
  from sqlalchemy.sql.functions import percentile_cont
13
12
  from strawberry import ID, UNSET, Private, lazy
@@ -21,8 +20,8 @@ from phoenix.db.helpers import SupportedSQLDialect, date_trunc
21
20
  from phoenix.server.api.context import Context
22
21
  from phoenix.server.api.exceptions import BadRequest
23
22
  from phoenix.server.api.input_types.ProjectSessionSort import (
24
- ProjectSessionColumn,
25
23
  ProjectSessionSort,
24
+ ProjectSessionSortConfig,
26
25
  )
27
26
  from phoenix.server.api.input_types.SpanSort import SpanColumn, SpanSort, SpanSortConfig
28
27
  from phoenix.server.api.input_types.TimeBinConfig import TimeBinConfig, TimeBinScale
@@ -459,74 +458,31 @@ class Project(Node):
459
458
  end_time=time_range.end if time_range else None,
460
459
  )
461
460
  stmt = stmt.where(table.id.in_(filtered_session_rowids))
461
+ sort_config: Optional[ProjectSessionSortConfig] = None
462
+ cursor_rowid_column: Any = table.id
462
463
  if sort:
463
- key: ColumnElement[Any]
464
- if sort.col is ProjectSessionColumn.startTime:
465
- key = table.start_time.label("key")
466
- elif sort.col is ProjectSessionColumn.endTime:
467
- key = table.end_time.label("key")
468
- elif (
469
- sort.col is ProjectSessionColumn.tokenCountTotal
470
- or sort.col is ProjectSessionColumn.numTraces
471
- ):
472
- if sort.col is ProjectSessionColumn.tokenCountTotal:
473
- sort_subq = (
474
- select(
475
- models.Trace.project_session_rowid.label("id"),
476
- func.sum(models.Span.cumulative_llm_token_count_total).label("key"),
477
- )
478
- .join_from(models.Trace, models.Span)
479
- .where(models.Span.parent_id.is_(None))
480
- .group_by(models.Trace.project_session_rowid)
481
- ).subquery()
482
- elif sort.col is ProjectSessionColumn.numTraces:
483
- sort_subq = (
484
- select(
485
- models.Trace.project_session_rowid.label("id"),
486
- func.count(models.Trace.id).label("key"),
487
- ).group_by(models.Trace.project_session_rowid)
488
- ).subquery()
464
+ sort_config = sort.update_orm_expr(stmt)
465
+ stmt = sort_config.stmt
466
+ if sort_config.dir is SortDir.desc:
467
+ cursor_rowid_column = desc(cursor_rowid_column)
468
+ if after:
469
+ cursor = Cursor.from_string(after)
470
+ if sort_config and cursor.sort_column:
471
+ sort_column = cursor.sort_column
472
+ compare = operator.lt if sort_config.dir is SortDir.desc else operator.gt
473
+ if sort_column.type is CursorSortColumnDataType.NULL:
474
+ stmt = stmt.where(sort_config.orm_expression.is_(None))
475
+ stmt = stmt.where(compare(table.id, cursor.rowid))
489
476
  else:
490
- assert_never(sort.col)
491
- key = sort_subq.c.key
492
- stmt = stmt.join(sort_subq, table.id == sort_subq.c.id)
493
- elif sort.col is ProjectSessionColumn.costTotal:
494
- sort_subq = (
495
- select(
496
- models.Trace.project_session_rowid.label("id"),
497
- func.sum(models.SpanCost.total_cost).label("key"),
498
- )
499
- .join_from(
500
- models.Trace,
501
- models.SpanCost,
502
- models.Trace.id == models.SpanCost.trace_rowid,
477
+ stmt = stmt.where(
478
+ compare(
479
+ tuple_(sort_config.orm_expression, table.id),
480
+ (sort_column.value, cursor.rowid),
481
+ )
503
482
  )
504
- .group_by(models.Trace.project_session_rowid)
505
- ).subquery()
506
- key = sort_subq.c.key
507
- stmt = stmt.join(sort_subq, table.id == sort_subq.c.id)
508
- else:
509
- assert_never(sort.col)
510
- stmt = stmt.add_columns(key)
511
- if sort.dir is SortDir.asc:
512
- stmt = stmt.order_by(key.asc(), table.id.asc())
513
483
  else:
514
- stmt = stmt.order_by(key.desc(), table.id.desc())
515
- if after:
516
- cursor = Cursor.from_string(after)
517
- assert cursor.sort_column is not None
518
- compare = operator.lt if sort.dir is SortDir.desc else operator.gt
519
- stmt = stmt.where(
520
- compare(
521
- tuple_(key, table.id),
522
- (cursor.sort_column.value, cursor.rowid),
523
- )
524
- )
525
- else:
526
- stmt = stmt.order_by(table.id.desc())
527
- if after:
528
- cursor = Cursor.from_string(after)
529
484
  stmt = stmt.where(table.id < cursor.rowid)
485
+ stmt = stmt.order_by(cursor_rowid_column)
530
486
  if first:
531
487
  stmt = stmt.limit(
532
488
  first + 1 # over-fetch by one to determine whether there's a next page
@@ -537,10 +493,10 @@ class Project(Node):
537
493
  async for record in islice(records, first):
538
494
  project_session = record[0]
539
495
  cursor = Cursor(rowid=project_session.id)
540
- if sort:
496
+ if sort_config:
541
497
  assert len(record) > 1
542
498
  cursor.sort_column = CursorSortColumn(
543
- type=sort.col.data_type,
499
+ type=sort_config.column_data_type,
544
500
  value=record[1],
545
501
  )
546
502
  cursors_and_nodes.append((cursor, to_gql_project_session(project_session)))
@@ -724,7 +680,7 @@ class Project(Node):
724
680
  stmt = span_filter(select(models.Span))
725
681
  dialect = info.context.db.dialect
726
682
  if dialect is SupportedSQLDialect.POSTGRESQL:
727
- str(stmt.compile(dialect=sqlite.dialect())) # type: ignore[no-untyped-call]
683
+ str(stmt.compile(dialect=sqlite.dialect()))
728
684
  elif dialect is SupportedSQLDialect.SQLITE:
729
685
  str(stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call]
730
686
  else:
phoenix/server/app.py CHANGED
@@ -45,7 +45,6 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoin
45
45
  from starlette.requests import Request
46
46
  from starlette.responses import JSONResponse, PlainTextResponse, RedirectResponse, Response
47
47
  from starlette.staticfiles import StaticFiles
48
- from starlette.status import HTTP_401_UNAUTHORIZED
49
48
  from starlette.templating import Jinja2Templates
50
49
  from starlette.types import Scope, StatefulLifespan
51
50
  from strawberry.extensions import SchemaExtension
@@ -89,6 +88,7 @@ from phoenix.server.api.dataloaders import (
89
88
  AverageExperimentRepeatedRunGroupLatencyDataLoader,
90
89
  AverageExperimentRunLatencyDataLoader,
91
90
  CacheForDataLoaders,
91
+ DatasetDatasetSplitsDataLoader,
92
92
  DatasetExampleRevisionsDataLoader,
93
93
  DatasetExamplesAndVersionsByExperimentRunDataLoader,
94
94
  DatasetExampleSpansDataLoader,
@@ -352,7 +352,7 @@ class RequestOriginHostnameValidator(BaseHTTPMiddleware):
352
352
  if not (url := headers.get(key)):
353
353
  continue
354
354
  if urlparse(url).hostname not in self._trusted_hostnames:
355
- return Response(f"untrusted {key}", status_code=HTTP_401_UNAUTHORIZED)
355
+ return Response(f"untrusted {key}", status_code=401)
356
356
  return await call_next(request)
357
357
 
358
358
 
@@ -710,6 +710,7 @@ def create_graphql_router(
710
710
  db
711
711
  ),
712
712
  average_experiment_run_latency=AverageExperimentRunLatencyDataLoader(db),
713
+ dataset_dataset_splits=DatasetDatasetSplitsDataLoader(db),
713
714
  dataset_example_revisions=DatasetExampleRevisionsDataLoader(db),
714
715
  dataset_example_spans=DatasetExampleSpansDataLoader(db),
715
716
  dataset_examples_and_versions_by_experiment_run=DatasetExamplesAndVersionsByExperimentRunDataLoader(
@@ -23,7 +23,6 @@ Usage:
23
23
  """
24
24
 
25
25
  from fastapi import HTTPException, Request
26
- from fastapi import status as fastapi_status
27
26
 
28
27
  from phoenix.config import get_env_support_email
29
28
  from phoenix.server.bearer_auth import PhoenixUser
@@ -43,13 +42,15 @@ def require_admin(request: Request) -> None:
43
42
  Behavior:
44
43
  - Allows access if the authenticated user is an admin or a system user.
45
44
  - Raises HTTP 403 Forbidden if the user is not authorized.
46
- - Expects authentication to be enabled and request.user to be set by the authentication.
45
+ - Allows access if authentication is not enabled.
47
46
  """
47
+ if not request.app.state.authentication_enabled:
48
+ return
48
49
  user = getattr(request, "user", None)
49
50
  # System users have all privileges
50
51
  if not (isinstance(user, PhoenixUser) and user.is_admin):
51
52
  raise HTTPException(
52
- status_code=fastapi_status.HTTP_403_FORBIDDEN,
53
+ status_code=403,
53
54
  detail="Only admin or system users can perform this action.",
54
55
  )
55
56
 
@@ -82,6 +83,6 @@ def is_not_locked(request: Request) -> None:
82
83
  if support_email := get_env_support_email():
83
84
  detail += f" Need help? Contact us at {support_email}"
84
85
  raise HTTPException(
85
- status_code=fastapi_status.HTTP_507_INSUFFICIENT_STORAGE,
86
+ status_code=507,
86
87
  detail=detail,
87
88
  )
@@ -9,7 +9,6 @@ from fastapi import HTTPException, Request, WebSocket, WebSocketException
9
9
  from grpc_interceptor import AsyncServerInterceptor
10
10
  from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
11
11
  from starlette.requests import HTTPConnection
12
- from starlette.status import HTTP_401_UNAUTHORIZED
13
12
  from typing_extensions import override
14
13
 
15
14
  from phoenix import config
@@ -76,11 +75,18 @@ class PhoenixUser(BaseUser):
76
75
  self._is_admin = (
77
76
  claims.status is ClaimSetStatus.VALID and claims.attributes.user_role == "ADMIN"
78
77
  )
78
+ self._is_viewer = (
79
+ claims.status is ClaimSetStatus.VALID and claims.attributes.user_role == "VIEWER"
80
+ )
79
81
 
80
82
  @cached_property
81
83
  def is_admin(self) -> bool:
82
84
  return self._is_admin
83
85
 
86
+ @cached_property
87
+ def is_viewer(self) -> bool:
88
+ return self._is_viewer
89
+
84
90
  @cached_property
85
91
  def identity(self) -> UserId:
86
92
  return self._user_id
@@ -93,6 +99,8 @@ class PhoenixUser(BaseUser):
93
99
  class PhoenixSystemUser(PhoenixUser):
94
100
  def __init__(self, user_id: UserId) -> None:
95
101
  self._user_id = user_id
102
+ self._is_admin = True # System users have admin privileges
103
+ self._is_viewer = False # System users are not viewers
96
104
 
97
105
  @property
98
106
  def is_admin(self) -> bool:
@@ -144,16 +152,16 @@ async def is_authenticated(
144
152
  """
145
153
  assert request or websocket
146
154
  if request and not isinstance((user := request.user), PhoenixUser):
147
- raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
155
+ raise HTTPException(status_code=401, detail="Invalid token")
148
156
  if websocket and not isinstance((user := websocket.user), PhoenixUser):
149
- raise WebSocketException(code=HTTP_401_UNAUTHORIZED, reason="Invalid token")
157
+ raise WebSocketException(code=401, reason="Invalid token")
150
158
  if isinstance(user, PhoenixSystemUser):
151
159
  return
152
160
  claims = user.claims
153
161
  if claims.status is ClaimSetStatus.EXPIRED:
154
- raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
162
+ raise HTTPException(status_code=401, detail="Expired token")
155
163
  if claims.status is not ClaimSetStatus.VALID:
156
- raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
164
+ raise HTTPException(status_code=401, detail="Invalid token")
157
165
 
158
166
 
159
167
  async def create_access_and_refresh_tokens(
@@ -164,7 +164,7 @@ class JwtStore:
164
164
  for token_id in token_ids:
165
165
  if isinstance(token_id, PasswordResetTokenId):
166
166
  password_reset_token_ids.append(token_id)
167
- if isinstance(token_id, AccessTokenId):
167
+ elif isinstance(token_id, AccessTokenId):
168
168
  access_token_ids.append(token_id)
169
169
  elif isinstance(token_id, RefreshTokenId):
170
170
  refresh_token_ids.append(token_id)
@@ -182,10 +182,10 @@ class JwtStore:
182
182
  await gather(*coroutines)
183
183
 
184
184
  async def log_out(self, user_id: UserId) -> None:
185
- for cls in (AccessTokenId, RefreshTokenId):
186
- table = cls.table
187
- stmt = delete(table).where(table.user_id == int(user_id)).returning(table.id)
188
- async with self._db() as session:
185
+ async with self._db() as session:
186
+ for cls in (AccessTokenId, RefreshTokenId):
187
+ table = cls.table
188
+ stmt = delete(table).where(table.user_id == int(user_id)).returning(table.id)
189
189
  async for id_ in await session.stream_scalars(stmt):
190
190
  await self._evict(cls(id_))
191
191
 
@@ -314,7 +314,9 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
314
314
 
315
315
  async def _delete_expired_tokens(self, session: Any) -> None:
316
316
  now = datetime.now(timezone.utc)
317
- await session.execute(delete(self._table).where(self._table.expires_at < now))
317
+ # Per JWT RFC 7519 Section 4.1.4, tokens expire "on or after" the expiration time.
318
+ # Use <= to include tokens expiring at exactly this moment.
319
+ await session.execute(delete(self._table).where(self._table.expires_at <= now))
318
320
 
319
321
  async def _run(self) -> None:
320
322
  while self._running: