arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import re
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
+
from dataclasses import dataclass, field
|
|
3
4
|
from datetime import timedelta
|
|
4
5
|
from random import randrange
|
|
5
6
|
from typing import Any, Optional, TypedDict
|
|
@@ -10,7 +11,7 @@ from authlib.integrations.starlette_client import OAuthError
|
|
|
10
11
|
from authlib.jose import jwt
|
|
11
12
|
from authlib.jose.errors import JoseError
|
|
12
13
|
from fastapi import APIRouter, Cookie, Depends, Path, Query, Request
|
|
13
|
-
from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select
|
|
14
|
+
from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select
|
|
14
15
|
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
15
16
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
16
17
|
from sqlalchemy.orm import joinedload
|
|
@@ -18,27 +19,32 @@ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore
|
|
|
18
19
|
from starlette.datastructures import URL, Secret, URLPath
|
|
19
20
|
from starlette.responses import RedirectResponse
|
|
20
21
|
from starlette.routing import Router
|
|
21
|
-
from starlette.status import HTTP_302_FOUND
|
|
22
22
|
from typing_extensions import Annotated, NotRequired, TypeGuard
|
|
23
23
|
|
|
24
24
|
from phoenix.auth import (
|
|
25
25
|
DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES,
|
|
26
|
+
PHOENIX_OAUTH2_CODE_VERIFIER_COOKIE_NAME,
|
|
26
27
|
PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
|
|
27
28
|
PHOENIX_OAUTH2_STATE_COOKIE_NAME,
|
|
29
|
+
delete_oauth2_code_verifier_cookie,
|
|
28
30
|
delete_oauth2_nonce_cookie,
|
|
29
31
|
delete_oauth2_state_cookie,
|
|
32
|
+
sanitize_email,
|
|
30
33
|
set_access_token_cookie,
|
|
34
|
+
set_oauth2_code_verifier_cookie,
|
|
31
35
|
set_oauth2_nonce_cookie,
|
|
32
36
|
set_oauth2_state_cookie,
|
|
33
37
|
set_refresh_token_cookie,
|
|
34
38
|
)
|
|
35
39
|
from phoenix.config import (
|
|
40
|
+
AssignableUserRoleName,
|
|
36
41
|
get_env_disable_basic_auth,
|
|
37
42
|
get_env_disable_rate_limit,
|
|
38
43
|
)
|
|
39
44
|
from phoenix.db import models
|
|
40
|
-
from phoenix.
|
|
45
|
+
from phoenix.server.api.auth_messages import AuthErrorCode
|
|
41
46
|
from phoenix.server.bearer_auth import create_access_and_refresh_tokens
|
|
47
|
+
from phoenix.server.ldap import is_ldap_user
|
|
42
48
|
from phoenix.server.oauth2 import OAuth2Client
|
|
43
49
|
from phoenix.server.rate_limiters import (
|
|
44
50
|
ServerRateLimiter,
|
|
@@ -46,9 +52,12 @@ from phoenix.server.rate_limiters import (
|
|
|
46
52
|
fastapi_route_rate_limiter,
|
|
47
53
|
)
|
|
48
54
|
from phoenix.server.types import TokenStore
|
|
55
|
+
from phoenix.server.utils import get_root_path, prepend_root_path
|
|
49
56
|
|
|
50
57
|
_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
|
|
51
58
|
|
|
59
|
+
logger = logging.getLogger(__name__)
|
|
60
|
+
|
|
52
61
|
login_rate_limiter = fastapi_ip_rate_limiter(
|
|
53
62
|
ServerRateLimiter(
|
|
54
63
|
per_second_rate_limit=0.2,
|
|
@@ -87,11 +96,12 @@ async def login(
|
|
|
87
96
|
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
|
|
88
97
|
return_url: Optional[str] = Query(default=None, alias="returnUrl"),
|
|
89
98
|
) -> RedirectResponse:
|
|
99
|
+
# Security Note: Query parameters should be treated as untrusted user input. Never display
|
|
100
|
+
# these values directly to users as they could be manipulated for XSS, phishing, or social
|
|
101
|
+
# engineering attacks.
|
|
102
|
+
if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
|
|
103
|
+
return _redirect_to_login(request=request, error="unknown_idp")
|
|
90
104
|
secret = request.app.state.get_secret()
|
|
91
|
-
if not isinstance(
|
|
92
|
-
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
93
|
-
):
|
|
94
|
-
return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
|
|
95
105
|
if (referer := request.headers.get("referer")) is not None:
|
|
96
106
|
# if the referer header is present, use it as the origin URL
|
|
97
107
|
parsed_url = urlparse(referer)
|
|
@@ -112,7 +122,7 @@ async def login(
|
|
|
112
122
|
assert isinstance(authorization_url := authorization_url_data.get("url"), str)
|
|
113
123
|
assert isinstance(state := authorization_url_data.get("state"), str)
|
|
114
124
|
assert isinstance(nonce := authorization_url_data.get("nonce"), str)
|
|
115
|
-
response = RedirectResponse(url=authorization_url, status_code=
|
|
125
|
+
response = RedirectResponse(url=authorization_url, status_code=302)
|
|
116
126
|
response = set_oauth2_state_cookie(
|
|
117
127
|
response=response,
|
|
118
128
|
state=state,
|
|
@@ -123,6 +133,12 @@ async def login(
|
|
|
123
133
|
nonce=nonce,
|
|
124
134
|
max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
|
|
125
135
|
)
|
|
136
|
+
if code_verifier := authorization_url_data.get("code_verifier"):
|
|
137
|
+
response = set_oauth2_code_verifier_cookie(
|
|
138
|
+
response=response,
|
|
139
|
+
code_verifier=code_verifier,
|
|
140
|
+
max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
|
|
141
|
+
)
|
|
126
142
|
return response
|
|
127
143
|
|
|
128
144
|
|
|
@@ -130,47 +146,107 @@ async def login(
|
|
|
130
146
|
async def create_tokens(
|
|
131
147
|
request: Request,
|
|
132
148
|
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
|
|
133
|
-
state: str = Query(),
|
|
134
|
-
authorization_code: str = Query(alias="code"),
|
|
149
|
+
state: str = Query(), # RFC 6749 §4.1.1: CSRF protection via state parameter
|
|
150
|
+
authorization_code: Optional[str] = Query(default=None, alias="code"), # RFC 6749 §4.1.2
|
|
151
|
+
error: Optional[str] = Query(default=None), # RFC 6749 §4.1.2.1: Error response
|
|
152
|
+
error_description: Optional[str] = Query(default=None),
|
|
135
153
|
stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME),
|
|
136
|
-
stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME),
|
|
154
|
+
stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME), # OIDC Core §3.1.2.1
|
|
155
|
+
code_verifier: Optional[str] = Cookie(
|
|
156
|
+
default=None, alias=PHOENIX_OAUTH2_CODE_VERIFIER_COOKIE_NAME
|
|
157
|
+
), # RFC 7636 §4.1
|
|
137
158
|
) -> RedirectResponse:
|
|
159
|
+
# Security Note: Query parameters should be treated as untrusted user input. Never display
|
|
160
|
+
# these values directly to users as they could be manipulated for XSS, phishing, or social
|
|
161
|
+
# engineering attacks.
|
|
162
|
+
if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
|
|
163
|
+
return _redirect_to_login(request=request, error="unknown_idp")
|
|
164
|
+
if error or error_description:
|
|
165
|
+
logger.error(
|
|
166
|
+
"OAuth2 authentication failed for IDP %s: error=%s, description=%s",
|
|
167
|
+
idp_name,
|
|
168
|
+
error,
|
|
169
|
+
error_description,
|
|
170
|
+
)
|
|
171
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
172
|
+
if authorization_code is None:
|
|
173
|
+
logger.error("OAuth2 callback missing authorization code for IDP %s", idp_name)
|
|
174
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
138
175
|
secret = request.app.state.get_secret()
|
|
176
|
+
# RFC 6749 §10.12: CSRF protection - validate state parameter
|
|
139
177
|
if state != stored_state:
|
|
140
|
-
return _redirect_to_login(request=request, error=
|
|
178
|
+
return _redirect_to_login(request=request, error="invalid_state")
|
|
141
179
|
try:
|
|
142
180
|
payload = _parse_state_payload(secret=secret, state=state)
|
|
143
181
|
except JoseError:
|
|
144
|
-
return _redirect_to_login(request=request, error=
|
|
182
|
+
return _redirect_to_login(request=request, error="invalid_state")
|
|
145
183
|
if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
|
|
146
184
|
unquote(return_url)
|
|
147
185
|
):
|
|
148
|
-
return _redirect_to_login(request=request, error="
|
|
186
|
+
return _redirect_to_login(request=request, error="unsafe_return_url")
|
|
149
187
|
assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
|
|
150
188
|
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
|
|
151
189
|
token_store: TokenStore = request.app.state.get_token_store()
|
|
152
|
-
if not isinstance(
|
|
153
|
-
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
154
|
-
):
|
|
155
|
-
return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
|
|
156
190
|
try:
|
|
157
|
-
|
|
191
|
+
# RFC 6749 §4.1.3: Token request - exchange authorization code for tokens
|
|
192
|
+
fetch_kwargs: dict[str, Any] = dict(
|
|
158
193
|
state=state,
|
|
159
194
|
code=authorization_code,
|
|
160
|
-
redirect_uri=_get_create_tokens_endpoint(
|
|
195
|
+
redirect_uri=_get_create_tokens_endpoint( # RFC 6749 §3.1.2
|
|
161
196
|
request=request, origin_url=payload["origin_url"], idp_name=idp_name
|
|
162
197
|
),
|
|
163
198
|
)
|
|
164
|
-
|
|
165
|
-
|
|
199
|
+
# PKCE validation: code_verifier is required when PKCE is enabled (RFC 7636 §4.5)
|
|
200
|
+
if oauth2_client.use_pkce:
|
|
201
|
+
if not code_verifier:
|
|
202
|
+
logger.error(
|
|
203
|
+
"PKCE enabled but code_verifier cookie missing for IDP %s. "
|
|
204
|
+
"This may indicate a cookie issue, CORS misconfiguration, or "
|
|
205
|
+
"browser compatibility problem.",
|
|
206
|
+
idp_name,
|
|
207
|
+
)
|
|
208
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
209
|
+
fetch_kwargs["code_verifier"] = code_verifier
|
|
210
|
+
token_data = await oauth2_client.fetch_access_token(**fetch_kwargs)
|
|
211
|
+
except OAuthError as e:
|
|
212
|
+
logger.error("OAuth2 error for IDP %s: %s", idp_name, e)
|
|
213
|
+
return _redirect_to_login(request=request, error="oauth_error")
|
|
166
214
|
_validate_token_data(token_data)
|
|
167
215
|
if "id_token" not in token_data:
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
216
|
+
logger.error("OAuth2 IDP %s does not appear to support OpenID Connect", idp_name)
|
|
217
|
+
return _redirect_to_login(request=request, error="no_oidc_support")
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
id_token_claims = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
|
|
221
|
+
except JoseError as e:
|
|
222
|
+
logger.error("ID token validation failed for IDP %s: %s", idp_name, e)
|
|
223
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
224
|
+
|
|
225
|
+
if oauth2_client.has_sufficient_claims(id_token_claims):
|
|
226
|
+
user_claims = id_token_claims
|
|
227
|
+
else:
|
|
228
|
+
user_claims = await _fetch_and_merge_userinfo_claims(
|
|
229
|
+
oauth2_client, token_data, id_token_claims
|
|
171
230
|
)
|
|
172
|
-
|
|
173
|
-
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
user_info = _parse_user_info(user_claims)
|
|
234
|
+
except (MissingEmailScope, InvalidUserInfo) as e:
|
|
235
|
+
logger.error("Error parsing user info for IDP %s: %s", idp_name, e)
|
|
236
|
+
return _redirect_to_login(request=request, error="missing_email_scope")
|
|
237
|
+
|
|
238
|
+
# Validate access and extract role from claims
|
|
239
|
+
# Both validate_access and extract_and_map_role may raise PermissionError
|
|
240
|
+
try:
|
|
241
|
+
oauth2_client.validate_access(user_info.claims)
|
|
242
|
+
# Extract and map role from claims
|
|
243
|
+
# Returns None if role mapping not configured (preserves existing user roles)
|
|
244
|
+
# Raises PermissionError if strict mode enabled and role validation fails
|
|
245
|
+
role_name = oauth2_client.extract_and_map_role(user_info.claims)
|
|
246
|
+
except PermissionError as e:
|
|
247
|
+
logger.error("Access validation failed for IDP %s: %s", idp_name, e)
|
|
248
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
249
|
+
|
|
174
250
|
try:
|
|
175
251
|
async with request.app.state.db() as session:
|
|
176
252
|
user = await _process_oauth2_user(
|
|
@@ -178,19 +254,24 @@ async def create_tokens(
|
|
|
178
254
|
oauth2_client_id=str(oauth2_client.client_id),
|
|
179
255
|
user_info=user_info,
|
|
180
256
|
allow_sign_up=oauth2_client.allow_sign_up,
|
|
257
|
+
role_name=role_name,
|
|
181
258
|
)
|
|
182
|
-
except
|
|
183
|
-
|
|
259
|
+
except EmailAlreadyInUse as e:
|
|
260
|
+
logger.error("Email already in use for IDP %s: %s", idp_name, e)
|
|
261
|
+
return _redirect_to_login(request=request, error="email_in_use")
|
|
262
|
+
except SignInNotAllowed as e:
|
|
263
|
+
logger.error("Sign in not allowed for IDP %s: %s", idp_name, e)
|
|
264
|
+
return _redirect_to_login(request=request, error="sign_in_not_allowed")
|
|
184
265
|
access_token, refresh_token = await create_access_and_refresh_tokens(
|
|
185
266
|
user=user,
|
|
186
267
|
token_store=token_store,
|
|
187
268
|
access_token_expiry=access_token_expiry,
|
|
188
269
|
refresh_token_expiry=refresh_token_expiry,
|
|
189
270
|
)
|
|
190
|
-
redirect_path =
|
|
271
|
+
redirect_path = prepend_root_path(request.scope, return_url or "/")
|
|
191
272
|
response = RedirectResponse(
|
|
192
273
|
url=redirect_path,
|
|
193
|
-
status_code=
|
|
274
|
+
status_code=302,
|
|
194
275
|
)
|
|
195
276
|
response = set_access_token_cookie(
|
|
196
277
|
response=response, access_token=access_token, max_age=access_token_expiry
|
|
@@ -200,6 +281,7 @@ async def create_tokens(
|
|
|
200
281
|
)
|
|
201
282
|
response = delete_oauth2_state_cookie(response)
|
|
202
283
|
response = delete_oauth2_nonce_cookie(response)
|
|
284
|
+
response = delete_oauth2_code_verifier_cookie(response)
|
|
203
285
|
return response
|
|
204
286
|
|
|
205
287
|
|
|
@@ -209,12 +291,13 @@ class UserInfo:
|
|
|
209
291
|
email: str
|
|
210
292
|
username: Optional[str] = None
|
|
211
293
|
profile_picture_url: Optional[str] = None
|
|
294
|
+
claims: dict[str, Any] = field(default_factory=dict)
|
|
212
295
|
|
|
213
296
|
def __post_init__(self) -> None:
|
|
214
297
|
if not (idp_user_id := (self.idp_user_id or "").strip()):
|
|
215
298
|
raise ValueError("idp_user_id cannot be empty")
|
|
216
299
|
object.__setattr__(self, "idp_user_id", idp_user_id)
|
|
217
|
-
if not (email := (self.email or "")
|
|
300
|
+
if not (email := sanitize_email(self.email or "")):
|
|
218
301
|
raise ValueError("email cannot be empty")
|
|
219
302
|
object.__setattr__(self, "email", email)
|
|
220
303
|
if username := (self.username or "").strip():
|
|
@@ -223,9 +306,64 @@ class UserInfo:
|
|
|
223
306
|
object.__setattr__(self, "profile_picture_url", profile_picture_url)
|
|
224
307
|
|
|
225
308
|
|
|
309
|
+
async def _fetch_and_merge_userinfo_claims(
|
|
310
|
+
oauth2_client: OAuth2Client,
|
|
311
|
+
token_data: dict[str, Any],
|
|
312
|
+
id_token_claims: dict[str, Any],
|
|
313
|
+
) -> dict[str, Any]:
|
|
314
|
+
"""
|
|
315
|
+
Fetch claims from UserInfo endpoint and merge with ID token claims.
|
|
316
|
+
|
|
317
|
+
Why this is necessary (OIDC Core §5.4, §5.5):
|
|
318
|
+
When claims are requested via scopes (e.g., "profile", "email"), OIDC Core §5.4
|
|
319
|
+
specifies which claims are "REQUESTED" but does not mandate WHERE they must be
|
|
320
|
+
returned. Similarly, §5.5 allows requesting specific claims via the "claims"
|
|
321
|
+
parameter, but providers have discretion on whether to return them in the ID token
|
|
322
|
+
or UserInfo response. In practice, providers often return certain claims (especially
|
|
323
|
+
large ones like groups) only via UserInfo to keep ID tokens compact.
|
|
324
|
+
|
|
325
|
+
The UserInfo endpoint (OIDC Core §5.3) provides additional claims beyond what's
|
|
326
|
+
in the ID token, such as group memberships or custom attributes. This function:
|
|
327
|
+
|
|
328
|
+
1. Calls the UserInfo endpoint using the access token (OIDC Core §5.3.1, RFC 6750)
|
|
329
|
+
2. Merges userinfo claims with ID token claims
|
|
330
|
+
3. ID token claims override userinfo claims when both contain the same claim
|
|
331
|
+
|
|
332
|
+
Why ID token takes precedence (OIDC Core §5.3.2):
|
|
333
|
+
- ID tokens are signed JWTs that have been cryptographically verified
|
|
334
|
+
- UserInfo responses may be unsigned
|
|
335
|
+
- Signed claims are the authoritative source when present in both
|
|
336
|
+
|
|
337
|
+
Fallback behavior:
|
|
338
|
+
If the UserInfo request fails, returns only ID token claims. The returned claims
|
|
339
|
+
may be incomplete (missing email or groups), but subsequent validation will catch this:
|
|
340
|
+
- Missing email: _parse_user_info() raises MissingEmailScope
|
|
341
|
+
- Missing groups: validate_access() raises PermissionError if access is denied
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
oauth2_client: The OAuth2 client to use for fetching userinfo
|
|
345
|
+
token_data: Token response containing the access token (RFC 6749 §5.1)
|
|
346
|
+
id_token_claims: Claims from the verified ID token (OIDC Core §3.1.3.3)
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
Merged claims dictionary with ID token claims overriding userinfo claims
|
|
350
|
+
"""
|
|
351
|
+
try:
|
|
352
|
+
# OIDC Core §5.3.1: UserInfo request authenticated with access token
|
|
353
|
+
userinfo_claims = await oauth2_client.userinfo(token=token_data)
|
|
354
|
+
# ID token claims take precedence (signed and verified)
|
|
355
|
+
return {**userinfo_claims, **id_token_claims}
|
|
356
|
+
except Exception:
|
|
357
|
+
# Fallback: ID token has essential claims for authentication
|
|
358
|
+
return id_token_claims
|
|
359
|
+
|
|
360
|
+
|
|
226
361
|
def _validate_token_data(token_data: dict[str, Any]) -> None:
|
|
227
362
|
"""
|
|
228
363
|
Performs basic validations on the token data returned by the IDP.
|
|
364
|
+
|
|
365
|
+
RFC 6749 §5.1: Successful response must include access_token and token_type.
|
|
366
|
+
RFC 6750 §1.1: Bearer token type for HTTP authentication.
|
|
229
367
|
"""
|
|
230
368
|
assert isinstance(token_data.get("access_token"), str)
|
|
231
369
|
assert isinstance(token_type := token_data.get("token_type"), str)
|
|
@@ -235,20 +373,107 @@ def _validate_token_data(token_data: dict[str, Any]) -> None:
|
|
|
235
373
|
def _parse_user_info(user_info: dict[str, Any]) -> UserInfo:
|
|
236
374
|
"""
|
|
237
375
|
Parses user info from the IDP's ID token.
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
376
|
+
|
|
377
|
+
Validates required OIDC claims and extracts user information according to the
|
|
378
|
+
OpenID Connect Core 1.0 specification.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
user_info: Claims from the ID token (validated JWT payload)
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
UserInfo object with validated user data
|
|
385
|
+
|
|
386
|
+
Raises:
|
|
387
|
+
InvalidUserInfo: If required claims are missing or malformed
|
|
388
|
+
MissingEmailScope: If email claim is missing or invalid
|
|
389
|
+
|
|
390
|
+
ID Token Required Claims (OIDC Core §2, §3.1.3.3):
|
|
391
|
+
- iss (issuer): Identifier for the OpenID Provider
|
|
392
|
+
- sub (subject): Unique identifier for the End-User at the Issuer
|
|
393
|
+
- aud (audience): Client ID this ID token is intended for
|
|
394
|
+
- exp (expiration): Expiration time
|
|
395
|
+
- iat (issued at): Time the JWT was issued
|
|
396
|
+
- nonce (if sent in auth request): Value sent in the Authentication Request
|
|
397
|
+
|
|
398
|
+
Application-Required Claims:
|
|
399
|
+
- email: Required by this application for user identification
|
|
400
|
+
|
|
401
|
+
Optional Standard Claims (OIDC Core §5.1):
|
|
402
|
+
- name: Full name
|
|
403
|
+
- picture: Profile picture URL
|
|
404
|
+
- Other profile, email, address, and phone claims
|
|
405
|
+
|
|
406
|
+
Note: While iss, sub, aud, exp, iat are REQUIRED in all ID tokens per spec,
|
|
407
|
+
other claims like email, name, groups are optional and may appear in the ID token,
|
|
408
|
+
UserInfo response, or both depending on what was requested and provider implementation.
|
|
409
|
+
"""
|
|
410
|
+
# Validate 'sub' claim (OIDC required, MUST be a string per spec)
|
|
411
|
+
subject = user_info.get("sub")
|
|
412
|
+
if subject is None:
|
|
413
|
+
raise InvalidUserInfo(
|
|
414
|
+
"Missing required 'sub' claim in ID token. "
|
|
415
|
+
"Please check your OIDC provider configuration."
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# OIDC spec: sub MUST be a string, but some IDPs send integers
|
|
419
|
+
# Convert to string for compatibility
|
|
420
|
+
if isinstance(subject, (str, int)):
|
|
421
|
+
idp_user_id = str(subject).strip()
|
|
422
|
+
else:
|
|
423
|
+
raise InvalidUserInfo(
|
|
424
|
+
f"Invalid 'sub' claim type: {type(subject).__name__}. Expected string or integer."
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
if not idp_user_id:
|
|
428
|
+
raise InvalidUserInfo("The 'sub' claim cannot be empty.")
|
|
429
|
+
|
|
430
|
+
# Validate 'email' claim (application requirement)
|
|
431
|
+
email = user_info.get("email")
|
|
432
|
+
if not isinstance(email, str) or not email.strip():
|
|
433
|
+
raise MissingEmailScope(
|
|
434
|
+
"Missing or invalid 'email' claim. "
|
|
435
|
+
"Please ensure your OIDC provider is configured to include the 'email' scope."
|
|
436
|
+
)
|
|
437
|
+
email = email.strip()
|
|
438
|
+
|
|
439
|
+
# Optional: 'name' claim (Full name)
|
|
440
|
+
username = user_info.get("name")
|
|
441
|
+
if username is not None:
|
|
442
|
+
if not isinstance(username, str):
|
|
443
|
+
# Some IDPs might send unexpected types; ignore gracefully
|
|
444
|
+
username = None
|
|
445
|
+
else:
|
|
446
|
+
username = username.strip() or None
|
|
447
|
+
|
|
448
|
+
# Optional: 'picture' claim (Profile picture URL)
|
|
449
|
+
profile_picture_url = user_info.get("picture")
|
|
450
|
+
if profile_picture_url is not None:
|
|
451
|
+
if not isinstance(profile_picture_url, str):
|
|
452
|
+
# Some IDPs might send unexpected types; ignore gracefully
|
|
453
|
+
profile_picture_url = None
|
|
454
|
+
else:
|
|
455
|
+
profile_picture_url = profile_picture_url.strip() or None
|
|
456
|
+
|
|
457
|
+
# Keep only non-empty claim values for downstream processing
|
|
458
|
+
def _has_value(v: Any) -> bool:
|
|
459
|
+
"""Check if a claim value is considered non-empty."""
|
|
460
|
+
if v is None:
|
|
461
|
+
return False
|
|
462
|
+
if isinstance(v, str):
|
|
463
|
+
return bool(v.strip())
|
|
464
|
+
if isinstance(v, (list, dict, set, tuple)):
|
|
465
|
+
return len(v) > 0
|
|
466
|
+
# Include all other types (numbers, booleans, etc.)
|
|
467
|
+
return True
|
|
468
|
+
|
|
469
|
+
filtered_claims = {k: v for k, v in user_info.items() if _has_value(v)}
|
|
470
|
+
|
|
247
471
|
return UserInfo(
|
|
248
472
|
idp_user_id=idp_user_id,
|
|
249
473
|
email=email,
|
|
250
474
|
username=username,
|
|
251
475
|
profile_picture_url=profile_picture_url,
|
|
476
|
+
claims=filtered_claims,
|
|
252
477
|
)
|
|
253
478
|
|
|
254
479
|
|
|
@@ -259,6 +484,7 @@ async def _process_oauth2_user(
|
|
|
259
484
|
oauth2_client_id: str,
|
|
260
485
|
user_info: UserInfo,
|
|
261
486
|
allow_sign_up: bool,
|
|
487
|
+
role_name: Optional[AssignableUserRoleName],
|
|
262
488
|
) -> models.User:
|
|
263
489
|
"""
|
|
264
490
|
Processes an OAuth2 user, either signing in an existing user or creating/updating one.
|
|
@@ -267,11 +493,12 @@ async def _process_oauth2_user(
|
|
|
267
493
|
1. When sign-up is not allowed (allow_sign_up=False):
|
|
268
494
|
- Checks if the user exists and can sign in with the given OAuth2 credentials
|
|
269
495
|
- Updates placeholder OAuth2 credentials if needed (e.g., temporary IDs)
|
|
496
|
+
- Updates the user's role if role_name is provided (role mapping configured)
|
|
270
497
|
- If the user doesn't exist or has a password set, raises SignInNotAllowed
|
|
271
498
|
2. When sign-up is allowed (allow_sign_up=True):
|
|
272
499
|
- Finds the user by OAuth2 credentials (client_id and user_id)
|
|
273
|
-
- Creates a new user if one doesn't exist, with
|
|
274
|
-
- Updates the user's email if
|
|
500
|
+
- Creates a new user if one doesn't exist, with the provided role (or VIEWER if None)
|
|
501
|
+
- Updates the user's email and role (if role_name provided) if they have changed
|
|
275
502
|
- Handles username conflicts by adding a random suffix if needed
|
|
276
503
|
|
|
277
504
|
The allow_sign_up parameter is typically controlled by the PHOENIX_OAUTH2_{IDP_NAME}_ALLOW_SIGN_UP
|
|
@@ -282,6 +509,8 @@ async def _process_oauth2_user(
|
|
|
282
509
|
oauth2_client_id: The ID of the OAuth2 client
|
|
283
510
|
user_info: User information from the OAuth2 provider
|
|
284
511
|
allow_sign_up: Whether to allow creating new users
|
|
512
|
+
role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to preserve
|
|
513
|
+
existing user roles (backward compatibility when role mapping not configured)
|
|
285
514
|
|
|
286
515
|
Returns:
|
|
287
516
|
The user object
|
|
@@ -291,24 +520,27 @@ async def _process_oauth2_user(
|
|
|
291
520
|
EmailAlreadyInUse: When the email is already in use by another account
|
|
292
521
|
""" # noqa: E501
|
|
293
522
|
if not allow_sign_up:
|
|
294
|
-
return await
|
|
523
|
+
return await _sign_in_existing_oauth2_user(
|
|
295
524
|
session,
|
|
296
525
|
oauth2_client_id=oauth2_client_id,
|
|
297
526
|
user_info=user_info,
|
|
527
|
+
role_name=role_name,
|
|
298
528
|
)
|
|
299
529
|
return await _create_or_update_user(
|
|
300
530
|
session,
|
|
301
531
|
oauth2_client_id=oauth2_client_id,
|
|
302
532
|
user_info=user_info,
|
|
533
|
+
role_name=role_name,
|
|
303
534
|
)
|
|
304
535
|
|
|
305
536
|
|
|
306
|
-
async def
|
|
537
|
+
async def _sign_in_existing_oauth2_user(
|
|
307
538
|
session: AsyncSession,
|
|
308
539
|
/,
|
|
309
540
|
*,
|
|
310
541
|
oauth2_client_id: str,
|
|
311
542
|
user_info: UserInfo,
|
|
543
|
+
role_name: Optional[AssignableUserRoleName],
|
|
312
544
|
) -> models.User:
|
|
313
545
|
"""Signs in an existing user with OAuth2 credentials.
|
|
314
546
|
|
|
@@ -330,6 +562,7 @@ async def _get_existing_oauth2_user(
|
|
|
330
562
|
Profile Updates:
|
|
331
563
|
- Email: Updated if different from IDP info
|
|
332
564
|
- Profile Picture: Updated if provided in user_info
|
|
565
|
+
- Role: Updated ONLY if role_name is provided (role mapping configured)
|
|
333
566
|
- Username: Never updated (remains unchanged)
|
|
334
567
|
- OAuth2 Credentials: Updated based on the three cases above
|
|
335
568
|
|
|
@@ -337,6 +570,8 @@ async def _get_existing_oauth2_user(
|
|
|
337
570
|
session: The database session
|
|
338
571
|
oauth2_client_id: The ID of the OAuth2 client
|
|
339
572
|
user_info: User information from the OAuth2 provider
|
|
573
|
+
role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to preserve
|
|
574
|
+
existing role (backward compatibility when role mapping not configured)
|
|
340
575
|
|
|
341
576
|
Returns:
|
|
342
577
|
The signed-in user
|
|
@@ -348,7 +583,7 @@ async def _get_existing_oauth2_user(
|
|
|
348
583
|
- User has a password set
|
|
349
584
|
- User has mismatched OAuth2 credentials
|
|
350
585
|
""" # noqa: E501
|
|
351
|
-
if not (email := (user_info.email or "")
|
|
586
|
+
if not (email := sanitize_email(user_info.email or "")):
|
|
352
587
|
raise ValueError("Email is required.")
|
|
353
588
|
if not (oauth2_user_id := (user_info.idp_user_id or "").strip()):
|
|
354
589
|
raise ValueError("OAuth2 user ID is required.")
|
|
@@ -362,9 +597,14 @@ async def _get_existing_oauth2_user(
|
|
|
362
597
|
if email and email != user.email:
|
|
363
598
|
user.email = email
|
|
364
599
|
else:
|
|
365
|
-
user = await session.scalar(stmt.
|
|
600
|
+
user = await session.scalar(stmt.where(func.lower(models.User.email) == email))
|
|
366
601
|
if user is None or not isinstance(user, models.OAuth2User):
|
|
367
602
|
raise SignInNotAllowed("Sign in is not allowed.")
|
|
603
|
+
# Security: Prevent OIDC from hijacking LDAP users
|
|
604
|
+
# LDAP users are identified by the special Unicode marker in oauth2_client_id
|
|
605
|
+
# Use generic error message to avoid revealing auth method (username enumeration)
|
|
606
|
+
if is_ldap_user(user.oauth2_client_id):
|
|
607
|
+
raise SignInNotAllowed("Sign in is not allowed.")
|
|
368
608
|
# Case 1: Different OAuth2 client - update both client and user IDs
|
|
369
609
|
if oauth2_client_id != user.oauth2_client_id:
|
|
370
610
|
user.oauth2_client_id = oauth2_client_id
|
|
@@ -377,6 +617,16 @@ async def _get_existing_oauth2_user(
|
|
|
377
617
|
raise SignInNotAllowed("Sign in is not allowed.")
|
|
378
618
|
if profile_picture_url != user.profile_picture_url:
|
|
379
619
|
user.profile_picture_url = profile_picture_url
|
|
620
|
+
|
|
621
|
+
# Update role ONLY if role mapping is configured (role_name is not None)
|
|
622
|
+
# This preserves existing user roles when role mapping is not configured
|
|
623
|
+
if role_name is not None and user.role.name != role_name:
|
|
624
|
+
role = await session.scalar(
|
|
625
|
+
select(models.UserRole).where(models.UserRole.name == role_name)
|
|
626
|
+
)
|
|
627
|
+
if role is not None:
|
|
628
|
+
user.role = role
|
|
629
|
+
|
|
380
630
|
if user in session.dirty:
|
|
381
631
|
await session.flush()
|
|
382
632
|
return user
|
|
@@ -388,6 +638,7 @@ async def _create_or_update_user(
|
|
|
388
638
|
*,
|
|
389
639
|
oauth2_client_id: str,
|
|
390
640
|
user_info: UserInfo,
|
|
641
|
+
role_name: Optional[AssignableUserRoleName],
|
|
391
642
|
) -> models.User:
|
|
392
643
|
"""
|
|
393
644
|
Creates a new user or updates an existing one with OAuth2 credentials.
|
|
@@ -396,6 +647,8 @@ async def _create_or_update_user(
|
|
|
396
647
|
session: The database session
|
|
397
648
|
oauth2_client_id: The ID of the OAuth2 client
|
|
398
649
|
user_info: User information from the OAuth2 provider
|
|
650
|
+
role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to use
|
|
651
|
+
VIEWER for new users and preserve existing users' roles (backward compatibility)
|
|
399
652
|
|
|
400
653
|
Returns:
|
|
401
654
|
The created or updated user
|
|
@@ -409,9 +662,37 @@ async def _create_or_update_user(
|
|
|
409
662
|
idp_user_id=user_info.idp_user_id,
|
|
410
663
|
)
|
|
411
664
|
if user is None:
|
|
412
|
-
user
|
|
413
|
-
|
|
414
|
-
|
|
665
|
+
# New user: use provided role_name, or default to VIEWER if role mapping not configured
|
|
666
|
+
user = await _create_user(
|
|
667
|
+
session,
|
|
668
|
+
oauth2_client_id=oauth2_client_id,
|
|
669
|
+
user_info=user_info,
|
|
670
|
+
role_name=role_name or "VIEWER", # Default for new users
|
|
671
|
+
)
|
|
672
|
+
else:
|
|
673
|
+
# Existing user: update email, profile picture, and/or role if changed
|
|
674
|
+
if user.email != user_info.email:
|
|
675
|
+
user.email = user_info.email
|
|
676
|
+
|
|
677
|
+
# Update profile picture if changed
|
|
678
|
+
if user.profile_picture_url != user_info.profile_picture_url:
|
|
679
|
+
user.profile_picture_url = user_info.profile_picture_url
|
|
680
|
+
|
|
681
|
+
# Update role ONLY if role mapping is configured (role_name is not None)
|
|
682
|
+
# This preserves existing user roles when role mapping is not configured
|
|
683
|
+
if role_name is not None and user.role.name != role_name:
|
|
684
|
+
role = await session.scalar(
|
|
685
|
+
select(models.UserRole).where(models.UserRole.name == role_name)
|
|
686
|
+
)
|
|
687
|
+
if role is not None:
|
|
688
|
+
user.role = role
|
|
689
|
+
|
|
690
|
+
# Flush to execute the UPDATE and catch any email conflicts
|
|
691
|
+
if user in session.dirty:
|
|
692
|
+
try:
|
|
693
|
+
await session.flush()
|
|
694
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
695
|
+
raise EmailAlreadyInUse(f"An account for {user_info.email} is already in use.")
|
|
415
696
|
return user
|
|
416
697
|
|
|
417
698
|
|
|
@@ -445,9 +726,19 @@ async def _create_user(
|
|
|
445
726
|
*,
|
|
446
727
|
oauth2_client_id: str,
|
|
447
728
|
user_info: UserInfo,
|
|
729
|
+
role_name: AssignableUserRoleName,
|
|
448
730
|
) -> models.User:
|
|
449
731
|
"""
|
|
450
732
|
Creates a new user with the user info from the IDP.
|
|
733
|
+
|
|
734
|
+
Args:
|
|
735
|
+
session: The database session
|
|
736
|
+
oauth2_client_id: The ID of the OAuth2 client
|
|
737
|
+
user_info: User information from the OAuth2 provider
|
|
738
|
+
role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER)
|
|
739
|
+
|
|
740
|
+
Returns:
|
|
741
|
+
The created user
|
|
451
742
|
"""
|
|
452
743
|
email_exists, username_exists = await _email_and_username_exist(
|
|
453
744
|
session,
|
|
@@ -456,16 +747,12 @@ async def _create_user(
|
|
|
456
747
|
)
|
|
457
748
|
if email_exists:
|
|
458
749
|
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
|
|
459
|
-
|
|
460
|
-
select(models.UserRole.id)
|
|
461
|
-
.where(models.UserRole.name == UserRole.MEMBER.value)
|
|
462
|
-
.scalar_subquery()
|
|
463
|
-
)
|
|
750
|
+
role_id = select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery()
|
|
464
751
|
user_id = await session.scalar(
|
|
465
752
|
insert(models.User)
|
|
466
753
|
.returning(models.User.id)
|
|
467
754
|
.values(
|
|
468
|
-
user_role_id=
|
|
755
|
+
user_role_id=role_id,
|
|
469
756
|
oauth2_client_id=oauth2_client_id,
|
|
470
757
|
oauth2_user_id=user_info.idp_user_id,
|
|
471
758
|
username=_with_random_suffix(username) if username and username_exists else username,
|
|
@@ -483,26 +770,6 @@ async def _create_user(
|
|
|
483
770
|
return user
|
|
484
771
|
|
|
485
772
|
|
|
486
|
-
async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: str) -> models.User:
|
|
487
|
-
"""
|
|
488
|
-
Updates an existing user's email.
|
|
489
|
-
"""
|
|
490
|
-
try:
|
|
491
|
-
await session.execute(
|
|
492
|
-
update(models.User)
|
|
493
|
-
.where(models.User.id == user_id)
|
|
494
|
-
.values(email=email)
|
|
495
|
-
.options(joinedload(models.User.role))
|
|
496
|
-
)
|
|
497
|
-
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
498
|
-
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
|
|
499
|
-
user = await session.scalar(
|
|
500
|
-
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
|
|
501
|
-
) # query user again for joined load
|
|
502
|
-
assert isinstance(user, models.User)
|
|
503
|
-
return user
|
|
504
|
-
|
|
505
|
-
|
|
506
773
|
async def _email_and_username_exist(
|
|
507
774
|
session: AsyncSession, /, *, email: str, username: Optional[str]
|
|
508
775
|
) -> tuple[bool, bool]:
|
|
@@ -514,7 +781,7 @@ async def _email_and_username_exist(
|
|
|
514
781
|
select(
|
|
515
782
|
cast(
|
|
516
783
|
func.coalesce(
|
|
517
|
-
func.max(case((models.User.email == email, 1), else_=0)),
|
|
784
|
+
func.max(case((func.lower(models.User.email) == email, 1), else_=0)),
|
|
518
785
|
0,
|
|
519
786
|
),
|
|
520
787
|
Boolean,
|
|
@@ -526,7 +793,7 @@ async def _email_and_username_exist(
|
|
|
526
793
|
),
|
|
527
794
|
Boolean,
|
|
528
795
|
).label("username_exists"),
|
|
529
|
-
).where(or_(models.User.email == email, models.User.username == username))
|
|
796
|
+
).where(or_(func.lower(models.User.email) == email, models.User.username == username))
|
|
530
797
|
)
|
|
531
798
|
).all()
|
|
532
799
|
return email_exists, username_exists
|
|
@@ -544,49 +811,48 @@ class NotInvited(Exception):
|
|
|
544
811
|
pass
|
|
545
812
|
|
|
546
813
|
|
|
547
|
-
|
|
814
|
+
class MissingEmailScope(Exception):
|
|
815
|
+
"""
|
|
816
|
+
Raised when the OIDC provider does not return the email scope.
|
|
817
|
+
"""
|
|
818
|
+
|
|
819
|
+
pass
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
class InvalidUserInfo(Exception):
|
|
823
|
+
"""
|
|
824
|
+
Raised when the OIDC user info is malformed or missing required claims.
|
|
825
|
+
"""
|
|
826
|
+
|
|
827
|
+
pass
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def _redirect_to_login(*, request: Request, error: AuthErrorCode) -> RedirectResponse:
|
|
548
831
|
"""
|
|
549
|
-
Creates a RedirectResponse to the login page to display an error
|
|
832
|
+
Creates a RedirectResponse to the login page to display an error code.
|
|
833
|
+
The error code will be validated and mapped to a user-friendly message on the frontend.
|
|
550
834
|
"""
|
|
551
835
|
# TODO: this needs some cleanup
|
|
552
|
-
login_path =
|
|
553
|
-
request
|
|
836
|
+
login_path = prepend_root_path(
|
|
837
|
+
request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
|
|
554
838
|
)
|
|
555
839
|
url = URL(login_path).include_query_params(error=error)
|
|
556
840
|
response = RedirectResponse(url=url)
|
|
557
841
|
response = delete_oauth2_state_cookie(response)
|
|
558
842
|
response = delete_oauth2_nonce_cookie(response)
|
|
843
|
+
response = delete_oauth2_code_verifier_cookie(response)
|
|
559
844
|
return response
|
|
560
845
|
|
|
561
846
|
|
|
562
|
-
def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
|
|
563
|
-
"""
|
|
564
|
-
If a root path is configured, prepends it to the input path.
|
|
565
|
-
"""
|
|
566
|
-
if not path.startswith("/"):
|
|
567
|
-
raise ValueError("path must start with a forward slash")
|
|
568
|
-
root_path = _get_root_path(request=request)
|
|
569
|
-
if root_path.endswith("/"):
|
|
570
|
-
root_path = root_path.rstrip("/")
|
|
571
|
-
return root_path + path
|
|
572
|
-
|
|
573
|
-
|
|
574
847
|
def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
|
|
575
848
|
"""
|
|
576
849
|
If a root path is configured, appends it to the input base url.
|
|
577
850
|
"""
|
|
578
|
-
if not (root_path :=
|
|
851
|
+
if not (root_path := get_root_path(request.scope)):
|
|
579
852
|
return base_url
|
|
580
853
|
return str(URLPath(root_path).make_absolute_url(base_url=base_url))
|
|
581
854
|
|
|
582
855
|
|
|
583
|
-
def _get_root_path(*, request: Request) -> str:
|
|
584
|
-
"""
|
|
585
|
-
Gets the root path from the request.
|
|
586
|
-
"""
|
|
587
|
-
return str(request.scope.get("root_path", ""))
|
|
588
|
-
|
|
589
|
-
|
|
590
856
|
def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
|
|
591
857
|
"""
|
|
592
858
|
Gets the endpoint for create tokens route.
|
|
@@ -664,7 +930,4 @@ def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2State
|
|
|
664
930
|
|
|
665
931
|
|
|
666
932
|
_JWT_ALGORITHM = "HS256"
|
|
667
|
-
_INVALID_OAUTH2_STATE_MESSAGE = (
|
|
668
|
-
"Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}."
|
|
669
|
-
)
|
|
670
933
|
_RELATIVE_URL_PATTERN = re.compile(r"^/($|\w)")
|