arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
+ import logging
1
2
  import re
2
- from dataclasses import dataclass
3
+ from dataclasses import dataclass, field
3
4
  from datetime import timedelta
4
5
  from random import randrange
5
6
  from typing import Any, Optional, TypedDict
@@ -10,7 +11,7 @@ from authlib.integrations.starlette_client import OAuthError
10
11
  from authlib.jose import jwt
11
12
  from authlib.jose.errors import JoseError
12
13
  from fastapi import APIRouter, Cookie, Depends, Path, Query, Request
13
- from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update
14
+ from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select
14
15
  from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
15
16
  from sqlalchemy.ext.asyncio import AsyncSession
16
17
  from sqlalchemy.orm import joinedload
@@ -18,27 +19,32 @@ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore
18
19
  from starlette.datastructures import URL, Secret, URLPath
19
20
  from starlette.responses import RedirectResponse
20
21
  from starlette.routing import Router
21
- from starlette.status import HTTP_302_FOUND
22
22
  from typing_extensions import Annotated, NotRequired, TypeGuard
23
23
 
24
24
  from phoenix.auth import (
25
25
  DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES,
26
+ PHOENIX_OAUTH2_CODE_VERIFIER_COOKIE_NAME,
26
27
  PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
27
28
  PHOENIX_OAUTH2_STATE_COOKIE_NAME,
29
+ delete_oauth2_code_verifier_cookie,
28
30
  delete_oauth2_nonce_cookie,
29
31
  delete_oauth2_state_cookie,
32
+ sanitize_email,
30
33
  set_access_token_cookie,
34
+ set_oauth2_code_verifier_cookie,
31
35
  set_oauth2_nonce_cookie,
32
36
  set_oauth2_state_cookie,
33
37
  set_refresh_token_cookie,
34
38
  )
35
39
  from phoenix.config import (
40
+ AssignableUserRoleName,
36
41
  get_env_disable_basic_auth,
37
42
  get_env_disable_rate_limit,
38
43
  )
39
44
  from phoenix.db import models
40
- from phoenix.db.enums import UserRole
45
+ from phoenix.server.api.auth_messages import AuthErrorCode
41
46
  from phoenix.server.bearer_auth import create_access_and_refresh_tokens
47
+ from phoenix.server.ldap import is_ldap_user
42
48
  from phoenix.server.oauth2 import OAuth2Client
43
49
  from phoenix.server.rate_limiters import (
44
50
  ServerRateLimiter,
@@ -46,9 +52,12 @@ from phoenix.server.rate_limiters import (
46
52
  fastapi_route_rate_limiter,
47
53
  )
48
54
  from phoenix.server.types import TokenStore
55
+ from phoenix.server.utils import get_root_path, prepend_root_path
49
56
 
50
57
  _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
51
58
 
59
+ logger = logging.getLogger(__name__)
60
+
52
61
  login_rate_limiter = fastapi_ip_rate_limiter(
53
62
  ServerRateLimiter(
54
63
  per_second_rate_limit=0.2,
@@ -87,11 +96,12 @@ async def login(
87
96
  idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
88
97
  return_url: Optional[str] = Query(default=None, alias="returnUrl"),
89
98
  ) -> RedirectResponse:
99
+ # Security Note: Query parameters should be treated as untrusted user input. Never display
100
+ # these values directly to users as they could be manipulated for XSS, phishing, or social
101
+ # engineering attacks.
102
+ if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
103
+ return _redirect_to_login(request=request, error="unknown_idp")
90
104
  secret = request.app.state.get_secret()
91
- if not isinstance(
92
- oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
93
- ):
94
- return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
95
105
  if (referer := request.headers.get("referer")) is not None:
96
106
  # if the referer header is present, use it as the origin URL
97
107
  parsed_url = urlparse(referer)
@@ -112,7 +122,7 @@ async def login(
112
122
  assert isinstance(authorization_url := authorization_url_data.get("url"), str)
113
123
  assert isinstance(state := authorization_url_data.get("state"), str)
114
124
  assert isinstance(nonce := authorization_url_data.get("nonce"), str)
115
- response = RedirectResponse(url=authorization_url, status_code=HTTP_302_FOUND)
125
+ response = RedirectResponse(url=authorization_url, status_code=302)
116
126
  response = set_oauth2_state_cookie(
117
127
  response=response,
118
128
  state=state,
@@ -123,6 +133,12 @@ async def login(
123
133
  nonce=nonce,
124
134
  max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
125
135
  )
136
+ if code_verifier := authorization_url_data.get("code_verifier"):
137
+ response = set_oauth2_code_verifier_cookie(
138
+ response=response,
139
+ code_verifier=code_verifier,
140
+ max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
141
+ )
126
142
  return response
127
143
 
128
144
 
@@ -130,47 +146,107 @@ async def login(
130
146
  async def create_tokens(
131
147
  request: Request,
132
148
  idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
133
- state: str = Query(),
134
- authorization_code: str = Query(alias="code"),
149
+ state: str = Query(), # RFC 6749 §4.1.1: CSRF protection via state parameter
150
+ authorization_code: Optional[str] = Query(default=None, alias="code"), # RFC 6749 §4.1.2
151
+ error: Optional[str] = Query(default=None), # RFC 6749 §4.1.2.1: Error response
152
+ error_description: Optional[str] = Query(default=None),
135
153
  stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME),
136
- stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME),
154
+ stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME), # OIDC Core §3.1.2.1
155
+ code_verifier: Optional[str] = Cookie(
156
+ default=None, alias=PHOENIX_OAUTH2_CODE_VERIFIER_COOKIE_NAME
157
+ ), # RFC 7636 §4.1
137
158
  ) -> RedirectResponse:
159
+ # Security Note: Query parameters should be treated as untrusted user input. Never display
160
+ # these values directly to users as they could be manipulated for XSS, phishing, or social
161
+ # engineering attacks.
162
+ if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
163
+ return _redirect_to_login(request=request, error="unknown_idp")
164
+ if error or error_description:
165
+ logger.error(
166
+ "OAuth2 authentication failed for IDP %s: error=%s, description=%s",
167
+ idp_name,
168
+ error,
169
+ error_description,
170
+ )
171
+ return _redirect_to_login(request=request, error="auth_failed")
172
+ if authorization_code is None:
173
+ logger.error("OAuth2 callback missing authorization code for IDP %s", idp_name)
174
+ return _redirect_to_login(request=request, error="auth_failed")
138
175
  secret = request.app.state.get_secret()
176
+ # RFC 6749 §10.12: CSRF protection - validate state parameter
139
177
  if state != stored_state:
140
- return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
178
+ return _redirect_to_login(request=request, error="invalid_state")
141
179
  try:
142
180
  payload = _parse_state_payload(secret=secret, state=state)
143
181
  except JoseError:
144
- return _redirect_to_login(request=request, error=_INVALID_OAUTH2_STATE_MESSAGE)
182
+ return _redirect_to_login(request=request, error="invalid_state")
145
183
  if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
146
184
  unquote(return_url)
147
185
  ):
148
- return _redirect_to_login(request=request, error="Attempting login with unsafe return URL.")
186
+ return _redirect_to_login(request=request, error="unsafe_return_url")
149
187
  assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
150
188
  assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
151
189
  token_store: TokenStore = request.app.state.get_token_store()
152
- if not isinstance(
153
- oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
154
- ):
155
- return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
156
190
  try:
157
- token_data = await oauth2_client.fetch_access_token(
191
+ # RFC 6749 §4.1.3: Token request - exchange authorization code for tokens
192
+ fetch_kwargs: dict[str, Any] = dict(
158
193
  state=state,
159
194
  code=authorization_code,
160
- redirect_uri=_get_create_tokens_endpoint(
195
+ redirect_uri=_get_create_tokens_endpoint( # RFC 6749 §3.1.2
161
196
  request=request, origin_url=payload["origin_url"], idp_name=idp_name
162
197
  ),
163
198
  )
164
- except OAuthError as error:
165
- return _redirect_to_login(request=request, error=str(error))
199
+ # PKCE validation: code_verifier is required when PKCE is enabled (RFC 7636 §4.5)
200
+ if oauth2_client.use_pkce:
201
+ if not code_verifier:
202
+ logger.error(
203
+ "PKCE enabled but code_verifier cookie missing for IDP %s. "
204
+ "This may indicate a cookie issue, CORS misconfiguration, or "
205
+ "browser compatibility problem.",
206
+ idp_name,
207
+ )
208
+ return _redirect_to_login(request=request, error="auth_failed")
209
+ fetch_kwargs["code_verifier"] = code_verifier
210
+ token_data = await oauth2_client.fetch_access_token(**fetch_kwargs)
211
+ except OAuthError as e:
212
+ logger.error("OAuth2 error for IDP %s: %s", idp_name, e)
213
+ return _redirect_to_login(request=request, error="oauth_error")
166
214
  _validate_token_data(token_data)
167
215
  if "id_token" not in token_data:
168
- return _redirect_to_login(
169
- request=request,
170
- error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect.",
216
+ logger.error("OAuth2 IDP %s does not appear to support OpenID Connect", idp_name)
217
+ return _redirect_to_login(request=request, error="no_oidc_support")
218
+
219
+ try:
220
+ id_token_claims = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
221
+ except JoseError as e:
222
+ logger.error("ID token validation failed for IDP %s: %s", idp_name, e)
223
+ return _redirect_to_login(request=request, error="auth_failed")
224
+
225
+ if oauth2_client.has_sufficient_claims(id_token_claims):
226
+ user_claims = id_token_claims
227
+ else:
228
+ user_claims = await _fetch_and_merge_userinfo_claims(
229
+ oauth2_client, token_data, id_token_claims
171
230
  )
172
- user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
173
- user_info = _parse_user_info(user_info)
231
+
232
+ try:
233
+ user_info = _parse_user_info(user_claims)
234
+ except (MissingEmailScope, InvalidUserInfo) as e:
235
+ logger.error("Error parsing user info for IDP %s: %s", idp_name, e)
236
+ return _redirect_to_login(request=request, error="missing_email_scope")
237
+
238
+ # Validate access and extract role from claims
239
+ # Both validate_access and extract_and_map_role may raise PermissionError
240
+ try:
241
+ oauth2_client.validate_access(user_info.claims)
242
+ # Extract and map role from claims
243
+ # Returns None if role mapping not configured (preserves existing user roles)
244
+ # Raises PermissionError if strict mode enabled and role validation fails
245
+ role_name = oauth2_client.extract_and_map_role(user_info.claims)
246
+ except PermissionError as e:
247
+ logger.error("Access validation failed for IDP %s: %s", idp_name, e)
248
+ return _redirect_to_login(request=request, error="auth_failed")
249
+
174
250
  try:
175
251
  async with request.app.state.db() as session:
176
252
  user = await _process_oauth2_user(
@@ -178,19 +254,24 @@ async def create_tokens(
178
254
  oauth2_client_id=str(oauth2_client.client_id),
179
255
  user_info=user_info,
180
256
  allow_sign_up=oauth2_client.allow_sign_up,
257
+ role_name=role_name,
181
258
  )
182
- except (EmailAlreadyInUse, SignInNotAllowed) as error:
183
- return _redirect_to_login(request=request, error=str(error))
259
+ except EmailAlreadyInUse as e:
260
+ logger.error("Email already in use for IDP %s: %s", idp_name, e)
261
+ return _redirect_to_login(request=request, error="email_in_use")
262
+ except SignInNotAllowed as e:
263
+ logger.error("Sign in not allowed for IDP %s: %s", idp_name, e)
264
+ return _redirect_to_login(request=request, error="sign_in_not_allowed")
184
265
  access_token, refresh_token = await create_access_and_refresh_tokens(
185
266
  user=user,
186
267
  token_store=token_store,
187
268
  access_token_expiry=access_token_expiry,
188
269
  refresh_token_expiry=refresh_token_expiry,
189
270
  )
190
- redirect_path = _prepend_root_path_if_exists(request=request, path=return_url or "/")
271
+ redirect_path = prepend_root_path(request.scope, return_url or "/")
191
272
  response = RedirectResponse(
192
273
  url=redirect_path,
193
- status_code=HTTP_302_FOUND,
274
+ status_code=302,
194
275
  )
195
276
  response = set_access_token_cookie(
196
277
  response=response, access_token=access_token, max_age=access_token_expiry
@@ -200,6 +281,7 @@ async def create_tokens(
200
281
  )
201
282
  response = delete_oauth2_state_cookie(response)
202
283
  response = delete_oauth2_nonce_cookie(response)
284
+ response = delete_oauth2_code_verifier_cookie(response)
203
285
  return response
204
286
 
205
287
 
@@ -209,12 +291,13 @@ class UserInfo:
209
291
  email: str
210
292
  username: Optional[str] = None
211
293
  profile_picture_url: Optional[str] = None
294
+ claims: dict[str, Any] = field(default_factory=dict)
212
295
 
213
296
  def __post_init__(self) -> None:
214
297
  if not (idp_user_id := (self.idp_user_id or "").strip()):
215
298
  raise ValueError("idp_user_id cannot be empty")
216
299
  object.__setattr__(self, "idp_user_id", idp_user_id)
217
- if not (email := (self.email or "").strip()):
300
+ if not (email := sanitize_email(self.email or "")):
218
301
  raise ValueError("email cannot be empty")
219
302
  object.__setattr__(self, "email", email)
220
303
  if username := (self.username or "").strip():
@@ -223,9 +306,64 @@ class UserInfo:
223
306
  object.__setattr__(self, "profile_picture_url", profile_picture_url)
224
307
 
225
308
 
309
+ async def _fetch_and_merge_userinfo_claims(
310
+ oauth2_client: OAuth2Client,
311
+ token_data: dict[str, Any],
312
+ id_token_claims: dict[str, Any],
313
+ ) -> dict[str, Any]:
314
+ """
315
+ Fetch claims from UserInfo endpoint and merge with ID token claims.
316
+
317
+ Why this is necessary (OIDC Core §5.4, §5.5):
318
+ When claims are requested via scopes (e.g., "profile", "email"), OIDC Core §5.4
319
+ specifies which claims are "REQUESTED" but does not mandate WHERE they must be
320
+ returned. Similarly, §5.5 allows requesting specific claims via the "claims"
321
+ parameter, but providers have discretion on whether to return them in the ID token
322
+ or UserInfo response. In practice, providers often return certain claims (especially
323
+ large ones like groups) only via UserInfo to keep ID tokens compact.
324
+
325
+ The UserInfo endpoint (OIDC Core §5.3) provides additional claims beyond what's
326
+ in the ID token, such as group memberships or custom attributes. This function:
327
+
328
+ 1. Calls the UserInfo endpoint using the access token (OIDC Core §5.3.1, RFC 6750)
329
+ 2. Merges userinfo claims with ID token claims
330
+ 3. ID token claims override userinfo claims when both contain the same claim
331
+
332
+ Why ID token takes precedence (OIDC Core §5.3.2):
333
+ - ID tokens are signed JWTs that have been cryptographically verified
334
+ - UserInfo responses may be unsigned
335
+ - Signed claims are the authoritative source when present in both
336
+
337
+ Fallback behavior:
338
+ If the UserInfo request fails, returns only ID token claims. The returned claims
339
+ may be incomplete (missing email or groups), but subsequent validation will catch this:
340
+ - Missing email: _parse_user_info() raises MissingEmailScope
341
+ - Missing groups: validate_access() raises PermissionError if access is denied
342
+
343
+ Args:
344
+ oauth2_client: The OAuth2 client to use for fetching userinfo
345
+ token_data: Token response containing the access token (RFC 6749 §5.1)
346
+ id_token_claims: Claims from the verified ID token (OIDC Core §3.1.3.3)
347
+
348
+ Returns:
349
+ Merged claims dictionary with ID token claims overriding userinfo claims
350
+ """
351
+ try:
352
+ # OIDC Core §5.3.1: UserInfo request authenticated with access token
353
+ userinfo_claims = await oauth2_client.userinfo(token=token_data)
354
+ # ID token claims take precedence (signed and verified)
355
+ return {**userinfo_claims, **id_token_claims}
356
+ except Exception:
357
+ # Fallback: ID token has essential claims for authentication
358
+ return id_token_claims
359
+
360
+
226
361
  def _validate_token_data(token_data: dict[str, Any]) -> None:
227
362
  """
228
363
  Performs basic validations on the token data returned by the IDP.
364
+
365
+ RFC 6749 §5.1: Successful response must include access_token and token_type.
366
+ RFC 6750 §1.1: Bearer token type for HTTP authentication.
229
367
  """
230
368
  assert isinstance(token_data.get("access_token"), str)
231
369
  assert isinstance(token_type := token_data.get("token_type"), str)
@@ -235,20 +373,107 @@ def _validate_token_data(token_data: dict[str, Any]) -> None:
235
373
  def _parse_user_info(user_info: dict[str, Any]) -> UserInfo:
236
374
  """
237
375
  Parses user info from the IDP's ID token.
238
- """
239
- assert isinstance(subject := user_info.get("sub"), (str, int))
240
- idp_user_id = str(subject)
241
- assert isinstance(email := user_info.get("email"), str)
242
- assert isinstance(username := user_info.get("name"), str) or username is None
243
- assert (
244
- isinstance(profile_picture_url := user_info.get("picture"), str)
245
- or profile_picture_url is None
246
- )
376
+
377
+ Validates required OIDC claims and extracts user information according to the
378
+ OpenID Connect Core 1.0 specification.
379
+
380
+ Args:
381
+ user_info: Claims from the ID token (validated JWT payload)
382
+
383
+ Returns:
384
+ UserInfo object with validated user data
385
+
386
+ Raises:
387
+ InvalidUserInfo: If required claims are missing or malformed
388
+ MissingEmailScope: If email claim is missing or invalid
389
+
390
+ ID Token Required Claims (OIDC Core §2, §3.1.3.3):
391
+ - iss (issuer): Identifier for the OpenID Provider
392
+ - sub (subject): Unique identifier for the End-User at the Issuer
393
+ - aud (audience): Client ID this ID token is intended for
394
+ - exp (expiration): Expiration time
395
+ - iat (issued at): Time the JWT was issued
396
+ - nonce (if sent in auth request): Value sent in the Authentication Request
397
+
398
+ Application-Required Claims:
399
+ - email: Required by this application for user identification
400
+
401
+ Optional Standard Claims (OIDC Core §5.1):
402
+ - name: Full name
403
+ - picture: Profile picture URL
404
+ - Other profile, email, address, and phone claims
405
+
406
+ Note: While iss, sub, aud, exp, iat are REQUIRED in all ID tokens per spec,
407
+ other claims like email, name, groups are optional and may appear in the ID token,
408
+ UserInfo response, or both depending on what was requested and provider implementation.
409
+ """
410
+ # Validate 'sub' claim (OIDC required, MUST be a string per spec)
411
+ subject = user_info.get("sub")
412
+ if subject is None:
413
+ raise InvalidUserInfo(
414
+ "Missing required 'sub' claim in ID token. "
415
+ "Please check your OIDC provider configuration."
416
+ )
417
+
418
+ # OIDC spec: sub MUST be a string, but some IDPs send integers
419
+ # Convert to string for compatibility
420
+ if isinstance(subject, (str, int)):
421
+ idp_user_id = str(subject).strip()
422
+ else:
423
+ raise InvalidUserInfo(
424
+ f"Invalid 'sub' claim type: {type(subject).__name__}. Expected string or integer."
425
+ )
426
+
427
+ if not idp_user_id:
428
+ raise InvalidUserInfo("The 'sub' claim cannot be empty.")
429
+
430
+ # Validate 'email' claim (application requirement)
431
+ email = user_info.get("email")
432
+ if not isinstance(email, str) or not email.strip():
433
+ raise MissingEmailScope(
434
+ "Missing or invalid 'email' claim. "
435
+ "Please ensure your OIDC provider is configured to include the 'email' scope."
436
+ )
437
+ email = email.strip()
438
+
439
+ # Optional: 'name' claim (Full name)
440
+ username = user_info.get("name")
441
+ if username is not None:
442
+ if not isinstance(username, str):
443
+ # Some IDPs might send unexpected types; ignore gracefully
444
+ username = None
445
+ else:
446
+ username = username.strip() or None
447
+
448
+ # Optional: 'picture' claim (Profile picture URL)
449
+ profile_picture_url = user_info.get("picture")
450
+ if profile_picture_url is not None:
451
+ if not isinstance(profile_picture_url, str):
452
+ # Some IDPs might send unexpected types; ignore gracefully
453
+ profile_picture_url = None
454
+ else:
455
+ profile_picture_url = profile_picture_url.strip() or None
456
+
457
+ # Keep only non-empty claim values for downstream processing
458
+ def _has_value(v: Any) -> bool:
459
+ """Check if a claim value is considered non-empty."""
460
+ if v is None:
461
+ return False
462
+ if isinstance(v, str):
463
+ return bool(v.strip())
464
+ if isinstance(v, (list, dict, set, tuple)):
465
+ return len(v) > 0
466
+ # Include all other types (numbers, booleans, etc.)
467
+ return True
468
+
469
+ filtered_claims = {k: v for k, v in user_info.items() if _has_value(v)}
470
+
247
471
  return UserInfo(
248
472
  idp_user_id=idp_user_id,
249
473
  email=email,
250
474
  username=username,
251
475
  profile_picture_url=profile_picture_url,
476
+ claims=filtered_claims,
252
477
  )
253
478
 
254
479
 
@@ -259,6 +484,7 @@ async def _process_oauth2_user(
259
484
  oauth2_client_id: str,
260
485
  user_info: UserInfo,
261
486
  allow_sign_up: bool,
487
+ role_name: Optional[AssignableUserRoleName],
262
488
  ) -> models.User:
263
489
  """
264
490
  Processes an OAuth2 user, either signing in an existing user or creating/updating one.
@@ -267,11 +493,12 @@ async def _process_oauth2_user(
267
493
  1. When sign-up is not allowed (allow_sign_up=False):
268
494
  - Checks if the user exists and can sign in with the given OAuth2 credentials
269
495
  - Updates placeholder OAuth2 credentials if needed (e.g., temporary IDs)
496
+ - Updates the user's role if role_name is provided (role mapping configured)
270
497
  - If the user doesn't exist or has a password set, raises SignInNotAllowed
271
498
  2. When sign-up is allowed (allow_sign_up=True):
272
499
  - Finds the user by OAuth2 credentials (client_id and user_id)
273
- - Creates a new user if one doesn't exist, with default member role
274
- - Updates the user's email if it has changed
500
+ - Creates a new user if one doesn't exist, with the provided role (or VIEWER if None)
501
+ - Updates the user's email and role (if role_name provided) if they have changed
275
502
  - Handles username conflicts by adding a random suffix if needed
276
503
 
277
504
  The allow_sign_up parameter is typically controlled by the PHOENIX_OAUTH2_{IDP_NAME}_ALLOW_SIGN_UP
@@ -282,6 +509,8 @@ async def _process_oauth2_user(
282
509
  oauth2_client_id: The ID of the OAuth2 client
283
510
  user_info: User information from the OAuth2 provider
284
511
  allow_sign_up: Whether to allow creating new users
512
+ role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to preserve
513
+ existing user roles (backward compatibility when role mapping not configured)
285
514
 
286
515
  Returns:
287
516
  The user object
@@ -291,24 +520,27 @@ async def _process_oauth2_user(
291
520
  EmailAlreadyInUse: When the email is already in use by another account
292
521
  """ # noqa: E501
293
522
  if not allow_sign_up:
294
- return await _get_existing_oauth2_user(
523
+ return await _sign_in_existing_oauth2_user(
295
524
  session,
296
525
  oauth2_client_id=oauth2_client_id,
297
526
  user_info=user_info,
527
+ role_name=role_name,
298
528
  )
299
529
  return await _create_or_update_user(
300
530
  session,
301
531
  oauth2_client_id=oauth2_client_id,
302
532
  user_info=user_info,
533
+ role_name=role_name,
303
534
  )
304
535
 
305
536
 
306
- async def _get_existing_oauth2_user(
537
+ async def _sign_in_existing_oauth2_user(
307
538
  session: AsyncSession,
308
539
  /,
309
540
  *,
310
541
  oauth2_client_id: str,
311
542
  user_info: UserInfo,
543
+ role_name: Optional[AssignableUserRoleName],
312
544
  ) -> models.User:
313
545
  """Signs in an existing user with OAuth2 credentials.
314
546
 
@@ -330,6 +562,7 @@ async def _get_existing_oauth2_user(
330
562
  Profile Updates:
331
563
  - Email: Updated if different from IDP info
332
564
  - Profile Picture: Updated if provided in user_info
565
+ - Role: Updated ONLY if role_name is provided (role mapping configured)
333
566
  - Username: Never updated (remains unchanged)
334
567
  - OAuth2 Credentials: Updated based on the three cases above
335
568
 
@@ -337,6 +570,8 @@ async def _get_existing_oauth2_user(
337
570
  session: The database session
338
571
  oauth2_client_id: The ID of the OAuth2 client
339
572
  user_info: User information from the OAuth2 provider
573
+ role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to preserve
574
+ existing role (backward compatibility when role mapping not configured)
340
575
 
341
576
  Returns:
342
577
  The signed-in user
@@ -348,7 +583,7 @@ async def _get_existing_oauth2_user(
348
583
  - User has a password set
349
584
  - User has mismatched OAuth2 credentials
350
585
  """ # noqa: E501
351
- if not (email := (user_info.email or "").strip()):
586
+ if not (email := sanitize_email(user_info.email or "")):
352
587
  raise ValueError("Email is required.")
353
588
  if not (oauth2_user_id := (user_info.idp_user_id or "").strip()):
354
589
  raise ValueError("OAuth2 user ID is required.")
@@ -362,9 +597,14 @@ async def _get_existing_oauth2_user(
362
597
  if email and email != user.email:
363
598
  user.email = email
364
599
  else:
365
- user = await session.scalar(stmt.filter_by(email=email))
600
+ user = await session.scalar(stmt.where(func.lower(models.User.email) == email))
366
601
  if user is None or not isinstance(user, models.OAuth2User):
367
602
  raise SignInNotAllowed("Sign in is not allowed.")
603
+ # Security: Prevent OIDC from hijacking LDAP users
604
+ # LDAP users are identified by the special Unicode marker in oauth2_client_id
605
+ # Use generic error message to avoid revealing auth method (username enumeration)
606
+ if is_ldap_user(user.oauth2_client_id):
607
+ raise SignInNotAllowed("Sign in is not allowed.")
368
608
  # Case 1: Different OAuth2 client - update both client and user IDs
369
609
  if oauth2_client_id != user.oauth2_client_id:
370
610
  user.oauth2_client_id = oauth2_client_id
@@ -377,6 +617,16 @@ async def _get_existing_oauth2_user(
377
617
  raise SignInNotAllowed("Sign in is not allowed.")
378
618
  if profile_picture_url != user.profile_picture_url:
379
619
  user.profile_picture_url = profile_picture_url
620
+
621
+ # Update role ONLY if role mapping is configured (role_name is not None)
622
+ # This preserves existing user roles when role mapping is not configured
623
+ if role_name is not None and user.role.name != role_name:
624
+ role = await session.scalar(
625
+ select(models.UserRole).where(models.UserRole.name == role_name)
626
+ )
627
+ if role is not None:
628
+ user.role = role
629
+
380
630
  if user in session.dirty:
381
631
  await session.flush()
382
632
  return user
@@ -388,6 +638,7 @@ async def _create_or_update_user(
388
638
  *,
389
639
  oauth2_client_id: str,
390
640
  user_info: UserInfo,
641
+ role_name: Optional[AssignableUserRoleName],
391
642
  ) -> models.User:
392
643
  """
393
644
  Creates a new user or updates an existing one with OAuth2 credentials.
@@ -396,6 +647,8 @@ async def _create_or_update_user(
396
647
  session: The database session
397
648
  oauth2_client_id: The ID of the OAuth2 client
398
649
  user_info: User information from the OAuth2 provider
650
+ role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER), or None to use
651
+ VIEWER for new users and preserve existing users' roles (backward compatibility)
399
652
 
400
653
  Returns:
401
654
  The created or updated user
@@ -409,9 +662,37 @@ async def _create_or_update_user(
409
662
  idp_user_id=user_info.idp_user_id,
410
663
  )
411
664
  if user is None:
412
- user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info)
413
- elif user.email != user_info.email:
414
- user = await _update_user_email(session, user_id=user.id, email=user_info.email)
665
+ # New user: use provided role_name, or default to VIEWER if role mapping not configured
666
+ user = await _create_user(
667
+ session,
668
+ oauth2_client_id=oauth2_client_id,
669
+ user_info=user_info,
670
+ role_name=role_name or "VIEWER", # Default for new users
671
+ )
672
+ else:
673
+ # Existing user: update email, profile picture, and/or role if changed
674
+ if user.email != user_info.email:
675
+ user.email = user_info.email
676
+
677
+ # Update profile picture if changed
678
+ if user.profile_picture_url != user_info.profile_picture_url:
679
+ user.profile_picture_url = user_info.profile_picture_url
680
+
681
+ # Update role ONLY if role mapping is configured (role_name is not None)
682
+ # This preserves existing user roles when role mapping is not configured
683
+ if role_name is not None and user.role.name != role_name:
684
+ role = await session.scalar(
685
+ select(models.UserRole).where(models.UserRole.name == role_name)
686
+ )
687
+ if role is not None:
688
+ user.role = role
689
+
690
+ # Flush to execute the UPDATE and catch any email conflicts
691
+ if user in session.dirty:
692
+ try:
693
+ await session.flush()
694
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
695
+ raise EmailAlreadyInUse(f"An account for {user_info.email} is already in use.")
415
696
  return user
416
697
 
417
698
 
@@ -445,9 +726,19 @@ async def _create_user(
445
726
  *,
446
727
  oauth2_client_id: str,
447
728
  user_info: UserInfo,
729
+ role_name: AssignableUserRoleName,
448
730
  ) -> models.User:
449
731
  """
450
732
  Creates a new user with the user info from the IDP.
733
+
734
+ Args:
735
+ session: The database session
736
+ oauth2_client_id: The ID of the OAuth2 client
737
+ user_info: User information from the OAuth2 provider
738
+ role_name: The Phoenix role name to assign (ADMIN, MEMBER, VIEWER)
739
+
740
+ Returns:
741
+ The created user
451
742
  """
452
743
  email_exists, username_exists = await _email_and_username_exist(
453
744
  session,
@@ -456,16 +747,12 @@ async def _create_user(
456
747
  )
457
748
  if email_exists:
458
749
  raise EmailAlreadyInUse(f"An account for {email} is already in use.")
459
- member_role_id = (
460
- select(models.UserRole.id)
461
- .where(models.UserRole.name == UserRole.MEMBER.value)
462
- .scalar_subquery()
463
- )
750
+ role_id = select(models.UserRole.id).where(models.UserRole.name == role_name).scalar_subquery()
464
751
  user_id = await session.scalar(
465
752
  insert(models.User)
466
753
  .returning(models.User.id)
467
754
  .values(
468
- user_role_id=member_role_id,
755
+ user_role_id=role_id,
469
756
  oauth2_client_id=oauth2_client_id,
470
757
  oauth2_user_id=user_info.idp_user_id,
471
758
  username=_with_random_suffix(username) if username and username_exists else username,
@@ -483,26 +770,6 @@ async def _create_user(
483
770
  return user
484
771
 
485
772
 
486
- async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: str) -> models.User:
487
- """
488
- Updates an existing user's email.
489
- """
490
- try:
491
- await session.execute(
492
- update(models.User)
493
- .where(models.User.id == user_id)
494
- .values(email=email)
495
- .options(joinedload(models.User.role))
496
- )
497
- except (PostgreSQLIntegrityError, SQLiteIntegrityError):
498
- raise EmailAlreadyInUse(f"An account for {email} is already in use.")
499
- user = await session.scalar(
500
- select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
501
- ) # query user again for joined load
502
- assert isinstance(user, models.User)
503
- return user
504
-
505
-
506
773
  async def _email_and_username_exist(
507
774
  session: AsyncSession, /, *, email: str, username: Optional[str]
508
775
  ) -> tuple[bool, bool]:
@@ -514,7 +781,7 @@ async def _email_and_username_exist(
514
781
  select(
515
782
  cast(
516
783
  func.coalesce(
517
- func.max(case((models.User.email == email, 1), else_=0)),
784
+ func.max(case((func.lower(models.User.email) == email, 1), else_=0)),
518
785
  0,
519
786
  ),
520
787
  Boolean,
@@ -526,7 +793,7 @@ async def _email_and_username_exist(
526
793
  ),
527
794
  Boolean,
528
795
  ).label("username_exists"),
529
- ).where(or_(models.User.email == email, models.User.username == username))
796
+ ).where(or_(func.lower(models.User.email) == email, models.User.username == username))
530
797
  )
531
798
  ).all()
532
799
  return email_exists, username_exists
@@ -544,49 +811,48 @@ class NotInvited(Exception):
544
811
  pass
545
812
 
546
813
 
547
- def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
814
+ class MissingEmailScope(Exception):
815
+ """
816
+ Raised when the OIDC provider does not return the email scope.
817
+ """
818
+
819
+ pass
820
+
821
+
822
+ class InvalidUserInfo(Exception):
823
+ """
824
+ Raised when the OIDC user info is malformed or missing required claims.
825
+ """
826
+
827
+ pass
828
+
829
+
830
+ def _redirect_to_login(*, request: Request, error: AuthErrorCode) -> RedirectResponse:
548
831
  """
549
- Creates a RedirectResponse to the login page to display an error message.
832
+ Creates a RedirectResponse to the login page to display an error code.
833
+ The error code will be validated and mapped to a user-friendly message on the frontend.
550
834
  """
551
835
  # TODO: this needs some cleanup
552
- login_path = _prepend_root_path_if_exists(
553
- request=request, path="/login" if not get_env_disable_basic_auth() else "/logout"
836
+ login_path = prepend_root_path(
837
+ request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
554
838
  )
555
839
  url = URL(login_path).include_query_params(error=error)
556
840
  response = RedirectResponse(url=url)
557
841
  response = delete_oauth2_state_cookie(response)
558
842
  response = delete_oauth2_nonce_cookie(response)
843
+ response = delete_oauth2_code_verifier_cookie(response)
559
844
  return response
560
845
 
561
846
 
562
- def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
563
- """
564
- If a root path is configured, prepends it to the input path.
565
- """
566
- if not path.startswith("/"):
567
- raise ValueError("path must start with a forward slash")
568
- root_path = _get_root_path(request=request)
569
- if root_path.endswith("/"):
570
- root_path = root_path.rstrip("/")
571
- return root_path + path
572
-
573
-
574
847
  def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
575
848
  """
576
849
  If a root path is configured, appends it to the input base url.
577
850
  """
578
- if not (root_path := _get_root_path(request=request)):
851
+ if not (root_path := get_root_path(request.scope)):
579
852
  return base_url
580
853
  return str(URLPath(root_path).make_absolute_url(base_url=base_url))
581
854
 
582
855
 
583
- def _get_root_path(*, request: Request) -> str:
584
- """
585
- Gets the root path from the request.
586
- """
587
- return str(request.scope.get("root_path", ""))
588
-
589
-
590
856
  def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
591
857
  """
592
858
  Gets the endpoint for create tokens route.
@@ -664,7 +930,4 @@ def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2State
664
930
 
665
931
 
666
932
  _JWT_ALGORITHM = "HS256"
667
- _INVALID_OAUTH2_STATE_MESSAGE = (
668
- "Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}."
669
- )
670
933
  _RELATIVE_URL_PATTERN = re.compile(r"^/($|\w)")