aizen-ai-cli 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aizen/main.py ADDED
@@ -0,0 +1,616 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Aizen AI Agent — A professional-grade AI coding assistant for your terminal.
4
+ """
5
+
6
+ import argparse
7
+ import asyncio
8
+ import json
9
+ import os
10
+ import random
11
+ import re
12
+ import subprocess
13
+ import sys
14
+ from typing import Any
15
+
16
+ from openai import APIConnectionError as OpenAIConnectionError
17
+ from openai import APITimeoutError, AsyncOpenAI, AuthenticationError
18
+ from openai import RateLimitError as OpenAIRateLimitError
19
+ from prompt_toolkit import PromptSession
20
+ from prompt_toolkit.filters import completion_is_selected, has_completions
21
+ from prompt_toolkit.formatted_text import HTML
22
+ from prompt_toolkit.key_binding import KeyBindings
23
+ from rich.live import Live
24
+ from rich.markdown import Markdown
25
+ from rich.panel import Panel
26
+ from rich.text import Text
27
+
28
+ from .commands import AizenCompleter, handle_slash_command
29
+ from .config import (
30
+ AIZEN_ASCII,
31
+ VERSION,
32
+ build_system_prompt,
33
+ check_for_updates,
34
+ console,
35
+ fetch_openrouter_models_bg,
36
+ get_active_model,
37
+ get_api_key,
38
+ get_mcp_servers,
39
+ load_config,
40
+ save_config,
41
+ set_active_model,
42
+ )
43
+ from .context import ContextManager
44
+ from .logging_config import logger, setup_logging
45
+ from .mcp import MCPManager
46
+ from .plugins import plugin_manager
47
+ from .retry import retry_with_backoff
48
+ from .session import save_session
49
+ from .tools import backup_manager, execute_tool, tools
50
+ from .utils import Struct, TokenTracker, fetch_url_content, generate_directory_tree, truncate_output
51
+
52
+
53
+ def inject_file_context(user_input: str) -> str:
54
+ context_blocks = []
55
+
56
+ # 1. Handle command injection (@cmd:"...")
57
+ cmd_pattern = r"(?:^|\s)@cmd:(?:\"([^\"]+)\"|\'([^\']+)\'|([^\s]+))"
58
+ cmd_matches = re.finditer(cmd_pattern, user_input)
59
+ for match in cmd_matches:
60
+ cmd = match.group(1) or match.group(2) or match.group(3)
61
+ if cmd:
62
+ console.print(f" [dim]⚡ Executing: {cmd}[/dim]")
63
+ try:
64
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=30)
65
+ output = result.stdout
66
+ if result.stderr:
67
+ output += "\n--- STDERR ---\n" + result.stderr
68
+ if not output.strip():
69
+ output = "[Command executed successfully with no output]"
70
+ context_blocks.append(
71
+ f'<command_context cmd="{cmd}">\n{output}\n</command_context>'
72
+ )
73
+ except Exception as e:
74
+ console.print(f" [dim yellow]⚠️ Command failed: {e}[/dim yellow]")
75
+
76
+ # 2. Handle standard file/url/directory injection
77
+ pattern = r"(?:^|\s)@(?!(?:cmd:))([a-zA-Z0-9_\-\./:?&=]+)"
78
+ matches = re.findall(pattern, user_input)
79
+ if not matches and not context_blocks:
80
+ return user_input
81
+
82
+ for item in set(matches):
83
+ if item.startswith("http://") or item.startswith("https://"):
84
+ console.print(f" [dim]🌐 Fetching: {item}[/dim]")
85
+ content = fetch_url_content(item)
86
+ if content.startswith("Error fetching URL:"):
87
+ console.print(f" [dim yellow]⚠️ {content}[/dim yellow]")
88
+ else:
89
+ context_blocks.append(
90
+ f'<url_context url="{item}">\n{content}\n</url_context>'
91
+ )
92
+ elif os.path.isfile(item):
93
+ try:
94
+ with open(item, encoding="utf-8", errors="ignore") as f:
95
+ content = f.read()
96
+ context_blocks.append(
97
+ f'<file_context path="{item}">\n{content}\n</file_context>'
98
+ )
99
+ console.print(f" [dim]📎 Attached: {item}[/dim]")
100
+ except Exception as e:
101
+ console.print(
102
+ f" [dim yellow]⚠️ Failed to read {item}: {e}[/dim yellow]"
103
+ )
104
+ elif os.path.isdir(item):
105
+ try:
106
+ tree_output = generate_directory_tree(item)
107
+ context_blocks.append(
108
+ f'<directory_context path="{item}">\n{tree_output}\n</directory_context>'
109
+ )
110
+ console.print(f" [dim]📂 Attached directory tree: {item}[/dim]")
111
+ except Exception as e:
112
+ console.print(
113
+ f" [dim yellow]⚠️ Failed to read directory {item}: {e}[/dim yellow]"
114
+ )
115
+ else:
116
+ console.print(f" [dim yellow]⚠️ File not found: {item}[/dim yellow]")
117
+
118
+ if context_blocks:
119
+ user_input += "\n\n" + "\n".join(context_blocks)
120
+ return user_input
121
+
122
+
123
+ def parse_args():
124
+ parser = argparse.ArgumentParser(
125
+ description="Aizen AI Agent — A professional-grade AI coding assistant."
126
+ )
127
+ parser.add_argument("--version", action="store_true", help="Show version.")
128
+ parser.add_argument("--model", type=str, help="Override the default model.")
129
+ parser.add_argument(
130
+ "--reset-key", action="store_true", help="Reset the saved API key."
131
+ )
132
+ parser.add_argument(
133
+ "--set-base-url", type=str, help="Set custom API base URL."
134
+ )
135
+ parser.add_argument(
136
+ "--yolo",
137
+ action="store_true",
138
+ help="Auto-approve all tool operations (no confirmations).",
139
+ )
140
+ parser.add_argument(
141
+ "--verbose",
142
+ action="store_true",
143
+ help="Enable verbose logging output to console.",
144
+ )
145
+ return parser.parse_args()
146
+
147
+ @retry_with_backoff(max_retries=3, backoff_base=2.0)
148
+ async def _create_api_stream(client, messages, model, active_tools):
149
+ """
150
+ Create a streaming API call with retry logic for transient errors.
151
+ Retry is handled by the @retry_with_backoff decorator (with jitter).
152
+ """
153
+ return await client.chat.completions.create(
154
+ model=model,
155
+ messages=messages,
156
+ tools=active_tools,
157
+ tool_choice="auto",
158
+ stream=True,
159
+ stream_options={"include_usage": True},
160
+ )
161
+ async def main_loop():
162
+ args = parse_args()
163
+
164
+ if args.version:
165
+ print(f"Aizen v{VERSION}")
166
+ sys.exit(0)
167
+
168
+ # Initialize structured logging (file + optional console)
169
+ setup_logging(verbose=getattr(args, "verbose", False))
170
+ logger.info("Aizen starting v%s", VERSION)
171
+
172
+ config = load_config()
173
+
174
+ if args.set_base_url:
175
+ config["API_BASE_URL"] = args.set_base_url
176
+ save_config(config)
177
+ print(f"✓ API base URL set to: {args.set_base_url}")
178
+ sys.exit(0)
179
+
180
+ api_key = get_api_key(config, reset=args.reset_key)
181
+
182
+ if args.model:
183
+ set_active_model(args.model)
184
+ elif config.get("DEFAULT_MODEL"):
185
+ set_active_model(config["DEFAULT_MODEL"])
186
+
187
+ api_base = config.get("API_BASE_URL", "https://openrouter.ai/api/v1")
188
+ auto_approve = args.yolo
189
+
190
+ client = AsyncOpenAI(base_url=api_base, api_key=api_key)
191
+
192
+ token_tracker = TokenTracker()
193
+ context_manager = ContextManager(get_active_model())
194
+
195
+ # Cleanup old backups
196
+ backup_manager.cleanup()
197
+
198
+ # Non-blocking update check (background thread, 24h cache)
199
+ check_for_updates(config)
200
+
201
+ # Non-blocking models fetch (background thread, 24h cache)
202
+ fetch_openrouter_models_bg()
203
+
204
+ # Initialize MCP
205
+ mcp_servers_config = get_mcp_servers(config)
206
+ mcp_manager = MCPManager(mcp_servers_config)
207
+ if mcp_servers_config:
208
+ console.print("[dim]Initializing MCP servers...[/dim]")
209
+ await mcp_manager.start()
210
+
211
+ active_tools = tools + mcp_manager.get_tools() + plugin_manager.get_tools()
212
+
213
+ # ── Header ──
214
+ console.print(AIZEN_ASCII)
215
+ header = Text()
216
+ header.append(f"v{VERSION}", style="bold magenta")
217
+ header.append(" │ ", style="dim")
218
+ header.append(get_active_model(), style="cyan")
219
+ if auto_approve:
220
+ header.append(" │ ", style="dim")
221
+ header.append("YOLO MODE", style="bold red")
222
+ console.print(header)
223
+ console.print(
224
+ "[dim]Type /help for commands • @file to attach • exit to quit[/dim]\n"
225
+ )
226
+
227
+ # ── Keybindings ──
228
+ kb = KeyBindings()
229
+
230
+ @kb.add("enter", filter=has_completions & completion_is_selected)
231
+ def _(event):
232
+ event.current_buffer.complete_state = None
233
+
234
+ session: PromptSession = PromptSession(completer=AizenCompleter(), key_bindings=kb)
235
+
236
+ messages = [{"role": "system", "content": build_system_prompt(config)}]
237
+
238
+ while True:
239
+ try:
240
+ # ── Multi-line Input ──
241
+ lines = []
242
+ prompt_html = HTML(
243
+ "<ansimagenta>╭─</ansimagenta> <ansimagenta><b>👤 You</b></ansimagenta>\n"
244
+ "<ansimagenta>╰─❯</ansimagenta> "
245
+ )
246
+ first_line = await session.prompt_async(prompt_html)
247
+ lines.append(first_line)
248
+
249
+ # Continue reading if line ends with backslash
250
+ while lines[-1].rstrip().endswith("\\"):
251
+ lines[-1] = lines[-1].rstrip()[:-1] # Remove trailing backslash
252
+ continuation = await session.prompt_async(
253
+ HTML("<ansimagenta> ⋮ </ansimagenta> ")
254
+ )
255
+ lines.append(continuation)
256
+
257
+ user_input = "\n".join(lines)
258
+
259
+ if user_input.lower().strip() in ("exit", "quit"):
260
+ # Auto-save on exit
261
+ if len(messages) > 2:
262
+ try:
263
+ save_session(messages, token_tracker=token_tracker)
264
+ console.print("[dim]Session auto-saved.[/dim]")
265
+ except Exception:
266
+ logger.exception("Failed to auto-save session on exit")
267
+ try:
268
+ await mcp_manager.stop()
269
+ except Exception:
270
+ logger.exception("Failed to stop MCP manager on exit")
271
+ console.print("[yellow]Goodbye! 👋[/yellow]")
272
+ break
273
+
274
+ if not user_input.strip():
275
+ continue
276
+
277
+ # ── Slash Commands ──
278
+ if user_input.strip().startswith("/"):
279
+ should_retry = await handle_slash_command(
280
+ user_input.strip(), messages, token_tracker, mcp_manager, client
281
+ )
282
+ if should_retry and messages and messages[-1]["role"] == "user":
283
+ pass # Fall through to the agent loop
284
+ else:
285
+ continue
286
+ else:
287
+ user_input = inject_file_context(user_input)
288
+ messages.append({"role": "user", "content": user_input})
289
+
290
+ # ── Context Window Check ──
291
+ estimated_total = context_manager.estimate_messages_tokens(
292
+ messages, token_tracker.estimate_tokens
293
+ )
294
+ context_manager.update(estimated_total)
295
+ warning = context_manager.check_and_warn()
296
+ if warning:
297
+ console.print(f"[yellow]{warning}[/yellow]\n")
298
+
299
+ # ── Auto-compact if context is critically full (>90%) ──
300
+ if context_manager.needs_auto_compact() and len(messages) > 6:
301
+ console.print("[dim yellow]⚡ Auto-compacting conversation to stay within context limits...[/dim yellow]")
302
+ system_msg = messages[0]
303
+ recent = messages[-4:]
304
+ middle = messages[1:-4]
305
+ if middle:
306
+ user_topics = [
307
+ m["content"][:100]
308
+ for m in middle
309
+ if m["role"] == "user" and m.get("content")
310
+ ]
311
+ summary = (
312
+ "Previous conversation summary: The user and assistant discussed "
313
+ + "; ".join(user_topics[:5])
314
+ + ". The assistant helped with these requests using code analysis and editing tools."
315
+ )
316
+ messages[:] = [
317
+ system_msg,
318
+ {"role": "user", "content": f"Previous conversation summary:\n{summary}"},
319
+ {
320
+ "role": "assistant",
321
+ "content": "Understood. I have the context from our previous discussion. How can I continue helping?",
322
+ },
323
+ ] + recent
324
+ console.print(
325
+ f"[green]✓ Auto-compacted {len(middle)} messages into a summary.[/green]\n"
326
+ )
327
+ # Recalculate token usage after compaction
328
+ estimated_total = context_manager.estimate_messages_tokens(
329
+ messages, token_tracker.estimate_tokens
330
+ )
331
+ context_manager.update(estimated_total)
332
+
333
+ # ── Agent Loop ──────────────────────────────────────────────────
334
+ while True:
335
+ full_content = ""
336
+ accumulated_tool_calls = {}
337
+
338
+ # Build spinner text
339
+ spinner_label = random.choice(
340
+ [
341
+ "Thinking...",
342
+ "Analyzing...",
343
+ "Reasoning...",
344
+ "Processing...",
345
+ "Considering...",
346
+ "Exploring...",
347
+ ]
348
+ )
349
+ spinner_display = Text()
350
+ spinner_display.append(" ✦ ", style="bold magenta")
351
+ spinner_display.append(spinner_label, style="dim italic")
352
+
353
+ try:
354
+ with Live(
355
+ spinner_display,
356
+ console=console,
357
+ refresh_per_second=8,
358
+ ) as live:
359
+ stream = await _create_api_stream(
360
+ client, messages, get_active_model(), active_tools
361
+ )
362
+
363
+ api_usage = None
364
+
365
+ async for chunk in stream:
366
+ # Parse API-reported usage from the final chunk
367
+ if hasattr(chunk, "usage") and chunk.usage:
368
+ api_usage = chunk.usage
369
+
370
+ delta = (
371
+ chunk.choices[0].delta if chunk.choices else None
372
+ )
373
+ if not delta:
374
+ continue
375
+
376
+ # ── Content tokens ──
377
+ if delta.content:
378
+ full_content += delta.content
379
+ # Live-render Markdown in a panel
380
+ try:
381
+ rendered = Panel(
382
+ Markdown(full_content),
383
+ title="[bold magenta]✦ Aizen[/bold magenta]",
384
+ border_style="magenta",
385
+ padding=(1, 2),
386
+ )
387
+ live.update(rendered)
388
+ except Exception:
389
+ # Fallback for incomplete markdown
390
+ live.update(
391
+ Panel(
392
+ Text(full_content),
393
+ title="[bold magenta]✦ Aizen[/bold magenta]",
394
+ border_style="magenta",
395
+ padding=(1, 2),
396
+ )
397
+ )
398
+
399
+ # ── Tool call tokens ──
400
+ if delta.tool_calls:
401
+ for tc in delta.tool_calls:
402
+ idx = tc.index
403
+ if idx not in accumulated_tool_calls:
404
+ accumulated_tool_calls[idx] = {
405
+ "id": "",
406
+ "name": "",
407
+ "arguments": "",
408
+ "type": "function",
409
+ }
410
+ if tc.id:
411
+ accumulated_tool_calls[idx]["id"] = tc.id
412
+ if tc.function:
413
+ if tc.function.name:
414
+ accumulated_tool_calls[idx][
415
+ "name"
416
+ ] += tc.function.name
417
+ if tc.function.arguments:
418
+ accumulated_tool_calls[idx][
419
+ "arguments"
420
+ ] += tc.function.arguments
421
+
422
+ # Update spinner with tool info
423
+ names = [
424
+ v["name"]
425
+ for v in accumulated_tool_calls.values()
426
+ if v["name"]
427
+ ]
428
+ if names and not full_content:
429
+ tool_text = Text()
430
+ tool_text.append(" ⚙️ ", style="magenta")
431
+ tool_text.append(
432
+ f"Preparing: {', '.join(names)}",
433
+ style="dim italic",
434
+ )
435
+ live.update(tool_text)
436
+
437
+ except AuthenticationError:
438
+ logger.error("Authentication failed — invalid API key")
439
+ console.print(
440
+ "\n[bold red]Authentication Error:[/bold red] Invalid API key."
441
+ )
442
+ console.print(
443
+ "[dim]Hint: Run with --reset-key to enter a new key.[/dim]"
444
+ )
445
+ break
446
+ except OpenAIRateLimitError:
447
+ logger.warning("Rate limited by API")
448
+ console.print(
449
+ "\n[bold red]Rate Limited:[/bold red] Too many requests."
450
+ )
451
+ console.print(
452
+ "[dim]Hint: Wait a moment and try again, or switch to a different model.[/dim]"
453
+ )
454
+ break
455
+ except APITimeoutError:
456
+ logger.warning("API request timed out")
457
+ console.print(
458
+ "\n[bold red]Timeout:[/bold red] The request timed out."
459
+ )
460
+ console.print(
461
+ "[dim]Hint: Check your internet connection and try again.[/dim]"
462
+ )
463
+ break
464
+ except OpenAIConnectionError:
465
+ logger.warning("API connection failed")
466
+ console.print(
467
+ "\n[bold red]Connection Error:[/bold red] Could not reach the API."
468
+ )
469
+ console.print(
470
+ "[dim]Hint: Check your internet connection or API base URL.[/dim]"
471
+ )
472
+ break
473
+ except Exception as e:
474
+ logger.exception("Unexpected API error: %s", e)
475
+ console.print(f"\n[bold red]API Error:[/bold red] {e}")
476
+ error_str = str(e).lower()
477
+ if "401" in error_str or "unauthorized" in error_str:
478
+ console.print(
479
+ "[dim]Hint: API key may be invalid. Run with --reset-key[/dim]"
480
+ )
481
+ elif "429" in error_str or "rate" in error_str:
482
+ console.print(
483
+ "[dim]Hint: Rate limited. Wait a moment and retry.[/dim]"
484
+ )
485
+ elif "timeout" in error_str:
486
+ console.print(
487
+ "[dim]Hint: Request timed out. Check your connection.[/dim]"
488
+ )
489
+ break
490
+ except (asyncio.CancelledError, KeyboardInterrupt):
491
+ logger.warning("Generation cancelled by user")
492
+ console.print("\n[yellow]Generation cancelled.[/yellow]")
493
+ break
494
+
495
+ # Track tokens — prefer API-reported usage, fall back to estimation
496
+ if api_usage and hasattr(api_usage, "prompt_tokens"):
497
+ token_tracker.add_api_usage(
498
+ api_usage.prompt_tokens or 0,
499
+ api_usage.completion_tokens or 0,
500
+ )
501
+ context_manager.update(
502
+ (api_usage.prompt_tokens or 0) + (api_usage.completion_tokens or 0)
503
+ )
504
+ elif full_content:
505
+ estimated_input = token_tracker.estimate_tokens(
506
+ json.dumps(messages[-1]) if messages else ""
507
+ )
508
+ estimated_output = token_tracker.estimate_tokens(full_content)
509
+ token_tracker.add_usage(estimated_input, estimated_output)
510
+
511
+ # Build tool calls list
512
+ tool_calls_list: list[dict[str, Any]] = []
513
+ for idx in sorted(accumulated_tool_calls.keys()):
514
+ tc = accumulated_tool_calls[idx]
515
+ tool_calls_list.append(
516
+ {
517
+ "id": tc["id"],
518
+ "type": "function",
519
+ "function": {
520
+ "name": tc["name"],
521
+ "arguments": tc["arguments"],
522
+ },
523
+ }
524
+ )
525
+
526
+ # Add assistant message to history
527
+ assistant_msg: dict[str, Any] = {
528
+ "role": "assistant",
529
+ "content": full_content or "",
530
+ }
531
+ if tool_calls_list:
532
+ assistant_msg["tool_calls"] = tool_calls_list
533
+ messages.append(assistant_msg)
534
+
535
+ # If no tool calls, we're done
536
+ if not tool_calls_list:
537
+ break
538
+
539
+ # Execute tool calls in parallel
540
+ async def _exec_tool(tc_dict):
541
+ func_name = tc_dict["function"]["name"]
542
+ if func_name.startswith("mcp_"):
543
+ try:
544
+ args = json.loads(tc_dict["function"]["arguments"])
545
+ result = await mcp_manager.call_tool(func_name, args)
546
+ except json.JSONDecodeError:
547
+ result = f"Error: Invalid JSON arguments for {func_name}."
548
+ else:
549
+ func_struct = Struct(**tc_dict["function"])
550
+ tc_struct = Struct(
551
+ id=tc_dict["id"],
552
+ type=tc_dict["type"],
553
+ function=func_struct,
554
+ )
555
+ result = await asyncio.to_thread(execute_tool, tc_struct, auto_approve)
556
+
557
+ return {
558
+ "role": "tool",
559
+ "tool_call_id": tc_dict["id"],
560
+ "name": func_name,
561
+ "content": truncate_output(result),
562
+ }
563
+
564
+ tool_results = await asyncio.gather(*[_exec_tool(tc) for tc in tool_calls_list])
565
+ messages.extend(tool_results)
566
+
567
+ # Continue the loop — model processes tool results
568
+
569
+ # ── Footer ──
570
+ footer = Text()
571
+
572
+ # Calculate estimated cost
573
+ cost = token_tracker.get_estimated_cost(get_active_model())
574
+
575
+ footer.append(
576
+ f" tokens: ~{token_tracker.total_tokens:,} (${cost:.3f}) │ " if cost > 0 else f" tokens: ~{token_tracker.total_tokens:,} │ "
577
+ )
578
+ footer.append(
579
+ f"messages: {token_tracker.message_count} │ "
580
+ f"model: {get_active_model()}",
581
+ style="dim",
582
+ )
583
+ # Add context usage bar
584
+ footer.append(" │ ", style="dim")
585
+
586
+ # Reconstruct string for dim printing to match existing pattern
587
+ cost_display = f" (${cost:.3f})" if cost > 0 else ""
588
+ console.print(
589
+ f"[dim] tokens: ~{token_tracker.total_tokens:,}{cost_display} │ "
590
+ f"messages: {token_tracker.message_count} │ "
591
+ f"model: {get_active_model()} │ "
592
+ f"{context_manager.get_footer_text()}[/dim]\n"
593
+ )
594
+
595
+ except (KeyboardInterrupt, EOFError):
596
+ # Auto-save on interrupt
597
+ if len(messages) > 2:
598
+ try:
599
+ save_session(messages, token_tracker=token_tracker)
600
+ console.print("\n[dim]Session auto-saved.[/dim]")
601
+ except Exception:
602
+ logger.exception("Failed to auto-save session on interrupt")
603
+ try:
604
+ await mcp_manager.stop()
605
+ except Exception:
606
+ logger.exception("Failed to stop MCP manager on interrupt")
607
+ console.print("[yellow]Goodbye! 👋[/yellow]")
608
+ break
609
+ except Exception as e:
610
+ logger.exception("Unhandled error in main loop: %s", e)
611
+ console.print(f"\n[bold red]Error:[/bold red] {e}")
612
+ def main():
613
+ asyncio.run(main_loop())
614
+
615
+ if __name__ == "__main__":
616
+ main()