oauth-codex 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,45 @@
1
+ from .client import CodexOAuthLLM
2
+ from .errors import (
3
+ AuthRequiredError,
4
+ CodexOAuthLLMError,
5
+ LLMRequestError,
6
+ ModelValidationError,
7
+ OAuthCallbackParseError,
8
+ OAuthStateMismatchError,
9
+ TokenExchangeError,
10
+ TokenRefreshError,
11
+ ToolCallRequiredError,
12
+ )
13
+ from .types import (
14
+ GenerateResult,
15
+ Message,
16
+ OAuthConfig,
17
+ OAuthTokens,
18
+ StreamEvent,
19
+ TokenUsage,
20
+ ToolCall,
21
+ ToolInput,
22
+ ToolResult,
23
+ )
24
+
25
+ __all__ = [
26
+ "AuthRequiredError",
27
+ "CodexOAuthLLM",
28
+ "CodexOAuthLLMError",
29
+ "GenerateResult",
30
+ "LLMRequestError",
31
+ "Message",
32
+ "ModelValidationError",
33
+ "OAuthCallbackParseError",
34
+ "OAuthConfig",
35
+ "OAuthStateMismatchError",
36
+ "OAuthTokens",
37
+ "StreamEvent",
38
+ "TokenExchangeError",
39
+ "TokenRefreshError",
40
+ "TokenUsage",
41
+ "ToolCall",
42
+ "ToolCallRequiredError",
43
+ "ToolInput",
44
+ "ToolResult",
45
+ ]
oauth_codex/auth.py ADDED
@@ -0,0 +1,310 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import hashlib
5
+ import json
6
+ import os
7
+ import secrets
8
+ import time
9
+ from typing import Any
10
+ from urllib.parse import parse_qs, urlencode, urlparse
11
+
12
+ import httpx
13
+
14
+ from .errors import (
15
+ OAuthCallbackParseError,
16
+ OAuthStateMismatchError,
17
+ TokenExchangeError,
18
+ TokenRefreshError,
19
+ )
20
+ from .types import OAuthConfig, OAuthTokens
21
+
22
+
23
+ def load_oauth_config(override: OAuthConfig | None = None) -> OAuthConfig:
24
+ base = override or OAuthConfig()
25
+ return OAuthConfig(
26
+ client_id=os.getenv("CODEX_OAUTH_CLIENT_ID", base.client_id),
27
+ scope=os.getenv("CODEX_OAUTH_SCOPE", base.scope),
28
+ audience=os.getenv("CODEX_OAUTH_AUDIENCE", base.audience or "") or None,
29
+ redirect_uri=os.getenv("CODEX_OAUTH_REDIRECT_URI", base.redirect_uri),
30
+ discovery_url=os.getenv("CODEX_OAUTH_DISCOVERY_URL", base.discovery_url),
31
+ authorization_endpoint=os.getenv(
32
+ "CODEX_OAUTH_AUTHORIZATION_ENDPOINT", base.authorization_endpoint
33
+ ),
34
+ token_endpoint=os.getenv("CODEX_OAUTH_TOKEN_ENDPOINT", base.token_endpoint),
35
+ originator=os.getenv("CODEX_OAUTH_ORIGINATOR", base.originator),
36
+ )
37
+
38
+
39
+ def _b64url(data: bytes) -> str:
40
+ return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
41
+
42
+
43
+ def generate_pkce_pair() -> tuple[str, str]:
44
+ verifier = _b64url(secrets.token_bytes(64))
45
+ if len(verifier) < 43:
46
+ verifier = (verifier + "A" * 43)[:43]
47
+ if len(verifier) > 128:
48
+ verifier = verifier[:128]
49
+
50
+ digest = hashlib.sha256(verifier.encode("ascii")).digest()
51
+ challenge = _b64url(digest)
52
+ return verifier, challenge
53
+
54
+
55
+ def generate_state() -> str:
56
+ return _b64url(secrets.token_bytes(24))
57
+
58
+
59
+ def build_authorize_url(config: OAuthConfig, state: str, code_challenge: str) -> str:
60
+ query = {
61
+ "response_type": "code",
62
+ "client_id": config.client_id,
63
+ "redirect_uri": config.redirect_uri,
64
+ "scope": config.scope,
65
+ "state": state,
66
+ "code_challenge": code_challenge,
67
+ "code_challenge_method": "S256",
68
+ "id_token_add_organizations": "true",
69
+ "codex_cli_simplified_flow": "true",
70
+ "originator": config.originator,
71
+ }
72
+ if config.audience:
73
+ query["audience"] = config.audience
74
+ return f"{config.authorization_endpoint}?{urlencode(query)}"
75
+
76
+
77
+ def discover_endpoints(client: httpx.Client, config: OAuthConfig) -> OAuthConfig:
78
+ try:
79
+ response = client.get(config.discovery_url)
80
+ response.raise_for_status()
81
+ payload = response.json()
82
+ except Exception:
83
+ return config
84
+
85
+ auth_endpoint = payload.get("authorization_endpoint") or config.authorization_endpoint
86
+ token_endpoint = payload.get("token_endpoint") or config.token_endpoint
87
+
88
+ return OAuthConfig(
89
+ client_id=config.client_id,
90
+ scope=config.scope,
91
+ audience=config.audience,
92
+ redirect_uri=config.redirect_uri,
93
+ discovery_url=config.discovery_url,
94
+ authorization_endpoint=auth_endpoint,
95
+ token_endpoint=token_endpoint,
96
+ originator=config.originator,
97
+ )
98
+
99
+
100
+ async def discover_endpoints_async(client: httpx.AsyncClient, config: OAuthConfig) -> OAuthConfig:
101
+ try:
102
+ response = await client.get(config.discovery_url)
103
+ response.raise_for_status()
104
+ payload = response.json()
105
+ except Exception:
106
+ return config
107
+
108
+ auth_endpoint = payload.get("authorization_endpoint") or config.authorization_endpoint
109
+ token_endpoint = payload.get("token_endpoint") or config.token_endpoint
110
+
111
+ return OAuthConfig(
112
+ client_id=config.client_id,
113
+ scope=config.scope,
114
+ audience=config.audience,
115
+ redirect_uri=config.redirect_uri,
116
+ discovery_url=config.discovery_url,
117
+ authorization_endpoint=auth_endpoint,
118
+ token_endpoint=token_endpoint,
119
+ originator=config.originator,
120
+ )
121
+
122
+
123
+ def parse_callback_url(callback_url: str, expected_state: str) -> str:
124
+ try:
125
+ parsed = urlparse(callback_url.strip())
126
+ except Exception as exc:
127
+ raise OAuthCallbackParseError("Failed to parse callback URL") from exc
128
+
129
+ query = parse_qs(parsed.query)
130
+
131
+ if "error" in query:
132
+ err = query.get("error", [""])[0]
133
+ desc = query.get("error_description", [""])[0]
134
+ raise OAuthCallbackParseError(f"OAuth callback returned error: {err} {desc}".strip())
135
+
136
+ code = query.get("code", [None])[0]
137
+ state = query.get("state", [None])[0]
138
+
139
+ if not code:
140
+ raise OAuthCallbackParseError("OAuth callback is missing authorization code")
141
+ if state != expected_state:
142
+ raise OAuthStateMismatchError("OAuth callback state mismatch")
143
+
144
+ return code
145
+
146
+
147
+ def _build_tokens(payload: dict[str, Any]) -> OAuthTokens:
148
+ expires_in = payload.get("expires_in")
149
+ expires_at = None
150
+ if isinstance(expires_in, (int, float)):
151
+ expires_at = time.time() + float(expires_in)
152
+
153
+ id_token = payload.get("id_token")
154
+ access_token = payload.get("access_token")
155
+ account_id = payload.get("account_id")
156
+ if not account_id:
157
+ account_id = _extract_chatgpt_account_id(id_token)
158
+ if not account_id:
159
+ account_id = _extract_chatgpt_account_id(access_token)
160
+
161
+ return OAuthTokens(
162
+ access_token=payload["access_token"],
163
+ api_key=None,
164
+ refresh_token=payload.get("refresh_token"),
165
+ id_token=id_token,
166
+ token_type=payload.get("token_type", "Bearer"),
167
+ scope=payload.get("scope"),
168
+ expires_at=expires_at,
169
+ account_id=account_id,
170
+ last_refresh=time.time(),
171
+ )
172
+
173
+
174
+ def _extract_chatgpt_account_id(jwt_token: str | None) -> str | None:
175
+ if not jwt_token or "." not in jwt_token:
176
+ return None
177
+ parts = jwt_token.split(".")
178
+ if len(parts) < 2:
179
+ return None
180
+ payload_b64 = parts[1]
181
+ padding = "=" * (-len(payload_b64) % 4)
182
+ try:
183
+ raw = base64.urlsafe_b64decode(payload_b64 + padding)
184
+ payload = json.loads(raw.decode("utf-8"))
185
+ except Exception:
186
+ return None
187
+
188
+ if not isinstance(payload, dict):
189
+ return None
190
+
191
+ auth_claims = payload.get("https://api.openai.com/auth")
192
+ if isinstance(auth_claims, dict):
193
+ account = auth_claims.get("chatgpt_account_id") or auth_claims.get("account_id")
194
+ if isinstance(account, str) and account:
195
+ return account
196
+ return None
197
+
198
+
199
+ def exchange_code_for_tokens(
200
+ client: httpx.Client,
201
+ config: OAuthConfig,
202
+ code: str,
203
+ code_verifier: str,
204
+ ) -> OAuthTokens:
205
+ data = {
206
+ "grant_type": "authorization_code",
207
+ "code": code,
208
+ "client_id": config.client_id,
209
+ "redirect_uri": config.redirect_uri,
210
+ "code_verifier": code_verifier,
211
+ }
212
+
213
+ if config.audience:
214
+ data["audience"] = config.audience
215
+
216
+ response = client.post(config.token_endpoint, data=data)
217
+ if response.status_code >= 400:
218
+ detail = _extract_oauth_error(response)
219
+ raise TokenExchangeError(f"OAuth token exchange failed: {detail}")
220
+
221
+ payload = response.json()
222
+ if "access_token" not in payload:
223
+ raise TokenExchangeError("OAuth token exchange failed: access_token missing")
224
+ return _build_tokens(payload)
225
+
226
+
227
+ def refresh_tokens(
228
+ client: httpx.Client,
229
+ config: OAuthConfig,
230
+ tokens: OAuthTokens,
231
+ ) -> OAuthTokens:
232
+ if not tokens.refresh_token:
233
+ raise TokenRefreshError("No refresh_token available")
234
+
235
+ payload = {
236
+ "grant_type": "refresh_token",
237
+ "refresh_token": tokens.refresh_token,
238
+ "client_id": config.client_id,
239
+ "scope": "openid profile email",
240
+ }
241
+
242
+ response = client.post(config.token_endpoint, json=payload)
243
+ if response.status_code >= 400:
244
+ detail = _extract_oauth_error(response)
245
+ raise TokenRefreshError(f"OAuth token refresh failed: {detail}")
246
+
247
+ payload = response.json()
248
+ if "access_token" not in payload:
249
+ raise TokenRefreshError("OAuth token refresh failed: access_token missing")
250
+
251
+ new_tokens = _build_tokens(payload)
252
+ if not new_tokens.refresh_token:
253
+ new_tokens.refresh_token = tokens.refresh_token
254
+ if not new_tokens.account_id:
255
+ new_tokens.account_id = tokens.account_id
256
+ if not new_tokens.id_token:
257
+ new_tokens.id_token = tokens.id_token
258
+ if not new_tokens.api_key:
259
+ new_tokens.api_key = tokens.api_key
260
+ return new_tokens
261
+
262
+
263
+ async def refresh_tokens_async(
264
+ client: httpx.AsyncClient,
265
+ config: OAuthConfig,
266
+ tokens: OAuthTokens,
267
+ ) -> OAuthTokens:
268
+ if not tokens.refresh_token:
269
+ raise TokenRefreshError("No refresh_token available")
270
+
271
+ payload = {
272
+ "grant_type": "refresh_token",
273
+ "refresh_token": tokens.refresh_token,
274
+ "client_id": config.client_id,
275
+ "scope": "openid profile email",
276
+ }
277
+
278
+ response = await client.post(config.token_endpoint, json=payload)
279
+ if response.status_code >= 400:
280
+ detail = _extract_oauth_error(response)
281
+ raise TokenRefreshError(f"OAuth token refresh failed: {detail}")
282
+
283
+ payload = response.json()
284
+ if "access_token" not in payload:
285
+ raise TokenRefreshError("OAuth token refresh failed: access_token missing")
286
+
287
+ new_tokens = _build_tokens(payload)
288
+ if not new_tokens.refresh_token:
289
+ new_tokens.refresh_token = tokens.refresh_token
290
+ if not new_tokens.account_id:
291
+ new_tokens.account_id = tokens.account_id
292
+ if not new_tokens.id_token:
293
+ new_tokens.id_token = tokens.id_token
294
+ if not new_tokens.api_key:
295
+ new_tokens.api_key = tokens.api_key
296
+ return new_tokens
297
+
298
+
299
+ def _extract_oauth_error(response: httpx.Response) -> str:
300
+ try:
301
+ payload = response.json()
302
+ except Exception:
303
+ return f"status={response.status_code}"
304
+
305
+ if isinstance(payload, dict):
306
+ error = payload.get("error")
307
+ desc = payload.get("error_description") or payload.get("message")
308
+ text = " ".join([part for part in [str(error) if error else "", str(desc) if desc else ""] if part])
309
+ return text or f"status={response.status_code}"
310
+ return f"status={response.status_code}"