arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import re
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
+
from dataclasses import dataclass, field
|
|
3
4
|
from datetime import timedelta
|
|
4
5
|
from random import randrange
|
|
5
6
|
from typing import Any, Optional, TypedDict
|
|
@@ -10,7 +11,7 @@ from authlib.integrations.starlette_client import OAuthError
|
|
|
10
11
|
from authlib.jose import jwt
|
|
11
12
|
from authlib.jose.errors import JoseError
|
|
12
13
|
from fastapi import APIRouter, Cookie, Depends, Path, Query, Request
|
|
13
|
-
from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select
|
|
14
|
+
from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select
|
|
14
15
|
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
15
16
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
16
17
|
from sqlalchemy.orm import joinedload
|
|
@@ -18,27 +19,32 @@ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore
|
|
|
18
19
|
from starlette.datastructures import URL, Secret, URLPath
|
|
19
20
|
from starlette.responses import RedirectResponse
|
|
20
21
|
from starlette.routing import Router
|
|
21
|
-
from starlette.status import HTTP_302_FOUND
|
|
22
22
|
from typing_extensions import Annotated, NotRequired, TypeGuard
|
|
23
23
|
|
|
24
24
|
from phoenix.auth import (
|
|
25
25
|
DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES,
|
|
26
|
+
PHOENIX_OAUTH2_CODE_VERIFIER_COOKIE_NAME,
|
|
26
27
|
PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
|
|
27
28
|
PHOENIX_OAUTH2_STATE_COOKIE_NAME,
|
|
29
|
+
delete_oauth2_code_verifier_cookie,
|
|
28
30
|
delete_oauth2_nonce_cookie,
|
|
29
31
|
delete_oauth2_state_cookie,
|
|
30
32
|
sanitize_email,
|
|
31
33
|
set_access_token_cookie,
|
|
34
|
+
set_oauth2_code_verifier_cookie,
|
|
32
35
|
set_oauth2_nonce_cookie,
|
|
33
36
|
set_oauth2_state_cookie,
|
|
34
37
|
set_refresh_token_cookie,
|
|
35
38
|
)
|
|
36
39
|
from phoenix.config import (
|
|
40
|
+
AssignableUserRoleName,
|
|
37
41
|
get_env_disable_basic_auth,
|
|
38
42
|
get_env_disable_rate_limit,
|
|
39
43
|
)
|
|
40
44
|
from phoenix.db import models
|
|
45
|
+
from phoenix.server.api.auth_messages import AuthErrorCode
|
|
41
46
|
from phoenix.server.bearer_auth import create_access_and_refresh_tokens
|
|
47
|
+
from phoenix.server.ldap import is_ldap_user
|
|
42
48
|
from phoenix.server.oauth2 import OAuth2Client
|
|
43
49
|
from phoenix.server.rate_limiters import (
|
|
44
50
|
ServerRateLimiter,
|
|
@@ -46,9 +52,12 @@ from phoenix.server.rate_limiters import (
|
|
|
46
52
|
fastapi_route_rate_limiter,
|
|
47
53
|
)
|
|
48
54
|
from phoenix.server.types import TokenStore
|
|
55
|
+
from phoenix.server.utils import get_root_path, prepend_root_path
|
|
49
56
|
|
|
50
57
|
_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
|
|
51
58
|
|
|
59
|
+
logger = logging.getLogger(__name__)
|
|
60
|
+
|
|
52
61
|
login_rate_limiter = fastapi_ip_rate_limiter(
|
|
53
62
|
ServerRateLimiter(
|
|
54
63
|
per_second_rate_limit=0.2,
|
|
@@ -87,11 +96,12 @@ async def login(
|
|
|
87
96
|
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
|
|
88
97
|
return_url: Optional[str] = Query(default=None, alias="returnUrl"),
|
|
89
98
|
) -> RedirectResponse:
|
|
99
|
+
# Security Note: Query parameters should be treated as untrusted user input. Never display
|
|
100
|
+
# these values directly to users as they could be manipulated for XSS, phishing, or social
|
|
101
|
+
# engineering attacks.
|
|
102
|
+
if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
|
|
103
|
+
return _redirect_to_login(request=request, error="unknown_idp")
|
|
90
104
|
secret = request.app.state.get_secret()
|
|
91
|
-
if not isinstance(
|
|
92
|
-
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
93
|
-
):
|
|
94
|
-
return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
|
|
95
105
|
if (referer := request.headers.get("referer")) is not None:
|
|
96
106
|
# if the referer header is present, use it as the origin URL
|
|
97
107
|
parsed_url = urlparse(referer)
|
|
@@ -112,7 +122,7 @@ async def login(
|
|
|
112
122
|
assert isinstance(authorization_url := authorization_url_data.get("url"), str)
|
|
113
123
|
assert isinstance(state := authorization_url_data.get("state"), str)
|
|
114
124
|
assert isinstance(nonce := authorization_url_data.get("nonce"), str)
|
|
115
|
-
response = RedirectResponse(url=authorization_url, status_code=
|
|
125
|
+
response = RedirectResponse(url=authorization_url, status_code=302)
|
|
116
126
|
response = set_oauth2_state_cookie(
|
|
117
127
|
response=response,
|
|
118
128
|
state=state,
|
|
@@ -123,6 +133,12 @@ async def login(
|
|
|
123
133
|
nonce=nonce,
|
|
124
134
|
max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
|
|
125
135
|
)
|
|
136
|
+
if code_verifier := authorization_url_data.get("code_verifier"):
|
|
137
|
+
response = set_oauth2_code_verifier_cookie(
|
|
138
|
+
response=response,
|
|
139
|
+
code_verifier=code_verifier,
|
|
140
|
+
max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
|
|
141
|
+
)
|
|
126
142
|
return response
|
|
127
143
|
|
|
128
144
|
|
|
@@ -130,50 +146,106 @@ async def login(
|
|
|
130
146
|
async def create_tokens(
|
|
131
147
|
request: Request,
|
|
132
148
|
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
|
|
133
|
-
state: str = Query(),
|
|
134
|
-
authorization_code: str = Query(alias="code"),
|
|
149
|
+
state: str = Query(), # RFC 6749 §4.1.1: CSRF protection via state parameter
|
|
150
|
+
authorization_code: Optional[str] = Query(default=None, alias="code"), # RFC 6749 §4.1.2
|
|
151
|
+
error: Optional[str] = Query(default=None), # RFC 6749 §4.1.2.1: Error response
|
|
152
|
+
error_description: Optional[str] = Query(default=None),
|
|
135
153
|
stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME),
|
|
136
|
-
stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME),
|
|
154
|
+
stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME), # OIDC Core §3.1.2.1
|
|
155
|
+
code_verifier: Optional[str] = Cookie(
|
|
156
|
+
default=None, alias=PHOENIX_OAUTH2_CODE_VERIFIER_COOKIE_NAME
|
|
157
|
+
), # RFC 7636 §4.1
|
|
137
158
|
) -> RedirectResponse:
|
|
159
|
+
# Security Note: Query parameters should be treated as untrusted user input. Never display
|
|
160
|
+
# these values directly to users as they could be manipulated for XSS, phishing, or social
|
|
161
|
+
# engineering attacks.
|
|
162
|
+
if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
|
|
163
|
+
return _redirect_to_login(request=request, error="unknown_idp")
|
|
164
|
+
if error or error_description:
|
|
165
|
+
logger.error(
|
|
166
|
+
"OAuth2 authentication failed for IDP %s: error=%s, description=%s",
|
|
167
|
+
idp_name,
|
|
168
|
+
error,
|
|
169
|
+
error_description,
|
|
170
|
+
)
|
|
171
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
172
|
+
if authorization_code is None:
|
|
173
|
+
logger.error("OAuth2 callback missing authorization code for IDP %s", idp_name)
|
|
174
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
138
175
|
secret = request.app.state.get_secret()
|
|
176
|
+
# RFC 6749 §10.12: CSRF protection - validate state parameter
|
|
139
177
|
if state != stored_state:
|
|
140
|
-
return _redirect_to_login(request=request, error=
|
|
178
|
+
return _redirect_to_login(request=request, error="invalid_state")
|
|
141
179
|
try:
|
|
142
180
|
payload = _parse_state_payload(secret=secret, state=state)
|
|
143
181
|
except JoseError:
|
|
144
|
-
return _redirect_to_login(request=request, error=
|
|
182
|
+
return _redirect_to_login(request=request, error="invalid_state")
|
|
145
183
|
if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
|
|
146
184
|
unquote(return_url)
|
|
147
185
|
):
|
|
148
|
-
return _redirect_to_login(request=request, error="
|
|
186
|
+
return _redirect_to_login(request=request, error="unsafe_return_url")
|
|
149
187
|
assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
|
|
150
188
|
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
|
|
151
189
|
token_store: TokenStore = request.app.state.get_token_store()
|
|
152
|
-
if not isinstance(
|
|
153
|
-
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
154
|
-
):
|
|
155
|
-
return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
|
|
156
190
|
try:
|
|
157
|
-
|
|
191
|
+
# RFC 6749 §4.1.3: Token request - exchange authorization code for tokens
|
|
192
|
+
fetch_kwargs: dict[str, Any] = dict(
|
|
158
193
|
state=state,
|
|
159
194
|
code=authorization_code,
|
|
160
|
-
redirect_uri=_get_create_tokens_endpoint(
|
|
195
|
+
redirect_uri=_get_create_tokens_endpoint( # RFC 6749 §3.1.2
|
|
161
196
|
request=request, origin_url=payload["origin_url"], idp_name=idp_name
|
|
162
197
|
),
|
|
163
198
|
)
|
|
164
|
-
|
|
165
|
-
|
|
199
|
+
# PKCE validation: code_verifier is required when PKCE is enabled (RFC 7636 §4.5)
|
|
200
|
+
if oauth2_client.use_pkce:
|
|
201
|
+
if not code_verifier:
|
|
202
|
+
logger.error(
|
|
203
|
+
"PKCE enabled but code_verifier cookie missing for IDP %s. "
|
|
204
|
+
"This may indicate a cookie issue, CORS misconfiguration, or "
|
|
205
|
+
"browser compatibility problem.",
|
|
206
|
+
idp_name,
|
|
207
|
+
)
|
|
208
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
209
|
+
fetch_kwargs["code_verifier"] = code_verifier
|
|
210
|
+
token_data = await oauth2_client.fetch_access_token(**fetch_kwargs)
|
|
211
|
+
except OAuthError as e:
|
|
212
|
+
logger.error("OAuth2 error for IDP %s: %s", idp_name, e)
|
|
213
|
+
return _redirect_to_login(request=request, error="oauth_error")
|
|
166
214
|
_validate_token_data(token_data)
|
|
167
215
|
if "id_token" not in token_data:
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
216
|
+
logger.error("OAuth2 IDP %s does not appear to support OpenID Connect", idp_name)
|
|
217
|
+
return _redirect_to_login(request=request, error="no_oidc_support")
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
id_token_claims = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
|
|
221
|
+
except JoseError as e:
|
|
222
|
+
logger.error("ID token validation failed for IDP %s: %s", idp_name, e)
|
|
223
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
224
|
+
|
|
225
|
+
if oauth2_client.has_sufficient_claims(id_token_claims):
|
|
226
|
+
user_claims = id_token_claims
|
|
227
|
+
else:
|
|
228
|
+
user_claims = await _fetch_and_merge_userinfo_claims(
|
|
229
|
+
oauth2_client, token_data, id_token_claims
|
|
171
230
|
)
|
|
172
|
-
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
user_info = _parse_user_info(user_claims)
|
|
234
|
+
except (MissingEmailScope, InvalidUserInfo) as e:
|
|
235
|
+
logger.error("Error parsing user info for IDP %s: %s", idp_name, e)
|
|
236
|
+
return _redirect_to_login(request=request, error="missing_email_scope")
|
|
237
|
+
|
|
238
|
+
# Validate access and extract role from claims
|
|
239
|
+
# Both validate_access and extract_and_map_role may raise PermissionError
|
|
173
240
|
try:
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
241
|
+
oauth2_client.validate_access(user_info.claims)
|
|
242
|
+
# Extract and map role from claims
|
|
243
|
+
# Returns None if role mapping not configured (preserves existing user roles)
|
|
244
|
+
# Raises PermissionError if strict mode enabled and role validation fails
|
|
245
|
+
role_name = oauth2_client.extract_and_map_role(user_info.claims)
|
|
246
|
+
except PermissionError as e:
|
|
247
|
+
logger.error("Access validation failed for IDP %s: %s", idp_name, e)
|
|
248
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
177
249
|
|
|
178
250
|
try:
|
|
179
251
|
async with request.app.state.db() as session:
|
|
@@ -182,19 +254,24 @@ async def create_tokens(
|
|
|
182
254
|
oauth2_client_id=str(oauth2_client.client_id),
|
|
183
255
|
user_info=user_info,
|
|
184
256
|
allow_sign_up=oauth2_client.allow_sign_up,
|
|
257
|
+
role_name=role_name,
|
|
185
258
|
)
|
|
186
|
-
except
|
|
187
|
-
|
|
259
|
+
except EmailAlreadyInUse as e:
|
|
260
|
+
logger.error("Email already in use for IDP %s: %s", idp_name, e)
|
|
261
|
+
return _redirect_to_login(request=request, error="email_in_use")
|
|
262
|
+
except SignInNotAllowed as e:
|
|
263
|
+
logger.error("Sign in not allowed for IDP %s: %s", idp_name, e)
|
|
264
|
+
return _redirect_to_login(request=request, error="sign_in_not_allowed")
|
|
188
265
|
access_token, refresh_token = await create_access_and_refresh_tokens(
|
|
189
266
|
user=user,
|
|
190
267
|
token_store=token_store,
|
|
191
268
|
access_token_expiry=access_token_expiry,
|
|
192
269
|
refresh_token_expiry=refresh_token_expiry,
|
|
193
270
|
)
|
|
194
|
-
redirect_path =
|
|
271
|
+
redirect_path = prepend_root_path(request.scope, return_url or "/")
|
|
195
272
|
response = RedirectResponse(
|
|
196
273
|
url=redirect_path,
|
|
197
|
-
status_code=
|
|
274
|
+
status_code=302,
|
|
198
275
|
)
|
|
199
276
|
response = set_access_token_cookie(
|
|
200
277
|
response=response, access_token=access_token, max_age=access_token_expiry
|
|
@@ -204,6 +281,7 @@ async def create_tokens(
|
|
|
204
281
|
)
|
|
205
282
|
response = delete_oauth2_state_cookie(response)
|
|
206
283
|
response = delete_oauth2_nonce_cookie(response)
|
|
284
|
+
response = delete_oauth2_code_verifier_cookie(response)
|
|
207
285
|
return response
|
|
208
286
|
|
|
209
287
|
|
|
@@ -213,6 +291,7 @@ class UserInfo:
|
|
|
213
291
|
email: str
|
|
214
292
|
username: Optional[str] = None
|
|
215
293
|
profile_picture_url: Optional[str] = None
|
|
294
|
+
claims: dict[str, Any] = field(default_factory=dict)
|
|
216
295
|
|
|
217
296
|
def __post_init__(self) -> None:
|
|
218
297
|
if not (idp_user_id := (self.idp_user_id or "").strip()):
|
|
@@ -227,9 +306,64 @@ class UserInfo:
|
|
|
227
306
|
object.__setattr__(self, "profile_picture_url", profile_picture_url)
|
|
228
307
|
|
|
229
308
|
|
|
309
|
+
async def _fetch_and_merge_userinfo_claims(
|
|
310
|
+
oauth2_client: OAuth2Client,
|
|
311
|
+
token_data: dict[str, Any],
|
|
312
|
+
id_token_claims: dict[str, Any],
|
|
313
|
+
) -> dict[str, Any]:
|
|
314
|
+
"""
|
|
315
|
+
Fetch claims from UserInfo endpoint and merge with ID token claims.
|
|
316
|
+
|
|
317
|
+
Why this is necessary (OIDC Core §5.4, §5.5):
|
|
318
|
+
When claims are requested via scopes (e.g., "profile", "email"), OIDC Core §5.4
|
|
319
|
+
specifies which claims are "REQUESTED" but does not mandate WHERE they must be
|
|
320
|
+
returned. Similarly, §5.5 allows requesting specific claims via the "claims"
|
|
321
|
+
parameter, but providers have discretion on whether to return them in the ID token
|
|
322
|
+
or UserInfo response. In practice, providers often return certain claims (especially
|
|
323
|
+
large ones like groups) only via UserInfo to keep ID tokens compact.
|
|
324
|
+
|
|
325
|
+
The UserInfo endpoint (OIDC Core §5.3) provides additional claims beyond what's
|
|
326
|
+
in the ID token, such as group memberships or custom attributes. This function:
|
|
327
|
+
|
|
328
|
+
1. Calls the UserInfo endpoint using the access token (OIDC Core §5.3.1, RFC 6750)
|
|
329
|
+
2. Merges userinfo claims with ID token claims
|
|
330
|
+
3. ID token claims override userinfo claims when both contain the same claim
|
|
331
|
+
|
|
332
|
+
Why ID token takes precedence (OIDC Core §5.3.2):
|
|
333
|
+
- ID tokens are signed JWTs that have been cryptographically verified
|
|
334
|
+
- UserInfo responses may be unsigned
|
|
335
|
+
- Signed claims are the authoritative source when present in both
|
|
336
|
+
|
|
337
|
+
Fallback behavior:
|
|
338
|
+
If the UserInfo request fails, returns only ID token claims. The returned claims
|
|
339
|
+
may be incomplete (missing email or groups), but subsequent validation will catch this:
|
|
340
|
+
- Missing email: _parse_user_info() raises MissingEmailScope
|
|
341
|
+
- Missing groups: validate_access() raises PermissionError if access is denied
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
oauth2_client: The OAuth2 client to use for fetching userinfo
|
|
345
|
+
token_data: Token response containing the access token (RFC 6749 §5.1)
|
|
346
|
+
id_token_claims: Claims from the verified ID token (OIDC Core §3.1.3.3)
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
Merged claims dictionary with ID token claims overriding userinfo claims
|
|
350
|
+
"""
|
|
351
|
+
try:
|
|
352
|
+
# OIDC Core §5.3.1: UserInfo request authenticated with access token
|
|
353
|
+
userinfo_claims = await oauth2_client.userinfo(token=token_data)
|
|
354
|
+
# ID token claims take precedence (signed and verified)
|
|
355
|
+
return {**userinfo_claims, **id_token_claims}
|
|
356
|
+
except Exception:
|
|
357
|
+
# Fallback: ID token has essential claims for authentication
|
|
358
|
+
return id_token_claims
|
|
359
|
+
|
|
360
|
+
|
|
230
361
|
def _validate_token_data(token_data: dict[str, Any]) -> None:
|
|
231
362
|
"""
|
|
232
363
|
Performs basic validations on the token data returned by the IDP.
|
|
364
|
+
|
|
365
|
+
RFC 6749 §5.1: Successful response must include access_token and token_type.
|
|
366
|
+
RFC 6750 §1.1: Bearer token type for HTTP authentication.
|
|
233
367
|
"""
|
|
234
368
|
assert isinstance(token_data.get("access_token"), str)
|
|
235
369
|
assert isinstance(token_type := token_data.get("token_type"), str)
|
|
@@ -239,25 +373,107 @@ def _validate_token_data(token_data: dict[str, Any]) -> None:
|
|
|
239
373
|
def _parse_user_info(user_info: dict[str, Any]) -> UserInfo:
|
|
240
374
|
"""
|
|
241
375
|
Parses user info from the IDP's ID token.
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
376
|
+
|
|
377
|
+
Validates required OIDC claims and extracts user information according to the
|
|
378
|
+
OpenID Connect Core 1.0 specification.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
user_info: Claims from the ID token (validated JWT payload)
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
UserInfo object with validated user data
|
|
385
|
+
|
|
386
|
+
Raises:
|
|
387
|
+
InvalidUserInfo: If required claims are missing or malformed
|
|
388
|
+
MissingEmailScope: If email claim is missing or invalid
|
|
389
|
+
|
|
390
|
+
ID Token Required Claims (OIDC Core §2, §3.1.3.3):
|
|
391
|
+
- iss (issuer): Identifier for the OpenID Provider
|
|
392
|
+
- sub (subject): Unique identifier for the End-User at the Issuer
|
|
393
|
+
- aud (audience): Client ID this ID token is intended for
|
|
394
|
+
- exp (expiration): Expiration time
|
|
395
|
+
- iat (issued at): Time the JWT was issued
|
|
396
|
+
- nonce (if sent in auth request): Value sent in the Authentication Request
|
|
397
|
+
|
|
398
|
+
Application-Required Claims:
|
|
399
|
+
- email: Required by this application for user identification
|
|
400
|
+
|
|
401
|
+
Optional Standard Claims (OIDC Core §5.1):
|
|
402
|
+
- name: Full name
|
|
403
|
+
- picture: Profile picture URL
|
|
404
|
+
- Other profile, email, address, and phone claims
|
|
405
|
+
|
|
406
|
+
Note: While iss, sub, aud, exp, iat are REQUIRED in all ID tokens per spec,
|
|
407
|
+
other claims like email, name, groups are optional and may appear in the ID token,
|
|
408
|
+
UserInfo response, or both depending on what was requested and provider implementation.
|
|
409
|
+
"""
|
|
410
|
+
# Validate 'sub' claim (OIDC required, MUST be a string per spec)
|
|
411
|
+
subject = user_info.get("sub")
|
|
412
|
+
if subject is None:
|
|
413
|
+
raise InvalidUserInfo(
|
|
414
|
+
"Missing required 'sub' claim in ID token. "
|
|
415
|
+
"Please check your OIDC provider configuration."
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# OIDC spec: sub MUST be a string, but some IDPs send integers
|
|
419
|
+
# Convert to string for compatibility
|
|
420
|
+
if isinstance(subject, (str, int)):
|
|
421
|
+
idp_user_id = str(subject).strip()
|
|
422
|
+
else:
|
|
423
|
+
raise InvalidUserInfo(
|
|
424
|
+
f"Invalid 'sub' claim type: {type(subject).__name__}. Expected string or integer."
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
if not idp_user_id:
|
|
428
|
+
raise InvalidUserInfo("The 'sub' claim cannot be empty.")
|
|
429
|
+
|
|
430
|
+
# Validate 'email' claim (application requirement)
|
|
245
431
|
email = user_info.get("email")
|
|
246
|
-
if not isinstance(email, str):
|
|
432
|
+
if not isinstance(email, str) or not email.strip():
|
|
247
433
|
raise MissingEmailScope(
|
|
248
|
-
"
|
|
434
|
+
"Missing or invalid 'email' claim. "
|
|
435
|
+
"Please ensure your OIDC provider is configured to include the 'email' scope."
|
|
249
436
|
)
|
|
437
|
+
email = email.strip()
|
|
438
|
+
|
|
439
|
+
# Optional: 'name' claim (Full name)
|
|
440
|
+
username = user_info.get("name")
|
|
441
|
+
if username is not None:
|
|
442
|
+
if not isinstance(username, str):
|
|
443
|
+
# Some IDPs might send unexpected types; ignore gracefully
|
|
444
|
+
username = None
|
|
445
|
+
else:
|
|
446
|
+
username = username.strip() or None
|
|
447
|
+
|
|
448
|
+
# Optional: 'picture' claim (Profile picture URL)
|
|
449
|
+
profile_picture_url = user_info.get("picture")
|
|
450
|
+
if profile_picture_url is not None:
|
|
451
|
+
if not isinstance(profile_picture_url, str):
|
|
452
|
+
# Some IDPs might send unexpected types; ignore gracefully
|
|
453
|
+
profile_picture_url = None
|
|
454
|
+
else:
|
|
455
|
+
profile_picture_url = profile_picture_url.strip() or None
|
|
456
|
+
|
|
457
|
+
# Keep only non-empty claim values for downstream processing
|
|
458
|
+
def _has_value(v: Any) -> bool:
|
|
459
|
+
"""Check if a claim value is considered non-empty."""
|
|
460
|
+
if v is None:
|
|
461
|
+
return False
|
|
462
|
+
if isinstance(v, str):
|
|
463
|
+
return bool(v.strip())
|
|
464
|
+
if isinstance(v, (list, dict, set, tuple)):
|
|
465
|
+
return len(v) > 0
|
|
466
|
+
# Include all other types (numbers, booleans, etc.)
|
|
467
|
+
return True
|
|
468
|
+
|
|
469
|
+
filtered_claims = {k: v for k, v in user_info.items() if _has_value(v)}
|
|
250
470
|
|
|
251
|
-
assert isinstance(username := user_info.get("name"), str) or username is None
|
|
252
|
-
assert (
|
|
253
|
-
isinstance(profile_picture_url := user_info.get("picture"), str)
|
|
254
|
-
or profile_picture_url is None
|
|
255
|
-
)
|
|
256
471
|
return UserInfo(
|
|
257
472
|
idp_user_id=idp_user_id,
|
|
258
473
|
email=email,
|
|
259
474
|
username=username,
|
|
260
475
|
profile_picture_url=profile_picture_url,
|
|
476
|
+
claims=filtered_claims,
|
|
261
477
|
)
|
|
262
478
|
|
|
263
479
|
|
|
@@ -268,6 +484,7 @@ async def _process_oauth2_user(
|
|
|
268
484
|
oauth2_client_id: str,
|
|
269
485
|
user_info: UserInfo,
|
|
270
486
|
allow_sign_up: bool,
|
|
487
|
+
role_name: Optional[AssignableUserRoleName],
|
|
271
488
|
) -> models.User:
|
|
272
489
|
"""
|
|
273
490
|
Processes an OAuth2 user, either signing in an existing user or creating/updating one.
|
|
@@ -276,11 +493,12 @@ async def _process_oauth2_user(
|
|
|
276
493
|
1. When sign-up is not allowed (allow_sign_up=False):
|
|
277
494
|
- Checks if the user exists and can sign in with the given OAuth2 credentials
|
|
278
495
|
- Updates placeholder OAuth2 credentials if needed (e.g., temporary IDs)
|
|
496
|
+
- Updates the user's role if role_name is provided (role mapping configured)
|
|
279
497
|
- If the user doesn't exist or has a password set, raises SignInNotAllowed
|
|
280
498
|
2. When sign-up is allowed (allow_sign_up=True):
|
|
281
499
|
- Finds the user by OAuth2 credentials (client_id and user_id)
|
|
282
|
-
- Creates a new user if one doesn't exist, with
|
|
283
|
-
- Updates the user's email if
|
|
500
|
+
- Creates a new user if one doesn't exist, with the provided role (or VIEWER if None)
|
|
501
|
+
- Updates the user's email and role (if role_name provided) if they have changed
|
|
284
502
|
- Handles username conflicts by adding a random suffix if needed
|
|
285
503
|
|
|
286
504
|
The allow_sign_up parameter is typically controlled by the PHOENIX_OAUTH2_{IDP_NAME}_ALLOW_SIGN_UP
|
|
@@ -291,6 +509,8 @@ async def _process_oauth2_user(
|
|
|
291
509
|
oauth2_client_id: The ID of the OAuth2 client
|
|
292
510
|
user_info: User information from the OAuth2 provider
|
|
293
511
|
allow_sign_up: Whether to allow creating new users
|
|
512
|
+
role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to preserve
|
|
513
|
+
existing user roles (backward compatibility when role mapping not configured)
|
|
294
514
|
|
|
295
515
|
Returns:
|
|
296
516
|
The user object
|
|
@@ -300,24 +520,27 @@ async def _process_oauth2_user(
|
|
|
300
520
|
EmailAlreadyInUse: When the email is already in use by another account
|
|
301
521
|
""" # noqa: E501
|
|
302
522
|
if not allow_sign_up:
|
|
303
|
-
return await
|
|
523
|
+
return await _sign_in_existing_oauth2_user(
|
|
304
524
|
session,
|
|
305
525
|
oauth2_client_id=oauth2_client_id,
|
|
306
526
|
user_info=user_info,
|
|
527
|
+
role_name=role_name,
|
|
307
528
|
)
|
|
308
529
|
return await _create_or_update_user(
|
|
309
530
|
session,
|
|
310
531
|
oauth2_client_id=oauth2_client_id,
|
|
311
532
|
user_info=user_info,
|
|
533
|
+
role_name=role_name,
|
|
312
534
|
)
|
|
313
535
|
|
|
314
536
|
|
|
315
|
-
async def
|
|
537
|
+
async def _sign_in_existing_oauth2_user(
|
|
316
538
|
session: AsyncSession,
|
|
317
539
|
/,
|
|
318
540
|
*,
|
|
319
541
|
oauth2_client_id: str,
|
|
320
542
|
user_info: UserInfo,
|
|
543
|
+
role_name: Optional[AssignableUserRoleName],
|
|
321
544
|
) -> models.User:
|
|
322
545
|
"""Signs in an existing user with OAuth2 credentials.
|
|
323
546
|
|
|
@@ -339,6 +562,7 @@ async def _get_existing_oauth2_user(
|
|
|
339
562
|
Profile Updates:
|
|
340
563
|
- Email: Updated if different from IDP info
|
|
341
564
|
- Profile Picture: Updated if provided in user_info
|
|
565
|
+
- Role: Updated ONLY if role_name is provided (role mapping configured)
|
|
342
566
|
- Username: Never updated (remains unchanged)
|
|
343
567
|
- OAuth2 Credentials: Updated based on the three cases above
|
|
344
568
|
|
|
@@ -346,6 +570,8 @@ async def _get_existing_oauth2_user(
|
|
|
346
570
|
session: The database session
|
|
347
571
|
oauth2_client_id: The ID of the OAuth2 client
|
|
348
572
|
user_info: User information from the OAuth2 provider
|
|
573
|
+
role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to preserve
|
|
574
|
+
existing role (backward compatibility when role mapping not configured)
|
|
349
575
|
|
|
350
576
|
Returns:
|
|
351
577
|
The signed-in user
|
|
@@ -374,6 +600,11 @@ async def _get_existing_oauth2_user(
|
|
|
374
600
|
user = await session.scalar(stmt.where(func.lower(models.User.email) == email))
|
|
375
601
|
if user is None or not isinstance(user, models.OAuth2User):
|
|
376
602
|
raise SignInNotAllowed("Sign in is not allowed.")
|
|
603
|
+
# Security: Prevent OIDC from hijacking LDAP users
|
|
604
|
+
# LDAP users are identified by the special Unicode marker in oauth2_client_id
|
|
605
|
+
# Use generic error message to avoid revealing auth method (username enumeration)
|
|
606
|
+
if is_ldap_user(user.oauth2_client_id):
|
|
607
|
+
raise SignInNotAllowed("Sign in is not allowed.")
|
|
377
608
|
# Case 1: Different OAuth2 client - update both client and user IDs
|
|
378
609
|
if oauth2_client_id != user.oauth2_client_id:
|
|
379
610
|
user.oauth2_client_id = oauth2_client_id
|
|
@@ -386,6 +617,16 @@ async def _get_existing_oauth2_user(
|
|
|
386
617
|
raise SignInNotAllowed("Sign in is not allowed.")
|
|
387
618
|
if profile_picture_url != user.profile_picture_url:
|
|
388
619
|
user.profile_picture_url = profile_picture_url
|
|
620
|
+
|
|
621
|
+
# Update role ONLY if role mapping is configured (role_name is not None)
|
|
622
|
+
# This preserves existing user roles when role mapping is not configured
|
|
623
|
+
if role_name is not None and user.role.name != role_name:
|
|
624
|
+
role = await session.scalar(
|
|
625
|
+
select(models.UserRole).where(models.UserRole.name == role_name)
|
|
626
|
+
)
|
|
627
|
+
if role is not None:
|
|
628
|
+
user.role = role
|
|
629
|
+
|
|
389
630
|
if user in session.dirty:
|
|
390
631
|
await session.flush()
|
|
391
632
|
return user
|
|
@@ -397,6 +638,7 @@ async def _create_or_update_user(
|
|
|
397
638
|
*,
|
|
398
639
|
oauth2_client_id: str,
|
|
399
640
|
user_info: UserInfo,
|
|
641
|
+
role_name: Optional[AssignableUserRoleName],
|
|
400
642
|
) -> models.User:
|
|
401
643
|
"""
|
|
402
644
|
Creates a new user or updates an existing one with OAuth2 credentials.
|
|
@@ -405,6 +647,8 @@ async def _create_or_update_user(
|
|
|
405
647
|
session: The database session
|
|
406
648
|
oauth2_client_id: The ID of the OAuth2 client
|
|
407
649
|
user_info: User information from the OAuth2 provider
|
|
650
|
+
role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to use
|
|
651
|
+
VIEWER for new users and preserve existing users' roles (backward compatibility)
|
|
408
652
|
|
|
409
653
|
Returns:
|
|
410
654
|
The created or updated user
|
|
@@ -418,9 +662,37 @@ async def _create_or_update_user(
|
|
|
418
662
|
idp_user_id=user_info.idp_user_id,
|
|
419
663
|
)
|
|
420
664
|
if user is None:
|
|
421
|
-
user
|
|
422
|
-
|
|
423
|
-
|
|
665
|
+
# New user: use provided role_name, or default to VIEWER if role mapping not configured
|
|
666
|
+
user = await _create_user(
|
|
667
|
+
session,
|
|
668
|
+
oauth2_client_id=oauth2_client_id,
|
|
669
|
+
user_info=user_info,
|
|
670
|
+
role_name=role_name or "VIEWER", # Default for new users
|
|
671
|
+
)
|
|
672
|
+
else:
|
|
673
|
+
# Existing user: update email, profile picture, and/or role if changed
|
|
674
|
+
if user.email != user_info.email:
|
|
675
|
+
user.email = user_info.email
|
|
676
|
+
|
|
677
|
+
# Update profile picture if changed
|
|
678
|
+
if user.profile_picture_url != user_info.profile_picture_url:
|
|
679
|
+
user.profile_picture_url = user_info.profile_picture_url
|
|
680
|
+
|
|
681
|
+
# Update role ONLY if role mapping is configured (role_name is not None)
|
|
682
|
+
# This preserves existing user roles when role mapping is not configured
|
|
683
|
+
if role_name is not None and user.role.name != role_name:
|
|
684
|
+
role = await session.scalar(
|
|
685
|
+
select(models.UserRole).where(models.UserRole.name == role_name)
|
|
686
|
+
)
|
|
687
|
+
if role is not None:
|
|
688
|
+
user.role = role
|
|
689
|
+
|
|
690
|
+
# Flush to execute the UPDATE and catch any email conflicts
|
|
691
|
+
if user in session.dirty:
|
|
692
|
+
try:
|
|
693
|
+
await session.flush()
|
|
694
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
695
|
+
raise EmailAlreadyInUse(f"An account for {user_info.email} is already in use.")
|
|
424
696
|
return user
|
|
425
697
|
|
|
426
698
|
|
|
@@ -454,9 +726,19 @@ async def _create_user(
|
|
|
454
726
|
*,
|
|
455
727
|
oauth2_client_id: str,
|
|
456
728
|
user_info: UserInfo,
|
|
729
|
+
role_name: AssignableUserRoleName,
|
|
457
730
|
) -> models.User:
|
|
458
731
|
"""
|
|
459
732
|
Creates a new user with the user info from the IDP.
|
|
733
|
+
|
|
734
|
+
Args:
|
|
735
|
+
session: The database session
|
|
736
|
+
oauth2_client_id: The ID of the OAuth2 client
|
|
737
|
+
user_info: User information from the OAuth2 provider
|
|
738
|
+
role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER)
|
|
739
|
+
|
|
740
|
+
Returns:
|
|
741
|
+
The created user
|
|
460
742
|
"""
|
|
461
743
|
email_exists, username_exists = await _email_and_username_exist(
|
|
462
744
|
session,
|
|
@@ -465,14 +747,12 @@ async def _create_user(
|
|
|
465
747
|
)
|
|
466
748
|
if email_exists:
|
|
467
749
|
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
|
|
468
|
-
|
|
469
|
-
select(models.UserRole.id).where(models.UserRole.name == "MEMBER").scalar_subquery()
|
|
470
|
-
)
|
|
750
|
+
role_id = select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery()
|
|
471
751
|
user_id = await session.scalar(
|
|
472
752
|
insert(models.User)
|
|
473
753
|
.returning(models.User.id)
|
|
474
754
|
.values(
|
|
475
|
-
user_role_id=
|
|
755
|
+
user_role_id=role_id,
|
|
476
756
|
oauth2_client_id=oauth2_client_id,
|
|
477
757
|
oauth2_user_id=user_info.idp_user_id,
|
|
478
758
|
username=_with_random_suffix(username) if username and username_exists else username,
|
|
@@ -490,26 +770,6 @@ async def _create_user(
|
|
|
490
770
|
return user
|
|
491
771
|
|
|
492
772
|
|
|
493
|
-
async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: str) -> models.User:
|
|
494
|
-
"""
|
|
495
|
-
Updates an existing user's email.
|
|
496
|
-
"""
|
|
497
|
-
try:
|
|
498
|
-
await session.execute(
|
|
499
|
-
update(models.User)
|
|
500
|
-
.where(models.User.id == user_id)
|
|
501
|
-
.values(email=email)
|
|
502
|
-
.options(joinedload(models.User.role))
|
|
503
|
-
)
|
|
504
|
-
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
505
|
-
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
|
|
506
|
-
user = await session.scalar(
|
|
507
|
-
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
|
|
508
|
-
) # query user again for joined load
|
|
509
|
-
assert isinstance(user, models.User)
|
|
510
|
-
return user
|
|
511
|
-
|
|
512
|
-
|
|
513
773
|
async def _email_and_username_exist(
|
|
514
774
|
session: AsyncSession, /, *, email: str, username: Optional[str]
|
|
515
775
|
) -> tuple[bool, bool]:
|
|
@@ -559,49 +819,40 @@ class MissingEmailScope(Exception):
|
|
|
559
819
|
pass
|
|
560
820
|
|
|
561
821
|
|
|
562
|
-
|
|
822
|
+
class InvalidUserInfo(Exception):
|
|
823
|
+
"""
|
|
824
|
+
Raised when the OIDC user info is malformed or missing required claims.
|
|
825
|
+
"""
|
|
826
|
+
|
|
827
|
+
pass
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def _redirect_to_login(*, request: Request, error: AuthErrorCode) -> RedirectResponse:
|
|
563
831
|
"""
|
|
564
|
-
Creates a RedirectResponse to the login page to display an error
|
|
832
|
+
Creates a RedirectResponse to the login page to display an error code.
|
|
833
|
+
The error code will be validated and mapped to a user-friendly message on the frontend.
|
|
565
834
|
"""
|
|
566
835
|
# TODO: this needs some cleanup
|
|
567
|
-
login_path =
|
|
568
|
-
request
|
|
836
|
+
login_path = prepend_root_path(
|
|
837
|
+
request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
|
|
569
838
|
)
|
|
570
839
|
url = URL(login_path).include_query_params(error=error)
|
|
571
840
|
response = RedirectResponse(url=url)
|
|
572
841
|
response = delete_oauth2_state_cookie(response)
|
|
573
842
|
response = delete_oauth2_nonce_cookie(response)
|
|
843
|
+
response = delete_oauth2_code_verifier_cookie(response)
|
|
574
844
|
return response
|
|
575
845
|
|
|
576
846
|
|
|
577
|
-
def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
|
|
578
|
-
"""
|
|
579
|
-
If a root path is configured, prepends it to the input path.
|
|
580
|
-
"""
|
|
581
|
-
if not path.startswith("/"):
|
|
582
|
-
raise ValueError("path must start with a forward slash")
|
|
583
|
-
root_path = _get_root_path(request=request)
|
|
584
|
-
if root_path.endswith("/"):
|
|
585
|
-
root_path = root_path.rstrip("/")
|
|
586
|
-
return root_path + path
|
|
587
|
-
|
|
588
|
-
|
|
589
847
|
def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
|
|
590
848
|
"""
|
|
591
849
|
If a root path is configured, appends it to the input base url.
|
|
592
850
|
"""
|
|
593
|
-
if not (root_path :=
|
|
851
|
+
if not (root_path := get_root_path(request.scope)):
|
|
594
852
|
return base_url
|
|
595
853
|
return str(URLPath(root_path).make_absolute_url(base_url=base_url))
|
|
596
854
|
|
|
597
855
|
|
|
598
|
-
def _get_root_path(*, request: Request) -> str:
|
|
599
|
-
"""
|
|
600
|
-
Gets the root path from the request.
|
|
601
|
-
"""
|
|
602
|
-
return str(request.scope.get("root_path", ""))
|
|
603
|
-
|
|
604
|
-
|
|
605
856
|
def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
|
|
606
857
|
"""
|
|
607
858
|
Gets the endpoint for create tokens route.
|
|
@@ -679,7 +930,4 @@ def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2State
|
|
|
679
930
|
|
|
680
931
|
|
|
681
932
|
_JWT_ALGORITHM = "HS256"
|
|
682
|
-
_INVALID_OAUTH2_STATE_MESSAGE = (
|
|
683
|
-
"Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}."
|
|
684
|
-
)
|
|
685
933
|
_RELATIVE_URL_PATTERN = re.compile(r"^/($|\w)")
|