aizen-ai-cli 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aizen/mcp.py ADDED
@@ -0,0 +1,110 @@
1
+ import asyncio
2
+ import contextlib
3
+ import os
4
+ import threading
5
+
6
+ from mcp import ClientSession, StdioServerParameters
7
+ from mcp.client.stdio import stdio_client
8
+
9
+ from .config import console
10
+ from .logging_config import logger
11
+
12
+
13
+ class MCPManager:
14
+ def __init__(self, mcp_servers_config: dict):
15
+ self.config = mcp_servers_config
16
+ self.sessions: dict[str, ClientSession] = {}
17
+ self.exit_stack: contextlib.AsyncExitStack | None = None
18
+ self._loop: asyncio.AbstractEventLoop | None = None
19
+ self._thread: threading.Thread | None = None
20
+ self._ready_event = threading.Event()
21
+ self.tools_cache: list[dict] = []
22
+
23
+ async def start(self):
24
+ if not self.config:
25
+ return
26
+
27
+ logger.info("Starting MCP servers: %s", list(self.config.keys()))
28
+
29
+ try:
30
+ await self._init_all_servers()
31
+ except Exception as e:
32
+ console.print(f"[dim yellow]⚠️ Error initializing MCP servers: {e}[/dim yellow]")
33
+
34
+ async def _init_all_servers(self):
35
+ self.exit_stack = contextlib.AsyncExitStack()
36
+
37
+ for name, server_config in self.config.items():
38
+ command = server_config.get("command")
39
+ args = server_config.get("args", [])
40
+ env = server_config.get("env")
41
+
42
+ if not command:
43
+ continue
44
+
45
+ try:
46
+ # Inherit environment, merging any custom env
47
+ merged_env = os.environ.copy()
48
+ if env:
49
+ merged_env.update(env)
50
+
51
+ server_params = StdioServerParameters(command=command, args=args, env=merged_env)
52
+ stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
53
+ read, write = stdio_transport
54
+
55
+ session = await self.exit_stack.enter_async_context(ClientSession(read, write))
56
+ await session.initialize()
57
+
58
+ self.sessions[name] = session
59
+
60
+ # Fetch tools and format them for OpenAI
61
+ response = await session.list_tools()
62
+ for tool in response.tools:
63
+ self.tools_cache.append({
64
+ "type": "function",
65
+ "function": {
66
+ # Prefix with mcp_serverName_ to avoid collisions
67
+ "name": f"mcp_{name}_{tool.name}",
68
+ "description": tool.description or f"MCP tool {tool.name} from {name}",
69
+ "parameters": tool.inputSchema,
70
+ }
71
+ })
72
+ logger.info("MCP server '%s' connected with %d tools", name, len(response.tools))
73
+ console.print(f" [dim green]✓ Connected to MCP server: {name}[/dim green]")
74
+ except Exception as e:
75
+ logger.exception("Failed to connect to MCP server '%s'", name)
76
+ console.print(f" [dim yellow]⚠️ Failed to connect to MCP server {name}: {e}[/dim yellow]")
77
+
78
+ async def stop(self):
79
+ if self.exit_stack:
80
+ await self.exit_stack.aclose()
81
+ self.sessions.clear()
82
+
83
+ def get_tools(self) -> list[dict]:
84
+ return self.tools_cache
85
+
86
+ async def call_tool(self, full_tool_name: str, arguments: dict) -> str:
87
+ logger.debug("MCP call_tool: %s args=%s", full_tool_name, list(arguments.keys()))
88
+ for server_name in self.sessions:
89
+ prefix = f"mcp_{server_name}_"
90
+ if full_tool_name.startswith(prefix):
91
+ tool_name = full_tool_name[len(prefix):]
92
+ session = self.sessions[server_name]
93
+ try:
94
+ result = await session.call_tool(tool_name, arguments=arguments)
95
+
96
+ if not result.content:
97
+ return "Tool executed successfully but returned no content."
98
+
99
+ output = []
100
+ for item in result.content:
101
+ if item.type == "text":
102
+ output.append(item.text)
103
+ else:
104
+ output.append(f"[Non-text content: {item.type}]")
105
+
106
+ return "\n".join(output)
107
+ except Exception as e:
108
+ return f"Error executing MCP tool {tool_name} on {server_name}: {e}"
109
+
110
+ return f"Error: MCP server for tool '{full_tool_name}' not found."
aizen/plugins.py ADDED
@@ -0,0 +1,63 @@
1
+ import importlib.util
2
+ import os
3
+ import sys
4
+ from collections.abc import Callable
5
+
6
+ from .logging_config import logger
7
+
8
+ PLUGINS_DIR = os.path.expanduser("~/.aizen/plugins")
9
+
10
+ class PluginManager:
11
+ """Manages loading and executing tools from user-provided Python scripts."""
12
+ def __init__(self):
13
+ self.plugins = {}
14
+ self.tools = []
15
+ self.handlers: dict[str, Callable] = {}
16
+ self._load_plugins()
17
+
18
+ def _load_plugins(self):
19
+ if not os.path.exists(PLUGINS_DIR):
20
+ try:
21
+ os.makedirs(PLUGINS_DIR, exist_ok=True)
22
+ except Exception as e:
23
+ logger.debug("Failed to create plugins directory: %s", e)
24
+ return
25
+
26
+ for filename in os.listdir(PLUGINS_DIR):
27
+ if filename.endswith(".py") and not filename.startswith("_"):
28
+ name = filename[:-3]
29
+ path = os.path.join(PLUGINS_DIR, filename)
30
+ try:
31
+ spec = importlib.util.spec_from_file_location(name, path)
32
+ if spec and spec.loader:
33
+ module = importlib.util.module_from_spec(spec)
34
+ # Add to sys.modules so plugins can import each other if needed
35
+ sys.modules[f"aizen_plugin_{name}"] = module
36
+ spec.loader.exec_module(module)
37
+
38
+ if hasattr(module, "get_tools") and hasattr(module, "execute_tool"):
39
+ plugin_tools = module.get_tools()
40
+ self.plugins[name] = module
41
+ self.tools.extend(plugin_tools)
42
+ for t in plugin_tools:
43
+ self.handlers[t["function"]["name"]] = module.execute_tool
44
+ logger.info("Loaded plugin '%s' with %d tools", name, len(plugin_tools))
45
+ except Exception as e:
46
+ logger.error("Failed to load plugin '%s': %s", filename, e)
47
+
48
+ def get_tools(self) -> list[dict]:
49
+ return self.tools
50
+
51
+ def execute_tool(self, tool_call, auto_approve: bool = False) -> str | None:
52
+ """Executes a plugin tool. Returns None if tool is not handled by plugins."""
53
+ func_name = tool_call.function.name
54
+ if func_name in self.handlers:
55
+ try:
56
+ return self.handlers[func_name](tool_call, auto_approve)
57
+ except Exception as e:
58
+ logger.error("Plugin tool error: %s", e)
59
+ return f"Error executing plugin tool {func_name}: {e}"
60
+ return None
61
+
62
+ # Global instance
63
+ plugin_manager = PluginManager()
aizen/retry.py ADDED
@@ -0,0 +1,133 @@
1
+ """
2
+ Retry logic with exponential backoff + jitter for transient API errors.
3
+
4
+ Supports both synchronous and asynchronous functions — the decorator
5
+ auto-detects coroutine functions and uses asyncio.sleep accordingly.
6
+ """
7
+
8
+ import asyncio
9
+ import functools
10
+ import inspect
11
+ import random
12
+ import time
13
+
14
+ from rich.text import Text
15
+
16
+ from .config import console
17
+
18
+
19
+ def _compute_delay(backoff_base: float, attempt: int, jitter: bool) -> float:
20
+ """Calculate retry delay with optional jitter."""
21
+ delay = backoff_base ** attempt
22
+ if jitter:
23
+ delay *= 1.0 + random.uniform(-0.25, 0.25)
24
+ return delay
25
+
26
+
27
+ def _print_retry_message(exception: BaseException, delay: float, attempt: int, max_retries: int) -> None:
28
+ """Print a formatted retry notice to the console."""
29
+ retry_msg = Text()
30
+ retry_msg.append(" ⏳ ", style="yellow")
31
+ retry_msg.append(f"{type(exception).__name__}. ", style="dim")
32
+ retry_msg.append(
33
+ f"Retrying in {delay:.1f}s... ({attempt + 1}/{max_retries})",
34
+ style="dim italic",
35
+ )
36
+ console.print(retry_msg)
37
+
38
+
39
+ def _is_retryable_503(e: BaseException) -> bool:
40
+ """Check if an exception represents a 503 Service Unavailable."""
41
+ return hasattr(e, "status_code") and e.status_code == 503
42
+
43
+
44
+ def retry_with_backoff(
45
+ max_retries: int = 3,
46
+ backoff_base: float = 2.0,
47
+ retryable_exceptions: tuple | None = None,
48
+ jitter: bool = True,
49
+ ):
50
+ """
51
+ Decorator that retries a function on transient failures with exponential backoff.
52
+
53
+ Automatically detects async functions and uses ``asyncio.sleep`` instead of
54
+ ``time.sleep`` so that the event loop is never blocked.
55
+
56
+ Args:
57
+ max_retries: Maximum number of retry attempts.
58
+ backoff_base: Base for exponential backoff (delay = base ** attempt).
59
+ retryable_exceptions: Tuple of exception types to retry on.
60
+ jitter: If True, adds random jitter (±25%) to prevent thundering herd.
61
+ """
62
+ if retryable_exceptions is None:
63
+ # Import here to avoid circular imports — these are the standard transient errors
64
+ from openai import (
65
+ APIConnectionError as OpenAIConnectionError,
66
+ )
67
+ from openai import (
68
+ APITimeoutError,
69
+ )
70
+ from openai import (
71
+ RateLimitError as OpenAIRateLimitError,
72
+ )
73
+ retryable_exceptions = (
74
+ OpenAIRateLimitError,
75
+ APITimeoutError,
76
+ OpenAIConnectionError,
77
+ )
78
+
79
+ def decorator(func):
80
+ # ── Async wrapper ──
81
+ if inspect.iscoroutinefunction(func):
82
+ @functools.wraps(func)
83
+ async def async_wrapper(*args, **kwargs):
84
+ last_exception: BaseException = RuntimeError("Retry exhausted")
85
+ for attempt in range(max_retries + 1):
86
+ try:
87
+ return await func(*args, **kwargs)
88
+ except retryable_exceptions as e:
89
+ last_exception = e
90
+ if attempt < max_retries:
91
+ delay = _compute_delay(backoff_base, attempt, jitter)
92
+ _print_retry_message(e, delay, attempt, max_retries)
93
+ await asyncio.sleep(delay)
94
+ except Exception as e:
95
+ if _is_retryable_503(e):
96
+ last_exception = e
97
+ if attempt < max_retries:
98
+ delay = _compute_delay(backoff_base, attempt, jitter)
99
+ _print_retry_message(e, delay, attempt, max_retries)
100
+ await asyncio.sleep(delay)
101
+ continue
102
+ raise
103
+ raise last_exception
104
+
105
+ return async_wrapper
106
+
107
+ # ── Sync wrapper ──
108
+ @functools.wraps(func)
109
+ def sync_wrapper(*args, **kwargs):
110
+ last_exception: BaseException = RuntimeError("Retry exhausted")
111
+ for attempt in range(max_retries + 1):
112
+ try:
113
+ return func(*args, **kwargs)
114
+ except retryable_exceptions as e:
115
+ last_exception = e
116
+ if attempt < max_retries:
117
+ delay = _compute_delay(backoff_base, attempt, jitter)
118
+ _print_retry_message(e, delay, attempt, max_retries)
119
+ time.sleep(delay)
120
+ except Exception as e:
121
+ if _is_retryable_503(e):
122
+ last_exception = e
123
+ if attempt < max_retries:
124
+ delay = _compute_delay(backoff_base, attempt, jitter)
125
+ _print_retry_message(e, delay, attempt, max_retries)
126
+ time.sleep(delay)
127
+ continue
128
+ raise
129
+ raise last_exception
130
+
131
+ return sync_wrapper
132
+
133
+ return decorator
aizen/session.py ADDED
@@ -0,0 +1,137 @@
1
+ import json
2
+ import os
3
+ import re
4
+ import sqlite3
5
+ from datetime import datetime
6
+
7
+ from .config import SESSIONS_DIR
8
+ from .logging_config import logger
9
+ from .utils import TokenTracker
10
+
11
+ # ─── Singleton DB connection ────────────────────────────────────────────────────
12
+
13
+ _db_connection: sqlite3.Connection | None = None
14
+ _db_path_cached: str | None = None
15
+
16
+
17
+ def _get_db() -> sqlite3.Connection:
18
+ """Return a singleton SQLite connection, creating the schema if needed.
19
+
20
+ Automatically reconnects if SESSIONS_DIR has changed (e.g. during testing).
21
+ """
22
+ global _db_connection, _db_path_cached
23
+
24
+ os.makedirs(SESSIONS_DIR, exist_ok=True)
25
+ db_path = os.path.join(SESSIONS_DIR, "aizen.db")
26
+
27
+ # Reconnect if the path changed (supports monkeypatch in tests)
28
+ if _db_connection is not None and _db_path_cached == db_path:
29
+ return _db_connection
30
+
31
+ if _db_connection is not None:
32
+ try:
33
+ _db_connection.close()
34
+ except Exception:
35
+ pass
36
+
37
+ _db_connection = sqlite3.connect(db_path)
38
+ _db_path_cached = db_path
39
+ _db_connection.execute('''
40
+ CREATE TABLE IF NOT EXISTS sessions (
41
+ name TEXT PRIMARY KEY,
42
+ saved_at TEXT,
43
+ message_count INTEGER,
44
+ messages TEXT,
45
+ input_tokens INTEGER,
46
+ output_tokens INTEGER
47
+ )
48
+ ''')
49
+ _db_connection.commit()
50
+ return _db_connection
51
+
52
+
53
+ def _migrate_legacy_sessions():
54
+ """Migrate old .json files into the SQLite DB once."""
55
+ if not os.path.exists(SESSIONS_DIR):
56
+ return
57
+
58
+ conn = _get_db()
59
+ migrated_any = False
60
+ for f in os.listdir(SESSIONS_DIR):
61
+ if f.endswith(".json"):
62
+ filepath = os.path.join(SESSIONS_DIR, f)
63
+ try:
64
+ with open(filepath) as fh:
65
+ data = json.load(fh)
66
+ name = data.get("name", f[:-5])
67
+ msgs = data.get("messages", [])
68
+ saved_at = data.get("saved_at", datetime.now().isoformat())
69
+
70
+ conn.execute(
71
+ "INSERT OR IGNORE INTO sessions (name, saved_at, message_count, messages, input_tokens, output_tokens) VALUES (?, ?, ?, ?, ?, ?)",
72
+ (name, saved_at, len(msgs), json.dumps(msgs), 0, 0)
73
+ )
74
+ # Mark as migrated
75
+ os.rename(filepath, filepath + ".migrated")
76
+ migrated_any = True
77
+ except Exception as e:
78
+ logger.debug("Failed to migrate session file %s: %s", filepath, e)
79
+
80
+ if migrated_any:
81
+ conn.commit()
82
+
83
+ # Run migration on import
84
+ _migrate_legacy_sessions()
85
+
86
+ def save_session(
87
+ messages: list, name: str | None = None, token_tracker: TokenTracker | None = None
88
+ ) -> str:
89
+ if not name:
90
+ name = datetime.now().strftime("session_%Y%m%d_%H%M%S")
91
+
92
+ # Sanitize
93
+ name = re.sub(r"[^\w\-]", "_", name)
94
+
95
+ input_toks = token_tracker.input_tokens if token_tracker else 0
96
+ output_toks = token_tracker.output_tokens if token_tracker else 0
97
+ saved_at = datetime.now().isoformat()
98
+
99
+ conn = _get_db()
100
+ conn.execute(
101
+ "REPLACE INTO sessions (name, saved_at, message_count, messages, input_tokens, output_tokens) VALUES (?, ?, ?, ?, ?, ?)",
102
+ (name, saved_at, len(messages), json.dumps(messages), input_toks, output_toks)
103
+ )
104
+ conn.commit()
105
+ return f"sqlite://{name}"
106
+
107
+
108
+ def load_session(name: str) -> list | None:
109
+ # If the user passed the legacy filename by accident
110
+ if name.endswith(".json"):
111
+ name = name[:-5]
112
+
113
+ conn = _get_db()
114
+ cur = conn.execute("SELECT messages FROM sessions WHERE name = ?", (name,))
115
+ row = cur.fetchone()
116
+ if row:
117
+ try:
118
+ return json.loads(row[0])
119
+ except json.JSONDecodeError:
120
+ return None
121
+ return None
122
+
123
+
124
+ def list_sessions() -> list:
125
+ if not os.path.exists(SESSIONS_DIR):
126
+ return []
127
+
128
+ conn = _get_db()
129
+ cur = conn.execute("SELECT name, saved_at, message_count FROM sessions ORDER BY saved_at DESC")
130
+ sessions = []
131
+ for row in cur.fetchall():
132
+ sessions.append({
133
+ "name": row[0],
134
+ "saved_at": row[1],
135
+ "messages": row[2]
136
+ })
137
+ return sessions