aizen-ai-cli 2.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aizen/__init__.py +4 -0
- aizen/commands.py +694 -0
- aizen/config.py +363 -0
- aizen/context.py +171 -0
- aizen/exceptions.py +46 -0
- aizen/logging_config.py +65 -0
- aizen/main.py +616 -0
- aizen/mcp.py +110 -0
- aizen/plugins.py +63 -0
- aizen/retry.py +133 -0
- aizen/session.py +137 -0
- aizen/tools.py +1035 -0
- aizen/utils.py +339 -0
- aizen_ai_cli-2.2.2.dist-info/METADATA +267 -0
- aizen_ai_cli-2.2.2.dist-info/RECORD +18 -0
- aizen_ai_cli-2.2.2.dist-info/WHEEL +5 -0
- aizen_ai_cli-2.2.2.dist-info/entry_points.txt +2 -0
- aizen_ai_cli-2.2.2.dist-info/top_level.txt +1 -0
aizen/main.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Aizen AI Agent — A professional-grade AI coding assistant for your terminal.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import random
|
|
11
|
+
import re
|
|
12
|
+
import subprocess
|
|
13
|
+
import sys
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from openai import APIConnectionError as OpenAIConnectionError
|
|
17
|
+
from openai import APITimeoutError, AsyncOpenAI, AuthenticationError
|
|
18
|
+
from openai import RateLimitError as OpenAIRateLimitError
|
|
19
|
+
from prompt_toolkit import PromptSession
|
|
20
|
+
from prompt_toolkit.filters import completion_is_selected, has_completions
|
|
21
|
+
from prompt_toolkit.formatted_text import HTML
|
|
22
|
+
from prompt_toolkit.key_binding import KeyBindings
|
|
23
|
+
from rich.live import Live
|
|
24
|
+
from rich.markdown import Markdown
|
|
25
|
+
from rich.panel import Panel
|
|
26
|
+
from rich.text import Text
|
|
27
|
+
|
|
28
|
+
from .commands import AizenCompleter, handle_slash_command
|
|
29
|
+
from .config import (
|
|
30
|
+
AIZEN_ASCII,
|
|
31
|
+
VERSION,
|
|
32
|
+
build_system_prompt,
|
|
33
|
+
check_for_updates,
|
|
34
|
+
console,
|
|
35
|
+
fetch_openrouter_models_bg,
|
|
36
|
+
get_active_model,
|
|
37
|
+
get_api_key,
|
|
38
|
+
get_mcp_servers,
|
|
39
|
+
load_config,
|
|
40
|
+
save_config,
|
|
41
|
+
set_active_model,
|
|
42
|
+
)
|
|
43
|
+
from .context import ContextManager
|
|
44
|
+
from .logging_config import logger, setup_logging
|
|
45
|
+
from .mcp import MCPManager
|
|
46
|
+
from .plugins import plugin_manager
|
|
47
|
+
from .retry import retry_with_backoff
|
|
48
|
+
from .session import save_session
|
|
49
|
+
from .tools import backup_manager, execute_tool, tools
|
|
50
|
+
from .utils import Struct, TokenTracker, fetch_url_content, generate_directory_tree, truncate_output
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def inject_file_context(user_input: str) -> str:
|
|
54
|
+
context_blocks = []
|
|
55
|
+
|
|
56
|
+
# 1. Handle command injection (@cmd:"...")
|
|
57
|
+
cmd_pattern = r"(?:^|\s)@cmd:(?:\"([^\"]+)\"|\'([^\']+)\'|([^\s]+))"
|
|
58
|
+
cmd_matches = re.finditer(cmd_pattern, user_input)
|
|
59
|
+
for match in cmd_matches:
|
|
60
|
+
cmd = match.group(1) or match.group(2) or match.group(3)
|
|
61
|
+
if cmd:
|
|
62
|
+
console.print(f" [dim]⚡ Executing: {cmd}[/dim]")
|
|
63
|
+
try:
|
|
64
|
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=30)
|
|
65
|
+
output = result.stdout
|
|
66
|
+
if result.stderr:
|
|
67
|
+
output += "\n--- STDERR ---\n" + result.stderr
|
|
68
|
+
if not output.strip():
|
|
69
|
+
output = "[Command executed successfully with no output]"
|
|
70
|
+
context_blocks.append(
|
|
71
|
+
f'<command_context cmd="{cmd}">\n{output}\n</command_context>'
|
|
72
|
+
)
|
|
73
|
+
except Exception as e:
|
|
74
|
+
console.print(f" [dim yellow]⚠️ Command failed: {e}[/dim yellow]")
|
|
75
|
+
|
|
76
|
+
# 2. Handle standard file/url/directory injection
|
|
77
|
+
pattern = r"(?:^|\s)@(?!(?:cmd:))([a-zA-Z0-9_\-\./:?&=]+)"
|
|
78
|
+
matches = re.findall(pattern, user_input)
|
|
79
|
+
if not matches and not context_blocks:
|
|
80
|
+
return user_input
|
|
81
|
+
|
|
82
|
+
for item in set(matches):
|
|
83
|
+
if item.startswith("http://") or item.startswith("https://"):
|
|
84
|
+
console.print(f" [dim]🌐 Fetching: {item}[/dim]")
|
|
85
|
+
content = fetch_url_content(item)
|
|
86
|
+
if content.startswith("Error fetching URL:"):
|
|
87
|
+
console.print(f" [dim yellow]⚠️ {content}[/dim yellow]")
|
|
88
|
+
else:
|
|
89
|
+
context_blocks.append(
|
|
90
|
+
f'<url_context url="{item}">\n{content}\n</url_context>'
|
|
91
|
+
)
|
|
92
|
+
elif os.path.isfile(item):
|
|
93
|
+
try:
|
|
94
|
+
with open(item, encoding="utf-8", errors="ignore") as f:
|
|
95
|
+
content = f.read()
|
|
96
|
+
context_blocks.append(
|
|
97
|
+
f'<file_context path="{item}">\n{content}\n</file_context>'
|
|
98
|
+
)
|
|
99
|
+
console.print(f" [dim]📎 Attached: {item}[/dim]")
|
|
100
|
+
except Exception as e:
|
|
101
|
+
console.print(
|
|
102
|
+
f" [dim yellow]⚠️ Failed to read {item}: {e}[/dim yellow]"
|
|
103
|
+
)
|
|
104
|
+
elif os.path.isdir(item):
|
|
105
|
+
try:
|
|
106
|
+
tree_output = generate_directory_tree(item)
|
|
107
|
+
context_blocks.append(
|
|
108
|
+
f'<directory_context path="{item}">\n{tree_output}\n</directory_context>'
|
|
109
|
+
)
|
|
110
|
+
console.print(f" [dim]📂 Attached directory tree: {item}[/dim]")
|
|
111
|
+
except Exception as e:
|
|
112
|
+
console.print(
|
|
113
|
+
f" [dim yellow]⚠️ Failed to read directory {item}: {e}[/dim yellow]"
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
console.print(f" [dim yellow]⚠️ File not found: {item}[/dim yellow]")
|
|
117
|
+
|
|
118
|
+
if context_blocks:
|
|
119
|
+
user_input += "\n\n" + "\n".join(context_blocks)
|
|
120
|
+
return user_input
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def parse_args():
|
|
124
|
+
parser = argparse.ArgumentParser(
|
|
125
|
+
description="Aizen AI Agent — A professional-grade AI coding assistant."
|
|
126
|
+
)
|
|
127
|
+
parser.add_argument("--version", action="store_true", help="Show version.")
|
|
128
|
+
parser.add_argument("--model", type=str, help="Override the default model.")
|
|
129
|
+
parser.add_argument(
|
|
130
|
+
"--reset-key", action="store_true", help="Reset the saved API key."
|
|
131
|
+
)
|
|
132
|
+
parser.add_argument(
|
|
133
|
+
"--set-base-url", type=str, help="Set custom API base URL."
|
|
134
|
+
)
|
|
135
|
+
parser.add_argument(
|
|
136
|
+
"--yolo",
|
|
137
|
+
action="store_true",
|
|
138
|
+
help="Auto-approve all tool operations (no confirmations).",
|
|
139
|
+
)
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--verbose",
|
|
142
|
+
action="store_true",
|
|
143
|
+
help="Enable verbose logging output to console.",
|
|
144
|
+
)
|
|
145
|
+
return parser.parse_args()
|
|
146
|
+
|
|
147
|
+
@retry_with_backoff(max_retries=3, backoff_base=2.0)
|
|
148
|
+
async def _create_api_stream(client, messages, model, active_tools):
|
|
149
|
+
"""
|
|
150
|
+
Create a streaming API call with retry logic for transient errors.
|
|
151
|
+
Retry is handled by the @retry_with_backoff decorator (with jitter).
|
|
152
|
+
"""
|
|
153
|
+
return await client.chat.completions.create(
|
|
154
|
+
model=model,
|
|
155
|
+
messages=messages,
|
|
156
|
+
tools=active_tools,
|
|
157
|
+
tool_choice="auto",
|
|
158
|
+
stream=True,
|
|
159
|
+
stream_options={"include_usage": True},
|
|
160
|
+
)
|
|
161
|
+
async def main_loop():
|
|
162
|
+
args = parse_args()
|
|
163
|
+
|
|
164
|
+
if args.version:
|
|
165
|
+
print(f"Aizen v{VERSION}")
|
|
166
|
+
sys.exit(0)
|
|
167
|
+
|
|
168
|
+
# Initialize structured logging (file + optional console)
|
|
169
|
+
setup_logging(verbose=getattr(args, "verbose", False))
|
|
170
|
+
logger.info("Aizen starting v%s", VERSION)
|
|
171
|
+
|
|
172
|
+
config = load_config()
|
|
173
|
+
|
|
174
|
+
if args.set_base_url:
|
|
175
|
+
config["API_BASE_URL"] = args.set_base_url
|
|
176
|
+
save_config(config)
|
|
177
|
+
print(f"✓ API base URL set to: {args.set_base_url}")
|
|
178
|
+
sys.exit(0)
|
|
179
|
+
|
|
180
|
+
api_key = get_api_key(config, reset=args.reset_key)
|
|
181
|
+
|
|
182
|
+
if args.model:
|
|
183
|
+
set_active_model(args.model)
|
|
184
|
+
elif config.get("DEFAULT_MODEL"):
|
|
185
|
+
set_active_model(config["DEFAULT_MODEL"])
|
|
186
|
+
|
|
187
|
+
api_base = config.get("API_BASE_URL", "https://openrouter.ai/api/v1")
|
|
188
|
+
auto_approve = args.yolo
|
|
189
|
+
|
|
190
|
+
client = AsyncOpenAI(base_url=api_base, api_key=api_key)
|
|
191
|
+
|
|
192
|
+
token_tracker = TokenTracker()
|
|
193
|
+
context_manager = ContextManager(get_active_model())
|
|
194
|
+
|
|
195
|
+
# Cleanup old backups
|
|
196
|
+
backup_manager.cleanup()
|
|
197
|
+
|
|
198
|
+
# Non-blocking update check (background thread, 24h cache)
|
|
199
|
+
check_for_updates(config)
|
|
200
|
+
|
|
201
|
+
# Non-blocking models fetch (background thread, 24h cache)
|
|
202
|
+
fetch_openrouter_models_bg()
|
|
203
|
+
|
|
204
|
+
# Initialize MCP
|
|
205
|
+
mcp_servers_config = get_mcp_servers(config)
|
|
206
|
+
mcp_manager = MCPManager(mcp_servers_config)
|
|
207
|
+
if mcp_servers_config:
|
|
208
|
+
console.print("[dim]Initializing MCP servers...[/dim]")
|
|
209
|
+
await mcp_manager.start()
|
|
210
|
+
|
|
211
|
+
active_tools = tools + mcp_manager.get_tools() + plugin_manager.get_tools()
|
|
212
|
+
|
|
213
|
+
# ── Header ──
|
|
214
|
+
console.print(AIZEN_ASCII)
|
|
215
|
+
header = Text()
|
|
216
|
+
header.append(f"v{VERSION}", style="bold magenta")
|
|
217
|
+
header.append(" │ ", style="dim")
|
|
218
|
+
header.append(get_active_model(), style="cyan")
|
|
219
|
+
if auto_approve:
|
|
220
|
+
header.append(" │ ", style="dim")
|
|
221
|
+
header.append("YOLO MODE", style="bold red")
|
|
222
|
+
console.print(header)
|
|
223
|
+
console.print(
|
|
224
|
+
"[dim]Type /help for commands • @file to attach • exit to quit[/dim]\n"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# ── Keybindings ──
|
|
228
|
+
kb = KeyBindings()
|
|
229
|
+
|
|
230
|
+
@kb.add("enter", filter=has_completions & completion_is_selected)
|
|
231
|
+
def _(event):
|
|
232
|
+
event.current_buffer.complete_state = None
|
|
233
|
+
|
|
234
|
+
session: PromptSession = PromptSession(completer=AizenCompleter(), key_bindings=kb)
|
|
235
|
+
|
|
236
|
+
messages = [{"role": "system", "content": build_system_prompt(config)}]
|
|
237
|
+
|
|
238
|
+
while True:
|
|
239
|
+
try:
|
|
240
|
+
# ── Multi-line Input ──
|
|
241
|
+
lines = []
|
|
242
|
+
prompt_html = HTML(
|
|
243
|
+
"<ansimagenta>╭─</ansimagenta> <ansimagenta><b>👤 You</b></ansimagenta>\n"
|
|
244
|
+
"<ansimagenta>╰─❯</ansimagenta> "
|
|
245
|
+
)
|
|
246
|
+
first_line = await session.prompt_async(prompt_html)
|
|
247
|
+
lines.append(first_line)
|
|
248
|
+
|
|
249
|
+
# Continue reading if line ends with backslash
|
|
250
|
+
while lines[-1].rstrip().endswith("\\"):
|
|
251
|
+
lines[-1] = lines[-1].rstrip()[:-1] # Remove trailing backslash
|
|
252
|
+
continuation = await session.prompt_async(
|
|
253
|
+
HTML("<ansimagenta> ⋮ </ansimagenta> ")
|
|
254
|
+
)
|
|
255
|
+
lines.append(continuation)
|
|
256
|
+
|
|
257
|
+
user_input = "\n".join(lines)
|
|
258
|
+
|
|
259
|
+
if user_input.lower().strip() in ("exit", "quit"):
|
|
260
|
+
# Auto-save on exit
|
|
261
|
+
if len(messages) > 2:
|
|
262
|
+
try:
|
|
263
|
+
save_session(messages, token_tracker=token_tracker)
|
|
264
|
+
console.print("[dim]Session auto-saved.[/dim]")
|
|
265
|
+
except Exception:
|
|
266
|
+
logger.exception("Failed to auto-save session on exit")
|
|
267
|
+
try:
|
|
268
|
+
await mcp_manager.stop()
|
|
269
|
+
except Exception:
|
|
270
|
+
logger.exception("Failed to stop MCP manager on exit")
|
|
271
|
+
console.print("[yellow]Goodbye! 👋[/yellow]")
|
|
272
|
+
break
|
|
273
|
+
|
|
274
|
+
if not user_input.strip():
|
|
275
|
+
continue
|
|
276
|
+
|
|
277
|
+
# ── Slash Commands ──
|
|
278
|
+
if user_input.strip().startswith("/"):
|
|
279
|
+
should_retry = await handle_slash_command(
|
|
280
|
+
user_input.strip(), messages, token_tracker, mcp_manager, client
|
|
281
|
+
)
|
|
282
|
+
if should_retry and messages and messages[-1]["role"] == "user":
|
|
283
|
+
pass # Fall through to the agent loop
|
|
284
|
+
else:
|
|
285
|
+
continue
|
|
286
|
+
else:
|
|
287
|
+
user_input = inject_file_context(user_input)
|
|
288
|
+
messages.append({"role": "user", "content": user_input})
|
|
289
|
+
|
|
290
|
+
# ── Context Window Check ──
|
|
291
|
+
estimated_total = context_manager.estimate_messages_tokens(
|
|
292
|
+
messages, token_tracker.estimate_tokens
|
|
293
|
+
)
|
|
294
|
+
context_manager.update(estimated_total)
|
|
295
|
+
warning = context_manager.check_and_warn()
|
|
296
|
+
if warning:
|
|
297
|
+
console.print(f"[yellow]{warning}[/yellow]\n")
|
|
298
|
+
|
|
299
|
+
# ── Auto-compact if context is critically full (>90%) ──
|
|
300
|
+
if context_manager.needs_auto_compact() and len(messages) > 6:
|
|
301
|
+
console.print("[dim yellow]⚡ Auto-compacting conversation to stay within context limits...[/dim yellow]")
|
|
302
|
+
system_msg = messages[0]
|
|
303
|
+
recent = messages[-4:]
|
|
304
|
+
middle = messages[1:-4]
|
|
305
|
+
if middle:
|
|
306
|
+
user_topics = [
|
|
307
|
+
m["content"][:100]
|
|
308
|
+
for m in middle
|
|
309
|
+
if m["role"] == "user" and m.get("content")
|
|
310
|
+
]
|
|
311
|
+
summary = (
|
|
312
|
+
"Previous conversation summary: The user and assistant discussed "
|
|
313
|
+
+ "; ".join(user_topics[:5])
|
|
314
|
+
+ ". The assistant helped with these requests using code analysis and editing tools."
|
|
315
|
+
)
|
|
316
|
+
messages[:] = [
|
|
317
|
+
system_msg,
|
|
318
|
+
{"role": "user", "content": f"Previous conversation summary:\n{summary}"},
|
|
319
|
+
{
|
|
320
|
+
"role": "assistant",
|
|
321
|
+
"content": "Understood. I have the context from our previous discussion. How can I continue helping?",
|
|
322
|
+
},
|
|
323
|
+
] + recent
|
|
324
|
+
console.print(
|
|
325
|
+
f"[green]✓ Auto-compacted {len(middle)} messages into a summary.[/green]\n"
|
|
326
|
+
)
|
|
327
|
+
# Recalculate token usage after compaction
|
|
328
|
+
estimated_total = context_manager.estimate_messages_tokens(
|
|
329
|
+
messages, token_tracker.estimate_tokens
|
|
330
|
+
)
|
|
331
|
+
context_manager.update(estimated_total)
|
|
332
|
+
|
|
333
|
+
# ── Agent Loop ──────────────────────────────────────────────────
|
|
334
|
+
while True:
|
|
335
|
+
full_content = ""
|
|
336
|
+
accumulated_tool_calls = {}
|
|
337
|
+
|
|
338
|
+
# Build spinner text
|
|
339
|
+
spinner_label = random.choice(
|
|
340
|
+
[
|
|
341
|
+
"Thinking...",
|
|
342
|
+
"Analyzing...",
|
|
343
|
+
"Reasoning...",
|
|
344
|
+
"Processing...",
|
|
345
|
+
"Considering...",
|
|
346
|
+
"Exploring...",
|
|
347
|
+
]
|
|
348
|
+
)
|
|
349
|
+
spinner_display = Text()
|
|
350
|
+
spinner_display.append(" ✦ ", style="bold magenta")
|
|
351
|
+
spinner_display.append(spinner_label, style="dim italic")
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
with Live(
|
|
355
|
+
spinner_display,
|
|
356
|
+
console=console,
|
|
357
|
+
refresh_per_second=8,
|
|
358
|
+
) as live:
|
|
359
|
+
stream = await _create_api_stream(
|
|
360
|
+
client, messages, get_active_model(), active_tools
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
api_usage = None
|
|
364
|
+
|
|
365
|
+
async for chunk in stream:
|
|
366
|
+
# Parse API-reported usage from the final chunk
|
|
367
|
+
if hasattr(chunk, "usage") and chunk.usage:
|
|
368
|
+
api_usage = chunk.usage
|
|
369
|
+
|
|
370
|
+
delta = (
|
|
371
|
+
chunk.choices[0].delta if chunk.choices else None
|
|
372
|
+
)
|
|
373
|
+
if not delta:
|
|
374
|
+
continue
|
|
375
|
+
|
|
376
|
+
# ── Content tokens ──
|
|
377
|
+
if delta.content:
|
|
378
|
+
full_content += delta.content
|
|
379
|
+
# Live-render Markdown in a panel
|
|
380
|
+
try:
|
|
381
|
+
rendered = Panel(
|
|
382
|
+
Markdown(full_content),
|
|
383
|
+
title="[bold magenta]✦ Aizen[/bold magenta]",
|
|
384
|
+
border_style="magenta",
|
|
385
|
+
padding=(1, 2),
|
|
386
|
+
)
|
|
387
|
+
live.update(rendered)
|
|
388
|
+
except Exception:
|
|
389
|
+
# Fallback for incomplete markdown
|
|
390
|
+
live.update(
|
|
391
|
+
Panel(
|
|
392
|
+
Text(full_content),
|
|
393
|
+
title="[bold magenta]✦ Aizen[/bold magenta]",
|
|
394
|
+
border_style="magenta",
|
|
395
|
+
padding=(1, 2),
|
|
396
|
+
)
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# ── Tool call tokens ──
|
|
400
|
+
if delta.tool_calls:
|
|
401
|
+
for tc in delta.tool_calls:
|
|
402
|
+
idx = tc.index
|
|
403
|
+
if idx not in accumulated_tool_calls:
|
|
404
|
+
accumulated_tool_calls[idx] = {
|
|
405
|
+
"id": "",
|
|
406
|
+
"name": "",
|
|
407
|
+
"arguments": "",
|
|
408
|
+
"type": "function",
|
|
409
|
+
}
|
|
410
|
+
if tc.id:
|
|
411
|
+
accumulated_tool_calls[idx]["id"] = tc.id
|
|
412
|
+
if tc.function:
|
|
413
|
+
if tc.function.name:
|
|
414
|
+
accumulated_tool_calls[idx][
|
|
415
|
+
"name"
|
|
416
|
+
] += tc.function.name
|
|
417
|
+
if tc.function.arguments:
|
|
418
|
+
accumulated_tool_calls[idx][
|
|
419
|
+
"arguments"
|
|
420
|
+
] += tc.function.arguments
|
|
421
|
+
|
|
422
|
+
# Update spinner with tool info
|
|
423
|
+
names = [
|
|
424
|
+
v["name"]
|
|
425
|
+
for v in accumulated_tool_calls.values()
|
|
426
|
+
if v["name"]
|
|
427
|
+
]
|
|
428
|
+
if names and not full_content:
|
|
429
|
+
tool_text = Text()
|
|
430
|
+
tool_text.append(" ⚙️ ", style="magenta")
|
|
431
|
+
tool_text.append(
|
|
432
|
+
f"Preparing: {', '.join(names)}",
|
|
433
|
+
style="dim italic",
|
|
434
|
+
)
|
|
435
|
+
live.update(tool_text)
|
|
436
|
+
|
|
437
|
+
except AuthenticationError:
|
|
438
|
+
logger.error("Authentication failed — invalid API key")
|
|
439
|
+
console.print(
|
|
440
|
+
"\n[bold red]Authentication Error:[/bold red] Invalid API key."
|
|
441
|
+
)
|
|
442
|
+
console.print(
|
|
443
|
+
"[dim]Hint: Run with --reset-key to enter a new key.[/dim]"
|
|
444
|
+
)
|
|
445
|
+
break
|
|
446
|
+
except OpenAIRateLimitError:
|
|
447
|
+
logger.warning("Rate limited by API")
|
|
448
|
+
console.print(
|
|
449
|
+
"\n[bold red]Rate Limited:[/bold red] Too many requests."
|
|
450
|
+
)
|
|
451
|
+
console.print(
|
|
452
|
+
"[dim]Hint: Wait a moment and try again, or switch to a different model.[/dim]"
|
|
453
|
+
)
|
|
454
|
+
break
|
|
455
|
+
except APITimeoutError:
|
|
456
|
+
logger.warning("API request timed out")
|
|
457
|
+
console.print(
|
|
458
|
+
"\n[bold red]Timeout:[/bold red] The request timed out."
|
|
459
|
+
)
|
|
460
|
+
console.print(
|
|
461
|
+
"[dim]Hint: Check your internet connection and try again.[/dim]"
|
|
462
|
+
)
|
|
463
|
+
break
|
|
464
|
+
except OpenAIConnectionError:
|
|
465
|
+
logger.warning("API connection failed")
|
|
466
|
+
console.print(
|
|
467
|
+
"\n[bold red]Connection Error:[/bold red] Could not reach the API."
|
|
468
|
+
)
|
|
469
|
+
console.print(
|
|
470
|
+
"[dim]Hint: Check your internet connection or API base URL.[/dim]"
|
|
471
|
+
)
|
|
472
|
+
break
|
|
473
|
+
except Exception as e:
|
|
474
|
+
logger.exception("Unexpected API error: %s", e)
|
|
475
|
+
console.print(f"\n[bold red]API Error:[/bold red] {e}")
|
|
476
|
+
error_str = str(e).lower()
|
|
477
|
+
if "401" in error_str or "unauthorized" in error_str:
|
|
478
|
+
console.print(
|
|
479
|
+
"[dim]Hint: API key may be invalid. Run with --reset-key[/dim]"
|
|
480
|
+
)
|
|
481
|
+
elif "429" in error_str or "rate" in error_str:
|
|
482
|
+
console.print(
|
|
483
|
+
"[dim]Hint: Rate limited. Wait a moment and retry.[/dim]"
|
|
484
|
+
)
|
|
485
|
+
elif "timeout" in error_str:
|
|
486
|
+
console.print(
|
|
487
|
+
"[dim]Hint: Request timed out. Check your connection.[/dim]"
|
|
488
|
+
)
|
|
489
|
+
break
|
|
490
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
491
|
+
logger.warning("Generation cancelled by user")
|
|
492
|
+
console.print("\n[yellow]Generation cancelled.[/yellow]")
|
|
493
|
+
break
|
|
494
|
+
|
|
495
|
+
# Track tokens — prefer API-reported usage, fall back to estimation
|
|
496
|
+
if api_usage and hasattr(api_usage, "prompt_tokens"):
|
|
497
|
+
token_tracker.add_api_usage(
|
|
498
|
+
api_usage.prompt_tokens or 0,
|
|
499
|
+
api_usage.completion_tokens or 0,
|
|
500
|
+
)
|
|
501
|
+
context_manager.update(
|
|
502
|
+
(api_usage.prompt_tokens or 0) + (api_usage.completion_tokens or 0)
|
|
503
|
+
)
|
|
504
|
+
elif full_content:
|
|
505
|
+
estimated_input = token_tracker.estimate_tokens(
|
|
506
|
+
json.dumps(messages[-1]) if messages else ""
|
|
507
|
+
)
|
|
508
|
+
estimated_output = token_tracker.estimate_tokens(full_content)
|
|
509
|
+
token_tracker.add_usage(estimated_input, estimated_output)
|
|
510
|
+
|
|
511
|
+
# Build tool calls list
|
|
512
|
+
tool_calls_list: list[dict[str, Any]] = []
|
|
513
|
+
for idx in sorted(accumulated_tool_calls.keys()):
|
|
514
|
+
tc = accumulated_tool_calls[idx]
|
|
515
|
+
tool_calls_list.append(
|
|
516
|
+
{
|
|
517
|
+
"id": tc["id"],
|
|
518
|
+
"type": "function",
|
|
519
|
+
"function": {
|
|
520
|
+
"name": tc["name"],
|
|
521
|
+
"arguments": tc["arguments"],
|
|
522
|
+
},
|
|
523
|
+
}
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
# Add assistant message to history
|
|
527
|
+
assistant_msg: dict[str, Any] = {
|
|
528
|
+
"role": "assistant",
|
|
529
|
+
"content": full_content or "",
|
|
530
|
+
}
|
|
531
|
+
if tool_calls_list:
|
|
532
|
+
assistant_msg["tool_calls"] = tool_calls_list
|
|
533
|
+
messages.append(assistant_msg)
|
|
534
|
+
|
|
535
|
+
# If no tool calls, we're done
|
|
536
|
+
if not tool_calls_list:
|
|
537
|
+
break
|
|
538
|
+
|
|
539
|
+
# Execute tool calls in parallel
|
|
540
|
+
async def _exec_tool(tc_dict):
|
|
541
|
+
func_name = tc_dict["function"]["name"]
|
|
542
|
+
if func_name.startswith("mcp_"):
|
|
543
|
+
try:
|
|
544
|
+
args = json.loads(tc_dict["function"]["arguments"])
|
|
545
|
+
result = await mcp_manager.call_tool(func_name, args)
|
|
546
|
+
except json.JSONDecodeError:
|
|
547
|
+
result = f"Error: Invalid JSON arguments for {func_name}."
|
|
548
|
+
else:
|
|
549
|
+
func_struct = Struct(**tc_dict["function"])
|
|
550
|
+
tc_struct = Struct(
|
|
551
|
+
id=tc_dict["id"],
|
|
552
|
+
type=tc_dict["type"],
|
|
553
|
+
function=func_struct,
|
|
554
|
+
)
|
|
555
|
+
result = await asyncio.to_thread(execute_tool, tc_struct, auto_approve)
|
|
556
|
+
|
|
557
|
+
return {
|
|
558
|
+
"role": "tool",
|
|
559
|
+
"tool_call_id": tc_dict["id"],
|
|
560
|
+
"name": func_name,
|
|
561
|
+
"content": truncate_output(result),
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
tool_results = await asyncio.gather(*[_exec_tool(tc) for tc in tool_calls_list])
|
|
565
|
+
messages.extend(tool_results)
|
|
566
|
+
|
|
567
|
+
# Continue the loop — model processes tool results
|
|
568
|
+
|
|
569
|
+
# ── Footer ──
|
|
570
|
+
footer = Text()
|
|
571
|
+
|
|
572
|
+
# Calculate estimated cost
|
|
573
|
+
cost = token_tracker.get_estimated_cost(get_active_model())
|
|
574
|
+
|
|
575
|
+
footer.append(
|
|
576
|
+
f" tokens: ~{token_tracker.total_tokens:,} (${cost:.3f}) │ " if cost > 0 else f" tokens: ~{token_tracker.total_tokens:,} │ "
|
|
577
|
+
)
|
|
578
|
+
footer.append(
|
|
579
|
+
f"messages: {token_tracker.message_count} │ "
|
|
580
|
+
f"model: {get_active_model()}",
|
|
581
|
+
style="dim",
|
|
582
|
+
)
|
|
583
|
+
# Add context usage bar
|
|
584
|
+
footer.append(" │ ", style="dim")
|
|
585
|
+
|
|
586
|
+
# Reconstruct string for dim printing to match existing pattern
|
|
587
|
+
cost_display = f" (${cost:.3f})" if cost > 0 else ""
|
|
588
|
+
console.print(
|
|
589
|
+
f"[dim] tokens: ~{token_tracker.total_tokens:,}{cost_display} │ "
|
|
590
|
+
f"messages: {token_tracker.message_count} │ "
|
|
591
|
+
f"model: {get_active_model()} │ "
|
|
592
|
+
f"{context_manager.get_footer_text()}[/dim]\n"
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
except (KeyboardInterrupt, EOFError):
|
|
596
|
+
# Auto-save on interrupt
|
|
597
|
+
if len(messages) > 2:
|
|
598
|
+
try:
|
|
599
|
+
save_session(messages, token_tracker=token_tracker)
|
|
600
|
+
console.print("\n[dim]Session auto-saved.[/dim]")
|
|
601
|
+
except Exception:
|
|
602
|
+
logger.exception("Failed to auto-save session on interrupt")
|
|
603
|
+
try:
|
|
604
|
+
await mcp_manager.stop()
|
|
605
|
+
except Exception:
|
|
606
|
+
logger.exception("Failed to stop MCP manager on interrupt")
|
|
607
|
+
console.print("[yellow]Goodbye! 👋[/yellow]")
|
|
608
|
+
break
|
|
609
|
+
except Exception as e:
|
|
610
|
+
logger.exception("Unhandled error in main loop: %s", e)
|
|
611
|
+
console.print(f"\n[bold red]Error:[/bold red] {e}")
|
|
612
|
+
def main():
|
|
613
|
+
asyncio.run(main_loop())
|
|
614
|
+
|
|
615
|
+
if __name__ == "__main__":
|
|
616
|
+
main()
|