byoi 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- byoi/__init__.py +233 -0
- byoi/__main__.py +228 -0
- byoi/cache.py +346 -0
- byoi/config.py +349 -0
- byoi/dependencies.py +144 -0
- byoi/errors.py +360 -0
- byoi/models.py +451 -0
- byoi/pkce.py +144 -0
- byoi/providers.py +434 -0
- byoi/py.typed +0 -0
- byoi/repositories.py +252 -0
- byoi/service.py +723 -0
- byoi/telemetry.py +352 -0
- byoi/tokens.py +340 -0
- byoi/types.py +130 -0
- byoi-0.1.0a1.dist-info/METADATA +504 -0
- byoi-0.1.0a1.dist-info/RECORD +20 -0
- byoi-0.1.0a1.dist-info/WHEEL +4 -0
- byoi-0.1.0a1.dist-info/entry_points.txt +3 -0
- byoi-0.1.0a1.dist-info/licenses/LICENSE +21 -0
byoi/models.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
"""Pydantic models for the BYOI library.
|
|
2
|
+
|
|
3
|
+
These models define the input and output schemas for BYOI operations.
|
|
4
|
+
They can be used directly in FastAPI route definitions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
12
|
+
|
|
13
|
+
__all__ = (
|
|
14
|
+
# Enums
|
|
15
|
+
"ClientType",
|
|
16
|
+
# Provider Models
|
|
17
|
+
"OIDCProviderConfig",
|
|
18
|
+
"OIDCProviderInfo",
|
|
19
|
+
# Auth Flow Models
|
|
20
|
+
"AuthorizationRequest",
|
|
21
|
+
"AuthorizationResponse",
|
|
22
|
+
"TokenExchangeRequest",
|
|
23
|
+
"TokenExchangeResponse",
|
|
24
|
+
"TokenRefreshRequest",
|
|
25
|
+
"TokenRefreshResponse",
|
|
26
|
+
# Identity Models
|
|
27
|
+
"IdentityInfo",
|
|
28
|
+
"LinkedIdentityInfo",
|
|
29
|
+
"LinkIdentityRequest",
|
|
30
|
+
"UnlinkIdentityRequest",
|
|
31
|
+
# User Models
|
|
32
|
+
"AuthenticatedUser",
|
|
33
|
+
"UserIdentities",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ClientType(str, Enum):
|
|
38
|
+
"""Type of client initiating authentication."""
|
|
39
|
+
|
|
40
|
+
WEB = "web"
|
|
41
|
+
DESKTOP = "desktop"
|
|
42
|
+
MOBILE = "mobile"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# =============================================================================
|
|
46
|
+
# Provider Models
|
|
47
|
+
# =============================================================================
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class OIDCProviderConfig(BaseModel):
|
|
51
|
+
"""Configuration for an OIDC provider.
|
|
52
|
+
|
|
53
|
+
This model is used to register a provider with BYOI.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
model_config = ConfigDict(frozen=True)
|
|
57
|
+
|
|
58
|
+
name: str = Field(
|
|
59
|
+
...,
|
|
60
|
+
description="Unique name for this provider (e.g., 'google', 'microsoft'). "
|
|
61
|
+
"Must start with a letter and contain only lowercase letters, numbers, underscores, or hyphens.",
|
|
62
|
+
min_length=1,
|
|
63
|
+
max_length=64,
|
|
64
|
+
)
|
|
65
|
+
display_name: str = Field(
|
|
66
|
+
...,
|
|
67
|
+
description="Human-readable name for display purposes",
|
|
68
|
+
min_length=1,
|
|
69
|
+
max_length=128,
|
|
70
|
+
)
|
|
71
|
+
issuer: str = Field(
|
|
72
|
+
...,
|
|
73
|
+
description="The OIDC issuer URL (used to discover endpoints)",
|
|
74
|
+
)
|
|
75
|
+
client_id: str = Field(
|
|
76
|
+
...,
|
|
77
|
+
description="OAuth client ID",
|
|
78
|
+
min_length=1,
|
|
79
|
+
)
|
|
80
|
+
client_secret: str | None = Field(
|
|
81
|
+
default=None,
|
|
82
|
+
description="OAuth client secret (not required for public clients)",
|
|
83
|
+
)
|
|
84
|
+
scopes: list[str] = Field(
|
|
85
|
+
default=["openid", "email", "profile"],
|
|
86
|
+
description="OAuth scopes to request",
|
|
87
|
+
)
|
|
88
|
+
extra_auth_params: dict[str, str] = Field(
|
|
89
|
+
default_factory=dict,
|
|
90
|
+
description="Extra parameters to include in the authorization request",
|
|
91
|
+
)
|
|
92
|
+
jwks_uri_override: str | None = Field(
|
|
93
|
+
default=None,
|
|
94
|
+
description="Override the JWKS URI from discovery",
|
|
95
|
+
)
|
|
96
|
+
authorization_endpoint_override: str | None = Field(
|
|
97
|
+
default=None,
|
|
98
|
+
description="Override the authorization endpoint from discovery",
|
|
99
|
+
)
|
|
100
|
+
token_endpoint_override: str | None = Field(
|
|
101
|
+
default=None,
|
|
102
|
+
description="Override the token endpoint from discovery",
|
|
103
|
+
)
|
|
104
|
+
userinfo_endpoint_override: str | None = Field(
|
|
105
|
+
default=None,
|
|
106
|
+
description="Override the userinfo endpoint from discovery",
|
|
107
|
+
)
|
|
108
|
+
audience: str | None = Field(
|
|
109
|
+
default=None,
|
|
110
|
+
description="Expected audience claim (defaults to client_id)",
|
|
111
|
+
)
|
|
112
|
+
clock_skew_seconds: int = Field(
|
|
113
|
+
default=60,
|
|
114
|
+
description="Allowed clock skew for token validation (seconds)",
|
|
115
|
+
ge=0,
|
|
116
|
+
le=300,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@field_validator("name", mode="before")
|
|
120
|
+
@classmethod
|
|
121
|
+
def validate_name(cls, v: str) -> str:
|
|
122
|
+
"""Validate and normalize the provider name.
|
|
123
|
+
|
|
124
|
+
Automatically converts to lowercase and validates the format.
|
|
125
|
+
"""
|
|
126
|
+
import re
|
|
127
|
+
|
|
128
|
+
if not isinstance(v, str):
|
|
129
|
+
raise ValueError("Provider name must be a string")
|
|
130
|
+
|
|
131
|
+
# Auto-lowercase for convenience
|
|
132
|
+
v = v.lower()
|
|
133
|
+
|
|
134
|
+
# Validate format
|
|
135
|
+
if not re.match(r"^[a-z][a-z0-9_-]*$", v):
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Provider name '{v}' is invalid. "
|
|
138
|
+
"Name must start with a letter and contain only lowercase letters, "
|
|
139
|
+
"numbers, underscores, or hyphens (e.g., 'google', 'my-provider', 'auth_provider1')."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return v
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class OIDCProviderInfo(BaseModel):
|
|
146
|
+
"""Public information about a configured OIDC provider."""
|
|
147
|
+
|
|
148
|
+
model_config = ConfigDict(frozen=True)
|
|
149
|
+
|
|
150
|
+
name: str = Field(..., description="Unique name for this provider")
|
|
151
|
+
display_name: str = Field(..., description="Human-readable name")
|
|
152
|
+
issuer: str = Field(..., description="The OIDC issuer URL")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# =============================================================================
|
|
156
|
+
# Auth Flow Models
|
|
157
|
+
# =============================================================================
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class AuthorizationRequest(BaseModel):
|
|
161
|
+
"""Request to initiate an authorization flow."""
|
|
162
|
+
|
|
163
|
+
model_config = ConfigDict(frozen=True)
|
|
164
|
+
|
|
165
|
+
provider_name: str = Field(
|
|
166
|
+
...,
|
|
167
|
+
description="Name of the provider to use",
|
|
168
|
+
)
|
|
169
|
+
redirect_uri: str = Field(
|
|
170
|
+
...,
|
|
171
|
+
description="URI to redirect to after authorization",
|
|
172
|
+
)
|
|
173
|
+
client_type: ClientType = Field(
|
|
174
|
+
default=ClientType.WEB,
|
|
175
|
+
description="Type of client initiating the request",
|
|
176
|
+
)
|
|
177
|
+
extra_data: dict[str, Any] = Field(
|
|
178
|
+
default_factory=dict,
|
|
179
|
+
description="Extra data to store with the auth state",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class AuthorizationResponse(BaseModel):
|
|
184
|
+
"""Response containing the authorization URL and state."""
|
|
185
|
+
|
|
186
|
+
model_config = ConfigDict(frozen=True)
|
|
187
|
+
|
|
188
|
+
authorization_url: str = Field(
|
|
189
|
+
...,
|
|
190
|
+
description="URL to redirect the user to for authorization",
|
|
191
|
+
)
|
|
192
|
+
state: str = Field(
|
|
193
|
+
...,
|
|
194
|
+
description="State parameter (store this for validation)",
|
|
195
|
+
)
|
|
196
|
+
provider_name: str = Field(
|
|
197
|
+
...,
|
|
198
|
+
description="Name of the provider",
|
|
199
|
+
)
|
|
200
|
+
expires_at: datetime = Field(
|
|
201
|
+
...,
|
|
202
|
+
description="When this authorization request expires",
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class TokenExchangeRequest(BaseModel):
|
|
207
|
+
"""Request to exchange an authorization code for tokens."""
|
|
208
|
+
|
|
209
|
+
model_config = ConfigDict(frozen=True)
|
|
210
|
+
|
|
211
|
+
code: str = Field(
|
|
212
|
+
...,
|
|
213
|
+
description="The authorization code from the callback",
|
|
214
|
+
min_length=1,
|
|
215
|
+
)
|
|
216
|
+
state: str = Field(
|
|
217
|
+
...,
|
|
218
|
+
description="The state parameter from the callback",
|
|
219
|
+
min_length=1,
|
|
220
|
+
)
|
|
221
|
+
redirect_uri: str = Field(
|
|
222
|
+
...,
|
|
223
|
+
description="The redirect URI used in the authorization request",
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class TokenExchangeResponse(BaseModel):
|
|
228
|
+
"""Response from a successful token exchange."""
|
|
229
|
+
|
|
230
|
+
model_config = ConfigDict(frozen=True)
|
|
231
|
+
|
|
232
|
+
identity: "IdentityInfo" = Field(
|
|
233
|
+
...,
|
|
234
|
+
description="Identity information from the ID token",
|
|
235
|
+
)
|
|
236
|
+
access_token: str = Field(
|
|
237
|
+
...,
|
|
238
|
+
description="The access token",
|
|
239
|
+
)
|
|
240
|
+
token_type: str = Field(
|
|
241
|
+
default="Bearer",
|
|
242
|
+
description="The token type",
|
|
243
|
+
)
|
|
244
|
+
expires_in: int | None = Field(
|
|
245
|
+
default=None,
|
|
246
|
+
description="Token expiration time in seconds",
|
|
247
|
+
)
|
|
248
|
+
refresh_token: str | None = Field(
|
|
249
|
+
default=None,
|
|
250
|
+
description="The refresh token (if provided)",
|
|
251
|
+
)
|
|
252
|
+
id_token: str = Field(
|
|
253
|
+
...,
|
|
254
|
+
description="The raw ID token",
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class TokenRefreshRequest(BaseModel):
|
|
259
|
+
"""Request to refresh an access token using a refresh token."""
|
|
260
|
+
|
|
261
|
+
model_config = ConfigDict(frozen=True)
|
|
262
|
+
|
|
263
|
+
provider_name: str = Field(
|
|
264
|
+
...,
|
|
265
|
+
description="Name of the provider that issued the token",
|
|
266
|
+
min_length=1,
|
|
267
|
+
)
|
|
268
|
+
refresh_token: str = Field(
|
|
269
|
+
...,
|
|
270
|
+
description="The refresh token",
|
|
271
|
+
min_length=1,
|
|
272
|
+
)
|
|
273
|
+
scope: str | None = Field(
|
|
274
|
+
default=None,
|
|
275
|
+
description="Optional scope to request (defaults to original scope)",
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class TokenRefreshResponse(BaseModel):
|
|
280
|
+
"""Response from a successful token refresh."""
|
|
281
|
+
|
|
282
|
+
model_config = ConfigDict(frozen=True)
|
|
283
|
+
|
|
284
|
+
access_token: str = Field(
|
|
285
|
+
...,
|
|
286
|
+
description="The new access token",
|
|
287
|
+
)
|
|
288
|
+
token_type: str = Field(
|
|
289
|
+
default="Bearer",
|
|
290
|
+
description="The token type",
|
|
291
|
+
)
|
|
292
|
+
expires_in: int | None = Field(
|
|
293
|
+
default=None,
|
|
294
|
+
description="Token expiration time in seconds",
|
|
295
|
+
)
|
|
296
|
+
refresh_token: str | None = Field(
|
|
297
|
+
default=None,
|
|
298
|
+
description="The new refresh token (if rotated)",
|
|
299
|
+
)
|
|
300
|
+
scope: str | None = Field(
|
|
301
|
+
default=None,
|
|
302
|
+
description="The scope of the access token",
|
|
303
|
+
)
|
|
304
|
+
id_token: str | None = Field(
|
|
305
|
+
default=None,
|
|
306
|
+
description="New ID token (if provided by the provider)",
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# =============================================================================
|
|
311
|
+
# Identity Models
|
|
312
|
+
# =============================================================================
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class IdentityInfo(BaseModel):
|
|
316
|
+
"""Identity information extracted from an ID token."""
|
|
317
|
+
|
|
318
|
+
model_config = ConfigDict(frozen=True)
|
|
319
|
+
|
|
320
|
+
provider_name: str = Field(
|
|
321
|
+
...,
|
|
322
|
+
description="Name of the provider",
|
|
323
|
+
)
|
|
324
|
+
subject: str = Field(
|
|
325
|
+
...,
|
|
326
|
+
description="Subject identifier (unique per provider)",
|
|
327
|
+
)
|
|
328
|
+
email: str | None = Field(
|
|
329
|
+
default=None,
|
|
330
|
+
description="Email address",
|
|
331
|
+
)
|
|
332
|
+
email_verified: bool = Field(
|
|
333
|
+
default=False,
|
|
334
|
+
description="Whether the email is verified",
|
|
335
|
+
)
|
|
336
|
+
name: str | None = Field(
|
|
337
|
+
default=None,
|
|
338
|
+
description="Full name",
|
|
339
|
+
)
|
|
340
|
+
given_name: str | None = Field(
|
|
341
|
+
default=None,
|
|
342
|
+
description="Given/first name",
|
|
343
|
+
)
|
|
344
|
+
family_name: str | None = Field(
|
|
345
|
+
default=None,
|
|
346
|
+
description="Family/last name",
|
|
347
|
+
)
|
|
348
|
+
picture: str | None = Field(
|
|
349
|
+
default=None,
|
|
350
|
+
description="URL to profile picture",
|
|
351
|
+
)
|
|
352
|
+
locale: str | None = Field(
|
|
353
|
+
default=None,
|
|
354
|
+
description="User's locale",
|
|
355
|
+
)
|
|
356
|
+
extra_claims: dict[str, Any] = Field(
|
|
357
|
+
default_factory=dict,
|
|
358
|
+
description="Additional claims from the ID token",
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class LinkedIdentityInfo(BaseModel):
|
|
363
|
+
"""Information about a linked identity."""
|
|
364
|
+
|
|
365
|
+
model_config = ConfigDict(frozen=True)
|
|
366
|
+
|
|
367
|
+
id: str = Field(..., description="Unique identifier for this link")
|
|
368
|
+
provider_name: str = Field(..., description="Name of the provider")
|
|
369
|
+
provider_display_name: str = Field(..., description="Display name of the provider")
|
|
370
|
+
email: str | None = Field(default=None, description="Email associated with this identity")
|
|
371
|
+
created_at: datetime = Field(..., description="When this link was created")
|
|
372
|
+
last_used_at: datetime | None = Field(
|
|
373
|
+
default=None,
|
|
374
|
+
description="When this identity was last used",
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class LinkIdentityRequest(BaseModel):
|
|
379
|
+
"""Request to link a new identity to an existing user."""
|
|
380
|
+
|
|
381
|
+
model_config = ConfigDict(frozen=True)
|
|
382
|
+
|
|
383
|
+
user_id: str = Field(
|
|
384
|
+
...,
|
|
385
|
+
description="ID of the user to link the identity to",
|
|
386
|
+
)
|
|
387
|
+
code: str = Field(
|
|
388
|
+
...,
|
|
389
|
+
description="The authorization code from the callback",
|
|
390
|
+
)
|
|
391
|
+
state: str = Field(
|
|
392
|
+
...,
|
|
393
|
+
description="The state parameter from the callback",
|
|
394
|
+
)
|
|
395
|
+
redirect_uri: str = Field(
|
|
396
|
+
...,
|
|
397
|
+
description="The redirect URI used in the authorization request",
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
class UnlinkIdentityRequest(BaseModel):
|
|
402
|
+
"""Request to unlink an identity from a user."""
|
|
403
|
+
|
|
404
|
+
model_config = ConfigDict(frozen=True)
|
|
405
|
+
|
|
406
|
+
user_id: str = Field(
|
|
407
|
+
...,
|
|
408
|
+
description="ID of the user",
|
|
409
|
+
)
|
|
410
|
+
identity_id: str = Field(
|
|
411
|
+
...,
|
|
412
|
+
description="ID of the linked identity to remove",
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
# =============================================================================
|
|
417
|
+
# User Models
|
|
418
|
+
# =============================================================================
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
class AuthenticatedUser(BaseModel):
|
|
422
|
+
"""Information about an authenticated user."""
|
|
423
|
+
|
|
424
|
+
model_config = ConfigDict(frozen=True)
|
|
425
|
+
|
|
426
|
+
user_id: str = Field(..., description="Unique identifier for the user")
|
|
427
|
+
email: str | None = Field(default=None, description="User's email address")
|
|
428
|
+
is_new_user: bool = Field(
|
|
429
|
+
default=False,
|
|
430
|
+
description="Whether this is a newly created user",
|
|
431
|
+
)
|
|
432
|
+
identity: IdentityInfo = Field(
|
|
433
|
+
...,
|
|
434
|
+
description="Identity information used for authentication",
|
|
435
|
+
)
|
|
436
|
+
linked_identity_id: str = Field(
|
|
437
|
+
...,
|
|
438
|
+
description="ID of the linked identity used for authentication",
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class UserIdentities(BaseModel):
|
|
443
|
+
"""All linked identities for a user."""
|
|
444
|
+
|
|
445
|
+
model_config = ConfigDict(frozen=True)
|
|
446
|
+
|
|
447
|
+
user_id: str = Field(..., description="User's unique identifier")
|
|
448
|
+
identities: list[LinkedIdentityInfo] = Field(
|
|
449
|
+
default_factory=list,
|
|
450
|
+
description="List of linked identities",
|
|
451
|
+
)
|
byoi/pkce.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""PKCE (Proof Key for Code Exchange) utilities.
|
|
2
|
+
|
|
3
|
+
Implements RFC 7636 for preventing authorization code interception attacks.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import base64
|
|
7
|
+
import hashlib
|
|
8
|
+
import secrets
|
|
9
|
+
from typing import NamedTuple
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
"PKCEChallenge",
|
|
13
|
+
"generate_code_verifier",
|
|
14
|
+
"generate_code_challenge",
|
|
15
|
+
"generate_pkce_pair",
|
|
16
|
+
"generate_state",
|
|
17
|
+
"generate_nonce",
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PKCEChallenge(NamedTuple):
|
|
22
|
+
"""A PKCE code verifier and challenge pair."""
|
|
23
|
+
|
|
24
|
+
code_verifier: str
|
|
25
|
+
code_challenge: str
|
|
26
|
+
code_challenge_method: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def generate_code_verifier(length: int = 64) -> str:
|
|
30
|
+
"""Generate a cryptographically random code verifier.
|
|
31
|
+
|
|
32
|
+
The code verifier is a high-entropy cryptographic random string
|
|
33
|
+
using unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
|
|
34
|
+
with a minimum length of 43 characters and a maximum length of 128 characters.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
length: Length of the code verifier (43-128 characters).
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A random code verifier string.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If length is not between 43 and 128.
|
|
44
|
+
"""
|
|
45
|
+
if not 43 <= length <= 128:
|
|
46
|
+
raise ValueError("Code verifier length must be between 43 and 128 characters")
|
|
47
|
+
|
|
48
|
+
# Use URL-safe base64 encoding of random bytes
|
|
49
|
+
# We need to generate enough bytes to get the desired length after base64 encoding
|
|
50
|
+
# base64 encoding expands 3 bytes to 4 characters
|
|
51
|
+
num_bytes = (length * 3) // 4 + 1
|
|
52
|
+
random_bytes = secrets.token_bytes(num_bytes)
|
|
53
|
+
verifier = base64.urlsafe_b64encode(random_bytes).decode("ascii")
|
|
54
|
+
|
|
55
|
+
# Remove padding and truncate to desired length
|
|
56
|
+
verifier = verifier.replace("=", "")[:length]
|
|
57
|
+
|
|
58
|
+
return verifier
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def generate_code_challenge(code_verifier: str, method: str = "S256") -> str:
|
|
62
|
+
"""Generate a code challenge from a code verifier.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
code_verifier: The code verifier string.
|
|
66
|
+
method: Challenge method - "S256" (recommended) or "plain".
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
The code challenge string.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
ValueError: If method is not "S256" or "plain".
|
|
73
|
+
"""
|
|
74
|
+
if method == "S256":
|
|
75
|
+
# SHA256 hash of the code verifier, then base64url encode
|
|
76
|
+
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
|
77
|
+
challenge = base64.urlsafe_b64encode(digest).decode("ascii")
|
|
78
|
+
# Remove padding
|
|
79
|
+
return challenge.rstrip("=")
|
|
80
|
+
elif method == "plain":
|
|
81
|
+
# Plain method just returns the verifier as-is
|
|
82
|
+
return code_verifier
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError(f"Unsupported code challenge method: {method}")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def generate_pkce_pair(verifier_length: int = 64, method: str = "S256") -> PKCEChallenge:
|
|
88
|
+
"""Generate a PKCE code verifier and challenge pair.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
verifier_length: Length of the code verifier (43-128 characters).
|
|
92
|
+
method: Challenge method - "S256" (recommended) or "plain".
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
A PKCEChallenge containing the verifier, challenge, and method.
|
|
96
|
+
"""
|
|
97
|
+
code_verifier = generate_code_verifier(verifier_length)
|
|
98
|
+
code_challenge = generate_code_challenge(code_verifier, method)
|
|
99
|
+
return PKCEChallenge(
|
|
100
|
+
code_verifier=code_verifier,
|
|
101
|
+
code_challenge=code_challenge,
|
|
102
|
+
code_challenge_method=method,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def generate_state(length: int = 32) -> str:
|
|
107
|
+
"""Generate a cryptographically random state parameter.
|
|
108
|
+
|
|
109
|
+
The state parameter is used to prevent CSRF attacks and maintain state
|
|
110
|
+
between the authorization request and callback.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
length: Length of the state string (minimum 16).
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
A random state string.
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
ValueError: If length is less than 16.
|
|
120
|
+
"""
|
|
121
|
+
if length < 16:
|
|
122
|
+
raise ValueError("State length must be at least 16 characters")
|
|
123
|
+
|
|
124
|
+
return secrets.token_urlsafe(length)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def generate_nonce(length: int = 32) -> str:
|
|
128
|
+
"""Generate a cryptographically random nonce.
|
|
129
|
+
|
|
130
|
+
The nonce is used in ID token validation to prevent replay attacks.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
length: Length of the nonce string (minimum 16).
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
A random nonce string.
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
ValueError: If length is less than 16.
|
|
140
|
+
"""
|
|
141
|
+
if length < 16:
|
|
142
|
+
raise ValueError("Nonce length must be at least 16 characters")
|
|
143
|
+
|
|
144
|
+
return secrets.token_urlsafe(length)
|