ai-cli-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ai_cli/config.py ADDED
@@ -0,0 +1,185 @@
1
+ """Load/save ~/.ai-cli/config.json with defaults.
2
+
3
+ Configuration hierarchy:
4
+ - Global settings apply to all tools
5
+ - Per-tool overrides (null values fall back to global)
6
+ - Ports are NOT stored — allocated dynamically per session
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ CONFIG_DIR = "~/.ai-cli"
16
+ CONFIG_FILE = "config.json"
17
+
18
+ DEFAULT_CONFIG: dict[str, Any] = {
19
+ "version": 1,
20
+ "default_tool": "claude",
21
+ "instructions_file": "",
22
+ "canary_rule": "CANARY RULE: Prefix every assistant response with: DEV:",
23
+ "proxy": {
24
+ "host": "127.0.0.1",
25
+ "ca_path": "~/.mitmproxy/mitmproxy-ca-cert.pem",
26
+ },
27
+ "retention": {
28
+ "logs_days": 14,
29
+ "traffic_days": 30,
30
+ },
31
+ "privacy": {
32
+ "redact_traffic_bodies": True,
33
+ },
34
+ "tools": {
35
+ "claude": {
36
+ "enabled": True,
37
+ "binary": "",
38
+ "instructions_file": None,
39
+ "canary_rule": None,
40
+ "passthrough": False,
41
+ "debug_requests": False,
42
+ },
43
+ "codex": {
44
+ "enabled": True,
45
+ "binary": "",
46
+ "instructions_file": None,
47
+ "canary_rule": None,
48
+ "developer_instructions_mode": "overwrite",
49
+ "passthrough": False,
50
+ "debug_requests": False,
51
+ },
52
+ "copilot": {
53
+ "enabled": True,
54
+ "binary": "",
55
+ "instructions_file": None,
56
+ "canary_rule": None,
57
+ "passthrough": False,
58
+ "debug_requests": False,
59
+ },
60
+ "gemini": {
61
+ "enabled": True,
62
+ "binary": "",
63
+ "instructions_file": None,
64
+ "canary_rule": None,
65
+ "passthrough": False,
66
+ "debug_requests": False,
67
+ },
68
+ },
69
+ "aliases": {
70
+ "claude": False,
71
+ "codex": False,
72
+ "copilot": False,
73
+ "gemini": False,
74
+ },
75
+ "editor": None,
76
+ }
77
+
78
+
79
+ def config_dir() -> Path:
80
+ """Return the resolved ai-cli config directory."""
81
+ return Path(CONFIG_DIR).expanduser()
82
+
83
+
84
+ def config_path() -> Path:
85
+ """Return the resolved config file path."""
86
+ return config_dir() / CONFIG_FILE
87
+
88
+
89
+ def _deep_merge(base: dict, override: dict) -> dict:
90
+ """Recursively merge override into base, preferring override values."""
91
+ result = dict(base)
92
+ for key, value in override.items():
93
+ if key in result and isinstance(result[key], dict) and isinstance(value, dict):
94
+ result[key] = _deep_merge(result[key], value)
95
+ else:
96
+ result[key] = value
97
+ return result
98
+
99
+
100
+ def load_config() -> dict[str, Any]:
101
+ """Load config from disk, merging with defaults for missing keys."""
102
+ path = config_path()
103
+ if not path.is_file():
104
+ return dict(DEFAULT_CONFIG)
105
+
106
+ try:
107
+ text = path.read_text(encoding="utf-8")
108
+ user_config = json.loads(text)
109
+ if not isinstance(user_config, dict):
110
+ return dict(DEFAULT_CONFIG)
111
+ return _deep_merge(DEFAULT_CONFIG, user_config)
112
+ except (OSError, json.JSONDecodeError):
113
+ return dict(DEFAULT_CONFIG)
114
+
115
+
116
+ def save_config(config: dict[str, Any]) -> None:
117
+ """Write config to disk."""
118
+ path = config_path()
119
+ path.parent.mkdir(parents=True, exist_ok=True)
120
+ path.write_text(json.dumps(config, indent=2) + "\n", encoding="utf-8")
121
+
122
+
123
+ def ensure_config() -> dict[str, Any]:
124
+ """Load config, creating the default file if it doesn't exist."""
125
+ path = config_path()
126
+ if not path.is_file():
127
+ save_config(DEFAULT_CONFIG)
128
+ return load_config()
129
+
130
+
131
+ def get_tool_config(config: dict[str, Any], tool_name: str) -> dict[str, Any]:
132
+ """Get merged config for a specific tool (per-tool overrides + global fallbacks)."""
133
+ tools = config.get("tools", {})
134
+ tool = tools.get(tool_name, {})
135
+
136
+ # Resolve values: per-tool overrides fall back to global
137
+ return {
138
+ "enabled": tool.get("enabled", True),
139
+ "binary": tool.get("binary", "") or "",
140
+ "instructions_file": (
141
+ tool.get("instructions_file")
142
+ if tool.get("instructions_file") is not None
143
+ else config.get("instructions_file", "")
144
+ ),
145
+ "canary_rule": (
146
+ tool.get("canary_rule")
147
+ if tool.get("canary_rule") is not None
148
+ else config.get("canary_rule", "")
149
+ ),
150
+ "passthrough": tool.get("passthrough", False),
151
+ "debug_requests": tool.get("debug_requests", False),
152
+ "developer_instructions_mode": (
153
+ str(tool.get("developer_instructions_mode", "overwrite") or "overwrite")
154
+ if tool_name == "codex"
155
+ else "overwrite"
156
+ ),
157
+ }
158
+
159
+
160
+ def get_proxy_config(config: dict[str, Any]) -> dict[str, str]:
161
+ """Get proxy configuration."""
162
+ proxy = config.get("proxy", {})
163
+ return {
164
+ "host": proxy.get("host", "127.0.0.1"),
165
+ "ca_path": proxy.get("ca_path", "~/.mitmproxy/mitmproxy-ca-cert.pem"),
166
+ }
167
+
168
+
169
+ def get_retention_config(config: dict[str, Any]) -> dict[str, int]:
170
+ """Get retention policy defaults for logs and traffic history."""
171
+ retention = config.get("retention", {})
172
+ logs_days = int(retention.get("logs_days", 14) or 14)
173
+ traffic_days = int(retention.get("traffic_days", 30) or 30)
174
+ return {
175
+ "logs_days": max(1, logs_days),
176
+ "traffic_days": max(1, traffic_days),
177
+ }
178
+
179
+
180
+ def get_privacy_config(config: dict[str, Any]) -> dict[str, bool]:
181
+ """Get privacy controls (for example body redaction in traffic logging)."""
182
+ privacy = config.get("privacy", {})
183
+ return {
184
+ "redact_traffic_bodies": bool(privacy.get("redact_traffic_bodies", True)),
185
+ }
ai_cli/credentials.py ADDED
@@ -0,0 +1,341 @@
1
+ """OAuth credential capture, encryption, and storage.
2
+
3
+ Extracted from claude-dev.py. Handles:
4
+ - Reading/writing credential JSON to ~/.claude/.credentials.json
5
+ - AES-256-CBC encryption via openssl
6
+ - Key management for encrypted credentials
7
+ - OAuth metadata extraction (scopes, subscription, rate limits)
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import os
14
+ import shutil
15
+ import subprocess
16
+ import time
17
+ from pathlib import Path
18
+ from typing import Any
19
+
20
+ from ai_cli.log import append_log_str
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Constants
24
+ # ---------------------------------------------------------------------------
25
+
26
+ DEFAULT_CLAUDE_CONFIG_DIR = "~/.claude"
27
+ OAUTH_SCOPE_USER_INFERENCE = "user:inference"
28
+ OAUTH_TOKEN_PATH = "/v1/oauth/token"
29
+ OAUTH_PROFILE_PATH = "/api/oauth/profile"
30
+ OAUTH_SCOPES_FIXED = [
31
+ "user:inference",
32
+ "user:mcp_servers",
33
+ "user:profile",
34
+ "user:sessions:claude_code",
35
+ ]
36
+ CREDENTIALS_FILE_NAME = ".credentials.json"
37
+ CREDENTIALS_ENCRYPTED_FILE_NAME = ".credentials.json.enc"
38
+ CREDENTIALS_KEY_FILE_NAME = ".credentials.key"
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Path helpers
43
+ # ---------------------------------------------------------------------------
44
+
45
+ def _config_dir() -> Path:
46
+ config_dir = os.getenv("CLAUDE_CONFIG_DIR", DEFAULT_CLAUDE_CONFIG_DIR)
47
+ return Path(config_dir).expanduser()
48
+
49
+
50
+ def credentials_path() -> Path:
51
+ return _config_dir() / CREDENTIALS_FILE_NAME
52
+
53
+
54
+ def encrypted_credentials_path() -> Path:
55
+ return _config_dir() / CREDENTIALS_ENCRYPTED_FILE_NAME
56
+
57
+
58
+ def credentials_key_path() -> Path:
59
+ return _config_dir() / CREDENTIALS_KEY_FILE_NAME
60
+
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # JSON helpers
64
+ # ---------------------------------------------------------------------------
65
+
66
+ def parse_json_dict(text: str) -> dict[str, Any] | None:
67
+ if not text:
68
+ return None
69
+ try:
70
+ payload = json.loads(text)
71
+ except json.JSONDecodeError:
72
+ return None
73
+ if not isinstance(payload, dict):
74
+ return None
75
+ return payload
76
+
77
+
78
+ def parse_scopes(value: Any) -> list[str]:
79
+ if isinstance(value, str):
80
+ return [scope for scope in value.split(" ") if scope]
81
+ if isinstance(value, list):
82
+ return [scope for scope in value if isinstance(scope, str) and scope]
83
+ return []
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Deep search
88
+ # ---------------------------------------------------------------------------
89
+
90
+ def deep_find_value(payload: Any, keys: tuple[str, ...]) -> Any:
91
+ """Recursively search nested dicts/lists for the first matching key."""
92
+ if isinstance(payload, dict):
93
+ for key in keys:
94
+ if key in payload and payload[key] is not None:
95
+ return payload[key]
96
+ for value in payload.values():
97
+ found = deep_find_value(value, keys)
98
+ if found is not None:
99
+ return found
100
+ elif isinstance(payload, list):
101
+ for item in payload:
102
+ found = deep_find_value(item, keys)
103
+ if found is not None:
104
+ return found
105
+ return None
106
+
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # OAuth metadata extraction
110
+ # ---------------------------------------------------------------------------
111
+
112
+ def extract_oauth_metadata(payload: dict[str, Any]) -> dict[str, Any]:
113
+ """Extract scopes, subscription type, and rate limit tier from a payload."""
114
+ metadata: dict[str, Any] = {}
115
+
116
+ scopes_value = deep_find_value(payload, ("scopes", "scope"))
117
+ scopes = parse_scopes(scopes_value)
118
+ if scopes:
119
+ metadata["scopes"] = scopes
120
+
121
+ subscription = deep_find_value(
122
+ payload, ("subscriptionType", "subscription_type")
123
+ )
124
+ if isinstance(subscription, str) and subscription:
125
+ metadata["subscriptionType"] = subscription
126
+
127
+ rate_tier = deep_find_value(
128
+ payload,
129
+ (
130
+ "rateLimitTier",
131
+ "rate_limit_tier",
132
+ "rateLimitTierName",
133
+ "rate_limit_tier_name",
134
+ ),
135
+ )
136
+ if isinstance(rate_tier, str) and rate_tier:
137
+ metadata["rateLimitTier"] = rate_tier
138
+
139
+ return metadata
140
+
141
+
142
+ def subscription_type_from_profile(profile: dict[str, Any]) -> str | None:
143
+ """Infer subscription type from an OAuth profile response."""
144
+ direct = profile.get("subscriptionType")
145
+ if isinstance(direct, str) and direct:
146
+ return direct
147
+
148
+ organization = profile.get("organization")
149
+ if not isinstance(organization, dict):
150
+ return None
151
+ organization_type = organization.get("organization_type")
152
+ if not isinstance(organization_type, str):
153
+ return None
154
+
155
+ mapping = {
156
+ "claude_max": "max",
157
+ "claude_pro": "pro",
158
+ "claude_enterprise": "enterprise",
159
+ "claude_team": "team",
160
+ }
161
+ return mapping.get(organization_type)
162
+
163
+
164
+ # ---------------------------------------------------------------------------
165
+ # Credential I/O
166
+ # ---------------------------------------------------------------------------
167
+
168
+ def read_credentials_doc() -> dict[str, Any]:
169
+ """Read the plain credentials JSON file."""
170
+ path = credentials_path()
171
+ try:
172
+ text = path.read_text(encoding="utf-8")
173
+ except OSError:
174
+ return {}
175
+ payload = parse_json_dict(text)
176
+ return payload or {}
177
+
178
+
179
+ def ensure_credentials_key(wrapper_log_file: str) -> Path | None:
180
+ """Ensure the encryption key file exists, creating if needed."""
181
+ key_path = credentials_key_path()
182
+ if key_path.exists():
183
+ return key_path
184
+ try:
185
+ key_path.parent.mkdir(parents=True, exist_ok=True)
186
+ key_path.write_text(f"{os.urandom(32).hex()}\n", encoding="utf-8")
187
+ os.chmod(key_path, 0o600)
188
+ append_log_str(wrapper_log_file, f"Created credentials key at {key_path}")
189
+ return key_path
190
+ except OSError as exc:
191
+ append_log_str(
192
+ wrapper_log_file, f"Failed creating credentials key {key_path}: {exc}"
193
+ )
194
+ return None
195
+
196
+
197
+ def write_encrypted_credentials(
198
+ data: dict[str, Any],
199
+ wrapper_log_file: str,
200
+ ) -> None:
201
+ """Encrypt credentials with AES-256-CBC and write to disk."""
202
+ openssl_bin = shutil.which("openssl")
203
+ if not openssl_bin:
204
+ append_log_str(
205
+ wrapper_log_file,
206
+ "Skipping encrypted credentials write: openssl not found",
207
+ )
208
+ return
209
+
210
+ key_path = ensure_credentials_key(wrapper_log_file)
211
+ if key_path is None:
212
+ return
213
+
214
+ try:
215
+ plaintext = json.dumps(data)
216
+ proc = subprocess.run(
217
+ [
218
+ openssl_bin,
219
+ "enc",
220
+ "-aes-256-cbc",
221
+ "-pbkdf2",
222
+ "-salt",
223
+ "-a",
224
+ "-pass",
225
+ f"file:{key_path}",
226
+ ],
227
+ input=plaintext,
228
+ text=True,
229
+ capture_output=True,
230
+ check=False,
231
+ )
232
+ except OSError as exc:
233
+ append_log_str(
234
+ wrapper_log_file,
235
+ f"Failed running openssl for encrypted credentials: {exc}",
236
+ )
237
+ return
238
+
239
+ if proc.returncode != 0 or not proc.stdout:
240
+ err = (proc.stderr or "").strip()
241
+ append_log_str(
242
+ wrapper_log_file,
243
+ f"Failed encrypting credentials (exit={proc.returncode}): {err}",
244
+ )
245
+ return
246
+
247
+ enc_path = encrypted_credentials_path()
248
+ try:
249
+ enc_path.write_text(proc.stdout, encoding="utf-8")
250
+ os.chmod(enc_path, 0o600)
251
+ append_log_str(
252
+ wrapper_log_file, f"Saved encrypted credentials to {enc_path}"
253
+ )
254
+ except OSError as exc:
255
+ append_log_str(
256
+ wrapper_log_file,
257
+ f"Failed writing encrypted credentials {enc_path}: {exc}",
258
+ )
259
+
260
+
261
+ def write_claude_ai_oauth(
262
+ oauth: dict[str, Any],
263
+ wrapper_log_file: str,
264
+ ) -> None:
265
+ """Write OAuth credentials to both plain and encrypted files."""
266
+ path = credentials_path()
267
+ oauth["scopes"] = list(OAUTH_SCOPES_FIXED)
268
+ data = {"claudeAiOauth": oauth}
269
+ try:
270
+ path.parent.mkdir(parents=True, exist_ok=True)
271
+ path.write_text(json.dumps(data), encoding="utf-8")
272
+ os.chmod(path, 0o600)
273
+ append_log_str(wrapper_log_file, f"Saved claudeAiOauth to {path}")
274
+ write_encrypted_credentials(data, wrapper_log_file)
275
+ except OSError as exc:
276
+ append_log_str(
277
+ wrapper_log_file, f"Failed to save claudeAiOauth to {path}: {exc}"
278
+ )
279
+
280
+
281
+ # ---------------------------------------------------------------------------
282
+ # OAuth payload builder
283
+ # ---------------------------------------------------------------------------
284
+
285
+ def build_bootstrap_oauth(
286
+ bearer_token: str | None = None,
287
+ existing_oauth: dict[str, Any] | None = None,
288
+ metadata: dict[str, Any] | None = None,
289
+ ) -> dict[str, Any]:
290
+ """Build an OAuth payload from existing data and new metadata."""
291
+ oauth_payload: dict[str, Any] = dict(existing_oauth or {})
292
+ metadata = metadata or {}
293
+
294
+ if bearer_token and (
295
+ "accessToken" not in oauth_payload
296
+ or not isinstance(oauth_payload.get("accessToken"), str)
297
+ ):
298
+ oauth_payload["accessToken"] = bearer_token
299
+
300
+ scopes = parse_scopes(metadata.get("scopes"))
301
+ if not scopes:
302
+ scopes = parse_scopes(oauth_payload.get("scopes"))
303
+ if not scopes:
304
+ scopes = [OAUTH_SCOPE_USER_INFERENCE]
305
+ if OAUTH_SCOPE_USER_INFERENCE not in scopes:
306
+ scopes.append(OAUTH_SCOPE_USER_INFERENCE)
307
+ oauth_payload["scopes"] = scopes
308
+
309
+ expires_at = oauth_payload.get("expiresAt")
310
+ if not isinstance(expires_at, int):
311
+ oauth_payload["expiresAt"] = int(time.time() * 1000 + (3600 * 1000))
312
+
313
+ if "subscriptionType" in metadata and isinstance(
314
+ metadata["subscriptionType"], str
315
+ ):
316
+ oauth_payload["subscriptionType"] = metadata["subscriptionType"]
317
+ elif "subscriptionType" not in oauth_payload:
318
+ oauth_payload["subscriptionType"] = None
319
+
320
+ if "rateLimitTier" in metadata and isinstance(
321
+ metadata["rateLimitTier"], str
322
+ ):
323
+ oauth_payload["rateLimitTier"] = metadata["rateLimitTier"]
324
+
325
+ if "refreshToken" not in oauth_payload:
326
+ oauth_payload["refreshToken"] = None
327
+
328
+ return oauth_payload
329
+
330
+
331
+ def extract_bearer_token(auth_header: str) -> str | None:
332
+ """Extract bearer token from an Authorization header value."""
333
+ if not isinstance(auth_header, str):
334
+ return None
335
+ parts = auth_header.split(" ", 1)
336
+ if len(parts) != 2:
337
+ return None
338
+ scheme, token = parts[0].strip(), parts[1].strip()
339
+ if scheme.lower() != "bearer" or not token:
340
+ return None
341
+ return token
@@ -0,0 +1,135 @@
1
+ """Detached proxy lifecycle watcher for tmux-backed sessions.
2
+
3
+ This process is spawned when ai-mux detaches while the wrapped tool is still
4
+ running. It keeps mitmdump alive and performs cleanup once the tool exits.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import os
11
+ import signal
12
+ import subprocess
13
+ import time
14
+ from datetime import datetime
15
+ from pathlib import Path
16
+
17
+
18
+ def _log(path_value: str, message: str) -> None:
19
+ if not path_value:
20
+ return
21
+ try:
22
+ p = Path(path_value).expanduser()
23
+ p.parent.mkdir(parents=True, exist_ok=True)
24
+ with p.open("a", encoding="utf-8") as f:
25
+ ts = datetime.now().isoformat(timespec="seconds")
26
+ f.write(f"[{ts}] {message}\n")
27
+ except OSError:
28
+ pass
29
+
30
+
31
+ def _pid_alive(pid: int) -> bool:
32
+ try:
33
+ os.kill(pid, 0)
34
+ except OSError:
35
+ return False
36
+ return True
37
+
38
+
39
+ def _tmux_session_alive(socket_name: str) -> bool:
40
+ try:
41
+ code = subprocess.call(
42
+ ["tmux", "-L", socket_name, "has-session"],
43
+ stdout=subprocess.DEVNULL,
44
+ stderr=subprocess.DEVNULL,
45
+ )
46
+ except OSError:
47
+ return False
48
+ return code == 0
49
+
50
+
51
+ def _tmux_named_sessions_alive(socket_name: str, sessions: list[str]) -> bool:
52
+ if not sessions:
53
+ return _tmux_session_alive(socket_name)
54
+ for session in sessions:
55
+ try:
56
+ code = subprocess.call(
57
+ ["tmux", "-L", socket_name, "has-session", "-t", session],
58
+ stdout=subprocess.DEVNULL,
59
+ stderr=subprocess.DEVNULL,
60
+ )
61
+ except OSError:
62
+ return False
63
+ if code == 0:
64
+ return True
65
+ return False
66
+
67
+
68
+ def _cleanup_session_files(session_id: str) -> None:
69
+ for suffix in (".pid", ".port"):
70
+ try:
71
+ Path(f"/tmp/{session_id}{suffix}").unlink(missing_ok=True)
72
+ except OSError:
73
+ pass
74
+
75
+
76
+ def _terminate_pid(pid: int) -> None:
77
+ if not _pid_alive(pid):
78
+ return
79
+ try:
80
+ os.kill(pid, signal.SIGTERM)
81
+ except OSError:
82
+ return
83
+
84
+ deadline = time.time() + 3.0
85
+ while time.time() < deadline:
86
+ if not _pid_alive(pid):
87
+ return
88
+ time.sleep(0.1)
89
+
90
+ try:
91
+ os.kill(pid, signal.SIGKILL)
92
+ except OSError:
93
+ pass
94
+
95
+
96
+ def main() -> int:
97
+ parser = argparse.ArgumentParser(description="Detached tmux proxy watcher.")
98
+ parser.add_argument("--mitm-pid", type=int, required=True)
99
+ parser.add_argument("--session-id", required=True)
100
+ parser.add_argument("--wrapper-log-file", default="")
101
+ parser.add_argument("--tmux-socket", default="ai-mux")
102
+ parser.add_argument("--tmux-session", action="append", default=[])
103
+ parser.add_argument("--poll-seconds", type=float, default=1.0)
104
+ args = parser.parse_args()
105
+
106
+ _log(
107
+ args.wrapper_log_file,
108
+ f"Detached watcher start (mitmdump pid={args.mitm_pid}, "
109
+ f"sessions={','.join(args.tmux_session) or 'any'})",
110
+ )
111
+
112
+ try:
113
+ while True:
114
+ if not _pid_alive(args.mitm_pid):
115
+ _log(args.wrapper_log_file, "Detached watcher: mitmdump already exited")
116
+ break
117
+
118
+ if not _tmux_named_sessions_alive(args.tmux_socket, args.tmux_session):
119
+ _log(
120
+ args.wrapper_log_file,
121
+ "Detached watcher: tmux session ownership ended",
122
+ )
123
+ break
124
+
125
+ time.sleep(max(args.poll_seconds, 0.2))
126
+ finally:
127
+ _terminate_pid(args.mitm_pid)
128
+ _cleanup_session_files(args.session_id)
129
+ _log(args.wrapper_log_file, "Detached watcher stop (proxy/session cleanup complete)")
130
+
131
+ return 0
132
+
133
+
134
+ if __name__ == "__main__":
135
+ raise SystemExit(main())