arize-phoenix 12.0.0__py3-none-any.whl → 12.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (44) hide show
  1. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/METADATA +1 -1
  2. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/RECORD +43 -38
  3. phoenix/db/insertion/document_annotation.py +1 -1
  4. phoenix/db/insertion/session_annotation.py +1 -1
  5. phoenix/db/insertion/span_annotation.py +1 -1
  6. phoenix/db/insertion/trace_annotation.py +1 -1
  7. phoenix/db/insertion/types.py +0 -4
  8. phoenix/db/models.py +22 -1
  9. phoenix/server/api/auth_messages.py +46 -0
  10. phoenix/server/api/context.py +2 -0
  11. phoenix/server/api/dataloaders/__init__.py +2 -0
  12. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  13. phoenix/server/api/helpers/playground_clients.py +1 -0
  14. phoenix/server/api/mutations/__init__.py +2 -0
  15. phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
  16. phoenix/server/api/mutations/dataset_split_mutations.py +38 -2
  17. phoenix/server/api/queries.py +21 -0
  18. phoenix/server/api/routers/auth.py +5 -5
  19. phoenix/server/api/routers/oauth2.py +53 -51
  20. phoenix/server/api/types/Dataset.py +8 -0
  21. phoenix/server/api/types/DatasetExample.py +7 -0
  22. phoenix/server/api/types/DatasetLabel.py +23 -0
  23. phoenix/server/api/types/Prompt.py +18 -1
  24. phoenix/server/app.py +12 -12
  25. phoenix/server/cost_tracking/model_cost_manifest.json +54 -54
  26. phoenix/server/oauth2.py +2 -4
  27. phoenix/server/static/.vite/manifest.json +39 -39
  28. phoenix/server/static/assets/{components-Dl9SUw1U.js → components-Bs8eJEpU.js} +699 -378
  29. phoenix/server/static/assets/{index-CqQS0dTo.js → index-C6WEu5UP.js} +3 -3
  30. phoenix/server/static/assets/{pages-DKSjVA_E.js → pages-D-n2pkoG.js} +1149 -1142
  31. phoenix/server/static/assets/vendor-D2eEI-6h.js +914 -0
  32. phoenix/server/static/assets/{vendor-arizeai-D-lWOwIS.js → vendor-arizeai-kfOei7nf.js} +15 -24
  33. phoenix/server/static/assets/{vendor-codemirror-BRBpy3_z.js → vendor-codemirror-1bq_t1Ec.js} +3 -3
  34. phoenix/server/static/assets/{vendor-recharts--KdSwB3m.js → vendor-recharts-DQ4xfrf4.js} +1 -1
  35. phoenix/server/static/assets/{vendor-shiki-CvRzZnIo.js → vendor-shiki-GGmcIQxA.js} +1 -1
  36. phoenix/server/templates/index.html +1 -0
  37. phoenix/server/utils.py +74 -0
  38. phoenix/session/session.py +25 -5
  39. phoenix/version.py +1 -1
  40. phoenix/server/static/assets/vendor-CtbHQYl8.js +0 -903
  41. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/WHEEL +0 -0
  42. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/entry_points.txt +0 -0
  43. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/licenses/IP_NOTICE +0 -0
  44. {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,291 @@
1
+ from typing import Optional
2
+
3
+ import sqlalchemy
4
+ import strawberry
5
+ from sqlalchemy import delete, select
6
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
7
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
8
+ from strawberry import UNSET
9
+ from strawberry.relay.types import GlobalID
10
+ from strawberry.types import Info
11
+
12
+ from phoenix.db import models
13
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly
14
+ from phoenix.server.api.context import Context
15
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
+ from phoenix.server.api.queries import Query
17
+ from phoenix.server.api.types.Dataset import Dataset
18
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
19
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
20
+
21
+
22
+ @strawberry.input
23
+ class CreateDatasetLabelInput:
24
+ name: str
25
+ description: Optional[str] = UNSET
26
+ color: str
27
+
28
+
29
+ @strawberry.type
30
+ class CreateDatasetLabelMutationPayload:
31
+ dataset_label: DatasetLabel
32
+
33
+
34
+ @strawberry.input
35
+ class DeleteDatasetLabelsInput:
36
+ dataset_label_ids: list[GlobalID]
37
+
38
+
39
+ @strawberry.type
40
+ class DeleteDatasetLabelsMutationPayload:
41
+ dataset_labels: list[DatasetLabel]
42
+
43
+
44
+ @strawberry.input
45
+ class UpdateDatasetLabelInput:
46
+ dataset_label_id: GlobalID
47
+ name: str
48
+ description: Optional[str] = None
49
+ color: str
50
+
51
+
52
+ @strawberry.type
53
+ class UpdateDatasetLabelMutationPayload:
54
+ dataset_label: DatasetLabel
55
+
56
+
57
+ @strawberry.input
58
+ class SetDatasetLabelsInput:
59
+ dataset_label_ids: list[GlobalID]
60
+ dataset_ids: list[GlobalID]
61
+
62
+
63
+ @strawberry.type
64
+ class SetDatasetLabelsMutationPayload:
65
+ query: "Query"
66
+
67
+
68
+ @strawberry.input
69
+ class UnsetDatasetLabelsInput:
70
+ dataset_label_ids: list[GlobalID]
71
+ dataset_ids: list[GlobalID]
72
+
73
+
74
+ @strawberry.type
75
+ class UnsetDatasetLabelsMutationPayload:
76
+ query: "Query"
77
+
78
+
79
+ @strawberry.type
80
+ class DatasetLabelMutationMixin:
81
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
82
+ async def create_dataset_label(
83
+ self,
84
+ info: Info[Context, None],
85
+ input: CreateDatasetLabelInput,
86
+ ) -> CreateDatasetLabelMutationPayload:
87
+ name = input.name
88
+ description = input.description
89
+ color = input.color
90
+ async with info.context.db() as session:
91
+ dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
92
+ session.add(dataset_label_orm)
93
+ try:
94
+ await session.commit()
95
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
96
+ raise Conflict(f"A dataset label named '{name}' already exists")
97
+ except sqlalchemy.exc.StatementError as error:
98
+ raise BadRequest(str(error.orig))
99
+ return CreateDatasetLabelMutationPayload(
100
+ dataset_label=to_gql_dataset_label(dataset_label_orm)
101
+ )
102
+
103
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
104
+ async def update_dataset_label(
105
+ self, info: Info[Context, None], input: UpdateDatasetLabelInput
106
+ ) -> UpdateDatasetLabelMutationPayload:
107
+ if not input.name or not input.name.strip():
108
+ raise BadRequest("Dataset label name cannot be empty")
109
+
110
+ try:
111
+ dataset_label_id = from_global_id_with_expected_type(
112
+ input.dataset_label_id, DatasetLabel.__name__
113
+ )
114
+ except ValueError:
115
+ raise BadRequest(f"Invalid dataset label ID: {input.dataset_label_id}")
116
+
117
+ async with info.context.db() as session:
118
+ dataset_label_orm = await session.get(models.DatasetLabel, dataset_label_id)
119
+ if not dataset_label_orm:
120
+ raise NotFound(f"DatasetLabel with ID {input.dataset_label_id} not found")
121
+
122
+ dataset_label_orm.name = input.name.strip()
123
+ dataset_label_orm.description = input.description
124
+ dataset_label_orm.color = input.color.strip()
125
+
126
+ try:
127
+ await session.commit()
128
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
129
+ raise Conflict(f"A dataset label named '{input.name}' already exists")
130
+ except sqlalchemy.exc.StatementError as error:
131
+ raise BadRequest(str(error.orig))
132
+ return UpdateDatasetLabelMutationPayload(
133
+ dataset_label=to_gql_dataset_label(dataset_label_orm)
134
+ )
135
+
136
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
137
+ async def delete_dataset_labels(
138
+ self, info: Info[Context, None], input: DeleteDatasetLabelsInput
139
+ ) -> DeleteDatasetLabelsMutationPayload:
140
+ dataset_label_row_ids: dict[int, None] = {}
141
+ for dataset_label_node_id in input.dataset_label_ids:
142
+ try:
143
+ dataset_label_row_id = from_global_id_with_expected_type(
144
+ dataset_label_node_id, DatasetLabel.__name__
145
+ )
146
+ except ValueError:
147
+ raise BadRequest(f"Unknown dataset label: {dataset_label_node_id}")
148
+ dataset_label_row_ids[dataset_label_row_id] = None
149
+ async with info.context.db() as session:
150
+ stmt = (
151
+ delete(models.DatasetLabel)
152
+ .where(models.DatasetLabel.id.in_(dataset_label_row_ids.keys()))
153
+ .returning(models.DatasetLabel)
154
+ )
155
+ deleted_dataset_labels = (await session.scalars(stmt)).all()
156
+ if len(deleted_dataset_labels) < len(dataset_label_row_ids):
157
+ await session.rollback()
158
+ raise NotFound("Could not find one or more dataset labels with given IDs")
159
+ deleted_dataset_labels_by_id = {
160
+ dataset_label.id: dataset_label for dataset_label in deleted_dataset_labels
161
+ }
162
+ return DeleteDatasetLabelsMutationPayload(
163
+ dataset_labels=[
164
+ to_gql_dataset_label(deleted_dataset_labels_by_id[dataset_label_row_id])
165
+ for dataset_label_row_id in dataset_label_row_ids
166
+ ]
167
+ )
168
+
169
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
170
+ async def set_dataset_labels(
171
+ self, info: Info[Context, None], input: SetDatasetLabelsInput
172
+ ) -> SetDatasetLabelsMutationPayload:
173
+ if not input.dataset_ids:
174
+ raise BadRequest("No datasets provided.")
175
+ if not input.dataset_label_ids:
176
+ raise BadRequest("No dataset labels provided.")
177
+
178
+ unique_dataset_rowids: set[int] = set()
179
+ for dataset_gid in input.dataset_ids:
180
+ try:
181
+ dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
182
+ except ValueError:
183
+ raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
184
+ unique_dataset_rowids.add(dataset_rowid)
185
+ dataset_rowids = list(unique_dataset_rowids)
186
+
187
+ unique_dataset_label_rowids: set[int] = set()
188
+ for dataset_label_gid in input.dataset_label_ids:
189
+ try:
190
+ dataset_label_rowid = from_global_id_with_expected_type(
191
+ dataset_label_gid, DatasetLabel.__name__
192
+ )
193
+ except ValueError:
194
+ raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
195
+ unique_dataset_label_rowids.add(dataset_label_rowid)
196
+ dataset_label_rowids = list(unique_dataset_label_rowids)
197
+
198
+ async with info.context.db() as session:
199
+ existing_dataset_ids = (
200
+ await session.scalars(
201
+ select(models.Dataset.id).where(models.Dataset.id.in_(dataset_rowids))
202
+ )
203
+ ).all()
204
+ if len(existing_dataset_ids) != len(dataset_rowids):
205
+ raise NotFound("One or more datasets not found")
206
+
207
+ existing_dataset_label_ids = (
208
+ await session.scalars(
209
+ select(models.DatasetLabel.id).where(
210
+ models.DatasetLabel.id.in_(dataset_label_rowids)
211
+ )
212
+ )
213
+ ).all()
214
+ if len(existing_dataset_label_ids) != len(dataset_label_rowids):
215
+ raise NotFound("One or more dataset labels not found")
216
+
217
+ existing_dataset_label_keys = await session.execute(
218
+ select(
219
+ models.DatasetsDatasetLabel.dataset_id,
220
+ models.DatasetsDatasetLabel.dataset_label_id,
221
+ ).where(
222
+ models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
223
+ & models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
224
+ )
225
+ )
226
+ unique_dataset_label_keys = set(existing_dataset_label_keys.all())
227
+
228
+ datasets_dataset_labels = []
229
+ for dataset_rowid in dataset_rowids:
230
+ for dataset_label_rowid in dataset_label_rowids:
231
+ if (dataset_rowid, dataset_label_rowid) in unique_dataset_label_keys:
232
+ continue
233
+ datasets_dataset_labels.append(
234
+ models.DatasetsDatasetLabel(
235
+ dataset_id=dataset_rowid,
236
+ dataset_label_id=dataset_label_rowid,
237
+ )
238
+ )
239
+ session.add_all(datasets_dataset_labels)
240
+
241
+ if datasets_dataset_labels:
242
+ try:
243
+ await session.commit()
244
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
245
+ raise Conflict("Failed to add dataset labels to datasets.") from e
246
+
247
+ return SetDatasetLabelsMutationPayload(
248
+ query=Query(),
249
+ )
250
+
251
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
252
+ async def unset_dataset_labels(
253
+ self, info: Info[Context, None], input: UnsetDatasetLabelsInput
254
+ ) -> UnsetDatasetLabelsMutationPayload:
255
+ if not input.dataset_ids:
256
+ raise BadRequest("No datasets provided.")
257
+ if not input.dataset_label_ids:
258
+ raise BadRequest("No dataset labels provided.")
259
+
260
+ unique_dataset_rowids: set[int] = set()
261
+ for dataset_gid in input.dataset_ids:
262
+ try:
263
+ dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
264
+ except ValueError:
265
+ raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
266
+ unique_dataset_rowids.add(dataset_rowid)
267
+ dataset_rowids = list(unique_dataset_rowids)
268
+
269
+ unique_dataset_label_rowids: set[int] = set()
270
+ for dataset_label_gid in input.dataset_label_ids:
271
+ try:
272
+ dataset_label_rowid = from_global_id_with_expected_type(
273
+ dataset_label_gid, DatasetLabel.__name__
274
+ )
275
+ except ValueError:
276
+ raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
277
+ unique_dataset_label_rowids.add(dataset_label_rowid)
278
+ dataset_label_rowids = list(unique_dataset_label_rowids)
279
+
280
+ async with info.context.db() as session:
281
+ await session.execute(
282
+ delete(models.DatasetsDatasetLabel).where(
283
+ models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
284
+ & models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
285
+ )
286
+ )
287
+ await session.commit()
288
+
289
+ return UnsetDatasetLabelsMutationPayload(
290
+ query=Query(),
291
+ )
@@ -15,6 +15,7 @@ from phoenix.server.api.context import Context
15
15
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
16
  from phoenix.server.api.helpers.playground_users import get_user
17
17
  from phoenix.server.api.queries import Query
18
+ from phoenix.server.api.types.DatasetExample import DatasetExample, to_gql_dataset_example
18
19
  from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
19
20
  from phoenix.server.api.types.node import from_global_id_with_expected_type
20
21
 
@@ -68,6 +69,13 @@ class DatasetSplitMutationPayload:
68
69
  query: "Query"
69
70
 
70
71
 
72
+ @strawberry.type
73
+ class DatasetSplitMutationPayloadWithExamples:
74
+ dataset_split: DatasetSplit
75
+ query: "Query"
76
+ examples: list[DatasetExample]
77
+
78
+
71
79
  @strawberry.type
72
80
  class DeleteDatasetSplitsMutationPayload:
73
81
  dataset_splits: list[DatasetSplit]
@@ -77,11 +85,13 @@ class DeleteDatasetSplitsMutationPayload:
77
85
  @strawberry.type
78
86
  class AddDatasetExamplesToDatasetSplitsMutationPayload:
79
87
  query: "Query"
88
+ examples: list[DatasetExample]
80
89
 
81
90
 
82
91
  @strawberry.type
83
92
  class RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
84
93
  query: "Query"
94
+ examples: list[DatasetExample]
85
95
 
86
96
 
87
97
  @strawberry.type
@@ -262,8 +272,16 @@ class DatasetSplitMutationMixin:
262
272
  except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
263
273
  raise Conflict("Failed to add examples to dataset splits.") from e
264
274
 
275
+ examples = (
276
+ await session.scalars(
277
+ select(models.DatasetExample).where(
278
+ models.DatasetExample.id.in_(example_rowids)
279
+ )
280
+ )
281
+ ).all()
265
282
  return AddDatasetExamplesToDatasetSplitsMutationPayload(
266
283
  query=Query(),
284
+ examples=[to_gql_dataset_example(example) for example in examples],
267
285
  )
268
286
 
269
287
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
@@ -314,14 +332,23 @@ class DatasetSplitMutationMixin:
314
332
 
315
333
  await session.execute(stmt)
316
334
 
335
+ examples = (
336
+ await session.scalars(
337
+ select(models.DatasetExample).where(
338
+ models.DatasetExample.id.in_(example_rowids)
339
+ )
340
+ )
341
+ ).all()
342
+
317
343
  return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
318
344
  query=Query(),
345
+ examples=[to_gql_dataset_example(example) for example in examples],
319
346
  )
320
347
 
321
348
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
322
349
  async def create_dataset_split_with_examples(
323
350
  self, info: Info[Context, None], input: CreateDatasetSplitWithExamplesInput
324
- ) -> DatasetSplitMutationPayload:
351
+ ) -> DatasetSplitMutationPayloadWithExamples:
325
352
  user_id = get_user(info)
326
353
  validated_name = _validated_name(input.name)
327
354
  unique_example_rowids: set[int] = set()
@@ -374,9 +401,18 @@ class DatasetSplitMutationMixin:
374
401
  "Failed to associate examples with the new dataset split."
375
402
  ) from e
376
403
 
377
- return DatasetSplitMutationPayload(
404
+ examples = (
405
+ await session.scalars(
406
+ select(models.DatasetExample).where(
407
+ models.DatasetExample.id.in_(example_rowids)
408
+ )
409
+ )
410
+ ).all()
411
+
412
+ return DatasetSplitMutationPayloadWithExamples(
378
413
  dataset_split=to_gql_dataset_split(dataset_split_orm),
379
414
  query=Query(),
415
+ examples=[to_gql_dataset_example(example) for example in examples],
380
416
  )
381
417
 
382
418
 
@@ -48,6 +48,7 @@ from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_a
48
48
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
49
49
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
50
50
  from phoenix.server.api.types.DatasetExample import DatasetExample
51
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
51
52
  from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
52
53
  from phoenix.server.api.types.Dimension import to_gql_dimension
53
54
  from phoenix.server.api.types.EmbeddingDimension import (
@@ -1149,6 +1150,26 @@ class Query:
1149
1150
  args=args,
1150
1151
  )
1151
1152
 
1153
+ @strawberry.field
1154
+ async def dataset_labels(
1155
+ self,
1156
+ info: Info[Context, None],
1157
+ first: Optional[int] = 50,
1158
+ last: Optional[int] = UNSET,
1159
+ after: Optional[CursorString] = UNSET,
1160
+ before: Optional[CursorString] = UNSET,
1161
+ ) -> Connection[DatasetLabel]:
1162
+ args = ConnectionArgs(
1163
+ first=first,
1164
+ after=after if isinstance(after, CursorString) else None,
1165
+ last=last,
1166
+ before=before if isinstance(before, CursorString) else None,
1167
+ )
1168
+ async with info.context.db() as session:
1169
+ dataset_labels = await session.scalars(select(models.DatasetLabel))
1170
+ data = [to_gql_dataset_label(dataset_label) for dataset_label in dataset_labels]
1171
+ return connection_from_list(data=data, args=args)
1172
+
1152
1173
  @strawberry.field
1153
1174
  async def dataset_splits(
1154
1175
  self,
@@ -2,7 +2,6 @@ import asyncio
2
2
  import secrets
3
3
  from datetime import datetime, timedelta, timezone
4
4
  from functools import partial
5
- from pathlib import Path
6
5
  from urllib.parse import urlencode, urlparse, urlunparse
7
6
 
8
7
  from fastapi import APIRouter, Depends, HTTPException, Request, Response
@@ -38,7 +37,6 @@ from phoenix.config import (
38
37
  get_base_url,
39
38
  get_env_disable_basic_auth,
40
39
  get_env_disable_rate_limit,
41
- get_env_host_root_path,
42
40
  )
43
41
  from phoenix.db import models
44
42
  from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens
@@ -52,6 +50,7 @@ from phoenix.server.types import (
52
50
  TokenStore,
53
51
  UserId,
54
52
  )
53
+ from phoenix.server.utils import prepend_root_path
55
54
 
56
55
  rate_limiter = ServerRateLimiter(
57
56
  per_second_rate_limit=0.2,
@@ -145,7 +144,8 @@ async def logout(
145
144
  user_id = subject
146
145
  if user_id:
147
146
  await token_store.log_out(user_id)
148
- redirect_url = "/logout" if get_env_disable_basic_auth() else "/login"
147
+ redirect_path = "/logout" if get_env_disable_basic_auth() else "/login"
148
+ redirect_url = prepend_root_path(request.scope, redirect_path)
149
149
  response = Response(status_code=HTTP_302_FOUND, headers={"Location": redirect_url})
150
150
  response = delete_access_token_cookie(response)
151
151
  response = delete_refresh_token_cookie(response)
@@ -242,9 +242,9 @@ async def initiate_password_reset(request: Request) -> Response:
242
242
  )
243
243
  token, _ = await token_store.create_password_reset_token(password_reset_token_claims)
244
244
  url = urlparse(request.headers.get("referer") or get_base_url())
245
- path = Path(get_env_host_root_path()) / "reset-password-with-token"
245
+ path = prepend_root_path(request.scope, "/reset-password-with-token")
246
246
  query_string = urlencode(dict(token=token))
247
- components = (url.scheme, url.netloc, path.as_posix(), "", query_string, "")
247
+ components = (url.scheme, url.netloc, path, "", query_string, "")
248
248
  reset_url = urlunparse(components)
249
249
  await sender.send_password_reset_email(email, reset_url)
250
250
  return Response(status_code=HTTP_204_NO_CONTENT)
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  import re
2
3
  from dataclasses import dataclass
3
4
  from datetime import timedelta
@@ -38,17 +39,20 @@ from phoenix.config import (
38
39
  get_env_disable_rate_limit,
39
40
  )
40
41
  from phoenix.db import models
42
+ from phoenix.server.api.auth_messages import AuthErrorCode
41
43
  from phoenix.server.bearer_auth import create_access_and_refresh_tokens
42
- from phoenix.server.oauth2 import OAuth2Client
43
44
  from phoenix.server.rate_limiters import (
44
45
  ServerRateLimiter,
45
46
  fastapi_ip_rate_limiter,
46
47
  fastapi_route_rate_limiter,
47
48
  )
48
49
  from phoenix.server.types import TokenStore
50
+ from phoenix.server.utils import get_root_path, prepend_root_path
49
51
 
50
52
  _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
51
53
 
54
+ logger = logging.getLogger(__name__)
55
+
52
56
  login_rate_limiter = fastapi_ip_rate_limiter(
53
57
  ServerRateLimiter(
54
58
  per_second_rate_limit=0.2,
@@ -87,11 +91,12 @@ async def login(
87
91
  idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
88
92
  return_url: Optional[str] = Query(default=None, alias="returnUrl"),
89
93
  ) -> RedirectResponse:
94
+ # Security Note: Query parameters should be treated as untrusted user input. Never display
95
+ # these values directly to users as they could be manipulated for XSS, phishing, or social
96
+ # engineering attacks.
97
+ if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
98
+ return _redirect_to_login(request=request, error="unknown_idp")
90
99
  secret = request.app.state.get_secret()
91
- if not isinstance(
92
- oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
93
- ):
94
- return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
95
100
  if (referer := request.headers.get("referer")) is not None:
96
101
  # if the referer header is present, use it as the origin URL
97
102
  parsed_url = urlparse(referer)
@@ -131,28 +136,42 @@ async def create_tokens(
131
136
  request: Request,
132
137
  idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
133
138
  state: str = Query(),
134
- authorization_code: str = Query(alias="code"),
139
+ authorization_code: Optional[str] = Query(default=None, alias="code"),
140
+ error: Optional[str] = Query(default=None),
141
+ error_description: Optional[str] = Query(default=None),
135
142
  stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME),
136
143
  stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME),
137
144
  ) -> RedirectResponse:
145
+ # Security Note: Query parameters should be treated as untrusted user input. Never display
146
+ # these values directly to users as they could be manipulated for XSS, phishing, or social
147
+ # engineering attacks.
148
+ if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
149
+ return _redirect_to_login(request=request, error="unknown_idp")
150
+ if error or error_description:
151
+ logger.error(
152
+ "OAuth2 authentication failed for IDP %s: error=%s, description=%s",
153
+ idp_name,
154
+ error,
155
+ error_description,
156
+ )
157
+ return _redirect_to_login(request=request, error="auth_failed")
158
+ if authorization_code is None:
159
+ logger.error("OAuth2 callback missing authorization code for IDP %s", idp_name)
160
+ return _redirect_to_login(request=request, error="auth_failed")
138
161
  secret = request.app.state.get_secret()
139
162
  if state != stored_state:
140
- return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
163
+ return _redirect_to_login(request=request, error="invalid_state")
141
164
  try:
142
165
  payload = _parse_state_payload(secret=secret, state=state)
143
166
  except JoseError:
144
- return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
167
+ return _redirect_to_login(request=request, error="invalid_state")
145
168
  if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
146
169
  unquote(return_url)
147
170
  ):
148
- return _redirect_to_login(request=request, error="Attempting login with unsafe return URL.")
171
+ return _redirect_to_login(request=request, error="unsafe_return_url")
149
172
  assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
150
173
  assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
151
174
  token_store: TokenStore = request.app.state.get_token_store()
152
- if not isinstance(
153
- oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
154
- ):
155
- return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
156
175
  try:
157
176
  token_data = await oauth2_client.fetch_access_token(
158
177
  state=state,
@@ -161,19 +180,19 @@ async def create_tokens(
161
180
  request=request, origin_url=payload["origin_url"], idp_name=idp_name
162
181
  ),
163
182
  )
164
- except OAuthError as error:
165
- return _redirect_to_login(request=request, error=str(error))
183
+ except OAuthError as e:
184
+ logger.error("OAuth2 error for IDP %s: %s", idp_name, e)
185
+ return _redirect_to_login(request=request, error="oauth_error")
166
186
  _validate_token_data(token_data)
167
187
  if "id_token" not in token_data:
168
- return _redirect_to_login(
169
- request=request,
170
- error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect.",
171
- )
188
+ logger.error("OAuth2 IDP %s does not appear to support OpenID Connect", idp_name)
189
+ return _redirect_to_login(request=request, error="no_oidc_support")
172
190
  user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
173
191
  try:
174
192
  user_info = _parse_user_info(user_info)
175
- except MissingEmailScope as error:
176
- return _redirect_to_login(request=request, error=str(error))
193
+ except MissingEmailScope as e:
194
+ logger.error("Missing email scope for IDP %s: %s", idp_name, e)
195
+ return _redirect_to_login(request=request, error="missing_email_scope")
177
196
 
178
197
  try:
179
198
  async with request.app.state.db() as session:
@@ -183,15 +202,19 @@ async def create_tokens(
183
202
  user_info=user_info,
184
203
  allow_sign_up=oauth2_client.allow_sign_up,
185
204
  )
186
- except (EmailAlreadyInUse, SignInNotAllowed) as error:
187
- return _redirect_to_login(request=request, error=str(error))
205
+ except EmailAlreadyInUse as e:
206
+ logger.error("Email already in use for IDP %s: %s", idp_name, e)
207
+ return _redirect_to_login(request=request, error="email_in_use")
208
+ except SignInNotAllowed as e:
209
+ logger.error("Sign in not allowed for IDP %s: %s", idp_name, e)
210
+ return _redirect_to_login(request=request, error="sign_in_not_allowed")
188
211
  access_token, refresh_token = await create_access_and_refresh_tokens(
189
212
  user=user,
190
213
  token_store=token_store,
191
214
  access_token_expiry=access_token_expiry,
192
215
  refresh_token_expiry=refresh_token_expiry,
193
216
  )
194
- redirect_path = _prepend_root_path_if_exists(request=request, path=return_url or "/")
217
+ redirect_path = prepend_root_path(request.scope, return_url or "/")
195
218
  response = RedirectResponse(
196
219
  url=redirect_path,
197
220
  status_code=HTTP_302_FOUND,
@@ -559,13 +582,14 @@ class MissingEmailScope(Exception):
559
582
  pass
560
583
 
561
584
 
562
- def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
585
+ def _redirect_to_login(*, request: Request, error: AuthErrorCode) -> RedirectResponse:
563
586
  """
564
- Creates a RedirectResponse to the login page to display an error message.
587
+ Creates a RedirectResponse to the login page to display an error code.
588
+ The error code will be validated and mapped to a user-friendly message on the frontend.
565
589
  """
566
590
  # TODO: this needs some cleanup
567
- login_path = _prepend_root_path_if_exists(
568
- request=request, path="/login" if not get_env_disable_basic_auth() else "/logout"
591
+ login_path = prepend_root_path(
592
+ request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
569
593
  )
570
594
  url = URL(login_path).include_query_params(error=error)
571
595
  response = RedirectResponse(url=url)
@@ -574,34 +598,15 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
574
598
  return response
575
599
 
576
600
 
577
- def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
578
- """
579
- If a root path is configured, prepends it to the input path.
580
- """
581
- if not path.startswith("/"):
582
- raise ValueError("path must start with a forward slash")
583
- root_path = _get_root_path(request=request)
584
- if root_path.endswith("/"):
585
- root_path = root_path.rstrip("/")
586
- return root_path + path
587
-
588
-
589
601
  def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
590
602
  """
591
603
  If a root path is configured, appends it to the input base url.
592
604
  """
593
- if not (root_path := _get_root_path(request=request)):
605
+ if not (root_path := get_root_path(request.scope)):
594
606
  return base_url
595
607
  return str(URLPath(root_path).make_absolute_url(base_url=base_url))
596
608
 
597
609
 
598
- def _get_root_path(*, request: Request) -> str:
599
- """
600
- Gets the root path from the request.
601
- """
602
- return str(request.scope.get("root_path", ""))
603
-
604
-
605
610
  def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
606
611
  """
607
612
  Gets the endpoint for create tokens route.
@@ -679,7 +684,4 @@ def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2State
679
684
 
680
685
 
681
686
  _JWT_ALGORITHM = "HS256"
682
- _INVALID_OAUTH2_STATE_MESSAGE = (
683
- "Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}."
684
- )
685
687
  _RELATIVE_URL_PATTERN = re.compile(r"^/($|\w)")