arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
+ import logging
1
2
  import re
2
- from dataclasses import dataclass
3
+ from dataclasses import dataclass, field
3
4
  from datetime import timedelta
4
5
  from random import randrange
5
6
  from typing import Any, Optional, TypedDict
@@ -10,7 +11,7 @@ from authlib.integrations.starlette_client import OAuthError
10
11
  from authlib.jose import jwt
11
12
  from authlib.jose.errors import JoseError
12
13
  from fastapi import APIRouter, Cookie, Depends, Path, Query, Request
13
- from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update
14
+ from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select
14
15
  from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
15
16
  from sqlalchemy.ext.asyncio import AsyncSession
16
17
  from sqlalchemy.orm import joinedload
@@ -18,27 +19,32 @@ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore
18
19
  from starlette.datastructures import URL, Secret, URLPath
19
20
  from starlette.responses import RedirectResponse
20
21
  from starlette.routing import Router
21
- from starlette.status import HTTP_302_FOUND
22
22
  from typing_extensions import Annotated, NotRequired, TypeGuard
23
23
 
24
24
  from phoenix.auth import (
25
25
  DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES,
26
+ PHOENIX_OAUTH2_CODE_VERIFIER_COOKIE_NAME,
26
27
  PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
27
28
  PHOENIX_OAUTH2_STATE_COOKIE_NAME,
29
+ delete_oauth2_code_verifier_cookie,
28
30
  delete_oauth2_nonce_cookie,
29
31
  delete_oauth2_state_cookie,
30
32
  sanitize_email,
31
33
  set_access_token_cookie,
34
+ set_oauth2_code_verifier_cookie,
32
35
  set_oauth2_nonce_cookie,
33
36
  set_oauth2_state_cookie,
34
37
  set_refresh_token_cookie,
35
38
  )
36
39
  from phoenix.config import (
40
+ AssignableUserRoleName,
37
41
  get_env_disable_basic_auth,
38
42
  get_env_disable_rate_limit,
39
43
  )
40
44
  from phoenix.db import models
45
+ from phoenix.server.api.auth_messages import AuthErrorCode
41
46
  from phoenix.server.bearer_auth import create_access_and_refresh_tokens
47
+ from phoenix.server.ldap import is_ldap_user
42
48
  from phoenix.server.oauth2 import OAuth2Client
43
49
  from phoenix.server.rate_limiters import (
44
50
  ServerRateLimiter,
@@ -46,9 +52,12 @@ from phoenix.server.rate_limiters import (
46
52
  fastapi_route_rate_limiter,
47
53
  )
48
54
  from phoenix.server.types import TokenStore
55
+ from phoenix.server.utils import get_root_path, prepend_root_path
49
56
 
50
57
  _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
51
58
 
59
+ logger = logging.getLogger(__name__)
60
+
52
61
  login_rate_limiter = fastapi_ip_rate_limiter(
53
62
  ServerRateLimiter(
54
63
  per_second_rate_limit=0.2,
@@ -87,11 +96,12 @@ async def login(
87
96
  idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
88
97
  return_url: Optional[str] = Query(default=None, alias="returnUrl"),
89
98
  ) -> RedirectResponse:
99
+ # Security Note: Query parameters should be treated as untrusted user input. Never display
100
+ # these values directly to users as they could be manipulated for XSS, phishing, or social
101
+ # engineering attacks.
102
+ if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
103
+ return _redirect_to_login(request=request, error="unknown_idp")
90
104
  secret = request.app.state.get_secret()
91
- if not isinstance(
92
- oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
93
- ):
94
- return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
95
105
  if (referer := request.headers.get("referer")) is not None:
96
106
  # if the referer header is present, use it as the origin URL
97
107
  parsed_url = urlparse(referer)
@@ -112,7 +122,7 @@ async def login(
112
122
  assert isinstance(authorization_url := authorization_url_data.get("url"), str)
113
123
  assert isinstance(state := authorization_url_data.get("state"), str)
114
124
  assert isinstance(nonce := authorization_url_data.get("nonce"), str)
115
- response = RedirectResponse(url=authorization_url, status_code=HTTP_302_FOUND)
125
+ response = RedirectResponse(url=authorization_url, status_code=302)
116
126
  response = set_oauth2_state_cookie(
117
127
  response=response,
118
128
  state=state,
@@ -123,6 +133,12 @@ async def login(
123
133
  nonce=nonce,
124
134
  max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
125
135
  )
136
+ if code_verifier := authorization_url_data.get("code_verifier"):
137
+ response = set_oauth2_code_verifier_cookie(
138
+ response=response,
139
+ code_verifier=code_verifier,
140
+ max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
141
+ )
126
142
  return response
127
143
 
128
144
 
@@ -130,50 +146,106 @@ async def login(
130
146
  async def create_tokens(
131
147
  request: Request,
132
148
  idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
133
- state: str = Query(),
134
- authorization_code: str = Query(alias="code"),
149
+ state: str = Query(), # RFC 6749 §4.1.1: CSRF protection via state parameter
150
+ authorization_code: Optional[str] = Query(default=None, alias="code"), # RFC 6749 §4.1.2
151
+ error: Optional[str] = Query(default=None), # RFC 6749 §4.1.2.1: Error response
152
+ error_description: Optional[str] = Query(default=None),
135
153
  stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME),
136
- stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME),
154
+ stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME), # OIDC Core §3.1.2.1
155
+ code_verifier: Optional[str] = Cookie(
156
+ default=None, alias=PHOENIX_OAUTH2_CODE_VERIFIER_COOKIE_NAME
157
+ ), # RFC 7636 §4.1
137
158
  ) -> RedirectResponse:
159
+ # Security Note: Query parameters should be treated as untrusted user input. Never display
160
+ # these values directly to users as they could be manipulated for XSS, phishing, or social
161
+ # engineering attacks.
162
+ if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
163
+ return _redirect_to_login(request=request, error="unknown_idp")
164
+ if error or error_description:
165
+ logger.error(
166
+ "OAuth2 authentication failed for IDP %s: error=%s, description=%s",
167
+ idp_name,
168
+ error,
169
+ error_description,
170
+ )
171
+ return _redirect_to_login(request=request, error="auth_failed")
172
+ if authorization_code is None:
173
+ logger.error("OAuth2 callback missing authorization code for IDP %s", idp_name)
174
+ return _redirect_to_login(request=request, error="auth_failed")
138
175
  secret = request.app.state.get_secret()
176
+ # RFC 6749 §10.12: CSRF protection - validate state parameter
139
177
  if state != stored_state:
140
- return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
178
+ return _redirect_to_login(request=request, error="invalid_state")
141
179
  try:
142
180
  payload = _parse_state_payload(secret=secret, state=state)
143
181
  except JoseError:
144
- return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
182
+ return _redirect_to_login(request=request, error="invalid_state")
145
183
  if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
146
184
  unquote(return_url)
147
185
  ):
148
- return _redirect_to_login(request=request, error="Attempting login with unsafe return URL.")
186
+ return _redirect_to_login(request=request, error="unsafe_return_url")
149
187
  assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
150
188
  assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
151
189
  token_store: TokenStore = request.app.state.get_token_store()
152
- if not isinstance(
153
- oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
154
- ):
155
- return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
156
190
  try:
157
- token_data = await oauth2_client.fetch_access_token(
191
+ # RFC 6749 §4.1.3: Token request - exchange authorization code for tokens
192
+ fetch_kwargs: dict[str, Any] = dict(
158
193
  state=state,
159
194
  code=authorization_code,
160
- redirect_uri=_get_create_tokens_endpoint(
195
+ redirect_uri=_get_create_tokens_endpoint( # RFC 6749 §3.1.2
161
196
  request=request, origin_url=payload["origin_url"], idp_name=idp_name
162
197
  ),
163
198
  )
164
- except OAuthError as error:
165
- return _redirect_to_login(request=request, error=str(error))
199
+ # PKCE validation: code_verifier is required when PKCE is enabled (RFC 7636 §4.5)
200
+ if oauth2_client.use_pkce:
201
+ if not code_verifier:
202
+ logger.error(
203
+ "PKCE enabled but code_verifier cookie missing for IDP %s. "
204
+ "This may indicate a cookie issue, CORS misconfiguration, or "
205
+ "browser compatibility problem.",
206
+ idp_name,
207
+ )
208
+ return _redirect_to_login(request=request, error="auth_failed")
209
+ fetch_kwargs["code_verifier"] = code_verifier
210
+ token_data = await oauth2_client.fetch_access_token(**fetch_kwargs)
211
+ except OAuthError as e:
212
+ logger.error("OAuth2 error for IDP %s: %s", idp_name, e)
213
+ return _redirect_to_login(request=request, error="oauth_error")
166
214
  _validate_token_data(token_data)
167
215
  if "id_token" not in token_data:
168
- return _redirect_to_login(
169
- request=request,
170
- error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect.",
216
+ logger.error("OAuth2 IDP %s does not appear to support OpenID Connect", idp_name)
217
+ return _redirect_to_login(request=request, error="no_oidc_support")
218
+
219
+ try:
220
+ id_token_claims = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
221
+ except JoseError as e:
222
+ logger.error("ID token validation failed for IDP %s: %s", idp_name, e)
223
+ return _redirect_to_login(request=request, error="auth_failed")
224
+
225
+ if oauth2_client.has_sufficient_claims(id_token_claims):
226
+ user_claims = id_token_claims
227
+ else:
228
+ user_claims = await _fetch_and_merge_userinfo_claims(
229
+ oauth2_client, token_data, id_token_claims
171
230
  )
172
- user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
231
+
232
+ try:
233
+ user_info = _parse_user_info(user_claims)
234
+ except (MissingEmailScope, InvalidUserInfo) as e:
235
+ logger.error("Error parsing user info for IDP %s: %s", idp_name, e)
236
+ return _redirect_to_login(request=request, error="missing_email_scope")
237
+
238
+ # Validate access and extract role from claims
239
+ # Both validate_access and extract_and_map_role may raise PermissionError
173
240
  try:
174
- user_info = _parse_user_info(user_info)
175
- except MissingEmailScope as error:
176
- return _redirect_to_login(request=request, error=str(error))
241
+ oauth2_client.validate_access(user_info.claims)
242
+ # Extract and map role from claims
243
+ # Returns None if role mapping not configured (preserves existing user roles)
244
+ # Raises PermissionError if strict mode enabled and role validation fails
245
+ role_name = oauth2_client.extract_and_map_role(user_info.claims)
246
+ except PermissionError as e:
247
+ logger.error("Access validation failed for IDP %s: %s", idp_name, e)
248
+ return _redirect_to_login(request=request, error="auth_failed")
177
249
 
178
250
  try:
179
251
  async with request.app.state.db() as session:
@@ -182,19 +254,24 @@ async def create_tokens(
182
254
  oauth2_client_id=str(oauth2_client.client_id),
183
255
  user_info=user_info,
184
256
  allow_sign_up=oauth2_client.allow_sign_up,
257
+ role_name=role_name,
185
258
  )
186
- except (EmailAlreadyInUse, SignInNotAllowed) as error:
187
- return _redirect_to_login(request=request, error=str(error))
259
+ except EmailAlreadyInUse as e:
260
+ logger.error("Email already in use for IDP %s: %s", idp_name, e)
261
+ return _redirect_to_login(request=request, error="email_in_use")
262
+ except SignInNotAllowed as e:
263
+ logger.error("Sign in not allowed for IDP %s: %s", idp_name, e)
264
+ return _redirect_to_login(request=request, error="sign_in_not_allowed")
188
265
  access_token, refresh_token = await create_access_and_refresh_tokens(
189
266
  user=user,
190
267
  token_store=token_store,
191
268
  access_token_expiry=access_token_expiry,
192
269
  refresh_token_expiry=refresh_token_expiry,
193
270
  )
194
- redirect_path = _prepend_root_path_if_exists(request=request, path=return_url or "/")
271
+ redirect_path = prepend_root_path(request.scope, return_url or "/")
195
272
  response = RedirectResponse(
196
273
  url=redirect_path,
197
- status_code=HTTP_302_FOUND,
274
+ status_code=302,
198
275
  )
199
276
  response = set_access_token_cookie(
200
277
  response=response, access_token=access_token, max_age=access_token_expiry
@@ -204,6 +281,7 @@ async def create_tokens(
204
281
  )
205
282
  response = delete_oauth2_state_cookie(response)
206
283
  response = delete_oauth2_nonce_cookie(response)
284
+ response = delete_oauth2_code_verifier_cookie(response)
207
285
  return response
208
286
 
209
287
 
@@ -213,6 +291,7 @@ class UserInfo:
213
291
  email: str
214
292
  username: Optional[str] = None
215
293
  profile_picture_url: Optional[str] = None
294
+ claims: dict[str, Any] = field(default_factory=dict)
216
295
 
217
296
  def __post_init__(self) -> None:
218
297
  if not (idp_user_id := (self.idp_user_id or "").strip()):
@@ -227,9 +306,64 @@ class UserInfo:
227
306
  object.__setattr__(self, "profile_picture_url", profile_picture_url)
228
307
 
229
308
 
309
+ async def _fetch_and_merge_userinfo_claims(
310
+ oauth2_client: OAuth2Client,
311
+ token_data: dict[str, Any],
312
+ id_token_claims: dict[str, Any],
313
+ ) -> dict[str, Any]:
314
+ """
315
+ Fetch claims from UserInfo endpoint and merge with ID token claims.
316
+
317
+ Why this is necessary (OIDC Core §5.4, §5.5):
318
+ When claims are requested via scopes (e.g., "profile", "email"), OIDC Core §5.4
319
+ specifies which claims are "REQUESTED" but does not mandate WHERE they must be
320
+ returned. Similarly, §5.5 allows requesting specific claims via the "claims"
321
+ parameter, but providers have discretion on whether to return them in the ID token
322
+ or UserInfo response. In practice, providers often return certain claims (especially
323
+ large ones like groups) only via UserInfo to keep ID tokens compact.
324
+
325
+ The UserInfo endpoint (OIDC Core §5.3) provides additional claims beyond what's
326
+ in the ID token, such as group memberships or custom attributes. This function:
327
+
328
+ 1. Calls the UserInfo endpoint using the access token (OIDC Core §5.3.1, RFC 6750)
329
+ 2. Merges userinfo claims with ID token claims
330
+ 3. ID token claims override userinfo claims when both contain the same claim
331
+
332
+ Why ID token takes precedence (OIDC Core §5.3.2):
333
+ - ID tokens are signed JWTs that have been cryptographically verified
334
+ - UserInfo responses may be unsigned
335
+ - Signed claims are the authoritative source when present in both
336
+
337
+ Fallback behavior:
338
+ If the UserInfo request fails, returns only ID token claims. The returned claims
339
+ may be incomplete (missing email or groups), but subsequent validation will catch this:
340
+ - Missing email: _parse_user_info() raises MissingEmailScope
341
+ - Missing groups: validate_access() raises PermissionError if access is denied
342
+
343
+ Args:
344
+ oauth2_client: The OAuth2 client to use for fetching userinfo
345
+ token_data: Token response containing the access token (RFC 6749 §5.1)
346
+ id_token_claims: Claims from the verified ID token (OIDC Core §3.1.3.3)
347
+
348
+ Returns:
349
+ Merged claims dictionary with ID token claims overriding userinfo claims
350
+ """
351
+ try:
352
+ # OIDC Core §5.3.1: UserInfo request authenticated with access token
353
+ userinfo_claims = await oauth2_client.userinfo(token=token_data)
354
+ # ID token claims take precedence (signed and verified)
355
+ return {**userinfo_claims, **id_token_claims}
356
+ except Exception:
357
+ # Fallback: ID token has essential claims for authentication
358
+ return id_token_claims
359
+
360
+
230
361
  def _validate_token_data(token_data: dict[str, Any]) -> None:
231
362
  """
232
363
  Performs basic validations on the token data returned by the IDP.
364
+
365
+ RFC 6749 §5.1: Successful response must include access_token and token_type.
366
+ RFC 6750 §1.1: Bearer token type for HTTP authentication.
233
367
  """
234
368
  assert isinstance(token_data.get("access_token"), str)
235
369
  assert isinstance(token_type := token_data.get("token_type"), str)
@@ -239,25 +373,107 @@ def _validate_token_data(token_data: dict[str, Any]) -> None:
239
373
  def _parse_user_info(user_info: dict[str, Any]) -> UserInfo:
240
374
  """
241
375
  Parses user info from the IDP's ID token.
242
- """
243
- assert isinstance(subject := user_info.get("sub"), (str, int))
244
- idp_user_id = str(subject)
376
+
377
+ Validates required OIDC claims and extracts user information according to the
378
+ OpenID Connect Core 1.0 specification.
379
+
380
+ Args:
381
+ user_info: Claims from the ID token (validated JWT payload)
382
+
383
+ Returns:
384
+ UserInfo object with validated user data
385
+
386
+ Raises:
387
+ InvalidUserInfo: If required claims are missing or malformed
388
+ MissingEmailScope: If email claim is missing or invalid
389
+
390
+ ID Token Required Claims (OIDC Core §2, §3.1.3.3):
391
+ - iss (issuer): Identifier for the OpenID Provider
392
+ - sub (subject): Unique identifier for the End-User at the Issuer
393
+ - aud (audience): Client ID this ID token is intended for
394
+ - exp (expiration): Expiration time
395
+ - iat (issued at): Time the JWT was issued
396
+ - nonce (if sent in auth request): Value sent in the Authentication Request
397
+
398
+ Application-Required Claims:
399
+ - email: Required by this application for user identification
400
+
401
+ Optional Standard Claims (OIDC Core §5.1):
402
+ - name: Full name
403
+ - picture: Profile picture URL
404
+ - Other profile, email, address, and phone claims
405
+
406
+ Note: While iss, sub, aud, exp, iat are REQUIRED in all ID tokens per spec,
407
+ other claims like email, name, groups are optional and may appear in the ID token,
408
+ UserInfo response, or both depending on what was requested and provider implementation.
409
+ """
410
+ # Validate 'sub' claim (OIDC required, MUST be a string per spec)
411
+ subject = user_info.get("sub")
412
+ if subject is None:
413
+ raise InvalidUserInfo(
414
+ "Missing required 'sub' claim in ID token. "
415
+ "Please check your OIDC provider configuration."
416
+ )
417
+
418
+ # OIDC spec: sub MUST be a string, but some IDPs send integers
419
+ # Convert to string for compatibility
420
+ if isinstance(subject, (str, int)):
421
+ idp_user_id = str(subject).strip()
422
+ else:
423
+ raise InvalidUserInfo(
424
+ f"Invalid 'sub' claim type: {type(subject).__name__}. Expected string or integer."
425
+ )
426
+
427
+ if not idp_user_id:
428
+ raise InvalidUserInfo("The 'sub' claim cannot be empty.")
429
+
430
+ # Validate 'email' claim (application requirement)
245
431
  email = user_info.get("email")
246
- if not isinstance(email, str):
432
+ if not isinstance(email, str) or not email.strip():
247
433
  raise MissingEmailScope(
248
- "Please ensure your OIDC provider is configured to use the 'email' scope."
434
+ "Missing or invalid 'email' claim. "
435
+ "Please ensure your OIDC provider is configured to include the 'email' scope."
249
436
  )
437
+ email = email.strip()
438
+
439
+ # Optional: 'name' claim (Full name)
440
+ username = user_info.get("name")
441
+ if username is not None:
442
+ if not isinstance(username, str):
443
+ # Some IDPs might send unexpected types; ignore gracefully
444
+ username = None
445
+ else:
446
+ username = username.strip() or None
447
+
448
+ # Optional: 'picture' claim (Profile picture URL)
449
+ profile_picture_url = user_info.get("picture")
450
+ if profile_picture_url is not None:
451
+ if not isinstance(profile_picture_url, str):
452
+ # Some IDPs might send unexpected types; ignore gracefully
453
+ profile_picture_url = None
454
+ else:
455
+ profile_picture_url = profile_picture_url.strip() or None
456
+
457
+ # Keep only non-empty claim values for downstream processing
458
+ def _has_value(v: Any) -> bool:
459
+ """Check if a claim value is considered non-empty."""
460
+ if v is None:
461
+ return False
462
+ if isinstance(v, str):
463
+ return bool(v.strip())
464
+ if isinstance(v, (list, dict, set, tuple)):
465
+ return len(v) > 0
466
+ # Include all other types (numbers, booleans, etc.)
467
+ return True
468
+
469
+ filtered_claims = {k: v for k, v in user_info.items() if _has_value(v)}
250
470
 
251
- assert isinstance(username := user_info.get("name"), str) or username is None
252
- assert (
253
- isinstance(profile_picture_url := user_info.get("picture"), str)
254
- or profile_picture_url is None
255
- )
256
471
  return UserInfo(
257
472
  idp_user_id=idp_user_id,
258
473
  email=email,
259
474
  username=username,
260
475
  profile_picture_url=profile_picture_url,
476
+ claims=filtered_claims,
261
477
  )
262
478
 
263
479
 
@@ -268,6 +484,7 @@ async def _process_oauth2_user(
268
484
  oauth2_client_id: str,
269
485
  user_info: UserInfo,
270
486
  allow_sign_up: bool,
487
+ role_name: Optional[AssignableUserRoleName],
271
488
  ) -> models.User:
272
489
  """
273
490
  Processes an OAuth2 user, either signing in an existing user or creating/updating one.
@@ -276,11 +493,12 @@ async def _process_oauth2_user(
276
493
  1. When sign-up is not allowed (allow_sign_up=False):
277
494
  - Checks if the user exists and can sign in with the given OAuth2 credentials
278
495
  - Updates placeholder OAuth2 credentials if needed (e.g., temporary IDs)
496
+ - Updates the user's role if role_name is provided (role mapping configured)
279
497
  - If the user doesn't exist or has a password set, raises SignInNotAllowed
280
498
  2. When sign-up is allowed (allow_sign_up=True):
281
499
  - Finds the user by OAuth2 credentials (client_id and user_id)
282
- - Creates a new user if one doesn't exist, with default member role
283
- - Updates the user's email if it has changed
500
+ - Creates a new user if one doesn't exist, with the provided role (or VIEWER if None)
501
+ - Updates the user's email and role (if role_name provided) if they have changed
284
502
  - Handles username conflicts by adding a random suffix if needed
285
503
 
286
504
  The allow_sign_up parameter is typically controlled by the PHOENIX_OAUTH2_{IDP_NAME}_ALLOW_SIGN_UP
@@ -291,6 +509,8 @@ async def _process_oauth2_user(
291
509
  oauth2_client_id: The ID of the OAuth2 client
292
510
  user_info: User information from the OAuth2 provider
293
511
  allow_sign_up: Whether to allow creating new users
512
+ role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to preserve
513
+ existing user roles (backward compatibility when role mapping not configured)
294
514
 
295
515
  Returns:
296
516
  The user object
@@ -300,24 +520,27 @@ async def _process_oauth2_user(
300
520
  EmailAlreadyInUse: When the email is already in use by another account
301
521
  """ # noqa: E501
302
522
  if not allow_sign_up:
303
- return await _get_existing_oauth2_user(
523
+ return await _sign_in_existing_oauth2_user(
304
524
  session,
305
525
  oauth2_client_id=oauth2_client_id,
306
526
  user_info=user_info,
527
+ role_name=role_name,
307
528
  )
308
529
  return await _create_or_update_user(
309
530
  session,
310
531
  oauth2_client_id=oauth2_client_id,
311
532
  user_info=user_info,
533
+ role_name=role_name,
312
534
  )
313
535
 
314
536
 
315
- async def _get_existing_oauth2_user(
537
+ async def _sign_in_existing_oauth2_user(
316
538
  session: AsyncSession,
317
539
  /,
318
540
  *,
319
541
  oauth2_client_id: str,
320
542
  user_info: UserInfo,
543
+ role_name: Optional[AssignableUserRoleName],
321
544
  ) -> models.User:
322
545
  """Signs in an existing user with OAuth2 credentials.
323
546
 
@@ -339,6 +562,7 @@ async def _get_existing_oauth2_user(
339
562
  Profile Updates:
340
563
  - Email: Updated if different from IDP info
341
564
  - Profile Picture: Updated if provided in user_info
565
+ - Role: Updated ONLY if role_name is provided (role mapping configured)
342
566
  - Username: Never updated (remains unchanged)
343
567
  - OAuth2 Credentials: Updated based on the three cases above
344
568
 
@@ -346,6 +570,8 @@ async def _get_existing_oauth2_user(
346
570
  session: The database session
347
571
  oauth2_client_id: The ID of the OAuth2 client
348
572
  user_info: User information from the OAuth2 provider
573
+ role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to preserve
574
+ existing role (backward compatibility when role mapping not configured)
349
575
 
350
576
  Returns:
351
577
  The signed-in user
@@ -374,6 +600,11 @@ async def _get_existing_oauth2_user(
374
600
  user = await session.scalar(stmt.where(func.lower(models.User.email) == email))
375
601
  if user is None or not isinstance(user, models.OAuth2User):
376
602
  raise SignInNotAllowed("Sign in is not allowed.")
603
+ # Security: Prevent OIDC from hijacking LDAP users
604
+ # LDAP users are identified by the special Unicode marker in oauth2_client_id
605
+ # Use generic error message to avoid revealing auth method (username enumeration)
606
+ if is_ldap_user(user.oauth2_client_id):
607
+ raise SignInNotAllowed("Sign in is not allowed.")
377
608
  # Case 1: Different OAuth2 client - update both client and user IDs
378
609
  if oauth2_client_id != user.oauth2_client_id:
379
610
  user.oauth2_client_id = oauth2_client_id
@@ -386,6 +617,16 @@ async def _get_existing_oauth2_user(
386
617
  raise SignInNotAllowed("Sign in is not allowed.")
387
618
  if profile_picture_url != user.profile_picture_url:
388
619
  user.profile_picture_url = profile_picture_url
620
+
621
+ # Update role ONLY if role mapping is configured (role_name is not None)
622
+ # This preserves existing user roles when role mapping is not configured
623
+ if role_name is not None and user.role.name != role_name:
624
+ role = await session.scalar(
625
+ select(models.UserRole).where(models.UserRole.name == role_name)
626
+ )
627
+ if role is not None:
628
+ user.role = role
629
+
389
630
  if user in session.dirty:
390
631
  await session.flush()
391
632
  return user
@@ -397,6 +638,7 @@ async def _create_or_update_user(
397
638
  *,
398
639
  oauth2_client_id: str,
399
640
  user_info: UserInfo,
641
+ role_name: Optional[AssignableUserRoleName],
400
642
  ) -> models.User:
401
643
  """
402
644
  Creates a new user or updates an existing one with OAuth2 credentials.
@@ -405,6 +647,8 @@ async def _create_or_update_user(
405
647
  session: The database session
406
648
  oauth2_client_id: The ID of the OAuth2 client
407
649
  user_info: User information from the OAuth2 provider
650
+ role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to use
651
+ VIEWER for new users and preserve existing users' roles (backward compatibility)
408
652
 
409
653
  Returns:
410
654
  The created or updated user
@@ -418,9 +662,37 @@ async def _create_or_update_user(
418
662
  idp_user_id=user_info.idp_user_id,
419
663
  )
420
664
  if user is None:
421
- user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info)
422
- elif user.email != user_info.email:
423
- user = await _update_user_email(session, user_id=user.id, email=user_info.email)
665
+ # New user: use provided role_name, or default to VIEWER if role mapping not configured
666
+ user = await _create_user(
667
+ session,
668
+ oauth2_client_id=oauth2_client_id,
669
+ user_info=user_info,
670
+ role_name=role_name or "VIEWER", # Default for new users
671
+ )
672
+ else:
673
+ # Existing user: update email, profile picture, and/or role if changed
674
+ if user.email != user_info.email:
675
+ user.email = user_info.email
676
+
677
+ # Update profile picture if changed
678
+ if user.profile_picture_url != user_info.profile_picture_url:
679
+ user.profile_picture_url = user_info.profile_picture_url
680
+
681
+ # Update role ONLY if role mapping is configured (role_name is not None)
682
+ # This preserves existing user roles when role mapping is not configured
683
+ if role_name is not None and user.role.name != role_name:
684
+ role = await session.scalar(
685
+ select(models.UserRole).where(models.UserRole.name == role_name)
686
+ )
687
+ if role is not None:
688
+ user.role = role
689
+
690
+ # Flush to execute the UPDATE and catch any email conflicts
691
+ if user in session.dirty:
692
+ try:
693
+ await session.flush()
694
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
695
+ raise EmailAlreadyInUse(f"An account for {user_info.email} is already in use.")
424
696
  return user
425
697
 
426
698
 
@@ -454,9 +726,19 @@ async def _create_user(
454
726
  *,
455
727
  oauth2_client_id: str,
456
728
  user_info: UserInfo,
729
+ role_name: AssignableUserRoleName,
457
730
  ) -> models.User:
458
731
  """
459
732
  Creates a new user with the user info from the IDP.
733
+
734
+ Args:
735
+ session: The database session
736
+ oauth2_client_id: The ID of the OAuth2 client
737
+ user_info: User information from the OAuth2 provider
738
+ role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER)
739
+
740
+ Returns:
741
+ The created user
460
742
  """
461
743
  email_exists, username_exists = await _email_and_username_exist(
462
744
  session,
@@ -465,14 +747,12 @@ async def _create_user(
465
747
  )
466
748
  if email_exists:
467
749
  raise EmailAlreadyInUse(f"An account for {email} is already in use.")
468
- member_role_id = (
469
- select(models.UserRole.id).where(models.UserRole.name == "MEMBER").scalar_subquery()
470
- )
750
+ role_id = select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery()
471
751
  user_id = await session.scalar(
472
752
  insert(models.User)
473
753
  .returning(models.User.id)
474
754
  .values(
475
- user_role_id=member_role_id,
755
+ user_role_id=role_id,
476
756
  oauth2_client_id=oauth2_client_id,
477
757
  oauth2_user_id=user_info.idp_user_id,
478
758
  username=_with_random_suffix(username) if username and username_exists else username,
@@ -490,26 +770,6 @@ async def _create_user(
490
770
  return user
491
771
 
492
772
 
493
- async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: str) -> models.User:
494
- """
495
- Updates an existing user's email.
496
- """
497
- try:
498
- await session.execute(
499
- update(models.User)
500
- .where(models.User.id == user_id)
501
- .values(email=email)
502
- .options(joinedload(models.User.role))
503
- )
504
- except (PostgreSQLIntegrityError, SQLiteIntegrityError):
505
- raise EmailAlreadyInUse(f"An account for {email} is already in use.")
506
- user = await session.scalar(
507
- select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
508
- ) # query user again for joined load
509
- assert isinstance(user, models.User)
510
- return user
511
-
512
-
513
773
  async def _email_and_username_exist(
514
774
  session: AsyncSession, /, *, email: str, username: Optional[str]
515
775
  ) -> tuple[bool, bool]:
@@ -559,49 +819,40 @@ class MissingEmailScope(Exception):
559
819
  pass
560
820
 
561
821
 
562
- def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
822
+ class InvalidUserInfo(Exception):
823
+ """
824
+ Raised when the OIDC user info is malformed or missing required claims.
825
+ """
826
+
827
+ pass
828
+
829
+
830
+ def _redirect_to_login(*, request: Request, error: AuthErrorCode) -> RedirectResponse:
563
831
  """
564
- Creates a RedirectResponse to the login page to display an error message.
832
+ Creates a RedirectResponse to the login page to display an error code.
833
+ The error code will be validated and mapped to a user-friendly message on the frontend.
565
834
  """
566
835
  # TODO: this needs some cleanup
567
- login_path = _prepend_root_path_if_exists(
568
- request=request, path="/login" if not get_env_disable_basic_auth() else "/logout"
836
+ login_path = prepend_root_path(
837
+ request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
569
838
  )
570
839
  url = URL(login_path).include_query_params(error=error)
571
840
  response = RedirectResponse(url=url)
572
841
  response = delete_oauth2_state_cookie(response)
573
842
  response = delete_oauth2_nonce_cookie(response)
843
+ response = delete_oauth2_code_verifier_cookie(response)
574
844
  return response
575
845
 
576
846
 
577
- def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
578
- """
579
- If a root path is configured, prepends it to the input path.
580
- """
581
- if not path.startswith("/"):
582
- raise ValueError("path must start with a forward slash")
583
- root_path = _get_root_path(request=request)
584
- if root_path.endswith("/"):
585
- root_path = root_path.rstrip("/")
586
- return root_path + path
587
-
588
-
589
847
  def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
590
848
  """
591
849
  If a root path is configured, appends it to the input base url.
592
850
  """
593
- if not (root_path := _get_root_path(request=request)):
851
+ if not (root_path := get_root_path(request.scope)):
594
852
  return base_url
595
853
  return str(URLPath(root_path).make_absolute_url(base_url=base_url))
596
854
 
597
855
 
598
- def _get_root_path(*, request: Request) -> str:
599
- """
600
- Gets the root path from the request.
601
- """
602
- return str(request.scope.get("root_path", ""))
603
-
604
-
605
856
  def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
606
857
  """
607
858
  Gets the endpoint for create tokens route.
@@ -679,7 +930,4 @@ def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2State
679
930
 
680
931
 
681
932
  _JWT_ALGORITHM = "HS256"
682
- _INVALID_OAUTH2_STATE_MESSAGE = (
683
- "Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}."
684
- )
685
933
  _RELATIVE_URL_PATTERN = re.compile(r"^/($|\w)")