axion-code 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- axion/__init__.py +3 -0
- axion/api/__init__.py +0 -0
- axion/api/anthropic.py +460 -0
- axion/api/client.py +259 -0
- axion/api/error.py +161 -0
- axion/api/ollama.py +597 -0
- axion/api/openai_compat.py +805 -0
- axion/api/openai_responses.py +627 -0
- axion/api/prompt_cache.py +31 -0
- axion/api/sse.py +98 -0
- axion/api/types.py +451 -0
- axion/cli/__init__.py +0 -0
- axion/cli/init_cmd.py +50 -0
- axion/cli/input.py +290 -0
- axion/cli/main.py +2953 -0
- axion/cli/render.py +489 -0
- axion/cli/tui.py +766 -0
- axion/commands/__init__.py +0 -0
- axion/commands/handlers/__init__.py +0 -0
- axion/commands/handlers/agents.py +51 -0
- axion/commands/handlers/builtin_commands.py +367 -0
- axion/commands/handlers/mcp.py +59 -0
- axion/commands/handlers/models.py +75 -0
- axion/commands/handlers/plugins.py +55 -0
- axion/commands/handlers/skills.py +61 -0
- axion/commands/parsing.py +317 -0
- axion/commands/registry.py +166 -0
- axion/compat_harness/__init__.py +0 -0
- axion/compat_harness/extractor.py +145 -0
- axion/plugins/__init__.py +0 -0
- axion/plugins/hooks.py +22 -0
- axion/plugins/manager.py +391 -0
- axion/plugins/manifest.py +270 -0
- axion/runtime/__init__.py +0 -0
- axion/runtime/bash.py +388 -0
- axion/runtime/bootstrap.py +39 -0
- axion/runtime/claude_subscription.py +300 -0
- axion/runtime/compact.py +233 -0
- axion/runtime/config.py +397 -0
- axion/runtime/conversation.py +1073 -0
- axion/runtime/file_ops.py +613 -0
- axion/runtime/git.py +213 -0
- axion/runtime/hooks.py +235 -0
- axion/runtime/image.py +212 -0
- axion/runtime/lanes.py +282 -0
- axion/runtime/lsp.py +425 -0
- axion/runtime/mcp/__init__.py +0 -0
- axion/runtime/mcp/client.py +76 -0
- axion/runtime/mcp/lifecycle.py +96 -0
- axion/runtime/mcp/stdio.py +318 -0
- axion/runtime/mcp/tool_bridge.py +79 -0
- axion/runtime/memory.py +196 -0
- axion/runtime/oauth.py +329 -0
- axion/runtime/openai_subscription.py +346 -0
- axion/runtime/permissions.py +247 -0
- axion/runtime/plan_mode.py +96 -0
- axion/runtime/policy_engine.py +259 -0
- axion/runtime/prompt.py +586 -0
- axion/runtime/recovery.py +261 -0
- axion/runtime/remote.py +28 -0
- axion/runtime/sandbox.py +68 -0
- axion/runtime/scheduler.py +231 -0
- axion/runtime/session.py +365 -0
- axion/runtime/sharing.py +159 -0
- axion/runtime/skills.py +124 -0
- axion/runtime/tasks.py +258 -0
- axion/runtime/usage.py +241 -0
- axion/runtime/workers.py +186 -0
- axion/telemetry/__init__.py +0 -0
- axion/telemetry/events.py +67 -0
- axion/telemetry/profile.py +49 -0
- axion/telemetry/sink.py +60 -0
- axion/telemetry/tracer.py +95 -0
- axion/tools/__init__.py +0 -0
- axion/tools/lane_completion.py +33 -0
- axion/tools/registry.py +853 -0
- axion/tools/tool_search.py +226 -0
- axion_code-1.0.0.dist-info/METADATA +709 -0
- axion_code-1.0.0.dist-info/RECORD +82 -0
- axion_code-1.0.0.dist-info/WHEEL +4 -0
- axion_code-1.0.0.dist-info/entry_points.txt +2 -0
- axion_code-1.0.0.dist-info/licenses/LICENSE +21 -0
axion/runtime/memory.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""Persistent memory system.
|
|
2
|
+
|
|
3
|
+
Maps to: rust/crates/runtime/src/memory.rs
|
|
4
|
+
|
|
5
|
+
Provides a file-backed memory store where entries are saved as .md files
|
|
6
|
+
with YAML frontmatter. An index file (MEMORY.md) tracks all entries.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import enum
|
|
12
|
+
import logging
|
|
13
|
+
import re
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from datetime import datetime, timezone
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
_DEFAULT_MEMORY_DIR = Path.home() / ".axion" / "memory"
|
|
21
|
+
|
|
22
|
+
# Reuse the lightweight frontmatter parser from skills
|
|
23
|
+
_FRONTMATTER_RE = re.compile(
|
|
24
|
+
r"\A---\s*\n(.*?)\n---\s*\n(.*)",
|
|
25
|
+
re.DOTALL,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MemoryType(enum.Enum):
|
|
30
|
+
USER = "user"
|
|
31
|
+
FEEDBACK = "feedback"
|
|
32
|
+
PROJECT = "project"
|
|
33
|
+
REFERENCE = "reference"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class MemoryEntry:
|
|
38
|
+
"""A single memory entry."""
|
|
39
|
+
|
|
40
|
+
name: str
|
|
41
|
+
description: str
|
|
42
|
+
type: MemoryType
|
|
43
|
+
content: str
|
|
44
|
+
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
45
|
+
updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# ---------------------------------------------------------------------------
|
|
49
|
+
# Frontmatter helpers
|
|
50
|
+
# ---------------------------------------------------------------------------
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _parse_frontmatter(text: str) -> tuple[dict[str, str], str]:
|
|
54
|
+
match = _FRONTMATTER_RE.match(text)
|
|
55
|
+
if not match:
|
|
56
|
+
return {}, text
|
|
57
|
+
meta: dict[str, str] = {}
|
|
58
|
+
for line in match.group(1).splitlines():
|
|
59
|
+
line = line.strip()
|
|
60
|
+
if not line or line.startswith("#"):
|
|
61
|
+
continue
|
|
62
|
+
if ":" in line:
|
|
63
|
+
key, _, value = line.partition(":")
|
|
64
|
+
meta[key.strip()] = value.strip()
|
|
65
|
+
return meta, match.group(2)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _build_frontmatter(entry: MemoryEntry) -> str:
|
|
69
|
+
lines = [
|
|
70
|
+
"---",
|
|
71
|
+
f"name: {entry.name}",
|
|
72
|
+
f"description: {entry.description}",
|
|
73
|
+
f"type: {entry.type.value}",
|
|
74
|
+
f"created_at: {entry.created_at}",
|
|
75
|
+
f"updated_at: {entry.updated_at}",
|
|
76
|
+
"---",
|
|
77
|
+
"",
|
|
78
|
+
entry.content,
|
|
79
|
+
]
|
|
80
|
+
return "\n".join(lines)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# ---------------------------------------------------------------------------
|
|
84
|
+
# MemoryStore
|
|
85
|
+
# ---------------------------------------------------------------------------
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class MemoryStore:
|
|
89
|
+
"""File-backed memory store.
|
|
90
|
+
|
|
91
|
+
Each entry is stored as a .md file under *memory_dir*. An index
|
|
92
|
+
file (MEMORY.md) provides a human-readable listing.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self, memory_dir: Path | None = None) -> None:
|
|
96
|
+
self.memory_dir = memory_dir or _DEFAULT_MEMORY_DIR
|
|
97
|
+
|
|
98
|
+
def _ensure_dir(self) -> None:
|
|
99
|
+
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
|
100
|
+
|
|
101
|
+
def _entry_path(self, name: str) -> Path:
|
|
102
|
+
safe_name = re.sub(r"[^\w\-.]", "_", name)
|
|
103
|
+
return self.memory_dir / f"{safe_name}.md"
|
|
104
|
+
|
|
105
|
+
# -- CRUD ---------------------------------------------------------------
|
|
106
|
+
|
|
107
|
+
def save(self, entry: MemoryEntry) -> Path:
|
|
108
|
+
"""Write an entry as a .md file with YAML frontmatter."""
|
|
109
|
+
self._ensure_dir()
|
|
110
|
+
entry.updated_at = datetime.now(timezone.utc).isoformat()
|
|
111
|
+
path = self._entry_path(entry.name)
|
|
112
|
+
path.write_text(_build_frontmatter(entry), encoding="utf-8")
|
|
113
|
+
logger.debug("Saved memory entry '%s' to %s", entry.name, path)
|
|
114
|
+
return path
|
|
115
|
+
|
|
116
|
+
def load(self, name: str) -> MemoryEntry | None:
|
|
117
|
+
"""Read a single memory file by name."""
|
|
118
|
+
path = self._entry_path(name)
|
|
119
|
+
if not path.is_file():
|
|
120
|
+
return None
|
|
121
|
+
return self._read_entry(path)
|
|
122
|
+
|
|
123
|
+
def load_all(self) -> list[MemoryEntry]:
|
|
124
|
+
"""Read all memory files in the store."""
|
|
125
|
+
if not self.memory_dir.is_dir():
|
|
126
|
+
return []
|
|
127
|
+
entries: list[MemoryEntry] = []
|
|
128
|
+
for path in sorted(self.memory_dir.glob("*.md")):
|
|
129
|
+
if path.name == "MEMORY.md":
|
|
130
|
+
continue
|
|
131
|
+
entry = self._read_entry(path)
|
|
132
|
+
if entry is not None:
|
|
133
|
+
entries.append(entry)
|
|
134
|
+
return entries
|
|
135
|
+
|
|
136
|
+
def remove(self, name: str) -> bool:
|
|
137
|
+
"""Delete a memory file. Returns True if the file existed."""
|
|
138
|
+
path = self._entry_path(name)
|
|
139
|
+
if path.is_file():
|
|
140
|
+
path.unlink()
|
|
141
|
+
logger.debug("Removed memory entry '%s'", name)
|
|
142
|
+
return True
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
# -- Index --------------------------------------------------------------
|
|
146
|
+
|
|
147
|
+
def load_index(self) -> str | None:
|
|
148
|
+
"""Read the MEMORY.md index file."""
|
|
149
|
+
index_path = self.memory_dir / "MEMORY.md"
|
|
150
|
+
if index_path.is_file():
|
|
151
|
+
return index_path.read_text(encoding="utf-8")
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
def save_index(self, entries: list[MemoryEntry] | None = None) -> Path:
|
|
155
|
+
"""Write a MEMORY.md index listing all entries."""
|
|
156
|
+
self._ensure_dir()
|
|
157
|
+
if entries is None:
|
|
158
|
+
entries = self.load_all()
|
|
159
|
+
|
|
160
|
+
lines = ["# Memory Index", ""]
|
|
161
|
+
for entry in entries:
|
|
162
|
+
lines.append(
|
|
163
|
+
f"- **{entry.name}** ({entry.type.value}): {entry.description}"
|
|
164
|
+
)
|
|
165
|
+
lines.append("")
|
|
166
|
+
|
|
167
|
+
index_path = self.memory_dir / "MEMORY.md"
|
|
168
|
+
index_path.write_text("\n".join(lines), encoding="utf-8")
|
|
169
|
+
logger.debug("Saved memory index with %d entries", len(entries))
|
|
170
|
+
return index_path
|
|
171
|
+
|
|
172
|
+
# -- Internal -----------------------------------------------------------
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def _read_entry(path: Path) -> MemoryEntry | None:
|
|
176
|
+
try:
|
|
177
|
+
text = path.read_text(encoding="utf-8")
|
|
178
|
+
except OSError as exc:
|
|
179
|
+
logger.warning("Failed to read memory file %s: %s", path, exc)
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
meta, content = _parse_frontmatter(text)
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
mem_type = MemoryType(meta.get("type", "reference"))
|
|
186
|
+
except ValueError:
|
|
187
|
+
mem_type = MemoryType.REFERENCE
|
|
188
|
+
|
|
189
|
+
return MemoryEntry(
|
|
190
|
+
name=meta.get("name", path.stem),
|
|
191
|
+
description=meta.get("description", ""),
|
|
192
|
+
type=mem_type,
|
|
193
|
+
content=content.strip(),
|
|
194
|
+
created_at=meta.get("created_at", ""),
|
|
195
|
+
updated_at=meta.get("updated_at", ""),
|
|
196
|
+
)
|
axion/runtime/oauth.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""OAuth PKCE flow with HTTP callback server, token refresh, and browser launch.
|
|
2
|
+
|
|
3
|
+
Maps to: rust/crates/runtime/src/oauth.rs
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import base64
|
|
9
|
+
import hashlib
|
|
10
|
+
import http.server
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import platform
|
|
15
|
+
import secrets
|
|
16
|
+
import subprocess
|
|
17
|
+
import threading
|
|
18
|
+
import time
|
|
19
|
+
import urllib.parse
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
DEFAULT_OAUTH_CALLBACK_PORT = 4545
|
|
27
|
+
DEFAULT_CALLBACK_PATH = "/oauth/callback"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class OAuthTokenSet:
|
|
32
|
+
access_token: str
|
|
33
|
+
refresh_token: str | None = None
|
|
34
|
+
expires_at: int | None = None
|
|
35
|
+
scopes: list[str] = field(default_factory=list)
|
|
36
|
+
|
|
37
|
+
def is_expired(self) -> bool:
|
|
38
|
+
if self.expires_at is None:
|
|
39
|
+
return False
|
|
40
|
+
return int(time.time()) >= self.expires_at
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class OAuthConfig:
|
|
45
|
+
client_id: str = ""
|
|
46
|
+
authorize_url: str = "https://console.anthropic.com/oauth/authorize"
|
|
47
|
+
token_url: str = "https://console.anthropic.com/oauth/token"
|
|
48
|
+
callback_port: int = DEFAULT_OAUTH_CALLBACK_PORT
|
|
49
|
+
scopes: list[str] = field(default_factory=lambda: ["user:inference"])
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class PkceCodePair:
|
|
54
|
+
code_verifier: str
|
|
55
|
+
code_challenge: str
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class OAuthCallbackParams:
|
|
60
|
+
code: str | None = None
|
|
61
|
+
state: str | None = None
|
|
62
|
+
error: str | None = None
|
|
63
|
+
error_description: str | None = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# ---------------------------------------------------------------------------
|
|
67
|
+
# PKCE helpers
|
|
68
|
+
# ---------------------------------------------------------------------------
|
|
69
|
+
|
|
70
|
+
def generate_pkce_pair() -> PkceCodePair:
|
|
71
|
+
"""Generate a PKCE code verifier and S256 challenge."""
|
|
72
|
+
verifier = secrets.token_urlsafe(64)
|
|
73
|
+
digest = hashlib.sha256(verifier.encode("ascii")).digest()
|
|
74
|
+
challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
|
75
|
+
return PkceCodePair(code_verifier=verifier, code_challenge=challenge)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def generate_state() -> str:
|
|
79
|
+
"""Generate a random OAuth state parameter."""
|
|
80
|
+
return secrets.token_urlsafe(32)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# ---------------------------------------------------------------------------
|
|
84
|
+
# Credential persistence
|
|
85
|
+
# ---------------------------------------------------------------------------
|
|
86
|
+
|
|
87
|
+
def _credentials_path(provider: str) -> Path:
|
|
88
|
+
return Path.home() / ".axion" / "credentials" / f"{provider}.json"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def save_oauth_credentials(provider: str, token_set: OAuthTokenSet) -> None:
|
|
92
|
+
"""Save OAuth credentials to disk."""
|
|
93
|
+
path = _credentials_path(provider)
|
|
94
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
95
|
+
data = {
|
|
96
|
+
"access_token": token_set.access_token,
|
|
97
|
+
"refresh_token": token_set.refresh_token,
|
|
98
|
+
"expires_at": token_set.expires_at,
|
|
99
|
+
"scopes": token_set.scopes,
|
|
100
|
+
}
|
|
101
|
+
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
|
102
|
+
try:
|
|
103
|
+
os.chmod(path, 0o600)
|
|
104
|
+
except OSError:
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def load_oauth_credentials(provider: str) -> OAuthTokenSet | None:
|
|
109
|
+
"""Load OAuth credentials from disk."""
|
|
110
|
+
path = _credentials_path(provider)
|
|
111
|
+
if not path.exists():
|
|
112
|
+
return None
|
|
113
|
+
try:
|
|
114
|
+
data = json.loads(path.read_text(encoding="utf-8"))
|
|
115
|
+
return OAuthTokenSet(
|
|
116
|
+
access_token=data["access_token"],
|
|
117
|
+
refresh_token=data.get("refresh_token"),
|
|
118
|
+
expires_at=data.get("expires_at"),
|
|
119
|
+
scopes=data.get("scopes", []),
|
|
120
|
+
)
|
|
121
|
+
except (json.JSONDecodeError, KeyError, OSError):
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def clear_oauth_credentials(provider: str) -> None:
|
|
126
|
+
"""Remove OAuth credentials from disk."""
|
|
127
|
+
path = _credentials_path(provider)
|
|
128
|
+
if path.exists():
|
|
129
|
+
path.unlink()
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# ---------------------------------------------------------------------------
|
|
133
|
+
# Browser launch (platform-specific)
|
|
134
|
+
# ---------------------------------------------------------------------------
|
|
135
|
+
|
|
136
|
+
def open_browser(url: str) -> bool:
|
|
137
|
+
"""Open a URL in the default browser. Returns True on success.
|
|
138
|
+
|
|
139
|
+
On Windows, `cmd /C start` interprets `&` in URLs as command separators,
|
|
140
|
+
truncating OAuth URLs with multiple query params. We use os.startfile
|
|
141
|
+
(Windows) and webbrowser.open (cross-platform) which handle this correctly.
|
|
142
|
+
"""
|
|
143
|
+
system = platform.system().lower()
|
|
144
|
+
try:
|
|
145
|
+
if system == "windows":
|
|
146
|
+
# os.startfile preserves & in URLs; cmd start does not
|
|
147
|
+
os.startfile(url)
|
|
148
|
+
return True
|
|
149
|
+
elif system == "darwin":
|
|
150
|
+
subprocess.Popen(["open", url], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
151
|
+
return True
|
|
152
|
+
elif system == "linux":
|
|
153
|
+
subprocess.Popen(
|
|
154
|
+
["xdg-open", url],
|
|
155
|
+
stdout=subprocess.DEVNULL,
|
|
156
|
+
stderr=subprocess.DEVNULL,
|
|
157
|
+
)
|
|
158
|
+
return True
|
|
159
|
+
except (FileNotFoundError, OSError, AttributeError) as exc:
|
|
160
|
+
logger.warning("Native browser launcher failed: %s, falling back to webbrowser module", exc)
|
|
161
|
+
|
|
162
|
+
# Fallback: Python stdlib webbrowser (handles all platforms safely)
|
|
163
|
+
try:
|
|
164
|
+
import webbrowser
|
|
165
|
+
return webbrowser.open(url)
|
|
166
|
+
except Exception as exc:
|
|
167
|
+
logger.warning("Failed to open browser: %s", exc)
|
|
168
|
+
return False
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# ---------------------------------------------------------------------------
|
|
172
|
+
# OAuth callback HTTP server
|
|
173
|
+
# ---------------------------------------------------------------------------
|
|
174
|
+
|
|
175
|
+
class _OAuthCallbackHandler(http.server.BaseHTTPRequestHandler):
|
|
176
|
+
"""HTTP handler that captures the OAuth callback parameters."""
|
|
177
|
+
|
|
178
|
+
callback_result: OAuthCallbackParams | None = None
|
|
179
|
+
|
|
180
|
+
def do_GET(self) -> None:
|
|
181
|
+
parsed = urllib.parse.urlparse(self.path)
|
|
182
|
+
params = urllib.parse.parse_qs(parsed.query)
|
|
183
|
+
|
|
184
|
+
result = OAuthCallbackParams(
|
|
185
|
+
code=params.get("code", [None])[0],
|
|
186
|
+
state=params.get("state", [None])[0],
|
|
187
|
+
error=params.get("error", [None])[0],
|
|
188
|
+
error_description=params.get("error_description", [None])[0],
|
|
189
|
+
)
|
|
190
|
+
_OAuthCallbackHandler.callback_result = result
|
|
191
|
+
|
|
192
|
+
# Send response
|
|
193
|
+
if result.error:
|
|
194
|
+
body = f"<h1>OAuth Error</h1><p>{result.error}: {result.error_description}</p>"
|
|
195
|
+
self.send_response(400)
|
|
196
|
+
else:
|
|
197
|
+
body = "<h1>Success!</h1><p>You can close this window and return to Axion Code.</p>"
|
|
198
|
+
self.send_response(200)
|
|
199
|
+
|
|
200
|
+
self.send_header("Content-Type", "text/html")
|
|
201
|
+
self.end_headers()
|
|
202
|
+
self.wfile.write(body.encode())
|
|
203
|
+
|
|
204
|
+
def log_message(self, format: str, *args: Any) -> None:
|
|
205
|
+
# Suppress default logging
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def wait_for_oauth_callback(
|
|
210
|
+
port: int = DEFAULT_OAUTH_CALLBACK_PORT,
|
|
211
|
+
timeout: float = 120.0,
|
|
212
|
+
) -> OAuthCallbackParams | None:
|
|
213
|
+
"""Start a local HTTP server and wait for the OAuth callback.
|
|
214
|
+
|
|
215
|
+
Returns the callback parameters or None on timeout.
|
|
216
|
+
"""
|
|
217
|
+
_OAuthCallbackHandler.callback_result = None
|
|
218
|
+
|
|
219
|
+
server = http.server.HTTPServer(("127.0.0.1", port), _OAuthCallbackHandler)
|
|
220
|
+
server.timeout = timeout
|
|
221
|
+
|
|
222
|
+
# Run in a thread with timeout
|
|
223
|
+
result: OAuthCallbackParams | None = None
|
|
224
|
+
|
|
225
|
+
def serve() -> None:
|
|
226
|
+
nonlocal result
|
|
227
|
+
server.handle_request() # Handle exactly one request
|
|
228
|
+
result = _OAuthCallbackHandler.callback_result
|
|
229
|
+
|
|
230
|
+
thread = threading.Thread(target=serve, daemon=True)
|
|
231
|
+
thread.start()
|
|
232
|
+
thread.join(timeout=timeout)
|
|
233
|
+
|
|
234
|
+
server.server_close()
|
|
235
|
+
return result
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
# ---------------------------------------------------------------------------
|
|
239
|
+
# Token exchange and refresh
|
|
240
|
+
# ---------------------------------------------------------------------------
|
|
241
|
+
|
|
242
|
+
async def exchange_authorization_code(
|
|
243
|
+
token_url: str,
|
|
244
|
+
code: str,
|
|
245
|
+
code_verifier: str,
|
|
246
|
+
client_id: str,
|
|
247
|
+
redirect_uri: str,
|
|
248
|
+
) -> OAuthTokenSet:
|
|
249
|
+
"""Exchange an authorization code for tokens."""
|
|
250
|
+
import httpx
|
|
251
|
+
|
|
252
|
+
async with httpx.AsyncClient() as client:
|
|
253
|
+
response = await client.post(
|
|
254
|
+
token_url,
|
|
255
|
+
data={
|
|
256
|
+
"grant_type": "authorization_code",
|
|
257
|
+
"code": code,
|
|
258
|
+
"code_verifier": code_verifier,
|
|
259
|
+
"client_id": client_id,
|
|
260
|
+
"redirect_uri": redirect_uri,
|
|
261
|
+
},
|
|
262
|
+
)
|
|
263
|
+
response.raise_for_status()
|
|
264
|
+
data = response.json()
|
|
265
|
+
|
|
266
|
+
expires_in = data.get("expires_in")
|
|
267
|
+
expires_at = int(time.time()) + expires_in if expires_in else None
|
|
268
|
+
|
|
269
|
+
return OAuthTokenSet(
|
|
270
|
+
access_token=data["access_token"],
|
|
271
|
+
refresh_token=data.get("refresh_token"),
|
|
272
|
+
expires_at=expires_at,
|
|
273
|
+
scopes=data.get("scope", "").split() if data.get("scope") else [],
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
async def refresh_token(
|
|
278
|
+
token_url: str,
|
|
279
|
+
refresh_token_str: str,
|
|
280
|
+
client_id: str,
|
|
281
|
+
) -> OAuthTokenSet:
|
|
282
|
+
"""Refresh an expired OAuth token."""
|
|
283
|
+
import httpx
|
|
284
|
+
|
|
285
|
+
async with httpx.AsyncClient() as client:
|
|
286
|
+
response = await client.post(
|
|
287
|
+
token_url,
|
|
288
|
+
data={
|
|
289
|
+
"grant_type": "refresh_token",
|
|
290
|
+
"refresh_token": refresh_token_str,
|
|
291
|
+
"client_id": client_id,
|
|
292
|
+
},
|
|
293
|
+
)
|
|
294
|
+
response.raise_for_status()
|
|
295
|
+
data = response.json()
|
|
296
|
+
|
|
297
|
+
expires_in = data.get("expires_in")
|
|
298
|
+
expires_at = int(time.time()) + expires_in if expires_in else None
|
|
299
|
+
|
|
300
|
+
return OAuthTokenSet(
|
|
301
|
+
access_token=data["access_token"],
|
|
302
|
+
refresh_token=data.get("refresh_token", refresh_token_str),
|
|
303
|
+
expires_at=expires_at,
|
|
304
|
+
scopes=data.get("scope", "").split() if data.get("scope") else [],
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
# ---------------------------------------------------------------------------
|
|
309
|
+
# Full login flow
|
|
310
|
+
# ---------------------------------------------------------------------------
|
|
311
|
+
|
|
312
|
+
def build_authorization_url(
|
|
313
|
+
config: OAuthConfig,
|
|
314
|
+
pkce: PkceCodePair,
|
|
315
|
+
state: str,
|
|
316
|
+
) -> str:
|
|
317
|
+
"""Build the full OAuth authorization URL."""
|
|
318
|
+
params = {
|
|
319
|
+
"client_id": config.client_id,
|
|
320
|
+
"response_type": "code",
|
|
321
|
+
"redirect_uri": f"http://127.0.0.1:{config.callback_port}{DEFAULT_CALLBACK_PATH}",
|
|
322
|
+
"state": state,
|
|
323
|
+
"code_challenge": pkce.code_challenge,
|
|
324
|
+
"code_challenge_method": "S256",
|
|
325
|
+
}
|
|
326
|
+
if config.scopes:
|
|
327
|
+
params["scope"] = " ".join(config.scopes)
|
|
328
|
+
|
|
329
|
+
return f"{config.authorize_url}?{urllib.parse.urlencode(params)}"
|