oauth-codex 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oauth_codex/__init__.py +45 -0
- oauth_codex/auth.py +310 -0
- oauth_codex/client.py +783 -0
- oauth_codex/errors.py +46 -0
- oauth_codex/store.py +181 -0
- oauth_codex/tooling.py +168 -0
- oauth_codex/types.py +103 -0
- oauth_codex-0.2.3.dist-info/METADATA +151 -0
- oauth_codex-0.2.3.dist-info/RECORD +11 -0
- oauth_codex-0.2.3.dist-info/WHEEL +5 -0
- oauth_codex-0.2.3.dist-info/top_level.txt +1 -0
oauth_codex/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from .client import CodexOAuthLLM
|
|
2
|
+
from .errors import (
|
|
3
|
+
AuthRequiredError,
|
|
4
|
+
CodexOAuthLLMError,
|
|
5
|
+
LLMRequestError,
|
|
6
|
+
ModelValidationError,
|
|
7
|
+
OAuthCallbackParseError,
|
|
8
|
+
OAuthStateMismatchError,
|
|
9
|
+
TokenExchangeError,
|
|
10
|
+
TokenRefreshError,
|
|
11
|
+
ToolCallRequiredError,
|
|
12
|
+
)
|
|
13
|
+
from .types import (
|
|
14
|
+
GenerateResult,
|
|
15
|
+
Message,
|
|
16
|
+
OAuthConfig,
|
|
17
|
+
OAuthTokens,
|
|
18
|
+
StreamEvent,
|
|
19
|
+
TokenUsage,
|
|
20
|
+
ToolCall,
|
|
21
|
+
ToolInput,
|
|
22
|
+
ToolResult,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"AuthRequiredError",
|
|
27
|
+
"CodexOAuthLLM",
|
|
28
|
+
"CodexOAuthLLMError",
|
|
29
|
+
"GenerateResult",
|
|
30
|
+
"LLMRequestError",
|
|
31
|
+
"Message",
|
|
32
|
+
"ModelValidationError",
|
|
33
|
+
"OAuthCallbackParseError",
|
|
34
|
+
"OAuthConfig",
|
|
35
|
+
"OAuthStateMismatchError",
|
|
36
|
+
"OAuthTokens",
|
|
37
|
+
"StreamEvent",
|
|
38
|
+
"TokenExchangeError",
|
|
39
|
+
"TokenRefreshError",
|
|
40
|
+
"TokenUsage",
|
|
41
|
+
"ToolCall",
|
|
42
|
+
"ToolCallRequiredError",
|
|
43
|
+
"ToolInput",
|
|
44
|
+
"ToolResult",
|
|
45
|
+
]
|
oauth_codex/auth.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import hashlib
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import secrets
|
|
8
|
+
import time
|
|
9
|
+
from typing import Any
|
|
10
|
+
from urllib.parse import parse_qs, urlencode, urlparse
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from .errors import (
|
|
15
|
+
OAuthCallbackParseError,
|
|
16
|
+
OAuthStateMismatchError,
|
|
17
|
+
TokenExchangeError,
|
|
18
|
+
TokenRefreshError,
|
|
19
|
+
)
|
|
20
|
+
from .types import OAuthConfig, OAuthTokens
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_oauth_config(override: OAuthConfig | None = None) -> OAuthConfig:
|
|
24
|
+
base = override or OAuthConfig()
|
|
25
|
+
return OAuthConfig(
|
|
26
|
+
client_id=os.getenv("CODEX_OAUTH_CLIENT_ID", base.client_id),
|
|
27
|
+
scope=os.getenv("CODEX_OAUTH_SCOPE", base.scope),
|
|
28
|
+
audience=os.getenv("CODEX_OAUTH_AUDIENCE", base.audience or "") or None,
|
|
29
|
+
redirect_uri=os.getenv("CODEX_OAUTH_REDIRECT_URI", base.redirect_uri),
|
|
30
|
+
discovery_url=os.getenv("CODEX_OAUTH_DISCOVERY_URL", base.discovery_url),
|
|
31
|
+
authorization_endpoint=os.getenv(
|
|
32
|
+
"CODEX_OAUTH_AUTHORIZATION_ENDPOINT", base.authorization_endpoint
|
|
33
|
+
),
|
|
34
|
+
token_endpoint=os.getenv("CODEX_OAUTH_TOKEN_ENDPOINT", base.token_endpoint),
|
|
35
|
+
originator=os.getenv("CODEX_OAUTH_ORIGINATOR", base.originator),
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _b64url(data: bytes) -> str:
|
|
40
|
+
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def generate_pkce_pair() -> tuple[str, str]:
|
|
44
|
+
verifier = _b64url(secrets.token_bytes(64))
|
|
45
|
+
if len(verifier) < 43:
|
|
46
|
+
verifier = (verifier + "A" * 43)[:43]
|
|
47
|
+
if len(verifier) > 128:
|
|
48
|
+
verifier = verifier[:128]
|
|
49
|
+
|
|
50
|
+
digest = hashlib.sha256(verifier.encode("ascii")).digest()
|
|
51
|
+
challenge = _b64url(digest)
|
|
52
|
+
return verifier, challenge
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def generate_state() -> str:
|
|
56
|
+
return _b64url(secrets.token_bytes(24))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def build_authorize_url(config: OAuthConfig, state: str, code_challenge: str) -> str:
|
|
60
|
+
query = {
|
|
61
|
+
"response_type": "code",
|
|
62
|
+
"client_id": config.client_id,
|
|
63
|
+
"redirect_uri": config.redirect_uri,
|
|
64
|
+
"scope": config.scope,
|
|
65
|
+
"state": state,
|
|
66
|
+
"code_challenge": code_challenge,
|
|
67
|
+
"code_challenge_method": "S256",
|
|
68
|
+
"id_token_add_organizations": "true",
|
|
69
|
+
"codex_cli_simplified_flow": "true",
|
|
70
|
+
"originator": config.originator,
|
|
71
|
+
}
|
|
72
|
+
if config.audience:
|
|
73
|
+
query["audience"] = config.audience
|
|
74
|
+
return f"{config.authorization_endpoint}?{urlencode(query)}"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def discover_endpoints(client: httpx.Client, config: OAuthConfig) -> OAuthConfig:
|
|
78
|
+
try:
|
|
79
|
+
response = client.get(config.discovery_url)
|
|
80
|
+
response.raise_for_status()
|
|
81
|
+
payload = response.json()
|
|
82
|
+
except Exception:
|
|
83
|
+
return config
|
|
84
|
+
|
|
85
|
+
auth_endpoint = payload.get("authorization_endpoint") or config.authorization_endpoint
|
|
86
|
+
token_endpoint = payload.get("token_endpoint") or config.token_endpoint
|
|
87
|
+
|
|
88
|
+
return OAuthConfig(
|
|
89
|
+
client_id=config.client_id,
|
|
90
|
+
scope=config.scope,
|
|
91
|
+
audience=config.audience,
|
|
92
|
+
redirect_uri=config.redirect_uri,
|
|
93
|
+
discovery_url=config.discovery_url,
|
|
94
|
+
authorization_endpoint=auth_endpoint,
|
|
95
|
+
token_endpoint=token_endpoint,
|
|
96
|
+
originator=config.originator,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def discover_endpoints_async(client: httpx.AsyncClient, config: OAuthConfig) -> OAuthConfig:
|
|
101
|
+
try:
|
|
102
|
+
response = await client.get(config.discovery_url)
|
|
103
|
+
response.raise_for_status()
|
|
104
|
+
payload = response.json()
|
|
105
|
+
except Exception:
|
|
106
|
+
return config
|
|
107
|
+
|
|
108
|
+
auth_endpoint = payload.get("authorization_endpoint") or config.authorization_endpoint
|
|
109
|
+
token_endpoint = payload.get("token_endpoint") or config.token_endpoint
|
|
110
|
+
|
|
111
|
+
return OAuthConfig(
|
|
112
|
+
client_id=config.client_id,
|
|
113
|
+
scope=config.scope,
|
|
114
|
+
audience=config.audience,
|
|
115
|
+
redirect_uri=config.redirect_uri,
|
|
116
|
+
discovery_url=config.discovery_url,
|
|
117
|
+
authorization_endpoint=auth_endpoint,
|
|
118
|
+
token_endpoint=token_endpoint,
|
|
119
|
+
originator=config.originator,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def parse_callback_url(callback_url: str, expected_state: str) -> str:
|
|
124
|
+
try:
|
|
125
|
+
parsed = urlparse(callback_url.strip())
|
|
126
|
+
except Exception as exc:
|
|
127
|
+
raise OAuthCallbackParseError("Failed to parse callback URL") from exc
|
|
128
|
+
|
|
129
|
+
query = parse_qs(parsed.query)
|
|
130
|
+
|
|
131
|
+
if "error" in query:
|
|
132
|
+
err = query.get("error", [""])[0]
|
|
133
|
+
desc = query.get("error_description", [""])[0]
|
|
134
|
+
raise OAuthCallbackParseError(f"OAuth callback returned error: {err} {desc}".strip())
|
|
135
|
+
|
|
136
|
+
code = query.get("code", [None])[0]
|
|
137
|
+
state = query.get("state", [None])[0]
|
|
138
|
+
|
|
139
|
+
if not code:
|
|
140
|
+
raise OAuthCallbackParseError("OAuth callback is missing authorization code")
|
|
141
|
+
if state != expected_state:
|
|
142
|
+
raise OAuthStateMismatchError("OAuth callback state mismatch")
|
|
143
|
+
|
|
144
|
+
return code
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _build_tokens(payload: dict[str, Any]) -> OAuthTokens:
|
|
148
|
+
expires_in = payload.get("expires_in")
|
|
149
|
+
expires_at = None
|
|
150
|
+
if isinstance(expires_in, (int, float)):
|
|
151
|
+
expires_at = time.time() + float(expires_in)
|
|
152
|
+
|
|
153
|
+
id_token = payload.get("id_token")
|
|
154
|
+
access_token = payload.get("access_token")
|
|
155
|
+
account_id = payload.get("account_id")
|
|
156
|
+
if not account_id:
|
|
157
|
+
account_id = _extract_chatgpt_account_id(id_token)
|
|
158
|
+
if not account_id:
|
|
159
|
+
account_id = _extract_chatgpt_account_id(access_token)
|
|
160
|
+
|
|
161
|
+
return OAuthTokens(
|
|
162
|
+
access_token=payload["access_token"],
|
|
163
|
+
api_key=None,
|
|
164
|
+
refresh_token=payload.get("refresh_token"),
|
|
165
|
+
id_token=id_token,
|
|
166
|
+
token_type=payload.get("token_type", "Bearer"),
|
|
167
|
+
scope=payload.get("scope"),
|
|
168
|
+
expires_at=expires_at,
|
|
169
|
+
account_id=account_id,
|
|
170
|
+
last_refresh=time.time(),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _extract_chatgpt_account_id(jwt_token: str | None) -> str | None:
|
|
175
|
+
if not jwt_token or "." not in jwt_token:
|
|
176
|
+
return None
|
|
177
|
+
parts = jwt_token.split(".")
|
|
178
|
+
if len(parts) < 2:
|
|
179
|
+
return None
|
|
180
|
+
payload_b64 = parts[1]
|
|
181
|
+
padding = "=" * (-len(payload_b64) % 4)
|
|
182
|
+
try:
|
|
183
|
+
raw = base64.urlsafe_b64decode(payload_b64 + padding)
|
|
184
|
+
payload = json.loads(raw.decode("utf-8"))
|
|
185
|
+
except Exception:
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
if not isinstance(payload, dict):
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
auth_claims = payload.get("https://api.openai.com/auth")
|
|
192
|
+
if isinstance(auth_claims, dict):
|
|
193
|
+
account = auth_claims.get("chatgpt_account_id") or auth_claims.get("account_id")
|
|
194
|
+
if isinstance(account, str) and account:
|
|
195
|
+
return account
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def exchange_code_for_tokens(
|
|
200
|
+
client: httpx.Client,
|
|
201
|
+
config: OAuthConfig,
|
|
202
|
+
code: str,
|
|
203
|
+
code_verifier: str,
|
|
204
|
+
) -> OAuthTokens:
|
|
205
|
+
data = {
|
|
206
|
+
"grant_type": "authorization_code",
|
|
207
|
+
"code": code,
|
|
208
|
+
"client_id": config.client_id,
|
|
209
|
+
"redirect_uri": config.redirect_uri,
|
|
210
|
+
"code_verifier": code_verifier,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
if config.audience:
|
|
214
|
+
data["audience"] = config.audience
|
|
215
|
+
|
|
216
|
+
response = client.post(config.token_endpoint, data=data)
|
|
217
|
+
if response.status_code >= 400:
|
|
218
|
+
detail = _extract_oauth_error(response)
|
|
219
|
+
raise TokenExchangeError(f"OAuth token exchange failed: {detail}")
|
|
220
|
+
|
|
221
|
+
payload = response.json()
|
|
222
|
+
if "access_token" not in payload:
|
|
223
|
+
raise TokenExchangeError("OAuth token exchange failed: access_token missing")
|
|
224
|
+
return _build_tokens(payload)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def refresh_tokens(
|
|
228
|
+
client: httpx.Client,
|
|
229
|
+
config: OAuthConfig,
|
|
230
|
+
tokens: OAuthTokens,
|
|
231
|
+
) -> OAuthTokens:
|
|
232
|
+
if not tokens.refresh_token:
|
|
233
|
+
raise TokenRefreshError("No refresh_token available")
|
|
234
|
+
|
|
235
|
+
payload = {
|
|
236
|
+
"grant_type": "refresh_token",
|
|
237
|
+
"refresh_token": tokens.refresh_token,
|
|
238
|
+
"client_id": config.client_id,
|
|
239
|
+
"scope": "openid profile email",
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
response = client.post(config.token_endpoint, json=payload)
|
|
243
|
+
if response.status_code >= 400:
|
|
244
|
+
detail = _extract_oauth_error(response)
|
|
245
|
+
raise TokenRefreshError(f"OAuth token refresh failed: {detail}")
|
|
246
|
+
|
|
247
|
+
payload = response.json()
|
|
248
|
+
if "access_token" not in payload:
|
|
249
|
+
raise TokenRefreshError("OAuth token refresh failed: access_token missing")
|
|
250
|
+
|
|
251
|
+
new_tokens = _build_tokens(payload)
|
|
252
|
+
if not new_tokens.refresh_token:
|
|
253
|
+
new_tokens.refresh_token = tokens.refresh_token
|
|
254
|
+
if not new_tokens.account_id:
|
|
255
|
+
new_tokens.account_id = tokens.account_id
|
|
256
|
+
if not new_tokens.id_token:
|
|
257
|
+
new_tokens.id_token = tokens.id_token
|
|
258
|
+
if not new_tokens.api_key:
|
|
259
|
+
new_tokens.api_key = tokens.api_key
|
|
260
|
+
return new_tokens
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
async def refresh_tokens_async(
|
|
264
|
+
client: httpx.AsyncClient,
|
|
265
|
+
config: OAuthConfig,
|
|
266
|
+
tokens: OAuthTokens,
|
|
267
|
+
) -> OAuthTokens:
|
|
268
|
+
if not tokens.refresh_token:
|
|
269
|
+
raise TokenRefreshError("No refresh_token available")
|
|
270
|
+
|
|
271
|
+
payload = {
|
|
272
|
+
"grant_type": "refresh_token",
|
|
273
|
+
"refresh_token": tokens.refresh_token,
|
|
274
|
+
"client_id": config.client_id,
|
|
275
|
+
"scope": "openid profile email",
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
response = await client.post(config.token_endpoint, json=payload)
|
|
279
|
+
if response.status_code >= 400:
|
|
280
|
+
detail = _extract_oauth_error(response)
|
|
281
|
+
raise TokenRefreshError(f"OAuth token refresh failed: {detail}")
|
|
282
|
+
|
|
283
|
+
payload = response.json()
|
|
284
|
+
if "access_token" not in payload:
|
|
285
|
+
raise TokenRefreshError("OAuth token refresh failed: access_token missing")
|
|
286
|
+
|
|
287
|
+
new_tokens = _build_tokens(payload)
|
|
288
|
+
if not new_tokens.refresh_token:
|
|
289
|
+
new_tokens.refresh_token = tokens.refresh_token
|
|
290
|
+
if not new_tokens.account_id:
|
|
291
|
+
new_tokens.account_id = tokens.account_id
|
|
292
|
+
if not new_tokens.id_token:
|
|
293
|
+
new_tokens.id_token = tokens.id_token
|
|
294
|
+
if not new_tokens.api_key:
|
|
295
|
+
new_tokens.api_key = tokens.api_key
|
|
296
|
+
return new_tokens
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _extract_oauth_error(response: httpx.Response) -> str:
|
|
300
|
+
try:
|
|
301
|
+
payload = response.json()
|
|
302
|
+
except Exception:
|
|
303
|
+
return f"status={response.status_code}"
|
|
304
|
+
|
|
305
|
+
if isinstance(payload, dict):
|
|
306
|
+
error = payload.get("error")
|
|
307
|
+
desc = payload.get("error_description") or payload.get("message")
|
|
308
|
+
text = " ".join([part for part in [str(error) if error else "", str(desc) if desc else ""] if part])
|
|
309
|
+
return text or f"status={response.status_code}"
|
|
310
|
+
return f"status={response.status_code}"
|