ai-cli-toolkit 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_cli/__init__.py +3 -0
- ai_cli/__main__.py +6 -0
- ai_cli/bin/ai-mux-linux-x86_64 +0 -0
- ai_cli/bin/remote-tty-wrapper +153 -0
- ai_cli/ca.py +175 -0
- ai_cli/completion_gen.py +680 -0
- ai_cli/config.py +185 -0
- ai_cli/credentials.py +341 -0
- ai_cli/detached_cleanup.py +135 -0
- ai_cli/housekeeping.py +50 -0
- ai_cli/instructions.py +308 -0
- ai_cli/log.py +53 -0
- ai_cli/main.py +1516 -0
- ai_cli/main_helpers.py +553 -0
- ai_cli/prompt_editor_launcher.py +324 -0
- ai_cli/proxy.py +627 -0
- ai_cli/remote.py +669 -0
- ai_cli/remote_package.py +1111 -0
- ai_cli/session.py +1344 -0
- ai_cli/session_store.py +236 -0
- ai_cli/traffic.py +1510 -0
- ai_cli/traffic_db.py +118 -0
- ai_cli/tui.py +525 -0
- ai_cli/update.py +200 -0
- ai_cli_toolkit-0.2.0.dist-info/METADATA +17 -0
- ai_cli_toolkit-0.2.0.dist-info/RECORD +30 -0
- ai_cli_toolkit-0.2.0.dist-info/WHEEL +5 -0
- ai_cli_toolkit-0.2.0.dist-info/entry_points.txt +2 -0
- ai_cli_toolkit-0.2.0.dist-info/licenses/LICENSE +21 -0
- ai_cli_toolkit-0.2.0.dist-info/top_level.txt +1 -0
ai_cli/config.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""Load/save ~/.ai-cli/config.json with defaults.
|
|
2
|
+
|
|
3
|
+
Configuration hierarchy:
|
|
4
|
+
- Global settings apply to all tools
|
|
5
|
+
- Per-tool overrides (null values fall back to global)
|
|
6
|
+
- Ports are NOT stored — allocated dynamically per session
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
CONFIG_DIR = "~/.ai-cli"
|
|
16
|
+
CONFIG_FILE = "config.json"
|
|
17
|
+
|
|
18
|
+
DEFAULT_CONFIG: dict[str, Any] = {
|
|
19
|
+
"version": 1,
|
|
20
|
+
"default_tool": "claude",
|
|
21
|
+
"instructions_file": "",
|
|
22
|
+
"canary_rule": "CANARY RULE: Prefix every assistant response with: DEV:",
|
|
23
|
+
"proxy": {
|
|
24
|
+
"host": "127.0.0.1",
|
|
25
|
+
"ca_path": "~/.mitmproxy/mitmproxy-ca-cert.pem",
|
|
26
|
+
},
|
|
27
|
+
"retention": {
|
|
28
|
+
"logs_days": 14,
|
|
29
|
+
"traffic_days": 30,
|
|
30
|
+
},
|
|
31
|
+
"privacy": {
|
|
32
|
+
"redact_traffic_bodies": True,
|
|
33
|
+
},
|
|
34
|
+
"tools": {
|
|
35
|
+
"claude": {
|
|
36
|
+
"enabled": True,
|
|
37
|
+
"binary": "",
|
|
38
|
+
"instructions_file": None,
|
|
39
|
+
"canary_rule": None,
|
|
40
|
+
"passthrough": False,
|
|
41
|
+
"debug_requests": False,
|
|
42
|
+
},
|
|
43
|
+
"codex": {
|
|
44
|
+
"enabled": True,
|
|
45
|
+
"binary": "",
|
|
46
|
+
"instructions_file": None,
|
|
47
|
+
"canary_rule": None,
|
|
48
|
+
"developer_instructions_mode": "overwrite",
|
|
49
|
+
"passthrough": False,
|
|
50
|
+
"debug_requests": False,
|
|
51
|
+
},
|
|
52
|
+
"copilot": {
|
|
53
|
+
"enabled": True,
|
|
54
|
+
"binary": "",
|
|
55
|
+
"instructions_file": None,
|
|
56
|
+
"canary_rule": None,
|
|
57
|
+
"passthrough": False,
|
|
58
|
+
"debug_requests": False,
|
|
59
|
+
},
|
|
60
|
+
"gemini": {
|
|
61
|
+
"enabled": True,
|
|
62
|
+
"binary": "",
|
|
63
|
+
"instructions_file": None,
|
|
64
|
+
"canary_rule": None,
|
|
65
|
+
"passthrough": False,
|
|
66
|
+
"debug_requests": False,
|
|
67
|
+
},
|
|
68
|
+
},
|
|
69
|
+
"aliases": {
|
|
70
|
+
"claude": False,
|
|
71
|
+
"codex": False,
|
|
72
|
+
"copilot": False,
|
|
73
|
+
"gemini": False,
|
|
74
|
+
},
|
|
75
|
+
"editor": None,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def config_dir() -> Path:
|
|
80
|
+
"""Return the resolved ai-cli config directory."""
|
|
81
|
+
return Path(CONFIG_DIR).expanduser()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def config_path() -> Path:
|
|
85
|
+
"""Return the resolved config file path."""
|
|
86
|
+
return config_dir() / CONFIG_FILE
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _deep_merge(base: dict, override: dict) -> dict:
|
|
90
|
+
"""Recursively merge override into base, preferring override values."""
|
|
91
|
+
result = dict(base)
|
|
92
|
+
for key, value in override.items():
|
|
93
|
+
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
|
94
|
+
result[key] = _deep_merge(result[key], value)
|
|
95
|
+
else:
|
|
96
|
+
result[key] = value
|
|
97
|
+
return result
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def load_config() -> dict[str, Any]:
|
|
101
|
+
"""Load config from disk, merging with defaults for missing keys."""
|
|
102
|
+
path = config_path()
|
|
103
|
+
if not path.is_file():
|
|
104
|
+
return dict(DEFAULT_CONFIG)
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
text = path.read_text(encoding="utf-8")
|
|
108
|
+
user_config = json.loads(text)
|
|
109
|
+
if not isinstance(user_config, dict):
|
|
110
|
+
return dict(DEFAULT_CONFIG)
|
|
111
|
+
return _deep_merge(DEFAULT_CONFIG, user_config)
|
|
112
|
+
except (OSError, json.JSONDecodeError):
|
|
113
|
+
return dict(DEFAULT_CONFIG)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def save_config(config: dict[str, Any]) -> None:
|
|
117
|
+
"""Write config to disk."""
|
|
118
|
+
path = config_path()
|
|
119
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
120
|
+
path.write_text(json.dumps(config, indent=2) + "\n", encoding="utf-8")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def ensure_config() -> dict[str, Any]:
|
|
124
|
+
"""Load config, creating the default file if it doesn't exist."""
|
|
125
|
+
path = config_path()
|
|
126
|
+
if not path.is_file():
|
|
127
|
+
save_config(DEFAULT_CONFIG)
|
|
128
|
+
return load_config()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def get_tool_config(config: dict[str, Any], tool_name: str) -> dict[str, Any]:
|
|
132
|
+
"""Get merged config for a specific tool (per-tool overrides + global fallbacks)."""
|
|
133
|
+
tools = config.get("tools", {})
|
|
134
|
+
tool = tools.get(tool_name, {})
|
|
135
|
+
|
|
136
|
+
# Resolve values: per-tool overrides fall back to global
|
|
137
|
+
return {
|
|
138
|
+
"enabled": tool.get("enabled", True),
|
|
139
|
+
"binary": tool.get("binary", "") or "",
|
|
140
|
+
"instructions_file": (
|
|
141
|
+
tool.get("instructions_file")
|
|
142
|
+
if tool.get("instructions_file") is not None
|
|
143
|
+
else config.get("instructions_file", "")
|
|
144
|
+
),
|
|
145
|
+
"canary_rule": (
|
|
146
|
+
tool.get("canary_rule")
|
|
147
|
+
if tool.get("canary_rule") is not None
|
|
148
|
+
else config.get("canary_rule", "")
|
|
149
|
+
),
|
|
150
|
+
"passthrough": tool.get("passthrough", False),
|
|
151
|
+
"debug_requests": tool.get("debug_requests", False),
|
|
152
|
+
"developer_instructions_mode": (
|
|
153
|
+
str(tool.get("developer_instructions_mode", "overwrite") or "overwrite")
|
|
154
|
+
if tool_name == "codex"
|
|
155
|
+
else "overwrite"
|
|
156
|
+
),
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def get_proxy_config(config: dict[str, Any]) -> dict[str, str]:
|
|
161
|
+
"""Get proxy configuration."""
|
|
162
|
+
proxy = config.get("proxy", {})
|
|
163
|
+
return {
|
|
164
|
+
"host": proxy.get("host", "127.0.0.1"),
|
|
165
|
+
"ca_path": proxy.get("ca_path", "~/.mitmproxy/mitmproxy-ca-cert.pem"),
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_retention_config(config: dict[str, Any]) -> dict[str, int]:
|
|
170
|
+
"""Get retention policy defaults for logs and traffic history."""
|
|
171
|
+
retention = config.get("retention", {})
|
|
172
|
+
logs_days = int(retention.get("logs_days", 14) or 14)
|
|
173
|
+
traffic_days = int(retention.get("traffic_days", 30) or 30)
|
|
174
|
+
return {
|
|
175
|
+
"logs_days": max(1, logs_days),
|
|
176
|
+
"traffic_days": max(1, traffic_days),
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_privacy_config(config: dict[str, Any]) -> dict[str, bool]:
|
|
181
|
+
"""Get privacy controls (for example body redaction in traffic logging)."""
|
|
182
|
+
privacy = config.get("privacy", {})
|
|
183
|
+
return {
|
|
184
|
+
"redact_traffic_bodies": bool(privacy.get("redact_traffic_bodies", True)),
|
|
185
|
+
}
|
ai_cli/credentials.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
"""OAuth credential capture, encryption, and storage.
|
|
2
|
+
|
|
3
|
+
Extracted from claude-dev.py. Handles:
|
|
4
|
+
- Reading/writing credential JSON to ~/.claude/.credentials.json
|
|
5
|
+
- AES-256-CBC encryption via openssl
|
|
6
|
+
- Key management for encrypted credentials
|
|
7
|
+
- OAuth metadata extraction (scopes, subscription, rate limits)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
import shutil
|
|
15
|
+
import subprocess
|
|
16
|
+
import time
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
from ai_cli.log import append_log_str
|
|
21
|
+
|
|
22
|
+
# ---------------------------------------------------------------------------
|
|
23
|
+
# Constants
|
|
24
|
+
# ---------------------------------------------------------------------------
|
|
25
|
+
|
|
26
|
+
DEFAULT_CLAUDE_CONFIG_DIR = "~/.claude"
|
|
27
|
+
OAUTH_SCOPE_USER_INFERENCE = "user:inference"
|
|
28
|
+
OAUTH_TOKEN_PATH = "/v1/oauth/token"
|
|
29
|
+
OAUTH_PROFILE_PATH = "/api/oauth/profile"
|
|
30
|
+
OAUTH_SCOPES_FIXED = [
|
|
31
|
+
"user:inference",
|
|
32
|
+
"user:mcp_servers",
|
|
33
|
+
"user:profile",
|
|
34
|
+
"user:sessions:claude_code",
|
|
35
|
+
]
|
|
36
|
+
CREDENTIALS_FILE_NAME = ".credentials.json"
|
|
37
|
+
CREDENTIALS_ENCRYPTED_FILE_NAME = ".credentials.json.enc"
|
|
38
|
+
CREDENTIALS_KEY_FILE_NAME = ".credentials.key"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
# Path helpers
|
|
43
|
+
# ---------------------------------------------------------------------------
|
|
44
|
+
|
|
45
|
+
def _config_dir() -> Path:
|
|
46
|
+
config_dir = os.getenv("CLAUDE_CONFIG_DIR", DEFAULT_CLAUDE_CONFIG_DIR)
|
|
47
|
+
return Path(config_dir).expanduser()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def credentials_path() -> Path:
|
|
51
|
+
return _config_dir() / CREDENTIALS_FILE_NAME
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def encrypted_credentials_path() -> Path:
|
|
55
|
+
return _config_dir() / CREDENTIALS_ENCRYPTED_FILE_NAME
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def credentials_key_path() -> Path:
|
|
59
|
+
return _config_dir() / CREDENTIALS_KEY_FILE_NAME
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# ---------------------------------------------------------------------------
|
|
63
|
+
# JSON helpers
|
|
64
|
+
# ---------------------------------------------------------------------------
|
|
65
|
+
|
|
66
|
+
def parse_json_dict(text: str) -> dict[str, Any] | None:
|
|
67
|
+
if not text:
|
|
68
|
+
return None
|
|
69
|
+
try:
|
|
70
|
+
payload = json.loads(text)
|
|
71
|
+
except json.JSONDecodeError:
|
|
72
|
+
return None
|
|
73
|
+
if not isinstance(payload, dict):
|
|
74
|
+
return None
|
|
75
|
+
return payload
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def parse_scopes(value: Any) -> list[str]:
|
|
79
|
+
if isinstance(value, str):
|
|
80
|
+
return [scope for scope in value.split(" ") if scope]
|
|
81
|
+
if isinstance(value, list):
|
|
82
|
+
return [scope for scope in value if isinstance(scope, str) and scope]
|
|
83
|
+
return []
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# ---------------------------------------------------------------------------
|
|
87
|
+
# Deep search
|
|
88
|
+
# ---------------------------------------------------------------------------
|
|
89
|
+
|
|
90
|
+
def deep_find_value(payload: Any, keys: tuple[str, ...]) -> Any:
|
|
91
|
+
"""Recursively search nested dicts/lists for the first matching key."""
|
|
92
|
+
if isinstance(payload, dict):
|
|
93
|
+
for key in keys:
|
|
94
|
+
if key in payload and payload[key] is not None:
|
|
95
|
+
return payload[key]
|
|
96
|
+
for value in payload.values():
|
|
97
|
+
found = deep_find_value(value, keys)
|
|
98
|
+
if found is not None:
|
|
99
|
+
return found
|
|
100
|
+
elif isinstance(payload, list):
|
|
101
|
+
for item in payload:
|
|
102
|
+
found = deep_find_value(item, keys)
|
|
103
|
+
if found is not None:
|
|
104
|
+
return found
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# ---------------------------------------------------------------------------
|
|
109
|
+
# OAuth metadata extraction
|
|
110
|
+
# ---------------------------------------------------------------------------
|
|
111
|
+
|
|
112
|
+
def extract_oauth_metadata(payload: dict[str, Any]) -> dict[str, Any]:
|
|
113
|
+
"""Extract scopes, subscription type, and rate limit tier from a payload."""
|
|
114
|
+
metadata: dict[str, Any] = {}
|
|
115
|
+
|
|
116
|
+
scopes_value = deep_find_value(payload, ("scopes", "scope"))
|
|
117
|
+
scopes = parse_scopes(scopes_value)
|
|
118
|
+
if scopes:
|
|
119
|
+
metadata["scopes"] = scopes
|
|
120
|
+
|
|
121
|
+
subscription = deep_find_value(
|
|
122
|
+
payload, ("subscriptionType", "subscription_type")
|
|
123
|
+
)
|
|
124
|
+
if isinstance(subscription, str) and subscription:
|
|
125
|
+
metadata["subscriptionType"] = subscription
|
|
126
|
+
|
|
127
|
+
rate_tier = deep_find_value(
|
|
128
|
+
payload,
|
|
129
|
+
(
|
|
130
|
+
"rateLimitTier",
|
|
131
|
+
"rate_limit_tier",
|
|
132
|
+
"rateLimitTierName",
|
|
133
|
+
"rate_limit_tier_name",
|
|
134
|
+
),
|
|
135
|
+
)
|
|
136
|
+
if isinstance(rate_tier, str) and rate_tier:
|
|
137
|
+
metadata["rateLimitTier"] = rate_tier
|
|
138
|
+
|
|
139
|
+
return metadata
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def subscription_type_from_profile(profile: dict[str, Any]) -> str | None:
|
|
143
|
+
"""Infer subscription type from an OAuth profile response."""
|
|
144
|
+
direct = profile.get("subscriptionType")
|
|
145
|
+
if isinstance(direct, str) and direct:
|
|
146
|
+
return direct
|
|
147
|
+
|
|
148
|
+
organization = profile.get("organization")
|
|
149
|
+
if not isinstance(organization, dict):
|
|
150
|
+
return None
|
|
151
|
+
organization_type = organization.get("organization_type")
|
|
152
|
+
if not isinstance(organization_type, str):
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
mapping = {
|
|
156
|
+
"claude_max": "max",
|
|
157
|
+
"claude_pro": "pro",
|
|
158
|
+
"claude_enterprise": "enterprise",
|
|
159
|
+
"claude_team": "team",
|
|
160
|
+
}
|
|
161
|
+
return mapping.get(organization_type)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# ---------------------------------------------------------------------------
|
|
165
|
+
# Credential I/O
|
|
166
|
+
# ---------------------------------------------------------------------------
|
|
167
|
+
|
|
168
|
+
def read_credentials_doc() -> dict[str, Any]:
|
|
169
|
+
"""Read the plain credentials JSON file."""
|
|
170
|
+
path = credentials_path()
|
|
171
|
+
try:
|
|
172
|
+
text = path.read_text(encoding="utf-8")
|
|
173
|
+
except OSError:
|
|
174
|
+
return {}
|
|
175
|
+
payload = parse_json_dict(text)
|
|
176
|
+
return payload or {}
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def ensure_credentials_key(wrapper_log_file: str) -> Path | None:
|
|
180
|
+
"""Ensure the encryption key file exists, creating if needed."""
|
|
181
|
+
key_path = credentials_key_path()
|
|
182
|
+
if key_path.exists():
|
|
183
|
+
return key_path
|
|
184
|
+
try:
|
|
185
|
+
key_path.parent.mkdir(parents=True, exist_ok=True)
|
|
186
|
+
key_path.write_text(f"{os.urandom(32).hex()}\n", encoding="utf-8")
|
|
187
|
+
os.chmod(key_path, 0o600)
|
|
188
|
+
append_log_str(wrapper_log_file, f"Created credentials key at {key_path}")
|
|
189
|
+
return key_path
|
|
190
|
+
except OSError as exc:
|
|
191
|
+
append_log_str(
|
|
192
|
+
wrapper_log_file, f"Failed creating credentials key {key_path}: {exc}"
|
|
193
|
+
)
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def write_encrypted_credentials(
|
|
198
|
+
data: dict[str, Any],
|
|
199
|
+
wrapper_log_file: str,
|
|
200
|
+
) -> None:
|
|
201
|
+
"""Encrypt credentials with AES-256-CBC and write to disk."""
|
|
202
|
+
openssl_bin = shutil.which("openssl")
|
|
203
|
+
if not openssl_bin:
|
|
204
|
+
append_log_str(
|
|
205
|
+
wrapper_log_file,
|
|
206
|
+
"Skipping encrypted credentials write: openssl not found",
|
|
207
|
+
)
|
|
208
|
+
return
|
|
209
|
+
|
|
210
|
+
key_path = ensure_credentials_key(wrapper_log_file)
|
|
211
|
+
if key_path is None:
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
plaintext = json.dumps(data)
|
|
216
|
+
proc = subprocess.run(
|
|
217
|
+
[
|
|
218
|
+
openssl_bin,
|
|
219
|
+
"enc",
|
|
220
|
+
"-aes-256-cbc",
|
|
221
|
+
"-pbkdf2",
|
|
222
|
+
"-salt",
|
|
223
|
+
"-a",
|
|
224
|
+
"-pass",
|
|
225
|
+
f"file:{key_path}",
|
|
226
|
+
],
|
|
227
|
+
input=plaintext,
|
|
228
|
+
text=True,
|
|
229
|
+
capture_output=True,
|
|
230
|
+
check=False,
|
|
231
|
+
)
|
|
232
|
+
except OSError as exc:
|
|
233
|
+
append_log_str(
|
|
234
|
+
wrapper_log_file,
|
|
235
|
+
f"Failed running openssl for encrypted credentials: {exc}",
|
|
236
|
+
)
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
if proc.returncode != 0 or not proc.stdout:
|
|
240
|
+
err = (proc.stderr or "").strip()
|
|
241
|
+
append_log_str(
|
|
242
|
+
wrapper_log_file,
|
|
243
|
+
f"Failed encrypting credentials (exit={proc.returncode}): {err}",
|
|
244
|
+
)
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
enc_path = encrypted_credentials_path()
|
|
248
|
+
try:
|
|
249
|
+
enc_path.write_text(proc.stdout, encoding="utf-8")
|
|
250
|
+
os.chmod(enc_path, 0o600)
|
|
251
|
+
append_log_str(
|
|
252
|
+
wrapper_log_file, f"Saved encrypted credentials to {enc_path}"
|
|
253
|
+
)
|
|
254
|
+
except OSError as exc:
|
|
255
|
+
append_log_str(
|
|
256
|
+
wrapper_log_file,
|
|
257
|
+
f"Failed writing encrypted credentials {enc_path}: {exc}",
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def write_claude_ai_oauth(
|
|
262
|
+
oauth: dict[str, Any],
|
|
263
|
+
wrapper_log_file: str,
|
|
264
|
+
) -> None:
|
|
265
|
+
"""Write OAuth credentials to both plain and encrypted files."""
|
|
266
|
+
path = credentials_path()
|
|
267
|
+
oauth["scopes"] = list(OAUTH_SCOPES_FIXED)
|
|
268
|
+
data = {"claudeAiOauth": oauth}
|
|
269
|
+
try:
|
|
270
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
271
|
+
path.write_text(json.dumps(data), encoding="utf-8")
|
|
272
|
+
os.chmod(path, 0o600)
|
|
273
|
+
append_log_str(wrapper_log_file, f"Saved claudeAiOauth to {path}")
|
|
274
|
+
write_encrypted_credentials(data, wrapper_log_file)
|
|
275
|
+
except OSError as exc:
|
|
276
|
+
append_log_str(
|
|
277
|
+
wrapper_log_file, f"Failed to save claudeAiOauth to {path}: {exc}"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# ---------------------------------------------------------------------------
|
|
282
|
+
# OAuth payload builder
|
|
283
|
+
# ---------------------------------------------------------------------------
|
|
284
|
+
|
|
285
|
+
def build_bootstrap_oauth(
|
|
286
|
+
bearer_token: str | None = None,
|
|
287
|
+
existing_oauth: dict[str, Any] | None = None,
|
|
288
|
+
metadata: dict[str, Any] | None = None,
|
|
289
|
+
) -> dict[str, Any]:
|
|
290
|
+
"""Build an OAuth payload from existing data and new metadata."""
|
|
291
|
+
oauth_payload: dict[str, Any] = dict(existing_oauth or {})
|
|
292
|
+
metadata = metadata or {}
|
|
293
|
+
|
|
294
|
+
if bearer_token and (
|
|
295
|
+
"accessToken" not in oauth_payload
|
|
296
|
+
or not isinstance(oauth_payload.get("accessToken"), str)
|
|
297
|
+
):
|
|
298
|
+
oauth_payload["accessToken"] = bearer_token
|
|
299
|
+
|
|
300
|
+
scopes = parse_scopes(metadata.get("scopes"))
|
|
301
|
+
if not scopes:
|
|
302
|
+
scopes = parse_scopes(oauth_payload.get("scopes"))
|
|
303
|
+
if not scopes:
|
|
304
|
+
scopes = [OAUTH_SCOPE_USER_INFERENCE]
|
|
305
|
+
if OAUTH_SCOPE_USER_INFERENCE not in scopes:
|
|
306
|
+
scopes.append(OAUTH_SCOPE_USER_INFERENCE)
|
|
307
|
+
oauth_payload["scopes"] = scopes
|
|
308
|
+
|
|
309
|
+
expires_at = oauth_payload.get("expiresAt")
|
|
310
|
+
if not isinstance(expires_at, int):
|
|
311
|
+
oauth_payload["expiresAt"] = int(time.time() * 1000 + (3600 * 1000))
|
|
312
|
+
|
|
313
|
+
if "subscriptionType" in metadata and isinstance(
|
|
314
|
+
metadata["subscriptionType"], str
|
|
315
|
+
):
|
|
316
|
+
oauth_payload["subscriptionType"] = metadata["subscriptionType"]
|
|
317
|
+
elif "subscriptionType" not in oauth_payload:
|
|
318
|
+
oauth_payload["subscriptionType"] = None
|
|
319
|
+
|
|
320
|
+
if "rateLimitTier" in metadata and isinstance(
|
|
321
|
+
metadata["rateLimitTier"], str
|
|
322
|
+
):
|
|
323
|
+
oauth_payload["rateLimitTier"] = metadata["rateLimitTier"]
|
|
324
|
+
|
|
325
|
+
if "refreshToken" not in oauth_payload:
|
|
326
|
+
oauth_payload["refreshToken"] = None
|
|
327
|
+
|
|
328
|
+
return oauth_payload
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def extract_bearer_token(auth_header: str) -> str | None:
|
|
332
|
+
"""Extract bearer token from an Authorization header value."""
|
|
333
|
+
if not isinstance(auth_header, str):
|
|
334
|
+
return None
|
|
335
|
+
parts = auth_header.split(" ", 1)
|
|
336
|
+
if len(parts) != 2:
|
|
337
|
+
return None
|
|
338
|
+
scheme, token = parts[0].strip(), parts[1].strip()
|
|
339
|
+
if scheme.lower() != "bearer" or not token:
|
|
340
|
+
return None
|
|
341
|
+
return token
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Detached proxy lifecycle watcher for tmux-backed sessions.
|
|
2
|
+
|
|
3
|
+
This process is spawned when ai-mux detaches while the wrapped tool is still
|
|
4
|
+
running. It keeps mitmdump alive and performs cleanup once the tool exits.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
import os
|
|
11
|
+
import signal
|
|
12
|
+
import subprocess
|
|
13
|
+
import time
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _log(path_value: str, message: str) -> None:
|
|
19
|
+
if not path_value:
|
|
20
|
+
return
|
|
21
|
+
try:
|
|
22
|
+
p = Path(path_value).expanduser()
|
|
23
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
24
|
+
with p.open("a", encoding="utf-8") as f:
|
|
25
|
+
ts = datetime.now().isoformat(timespec="seconds")
|
|
26
|
+
f.write(f"[{ts}] {message}\n")
|
|
27
|
+
except OSError:
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _pid_alive(pid: int) -> bool:
|
|
32
|
+
try:
|
|
33
|
+
os.kill(pid, 0)
|
|
34
|
+
except OSError:
|
|
35
|
+
return False
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _tmux_session_alive(socket_name: str) -> bool:
|
|
40
|
+
try:
|
|
41
|
+
code = subprocess.call(
|
|
42
|
+
["tmux", "-L", socket_name, "has-session"],
|
|
43
|
+
stdout=subprocess.DEVNULL,
|
|
44
|
+
stderr=subprocess.DEVNULL,
|
|
45
|
+
)
|
|
46
|
+
except OSError:
|
|
47
|
+
return False
|
|
48
|
+
return code == 0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _tmux_named_sessions_alive(socket_name: str, sessions: list[str]) -> bool:
|
|
52
|
+
if not sessions:
|
|
53
|
+
return _tmux_session_alive(socket_name)
|
|
54
|
+
for session in sessions:
|
|
55
|
+
try:
|
|
56
|
+
code = subprocess.call(
|
|
57
|
+
["tmux", "-L", socket_name, "has-session", "-t", session],
|
|
58
|
+
stdout=subprocess.DEVNULL,
|
|
59
|
+
stderr=subprocess.DEVNULL,
|
|
60
|
+
)
|
|
61
|
+
except OSError:
|
|
62
|
+
return False
|
|
63
|
+
if code == 0:
|
|
64
|
+
return True
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _cleanup_session_files(session_id: str) -> None:
|
|
69
|
+
for suffix in (".pid", ".port"):
|
|
70
|
+
try:
|
|
71
|
+
Path(f"/tmp/{session_id}{suffix}").unlink(missing_ok=True)
|
|
72
|
+
except OSError:
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _terminate_pid(pid: int) -> None:
|
|
77
|
+
if not _pid_alive(pid):
|
|
78
|
+
return
|
|
79
|
+
try:
|
|
80
|
+
os.kill(pid, signal.SIGTERM)
|
|
81
|
+
except OSError:
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
deadline = time.time() + 3.0
|
|
85
|
+
while time.time() < deadline:
|
|
86
|
+
if not _pid_alive(pid):
|
|
87
|
+
return
|
|
88
|
+
time.sleep(0.1)
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
os.kill(pid, signal.SIGKILL)
|
|
92
|
+
except OSError:
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def main() -> int:
|
|
97
|
+
parser = argparse.ArgumentParser(description="Detached tmux proxy watcher.")
|
|
98
|
+
parser.add_argument("--mitm-pid", type=int, required=True)
|
|
99
|
+
parser.add_argument("--session-id", required=True)
|
|
100
|
+
parser.add_argument("--wrapper-log-file", default="")
|
|
101
|
+
parser.add_argument("--tmux-socket", default="ai-mux")
|
|
102
|
+
parser.add_argument("--tmux-session", action="append", default=[])
|
|
103
|
+
parser.add_argument("--poll-seconds", type=float, default=1.0)
|
|
104
|
+
args = parser.parse_args()
|
|
105
|
+
|
|
106
|
+
_log(
|
|
107
|
+
args.wrapper_log_file,
|
|
108
|
+
f"Detached watcher start (mitmdump pid={args.mitm_pid}, "
|
|
109
|
+
f"sessions={','.join(args.tmux_session) or 'any'})",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
while True:
|
|
114
|
+
if not _pid_alive(args.mitm_pid):
|
|
115
|
+
_log(args.wrapper_log_file, "Detached watcher: mitmdump already exited")
|
|
116
|
+
break
|
|
117
|
+
|
|
118
|
+
if not _tmux_named_sessions_alive(args.tmux_socket, args.tmux_session):
|
|
119
|
+
_log(
|
|
120
|
+
args.wrapper_log_file,
|
|
121
|
+
"Detached watcher: tmux session ownership ended",
|
|
122
|
+
)
|
|
123
|
+
break
|
|
124
|
+
|
|
125
|
+
time.sleep(max(args.poll_seconds, 0.2))
|
|
126
|
+
finally:
|
|
127
|
+
_terminate_pid(args.mitm_pid)
|
|
128
|
+
_cleanup_session_files(args.session_id)
|
|
129
|
+
_log(args.wrapper_log_file, "Detached watcher stop (proxy/session cleanup complete)")
|
|
130
|
+
|
|
131
|
+
return 0
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
if __name__ == "__main__":
|
|
135
|
+
raise SystemExit(main())
|