aizen-ai-cli 2.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aizen/__init__.py +4 -0
- aizen/commands.py +694 -0
- aizen/config.py +363 -0
- aizen/context.py +171 -0
- aizen/exceptions.py +46 -0
- aizen/logging_config.py +65 -0
- aizen/main.py +616 -0
- aizen/mcp.py +110 -0
- aizen/plugins.py +63 -0
- aizen/retry.py +133 -0
- aizen/session.py +137 -0
- aizen/tools.py +1035 -0
- aizen/utils.py +339 -0
- aizen_ai_cli-2.2.2.dist-info/METADATA +267 -0
- aizen_ai_cli-2.2.2.dist-info/RECORD +18 -0
- aizen_ai_cli-2.2.2.dist-info/WHEEL +5 -0
- aizen_ai_cli-2.2.2.dist-info/entry_points.txt +2 -0
- aizen_ai_cli-2.2.2.dist-info/top_level.txt +1 -0
aizen/mcp.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import contextlib
|
|
3
|
+
import os
|
|
4
|
+
import threading
|
|
5
|
+
|
|
6
|
+
from mcp import ClientSession, StdioServerParameters
|
|
7
|
+
from mcp.client.stdio import stdio_client
|
|
8
|
+
|
|
9
|
+
from .config import console
|
|
10
|
+
from .logging_config import logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MCPManager:
|
|
14
|
+
def __init__(self, mcp_servers_config: dict):
|
|
15
|
+
self.config = mcp_servers_config
|
|
16
|
+
self.sessions: dict[str, ClientSession] = {}
|
|
17
|
+
self.exit_stack: contextlib.AsyncExitStack | None = None
|
|
18
|
+
self._loop: asyncio.AbstractEventLoop | None = None
|
|
19
|
+
self._thread: threading.Thread | None = None
|
|
20
|
+
self._ready_event = threading.Event()
|
|
21
|
+
self.tools_cache: list[dict] = []
|
|
22
|
+
|
|
23
|
+
async def start(self):
|
|
24
|
+
if not self.config:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
logger.info("Starting MCP servers: %s", list(self.config.keys()))
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
await self._init_all_servers()
|
|
31
|
+
except Exception as e:
|
|
32
|
+
console.print(f"[dim yellow]⚠️ Error initializing MCP servers: {e}[/dim yellow]")
|
|
33
|
+
|
|
34
|
+
async def _init_all_servers(self):
|
|
35
|
+
self.exit_stack = contextlib.AsyncExitStack()
|
|
36
|
+
|
|
37
|
+
for name, server_config in self.config.items():
|
|
38
|
+
command = server_config.get("command")
|
|
39
|
+
args = server_config.get("args", [])
|
|
40
|
+
env = server_config.get("env")
|
|
41
|
+
|
|
42
|
+
if not command:
|
|
43
|
+
continue
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
# Inherit environment, merging any custom env
|
|
47
|
+
merged_env = os.environ.copy()
|
|
48
|
+
if env:
|
|
49
|
+
merged_env.update(env)
|
|
50
|
+
|
|
51
|
+
server_params = StdioServerParameters(command=command, args=args, env=merged_env)
|
|
52
|
+
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
|
53
|
+
read, write = stdio_transport
|
|
54
|
+
|
|
55
|
+
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
|
56
|
+
await session.initialize()
|
|
57
|
+
|
|
58
|
+
self.sessions[name] = session
|
|
59
|
+
|
|
60
|
+
# Fetch tools and format them for OpenAI
|
|
61
|
+
response = await session.list_tools()
|
|
62
|
+
for tool in response.tools:
|
|
63
|
+
self.tools_cache.append({
|
|
64
|
+
"type": "function",
|
|
65
|
+
"function": {
|
|
66
|
+
# Prefix with mcp_serverName_ to avoid collisions
|
|
67
|
+
"name": f"mcp_{name}_{tool.name}",
|
|
68
|
+
"description": tool.description or f"MCP tool {tool.name} from {name}",
|
|
69
|
+
"parameters": tool.inputSchema,
|
|
70
|
+
}
|
|
71
|
+
})
|
|
72
|
+
logger.info("MCP server '%s' connected with %d tools", name, len(response.tools))
|
|
73
|
+
console.print(f" [dim green]✓ Connected to MCP server: {name}[/dim green]")
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.exception("Failed to connect to MCP server '%s'", name)
|
|
76
|
+
console.print(f" [dim yellow]⚠️ Failed to connect to MCP server {name}: {e}[/dim yellow]")
|
|
77
|
+
|
|
78
|
+
async def stop(self):
|
|
79
|
+
if self.exit_stack:
|
|
80
|
+
await self.exit_stack.aclose()
|
|
81
|
+
self.sessions.clear()
|
|
82
|
+
|
|
83
|
+
def get_tools(self) -> list[dict]:
|
|
84
|
+
return self.tools_cache
|
|
85
|
+
|
|
86
|
+
async def call_tool(self, full_tool_name: str, arguments: dict) -> str:
|
|
87
|
+
logger.debug("MCP call_tool: %s args=%s", full_tool_name, list(arguments.keys()))
|
|
88
|
+
for server_name in self.sessions:
|
|
89
|
+
prefix = f"mcp_{server_name}_"
|
|
90
|
+
if full_tool_name.startswith(prefix):
|
|
91
|
+
tool_name = full_tool_name[len(prefix):]
|
|
92
|
+
session = self.sessions[server_name]
|
|
93
|
+
try:
|
|
94
|
+
result = await session.call_tool(tool_name, arguments=arguments)
|
|
95
|
+
|
|
96
|
+
if not result.content:
|
|
97
|
+
return "Tool executed successfully but returned no content."
|
|
98
|
+
|
|
99
|
+
output = []
|
|
100
|
+
for item in result.content:
|
|
101
|
+
if item.type == "text":
|
|
102
|
+
output.append(item.text)
|
|
103
|
+
else:
|
|
104
|
+
output.append(f"[Non-text content: {item.type}]")
|
|
105
|
+
|
|
106
|
+
return "\n".join(output)
|
|
107
|
+
except Exception as e:
|
|
108
|
+
return f"Error executing MCP tool {tool_name} on {server_name}: {e}"
|
|
109
|
+
|
|
110
|
+
return f"Error: MCP server for tool '{full_tool_name}' not found."
|
aizen/plugins.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
|
|
6
|
+
from .logging_config import logger
|
|
7
|
+
|
|
8
|
+
PLUGINS_DIR = os.path.expanduser("~/.aizen/plugins")
|
|
9
|
+
|
|
10
|
+
class PluginManager:
|
|
11
|
+
"""Manages loading and executing tools from user-provided Python scripts."""
|
|
12
|
+
def __init__(self):
|
|
13
|
+
self.plugins = {}
|
|
14
|
+
self.tools = []
|
|
15
|
+
self.handlers: dict[str, Callable] = {}
|
|
16
|
+
self._load_plugins()
|
|
17
|
+
|
|
18
|
+
def _load_plugins(self):
|
|
19
|
+
if not os.path.exists(PLUGINS_DIR):
|
|
20
|
+
try:
|
|
21
|
+
os.makedirs(PLUGINS_DIR, exist_ok=True)
|
|
22
|
+
except Exception as e:
|
|
23
|
+
logger.debug("Failed to create plugins directory: %s", e)
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
for filename in os.listdir(PLUGINS_DIR):
|
|
27
|
+
if filename.endswith(".py") and not filename.startswith("_"):
|
|
28
|
+
name = filename[:-3]
|
|
29
|
+
path = os.path.join(PLUGINS_DIR, filename)
|
|
30
|
+
try:
|
|
31
|
+
spec = importlib.util.spec_from_file_location(name, path)
|
|
32
|
+
if spec and spec.loader:
|
|
33
|
+
module = importlib.util.module_from_spec(spec)
|
|
34
|
+
# Add to sys.modules so plugins can import each other if needed
|
|
35
|
+
sys.modules[f"aizen_plugin_{name}"] = module
|
|
36
|
+
spec.loader.exec_module(module)
|
|
37
|
+
|
|
38
|
+
if hasattr(module, "get_tools") and hasattr(module, "execute_tool"):
|
|
39
|
+
plugin_tools = module.get_tools()
|
|
40
|
+
self.plugins[name] = module
|
|
41
|
+
self.tools.extend(plugin_tools)
|
|
42
|
+
for t in plugin_tools:
|
|
43
|
+
self.handlers[t["function"]["name"]] = module.execute_tool
|
|
44
|
+
logger.info("Loaded plugin '%s' with %d tools", name, len(plugin_tools))
|
|
45
|
+
except Exception as e:
|
|
46
|
+
logger.error("Failed to load plugin '%s': %s", filename, e)
|
|
47
|
+
|
|
48
|
+
def get_tools(self) -> list[dict]:
|
|
49
|
+
return self.tools
|
|
50
|
+
|
|
51
|
+
def execute_tool(self, tool_call, auto_approve: bool = False) -> str | None:
|
|
52
|
+
"""Executes a plugin tool. Returns None if tool is not handled by plugins."""
|
|
53
|
+
func_name = tool_call.function.name
|
|
54
|
+
if func_name in self.handlers:
|
|
55
|
+
try:
|
|
56
|
+
return self.handlers[func_name](tool_call, auto_approve)
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.error("Plugin tool error: %s", e)
|
|
59
|
+
return f"Error executing plugin tool {func_name}: {e}"
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
# Global instance
|
|
63
|
+
plugin_manager = PluginManager()
|
aizen/retry.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Retry logic with exponential backoff + jitter for transient API errors.
|
|
3
|
+
|
|
4
|
+
Supports both synchronous and asynchronous functions — the decorator
|
|
5
|
+
auto-detects coroutine functions and uses asyncio.sleep accordingly.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import functools
|
|
10
|
+
import inspect
|
|
11
|
+
import random
|
|
12
|
+
import time
|
|
13
|
+
|
|
14
|
+
from rich.text import Text
|
|
15
|
+
|
|
16
|
+
from .config import console
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _compute_delay(backoff_base: float, attempt: int, jitter: bool) -> float:
|
|
20
|
+
"""Calculate retry delay with optional jitter."""
|
|
21
|
+
delay = backoff_base ** attempt
|
|
22
|
+
if jitter:
|
|
23
|
+
delay *= 1.0 + random.uniform(-0.25, 0.25)
|
|
24
|
+
return delay
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _print_retry_message(exception: BaseException, delay: float, attempt: int, max_retries: int) -> None:
|
|
28
|
+
"""Print a formatted retry notice to the console."""
|
|
29
|
+
retry_msg = Text()
|
|
30
|
+
retry_msg.append(" ⏳ ", style="yellow")
|
|
31
|
+
retry_msg.append(f"{type(exception).__name__}. ", style="dim")
|
|
32
|
+
retry_msg.append(
|
|
33
|
+
f"Retrying in {delay:.1f}s... ({attempt + 1}/{max_retries})",
|
|
34
|
+
style="dim italic",
|
|
35
|
+
)
|
|
36
|
+
console.print(retry_msg)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _is_retryable_503(e: BaseException) -> bool:
|
|
40
|
+
"""Check if an exception represents a 503 Service Unavailable."""
|
|
41
|
+
return hasattr(e, "status_code") and e.status_code == 503
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def retry_with_backoff(
|
|
45
|
+
max_retries: int = 3,
|
|
46
|
+
backoff_base: float = 2.0,
|
|
47
|
+
retryable_exceptions: tuple | None = None,
|
|
48
|
+
jitter: bool = True,
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Decorator that retries a function on transient failures with exponential backoff.
|
|
52
|
+
|
|
53
|
+
Automatically detects async functions and uses ``asyncio.sleep`` instead of
|
|
54
|
+
``time.sleep`` so that the event loop is never blocked.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
max_retries: Maximum number of retry attempts.
|
|
58
|
+
backoff_base: Base for exponential backoff (delay = base ** attempt).
|
|
59
|
+
retryable_exceptions: Tuple of exception types to retry on.
|
|
60
|
+
jitter: If True, adds random jitter (±25%) to prevent thundering herd.
|
|
61
|
+
"""
|
|
62
|
+
if retryable_exceptions is None:
|
|
63
|
+
# Import here to avoid circular imports — these are the standard transient errors
|
|
64
|
+
from openai import (
|
|
65
|
+
APIConnectionError as OpenAIConnectionError,
|
|
66
|
+
)
|
|
67
|
+
from openai import (
|
|
68
|
+
APITimeoutError,
|
|
69
|
+
)
|
|
70
|
+
from openai import (
|
|
71
|
+
RateLimitError as OpenAIRateLimitError,
|
|
72
|
+
)
|
|
73
|
+
retryable_exceptions = (
|
|
74
|
+
OpenAIRateLimitError,
|
|
75
|
+
APITimeoutError,
|
|
76
|
+
OpenAIConnectionError,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def decorator(func):
|
|
80
|
+
# ── Async wrapper ──
|
|
81
|
+
if inspect.iscoroutinefunction(func):
|
|
82
|
+
@functools.wraps(func)
|
|
83
|
+
async def async_wrapper(*args, **kwargs):
|
|
84
|
+
last_exception: BaseException = RuntimeError("Retry exhausted")
|
|
85
|
+
for attempt in range(max_retries + 1):
|
|
86
|
+
try:
|
|
87
|
+
return await func(*args, **kwargs)
|
|
88
|
+
except retryable_exceptions as e:
|
|
89
|
+
last_exception = e
|
|
90
|
+
if attempt < max_retries:
|
|
91
|
+
delay = _compute_delay(backoff_base, attempt, jitter)
|
|
92
|
+
_print_retry_message(e, delay, attempt, max_retries)
|
|
93
|
+
await asyncio.sleep(delay)
|
|
94
|
+
except Exception as e:
|
|
95
|
+
if _is_retryable_503(e):
|
|
96
|
+
last_exception = e
|
|
97
|
+
if attempt < max_retries:
|
|
98
|
+
delay = _compute_delay(backoff_base, attempt, jitter)
|
|
99
|
+
_print_retry_message(e, delay, attempt, max_retries)
|
|
100
|
+
await asyncio.sleep(delay)
|
|
101
|
+
continue
|
|
102
|
+
raise
|
|
103
|
+
raise last_exception
|
|
104
|
+
|
|
105
|
+
return async_wrapper
|
|
106
|
+
|
|
107
|
+
# ── Sync wrapper ──
|
|
108
|
+
@functools.wraps(func)
|
|
109
|
+
def sync_wrapper(*args, **kwargs):
|
|
110
|
+
last_exception: BaseException = RuntimeError("Retry exhausted")
|
|
111
|
+
for attempt in range(max_retries + 1):
|
|
112
|
+
try:
|
|
113
|
+
return func(*args, **kwargs)
|
|
114
|
+
except retryable_exceptions as e:
|
|
115
|
+
last_exception = e
|
|
116
|
+
if attempt < max_retries:
|
|
117
|
+
delay = _compute_delay(backoff_base, attempt, jitter)
|
|
118
|
+
_print_retry_message(e, delay, attempt, max_retries)
|
|
119
|
+
time.sleep(delay)
|
|
120
|
+
except Exception as e:
|
|
121
|
+
if _is_retryable_503(e):
|
|
122
|
+
last_exception = e
|
|
123
|
+
if attempt < max_retries:
|
|
124
|
+
delay = _compute_delay(backoff_base, attempt, jitter)
|
|
125
|
+
_print_retry_message(e, delay, attempt, max_retries)
|
|
126
|
+
time.sleep(delay)
|
|
127
|
+
continue
|
|
128
|
+
raise
|
|
129
|
+
raise last_exception
|
|
130
|
+
|
|
131
|
+
return sync_wrapper
|
|
132
|
+
|
|
133
|
+
return decorator
|
aizen/session.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import sqlite3
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
|
|
7
|
+
from .config import SESSIONS_DIR
|
|
8
|
+
from .logging_config import logger
|
|
9
|
+
from .utils import TokenTracker
|
|
10
|
+
|
|
11
|
+
# ─── Singleton DB connection ────────────────────────────────────────────────────
|
|
12
|
+
|
|
13
|
+
_db_connection: sqlite3.Connection | None = None
|
|
14
|
+
_db_path_cached: str | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_db() -> sqlite3.Connection:
|
|
18
|
+
"""Return a singleton SQLite connection, creating the schema if needed.
|
|
19
|
+
|
|
20
|
+
Automatically reconnects if SESSIONS_DIR has changed (e.g. during testing).
|
|
21
|
+
"""
|
|
22
|
+
global _db_connection, _db_path_cached
|
|
23
|
+
|
|
24
|
+
os.makedirs(SESSIONS_DIR, exist_ok=True)
|
|
25
|
+
db_path = os.path.join(SESSIONS_DIR, "aizen.db")
|
|
26
|
+
|
|
27
|
+
# Reconnect if the path changed (supports monkeypatch in tests)
|
|
28
|
+
if _db_connection is not None and _db_path_cached == db_path:
|
|
29
|
+
return _db_connection
|
|
30
|
+
|
|
31
|
+
if _db_connection is not None:
|
|
32
|
+
try:
|
|
33
|
+
_db_connection.close()
|
|
34
|
+
except Exception:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
_db_connection = sqlite3.connect(db_path)
|
|
38
|
+
_db_path_cached = db_path
|
|
39
|
+
_db_connection.execute('''
|
|
40
|
+
CREATE TABLE IF NOT EXISTS sessions (
|
|
41
|
+
name TEXT PRIMARY KEY,
|
|
42
|
+
saved_at TEXT,
|
|
43
|
+
message_count INTEGER,
|
|
44
|
+
messages TEXT,
|
|
45
|
+
input_tokens INTEGER,
|
|
46
|
+
output_tokens INTEGER
|
|
47
|
+
)
|
|
48
|
+
''')
|
|
49
|
+
_db_connection.commit()
|
|
50
|
+
return _db_connection
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _migrate_legacy_sessions():
|
|
54
|
+
"""Migrate old .json files into the SQLite DB once."""
|
|
55
|
+
if not os.path.exists(SESSIONS_DIR):
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
conn = _get_db()
|
|
59
|
+
migrated_any = False
|
|
60
|
+
for f in os.listdir(SESSIONS_DIR):
|
|
61
|
+
if f.endswith(".json"):
|
|
62
|
+
filepath = os.path.join(SESSIONS_DIR, f)
|
|
63
|
+
try:
|
|
64
|
+
with open(filepath) as fh:
|
|
65
|
+
data = json.load(fh)
|
|
66
|
+
name = data.get("name", f[:-5])
|
|
67
|
+
msgs = data.get("messages", [])
|
|
68
|
+
saved_at = data.get("saved_at", datetime.now().isoformat())
|
|
69
|
+
|
|
70
|
+
conn.execute(
|
|
71
|
+
"INSERT OR IGNORE INTO sessions (name, saved_at, message_count, messages, input_tokens, output_tokens) VALUES (?, ?, ?, ?, ?, ?)",
|
|
72
|
+
(name, saved_at, len(msgs), json.dumps(msgs), 0, 0)
|
|
73
|
+
)
|
|
74
|
+
# Mark as migrated
|
|
75
|
+
os.rename(filepath, filepath + ".migrated")
|
|
76
|
+
migrated_any = True
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.debug("Failed to migrate session file %s: %s", filepath, e)
|
|
79
|
+
|
|
80
|
+
if migrated_any:
|
|
81
|
+
conn.commit()
|
|
82
|
+
|
|
83
|
+
# Run migration on import
|
|
84
|
+
_migrate_legacy_sessions()
|
|
85
|
+
|
|
86
|
+
def save_session(
|
|
87
|
+
messages: list, name: str | None = None, token_tracker: TokenTracker | None = None
|
|
88
|
+
) -> str:
|
|
89
|
+
if not name:
|
|
90
|
+
name = datetime.now().strftime("session_%Y%m%d_%H%M%S")
|
|
91
|
+
|
|
92
|
+
# Sanitize
|
|
93
|
+
name = re.sub(r"[^\w\-]", "_", name)
|
|
94
|
+
|
|
95
|
+
input_toks = token_tracker.input_tokens if token_tracker else 0
|
|
96
|
+
output_toks = token_tracker.output_tokens if token_tracker else 0
|
|
97
|
+
saved_at = datetime.now().isoformat()
|
|
98
|
+
|
|
99
|
+
conn = _get_db()
|
|
100
|
+
conn.execute(
|
|
101
|
+
"REPLACE INTO sessions (name, saved_at, message_count, messages, input_tokens, output_tokens) VALUES (?, ?, ?, ?, ?, ?)",
|
|
102
|
+
(name, saved_at, len(messages), json.dumps(messages), input_toks, output_toks)
|
|
103
|
+
)
|
|
104
|
+
conn.commit()
|
|
105
|
+
return f"sqlite://{name}"
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def load_session(name: str) -> list | None:
|
|
109
|
+
# If the user passed the legacy filename by accident
|
|
110
|
+
if name.endswith(".json"):
|
|
111
|
+
name = name[:-5]
|
|
112
|
+
|
|
113
|
+
conn = _get_db()
|
|
114
|
+
cur = conn.execute("SELECT messages FROM sessions WHERE name = ?", (name,))
|
|
115
|
+
row = cur.fetchone()
|
|
116
|
+
if row:
|
|
117
|
+
try:
|
|
118
|
+
return json.loads(row[0])
|
|
119
|
+
except json.JSONDecodeError:
|
|
120
|
+
return None
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def list_sessions() -> list:
|
|
125
|
+
if not os.path.exists(SESSIONS_DIR):
|
|
126
|
+
return []
|
|
127
|
+
|
|
128
|
+
conn = _get_db()
|
|
129
|
+
cur = conn.execute("SELECT name, saved_at, message_count FROM sessions ORDER BY saved_at DESC")
|
|
130
|
+
sessions = []
|
|
131
|
+
for row in cur.fetchall():
|
|
132
|
+
sessions.append({
|
|
133
|
+
"name": row[0],
|
|
134
|
+
"saved_at": row[1],
|
|
135
|
+
"messages": row[2]
|
|
136
|
+
})
|
|
137
|
+
return sessions
|