aru-code 0.19.2__tar.gz → 0.20.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aru_code-0.19.2/aru_code.egg-info → aru_code-0.20.0}/PKG-INFO +1 -1
- aru_code-0.20.0/aru/__init__.py +1 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/agents/base.py +6 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/cli.py +32 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/config.py +10 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/context.py +42 -13
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/permissions.py +49 -1
- aru_code-0.20.0/aru/plugins/__init__.py +12 -0
- aru_code-0.20.0/aru/plugins/custom_tools.py +350 -0
- aru_code-0.20.0/aru/plugins/hooks.py +134 -0
- aru_code-0.20.0/aru/plugins/manager.py +330 -0
- aru_code-0.20.0/aru/plugins/tool_api.py +54 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/runtime.py +3 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/tools/codebase.py +35 -7
- {aru_code-0.19.2 → aru_code-0.20.0/aru_code.egg-info}/PKG-INFO +1 -1
- {aru_code-0.19.2 → aru_code-0.20.0}/aru_code.egg-info/SOURCES.txt +6 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/pyproject.toml +1 -1
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_cli_advanced.py +5 -3
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_context.py +20 -17
- aru_code-0.20.0/tests/test_plugins.py +447 -0
- aru_code-0.19.2/aru/__init__.py +0 -1
- {aru_code-0.19.2 → aru_code-0.20.0}/LICENSE +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/README.md +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/agent_factory.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/agents/__init__.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/agents/executor.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/agents/planner.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/cache_patch.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/commands.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/completers.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/display.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/history_blocks.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/providers.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/runner.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/session.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/tools/__init__.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/tools/ast_tools.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/tools/gitignore.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/tools/mcp_client.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/tools/ranker.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru/tools/tasklist.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru_code.egg-info/dependency_links.txt +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru_code.egg-info/entry_points.txt +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru_code.egg-info/requires.txt +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/aru_code.egg-info/top_level.txt +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/setup.cfg +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_agents_base.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_cli.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_cli_base.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_cli_completers.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_cli_new.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_cli_run_cli.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_cli_session.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_cli_shell.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_codebase.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_confabulation_regression.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_config.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_executor.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_gitignore.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_main.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_mcp_client.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_permissions.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_planner.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_providers.py +0 -0
- {aru_code-0.19.2 → aru_code-0.20.0}/tests/test_ranker.py +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.20.0"
|
|
@@ -21,6 +21,12 @@ Examples of ideal responses:
|
|
|
21
21
|
- user: "what command lists files?" → assistant: "ls"
|
|
22
22
|
- user: "fix the typo in line 5" → [call edit_file immediately, no narration]
|
|
23
23
|
|
|
24
|
+
## Permission denials — CRITICAL
|
|
25
|
+
|
|
26
|
+
When a tool returns "PERMISSION DENIED", the user intentionally refused the action. \
|
|
27
|
+
NEVER retry the same operation. Do NOT try alternative approaches to achieve the same edit. \
|
|
28
|
+
Instead, stop immediately and ask the user what they would like you to do instead.
|
|
29
|
+
|
|
24
30
|
## Scope rules
|
|
25
31
|
|
|
26
32
|
NEVER create documentation files (*.md) unless the user explicitly asks for them.
|
|
@@ -208,6 +208,38 @@ async def run_cli(skip_permissions: bool = False, resume_id: str | None = None):
|
|
|
208
208
|
paste_state = PasteState()
|
|
209
209
|
prompt_session = _create_prompt_session(paste_state, config)
|
|
210
210
|
|
|
211
|
+
# Load custom tools (synchronous — fast, no network)
|
|
212
|
+
from aru.plugins.custom_tools import discover_custom_tools, register_custom_tools
|
|
213
|
+
_disabled_tools = config.disabled_tools if hasattr(config, "disabled_tools") else []
|
|
214
|
+
_custom_tool_descs = discover_custom_tools(disabled=_disabled_tools)
|
|
215
|
+
if _custom_tool_descs:
|
|
216
|
+
_ct_count = register_custom_tools(_custom_tool_descs)
|
|
217
|
+
console.print(f"[dim]Loaded {_ct_count} custom tool(s): {', '.join(d['name'] for d in _custom_tool_descs)}[/dim]")
|
|
218
|
+
|
|
219
|
+
# Load plugins (local imports only, no network)
|
|
220
|
+
from aru.plugins.manager import PluginManager
|
|
221
|
+
from aru.plugins.hooks import PluginInput
|
|
222
|
+
_plugin_mgr = PluginManager()
|
|
223
|
+
ctx.plugin_manager = _plugin_mgr
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
_p_input = PluginInput(
|
|
227
|
+
directory=os.getcwd(),
|
|
228
|
+
config_path="aru.json" if os.path.isfile("aru.json") else "",
|
|
229
|
+
model_ref=session.model_ref,
|
|
230
|
+
)
|
|
231
|
+
_plugin_specs = config.plugin_specs if hasattr(config, "plugin_specs") else []
|
|
232
|
+
_plugin_count = await _plugin_mgr.load_all(_p_input, plugin_specs=_plugin_specs)
|
|
233
|
+
if _plugin_count:
|
|
234
|
+
plugin_tools = _plugin_mgr.get_plugin_tools()
|
|
235
|
+
if plugin_tools:
|
|
236
|
+
_pt_count = register_custom_tools(plugin_tools)
|
|
237
|
+
console.print(f"[dim]Loaded {_plugin_count} plugin(s): {', '.join(_plugin_mgr.plugin_names)} ({_pt_count} tool(s))[/dim]")
|
|
238
|
+
else:
|
|
239
|
+
console.print(f"[dim]Loaded {_plugin_count} plugin(s): {', '.join(_plugin_mgr.plugin_names)}[/dim]")
|
|
240
|
+
except Exception as exc:
|
|
241
|
+
console.print(f"[dim yellow]Warning: plugin loading failed: {exc}[/dim yellow]")
|
|
242
|
+
|
|
211
243
|
# Startup: load MCP tools in background (don't block REPL)
|
|
212
244
|
async def _load_mcp_background():
|
|
213
245
|
from aru.tools.codebase import load_mcp_tools
|
|
@@ -161,6 +161,8 @@ class AgentConfig:
|
|
|
161
161
|
custom_agents: dict[str, CustomAgent] = field(default_factory=dict)
|
|
162
162
|
plan_reviewer: bool = True
|
|
163
163
|
tree_depth: int = 2 # max depth for directory tree in system prompt
|
|
164
|
+
disabled_tools: list[str] = field(default_factory=list) # tools to skip loading
|
|
165
|
+
plugin_specs: list = field(default_factory=list) # plugin specs from aru.json
|
|
164
166
|
|
|
165
167
|
@property
|
|
166
168
|
def has_instructions(self) -> bool:
|
|
@@ -515,6 +517,14 @@ def load_config(cwd: str | None = None) -> AgentConfig:
|
|
|
515
517
|
td = data["tree_depth"]
|
|
516
518
|
if isinstance(td, int) and 0 <= td <= 5:
|
|
517
519
|
config.tree_depth = td
|
|
520
|
+
# Plugin specs
|
|
521
|
+
if "plugins" in data and isinstance(data["plugins"], list):
|
|
522
|
+
config.plugin_specs = data["plugins"]
|
|
523
|
+
# Custom tools config
|
|
524
|
+
if "tools" in data and isinstance(data["tools"], dict):
|
|
525
|
+
tools_cfg = data["tools"]
|
|
526
|
+
if "disabled" in tools_cfg and isinstance(tools_cfg["disabled"], list):
|
|
527
|
+
config.disabled_tools = [str(t) for t in tools_cfg["disabled"]]
|
|
518
528
|
# Resolve instructions (local files, globs, URLs)
|
|
519
529
|
if "instructions" in data and isinstance(data["instructions"], list):
|
|
520
530
|
entries = [str(e) for e in data["instructions"] if isinstance(e, str)]
|
|
@@ -420,34 +420,52 @@ def truncate_output(
|
|
|
420
420
|
# Save full output to disk before truncating (like OpenCode)
|
|
421
421
|
saved_path = _save_truncated_output(text)
|
|
422
422
|
|
|
423
|
-
# Truncate by lines
|
|
423
|
+
# Truncate by lines — keep head + tail so summaries at the end are visible
|
|
424
424
|
if line_count > TRUNCATE_MAX_LINES:
|
|
425
425
|
head = lines[:TRUNCATE_KEEP_START]
|
|
426
|
-
|
|
426
|
+
tail = lines[-TRUNCATE_KEEP_END:]
|
|
427
|
+
omitted = line_count - TRUNCATE_KEEP_START - TRUNCATE_KEEP_END
|
|
427
428
|
hint = _build_truncation_hint(source_file, source_tool, TRUNCATE_KEEP_START, saved_path)
|
|
428
429
|
return (
|
|
429
430
|
"".join(head)
|
|
430
431
|
+ f"\n\n[... {omitted:,} lines omitted ({line_count:,} total)]\n"
|
|
431
|
-
+ hint + "\n"
|
|
432
|
+
+ hint + "\n\n"
|
|
433
|
+
+ "".join(tail)
|
|
432
434
|
)
|
|
433
435
|
|
|
434
436
|
# Truncate by bytes (lines fit but total bytes too large)
|
|
435
|
-
|
|
436
|
-
|
|
437
|
+
# Reserve ~20% of budget for tail lines
|
|
438
|
+
head_budget = int(TRUNCATE_MAX_BYTES * 0.75)
|
|
439
|
+
tail_budget = TRUNCATE_MAX_BYTES - head_budget
|
|
440
|
+
|
|
441
|
+
head_lines: list[str] = []
|
|
442
|
+
head_bytes = 0
|
|
437
443
|
for line in lines:
|
|
438
444
|
line_bytes = len(line.encode("utf-8", errors="replace"))
|
|
439
|
-
if
|
|
445
|
+
if head_bytes + line_bytes > head_budget:
|
|
440
446
|
break
|
|
441
|
-
|
|
442
|
-
|
|
447
|
+
head_lines.append(line)
|
|
448
|
+
head_bytes += line_bytes
|
|
443
449
|
|
|
444
|
-
|
|
445
|
-
|
|
450
|
+
# Collect tail lines within tail budget
|
|
451
|
+
tail_lines: list[str] = []
|
|
452
|
+
tail_bytes = 0
|
|
453
|
+
for line in reversed(lines[len(head_lines):]):
|
|
454
|
+
line_bytes = len(line.encode("utf-8", errors="replace"))
|
|
455
|
+
if tail_bytes + line_bytes > tail_budget:
|
|
456
|
+
break
|
|
457
|
+
tail_lines.append(line)
|
|
458
|
+
tail_bytes += line_bytes
|
|
459
|
+
tail_lines.reverse()
|
|
460
|
+
|
|
461
|
+
omitted = line_count - len(head_lines) - len(tail_lines)
|
|
462
|
+
hint = _build_truncation_hint(source_file, source_tool, len(head_lines), saved_path)
|
|
446
463
|
return (
|
|
447
|
-
"".join(
|
|
464
|
+
"".join(head_lines)
|
|
448
465
|
+ f"\n\n[... truncated at ~{TRUNCATE_MAX_BYTES // 1024}KB — "
|
|
449
|
-
f"{
|
|
450
|
-
+ hint + "\n"
|
|
466
|
+
f"{omitted:,} lines omitted]\n"
|
|
467
|
+
+ hint + "\n\n"
|
|
468
|
+
+ "".join(tail_lines)
|
|
451
469
|
)
|
|
452
470
|
|
|
453
471
|
|
|
@@ -660,6 +678,17 @@ async def compact_conversation(
|
|
|
660
678
|
"""
|
|
661
679
|
from aru.providers import create_model
|
|
662
680
|
|
|
681
|
+
# Fire session.compact hook — plugins can pre-process history.
|
|
682
|
+
# Import is lazy here to avoid circular dependency (context ← runtime).
|
|
683
|
+
try:
|
|
684
|
+
from aru.runtime import get_ctx # noqa: lazy to avoid circular dep
|
|
685
|
+
mgr = get_ctx().plugin_manager
|
|
686
|
+
if mgr is not None and mgr.loaded:
|
|
687
|
+
event = await mgr.fire("session.compact", {"history": history})
|
|
688
|
+
history = event.data.get("history", history)
|
|
689
|
+
except (LookupError, AttributeError, ImportError):
|
|
690
|
+
pass # no plugin manager available — proceed without hooks
|
|
691
|
+
|
|
663
692
|
prompt = build_compaction_prompt(history, plan_task, model_id=model_id)
|
|
664
693
|
|
|
665
694
|
try:
|
|
@@ -413,6 +413,42 @@ def resolve_permission(
|
|
|
413
413
|
# Permission gate (user-facing prompt)
|
|
414
414
|
# ---------------------------------------------------------------------------
|
|
415
415
|
|
|
416
|
+
def _fire_permission_hook(mgr, category: str, subject: str) -> bool | None:
|
|
417
|
+
"""Fire permission.ask hook through all plugin handlers.
|
|
418
|
+
|
|
419
|
+
Supports both sync and async handlers. Returns True/False if a handler
|
|
420
|
+
sets event.data["allow"], or None if no handler overrode the decision.
|
|
421
|
+
"""
|
|
422
|
+
import asyncio
|
|
423
|
+
from aru.plugins.hooks import HookEvent
|
|
424
|
+
|
|
425
|
+
evt = HookEvent(hook="permission.ask", data={"category": category, "subject": subject})
|
|
426
|
+
|
|
427
|
+
for hooks_obj in mgr._hooks:
|
|
428
|
+
for handler in hooks_obj.get_handlers("permission.ask"):
|
|
429
|
+
try:
|
|
430
|
+
if asyncio.iscoroutinefunction(handler):
|
|
431
|
+
# Async handler — run via the event loop
|
|
432
|
+
loop = asyncio.get_event_loop()
|
|
433
|
+
if loop.is_running():
|
|
434
|
+
# Schedule as a task and wait with run_until_complete
|
|
435
|
+
# won't work, so use a new loop in a thread
|
|
436
|
+
import concurrent.futures
|
|
437
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
438
|
+
pool.submit(asyncio.run, handler(evt)).result(timeout=5)
|
|
439
|
+
else:
|
|
440
|
+
loop.run_until_complete(handler(evt))
|
|
441
|
+
else:
|
|
442
|
+
handler(evt)
|
|
443
|
+
except Exception:
|
|
444
|
+
continue # skip broken handlers
|
|
445
|
+
|
|
446
|
+
if "allow" in evt.data:
|
|
447
|
+
return bool(evt.data["allow"])
|
|
448
|
+
|
|
449
|
+
return None # no handler overrode
|
|
450
|
+
|
|
451
|
+
|
|
416
452
|
def check_permission(
|
|
417
453
|
category: str,
|
|
418
454
|
subject: str,
|
|
@@ -429,8 +465,20 @@ def check_permission(
|
|
|
429
465
|
if action == "deny":
|
|
430
466
|
return False
|
|
431
467
|
|
|
432
|
-
#
|
|
468
|
+
# Fire permission.ask hook — plugins can override the decision.
|
|
469
|
+
# check_permission runs in a sync context (called from tool threads),
|
|
470
|
+
# so we fire sync handlers directly and async handlers via the event loop.
|
|
433
471
|
ctx = get_ctx()
|
|
472
|
+
mgr = getattr(ctx, "plugin_manager", None)
|
|
473
|
+
if mgr is not None and getattr(mgr, "loaded", False):
|
|
474
|
+
try:
|
|
475
|
+
override = _fire_permission_hook(mgr, category, subject)
|
|
476
|
+
if override is not None:
|
|
477
|
+
return override
|
|
478
|
+
except Exception:
|
|
479
|
+
pass # never let plugin errors block permissions
|
|
480
|
+
|
|
481
|
+
# action == "ask" -> prompt user
|
|
434
482
|
with ctx.permission_lock:
|
|
435
483
|
# Re-check after acquiring lock (another thread may have resolved it)
|
|
436
484
|
action2, pattern2 = resolve_permission(category, subject)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Aru plugin system — custom tools, hooks, and OpenCode TS bridge.
|
|
2
|
+
|
|
3
|
+
Public API for plugin authors:
|
|
4
|
+
|
|
5
|
+
from aru.plugins import tool # @tool decorator for custom tools
|
|
6
|
+
from aru.plugins import PluginInput, Hooks # Full plugin API (Phase 2)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from aru.plugins.tool_api import tool
|
|
10
|
+
from aru.plugins.hooks import Hooks, HookEvent, PluginInput
|
|
11
|
+
|
|
12
|
+
__all__ = ["tool", "Hooks", "HookEvent", "PluginInput"]
|
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
"""Discover and load custom tool files (.py) from tool directories.
|
|
2
|
+
|
|
3
|
+
Discovery paths (later overrides earlier):
|
|
4
|
+
1. ~/.agents/tools/*.py, ~/.aru/tools/*.py (global)
|
|
5
|
+
2. .agents/tools/*.py, .aru/tools/*.py (project-local)
|
|
6
|
+
|
|
7
|
+
Tool files can define tools in two ways:
|
|
8
|
+
- @tool-decorated functions (from aru.plugins import tool)
|
|
9
|
+
- Bare functions with -> str return annotation
|
|
10
|
+
|
|
11
|
+
Naming convention (mirrors OpenCode):
|
|
12
|
+
- File deploy.py with def deploy -> tool name: "deploy"
|
|
13
|
+
- File ci.py with def build + def test -> tool names: "ci_build", "ci_test"
|
|
14
|
+
- Single function in file -> filename only (no prefix)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import asyncio
|
|
20
|
+
import importlib.util
|
|
21
|
+
import inspect
|
|
22
|
+
import logging
|
|
23
|
+
import sys
|
|
24
|
+
import types
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Any, Callable
|
|
27
|
+
|
|
28
|
+
from aru.plugins.tool_api import get_tool_meta
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger("aru.plugins")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _extract_tools_from_module(mod: types.ModuleType) -> list[tuple[str, Callable, dict[str, Any] | None]]:
|
|
34
|
+
"""Extract tool functions from a loaded module.
|
|
35
|
+
|
|
36
|
+
Returns list of (export_name, function, tool_meta_or_None).
|
|
37
|
+
"""
|
|
38
|
+
tools: list[tuple[str, Callable, dict[str, Any] | None]] = []
|
|
39
|
+
|
|
40
|
+
for name in dir(mod):
|
|
41
|
+
if name.startswith("_"):
|
|
42
|
+
continue
|
|
43
|
+
obj = getattr(mod, name)
|
|
44
|
+
if not callable(obj) or not inspect.isfunction(obj):
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
meta = get_tool_meta(obj)
|
|
48
|
+
if meta is not None:
|
|
49
|
+
# @tool-decorated function
|
|
50
|
+
tools.append((name, obj, meta))
|
|
51
|
+
else:
|
|
52
|
+
# Bare function — must have -> str annotation
|
|
53
|
+
hints = getattr(obj, "__annotations__", {})
|
|
54
|
+
ret = hints.get("return")
|
|
55
|
+
if ret is str or (isinstance(ret, str) and ret == "str"):
|
|
56
|
+
tools.append((name, obj, None))
|
|
57
|
+
|
|
58
|
+
return tools
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _build_parameters_from_function(fn: Callable) -> dict[str, Any]:
|
|
62
|
+
"""Build a JSON Schema parameters dict from a function's signature + docstring."""
|
|
63
|
+
sig = inspect.signature(fn)
|
|
64
|
+
properties: dict[str, Any] = {}
|
|
65
|
+
required: list[str] = []
|
|
66
|
+
|
|
67
|
+
# Parse argument descriptions from docstring Args section
|
|
68
|
+
arg_descriptions = _parse_arg_descriptions(fn.__doc__ or "")
|
|
69
|
+
|
|
70
|
+
type_map = {
|
|
71
|
+
str: "string",
|
|
72
|
+
int: "integer",
|
|
73
|
+
float: "number",
|
|
74
|
+
bool: "boolean",
|
|
75
|
+
list: "array",
|
|
76
|
+
dict: "object",
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
for pname, param in sig.parameters.items():
|
|
80
|
+
if pname in ("self", "cls"):
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
annotation = param.annotation
|
|
84
|
+
json_type = "string" # default
|
|
85
|
+
if annotation != inspect.Parameter.empty:
|
|
86
|
+
if annotation in type_map:
|
|
87
|
+
json_type = type_map[annotation]
|
|
88
|
+
elif isinstance(annotation, str) and annotation in ("str", "int", "float", "bool", "list", "dict"):
|
|
89
|
+
json_type = type_map.get({"str": str, "int": int, "float": float, "bool": bool, "list": list, "dict": dict}.get(annotation, str), "string")
|
|
90
|
+
|
|
91
|
+
prop: dict[str, Any] = {"type": json_type}
|
|
92
|
+
desc = arg_descriptions.get(pname)
|
|
93
|
+
if desc:
|
|
94
|
+
prop["description"] = desc
|
|
95
|
+
|
|
96
|
+
if param.default != inspect.Parameter.empty:
|
|
97
|
+
prop["default"] = param.default
|
|
98
|
+
else:
|
|
99
|
+
required.append(pname)
|
|
100
|
+
|
|
101
|
+
properties[pname] = prop
|
|
102
|
+
|
|
103
|
+
schema: dict[str, Any] = {
|
|
104
|
+
"type": "object",
|
|
105
|
+
"properties": properties,
|
|
106
|
+
}
|
|
107
|
+
if required:
|
|
108
|
+
schema["required"] = required
|
|
109
|
+
return schema
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _parse_arg_descriptions(docstring: str) -> dict[str, str]:
|
|
113
|
+
"""Parse 'Args:' section from a Google-style docstring."""
|
|
114
|
+
descriptions: dict[str, str] = {}
|
|
115
|
+
in_args = False
|
|
116
|
+
current_arg = ""
|
|
117
|
+
current_desc_parts: list[str] = []
|
|
118
|
+
|
|
119
|
+
for line in docstring.splitlines():
|
|
120
|
+
stripped = line.strip()
|
|
121
|
+
|
|
122
|
+
if stripped.lower().startswith("args:"):
|
|
123
|
+
in_args = True
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
if in_args:
|
|
127
|
+
# End of Args section: another section header or blank after content
|
|
128
|
+
if stripped and not stripped.startswith("-") and ":" not in stripped and current_arg:
|
|
129
|
+
# Continuation line
|
|
130
|
+
current_desc_parts.append(stripped)
|
|
131
|
+
continue
|
|
132
|
+
if stripped == "" and current_arg:
|
|
133
|
+
descriptions[current_arg] = " ".join(current_desc_parts)
|
|
134
|
+
current_arg = ""
|
|
135
|
+
current_desc_parts = []
|
|
136
|
+
continue
|
|
137
|
+
if stripped == "":
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
# Check for section headers (Returns:, Raises:, etc.)
|
|
141
|
+
if stripped.endswith(":") and not stripped.startswith("-") and " " not in stripped.rstrip(":"):
|
|
142
|
+
if current_arg:
|
|
143
|
+
descriptions[current_arg] = " ".join(current_desc_parts)
|
|
144
|
+
break
|
|
145
|
+
|
|
146
|
+
# Parse "param_name: description" or "param_name (type): description"
|
|
147
|
+
if ":" in stripped:
|
|
148
|
+
if current_arg:
|
|
149
|
+
descriptions[current_arg] = " ".join(current_desc_parts)
|
|
150
|
+
|
|
151
|
+
parts = stripped.split(":", 1)
|
|
152
|
+
arg_part = parts[0].strip().lstrip("-").strip()
|
|
153
|
+
# Remove type annotations like (str) or (int)
|
|
154
|
+
if "(" in arg_part:
|
|
155
|
+
arg_part = arg_part[:arg_part.index("(")].strip()
|
|
156
|
+
current_arg = arg_part
|
|
157
|
+
current_desc_parts = [parts[1].strip()] if len(parts) > 1 and parts[1].strip() else []
|
|
158
|
+
|
|
159
|
+
if current_arg:
|
|
160
|
+
descriptions[current_arg] = " ".join(current_desc_parts)
|
|
161
|
+
|
|
162
|
+
return descriptions
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _load_module_from_path(filepath: Path) -> types.ModuleType | None:
|
|
166
|
+
"""Dynamically import a Python file as a module."""
|
|
167
|
+
module_name = f"aru_custom_tool_{filepath.stem}"
|
|
168
|
+
try:
|
|
169
|
+
spec = importlib.util.spec_from_file_location(module_name, str(filepath))
|
|
170
|
+
if spec is None or spec.loader is None:
|
|
171
|
+
logger.warning("Cannot load tool file (no spec): %s", filepath)
|
|
172
|
+
return None
|
|
173
|
+
mod = importlib.util.module_from_spec(spec)
|
|
174
|
+
# Add the file's directory to sys.path temporarily so relative imports work
|
|
175
|
+
parent = str(filepath.parent)
|
|
176
|
+
added = parent not in sys.path
|
|
177
|
+
if added:
|
|
178
|
+
sys.path.insert(0, parent)
|
|
179
|
+
try:
|
|
180
|
+
spec.loader.exec_module(mod)
|
|
181
|
+
finally:
|
|
182
|
+
if added and parent in sys.path:
|
|
183
|
+
sys.path.remove(parent)
|
|
184
|
+
return mod
|
|
185
|
+
except Exception as e:
|
|
186
|
+
logger.warning("Failed to load custom tool %s: %s", filepath, e)
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def discover_custom_tools(
|
|
191
|
+
search_roots: list[Path] | None = None,
|
|
192
|
+
disabled: list[str] | None = None,
|
|
193
|
+
) -> list[dict[str, Any]]:
|
|
194
|
+
"""Discover custom tool files and return tool descriptors.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
search_roots: Directories to scan for tools/ subdirectories.
|
|
198
|
+
If None, uses default paths (global + project-local).
|
|
199
|
+
disabled: List of tool names to skip.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
List of dicts with keys: name, description, parameters, entrypoint, source.
|
|
203
|
+
Later entries override earlier ones (project-local wins).
|
|
204
|
+
"""
|
|
205
|
+
from agno.tools import Function
|
|
206
|
+
|
|
207
|
+
if search_roots is None:
|
|
208
|
+
search_roots = _default_search_roots()
|
|
209
|
+
|
|
210
|
+
disabled_set = set(disabled or [])
|
|
211
|
+
tools_by_name: dict[str, dict[str, Any]] = {}
|
|
212
|
+
|
|
213
|
+
for root in search_roots:
|
|
214
|
+
tools_dir = root / "tools"
|
|
215
|
+
if not tools_dir.is_dir():
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
for filepath in sorted(tools_dir.glob("*.py")):
|
|
219
|
+
if filepath.name.startswith("_"):
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
mod = _load_module_from_path(filepath)
|
|
223
|
+
if mod is None:
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
file_stem = filepath.stem
|
|
227
|
+
extracted = _extract_tools_from_module(mod)
|
|
228
|
+
|
|
229
|
+
# Naming: single function -> filename; multiple -> filename_exportname
|
|
230
|
+
use_prefix = len(extracted) > 1
|
|
231
|
+
|
|
232
|
+
for export_name, fn, meta in extracted:
|
|
233
|
+
if use_prefix and export_name != file_stem and export_name != "default":
|
|
234
|
+
tool_name = f"{file_stem}_{export_name}"
|
|
235
|
+
else:
|
|
236
|
+
tool_name = file_stem
|
|
237
|
+
|
|
238
|
+
if tool_name in disabled_set:
|
|
239
|
+
logger.debug("Skipping disabled custom tool: %s", tool_name)
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
description = ""
|
|
243
|
+
if meta and meta.get("description"):
|
|
244
|
+
description = meta["description"]
|
|
245
|
+
elif fn.__doc__:
|
|
246
|
+
# First line of docstring as description
|
|
247
|
+
description = fn.__doc__.strip().split("\n")[0]
|
|
248
|
+
else:
|
|
249
|
+
description = f"Custom tool: {tool_name}"
|
|
250
|
+
|
|
251
|
+
parameters = _build_parameters_from_function(fn)
|
|
252
|
+
|
|
253
|
+
# Wrap async/sync functions uniformly
|
|
254
|
+
original_fn = fn
|
|
255
|
+
if asyncio.iscoroutinefunction(fn):
|
|
256
|
+
async def entrypoint(*, _fn=original_fn, **kwargs) -> str:
|
|
257
|
+
result = await _fn(**kwargs)
|
|
258
|
+
return str(result) if result is not None else ""
|
|
259
|
+
else:
|
|
260
|
+
async def entrypoint(*, _fn=original_fn, **kwargs) -> str:
|
|
261
|
+
result = _fn(**kwargs)
|
|
262
|
+
return str(result) if result is not None else ""
|
|
263
|
+
|
|
264
|
+
entrypoint.__name__ = tool_name
|
|
265
|
+
entrypoint.__doc__ = fn.__doc__
|
|
266
|
+
|
|
267
|
+
tools_by_name[tool_name] = {
|
|
268
|
+
"name": tool_name,
|
|
269
|
+
"description": description,
|
|
270
|
+
"parameters": parameters,
|
|
271
|
+
"entrypoint": entrypoint,
|
|
272
|
+
"source": str(filepath),
|
|
273
|
+
"override": bool(meta and meta.get("override")),
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
return list(tools_by_name.values())
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def register_custom_tools(tool_descriptors: list[dict[str, Any]]) -> int:
|
|
280
|
+
"""Inject custom tools into the global tool registry.
|
|
281
|
+
|
|
282
|
+
Custom tools with the same name as built-in tools will override them.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Number of tools registered.
|
|
286
|
+
"""
|
|
287
|
+
from agno.tools import Function
|
|
288
|
+
|
|
289
|
+
from aru.tools.codebase import (
|
|
290
|
+
ALL_TOOLS,
|
|
291
|
+
EXECUTOR_TOOLS,
|
|
292
|
+
GENERAL_TOOLS,
|
|
293
|
+
TOOL_REGISTRY,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
count = 0
|
|
297
|
+
for desc in tool_descriptors:
|
|
298
|
+
name = desc["name"]
|
|
299
|
+
agno_fn = Function(
|
|
300
|
+
name=name,
|
|
301
|
+
description=desc["description"],
|
|
302
|
+
parameters=desc["parameters"],
|
|
303
|
+
entrypoint=desc["entrypoint"],
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
# Override existing tool if same name exists
|
|
307
|
+
existing = TOOL_REGISTRY.get(name)
|
|
308
|
+
if existing is not None:
|
|
309
|
+
# Determine what's being overridden for clear logging
|
|
310
|
+
existing_source = getattr(existing, "_aru_source", "built-in")
|
|
311
|
+
for tool_list in (ALL_TOOLS, GENERAL_TOOLS, EXECUTOR_TOOLS):
|
|
312
|
+
for i, t in enumerate(tool_list):
|
|
313
|
+
t_name = getattr(t, "__name__", None) or getattr(t, "name", None)
|
|
314
|
+
if t_name == name:
|
|
315
|
+
tool_list[i] = agno_fn
|
|
316
|
+
break
|
|
317
|
+
logger.warning("Tool '%s' from %s overrides %s", name, desc["source"], existing_source)
|
|
318
|
+
else:
|
|
319
|
+
ALL_TOOLS.append(agno_fn)
|
|
320
|
+
GENERAL_TOOLS.append(agno_fn)
|
|
321
|
+
EXECUTOR_TOOLS.append(agno_fn)
|
|
322
|
+
|
|
323
|
+
agno_fn._aru_source = desc["source"] # tag for collision logging
|
|
324
|
+
TOOL_REGISTRY[name] = agno_fn
|
|
325
|
+
count += 1
|
|
326
|
+
logger.debug("Registered custom tool: %s (from %s)", name, desc["source"])
|
|
327
|
+
|
|
328
|
+
return count
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _default_search_roots() -> list[Path]:
|
|
332
|
+
"""Return default tool search roots: global dirs first, then project-local."""
|
|
333
|
+
import os
|
|
334
|
+
roots: list[Path] = []
|
|
335
|
+
home = Path.home()
|
|
336
|
+
|
|
337
|
+
# Global roots
|
|
338
|
+
for dirname in (".agents", ".aru"):
|
|
339
|
+
global_dir = home / dirname
|
|
340
|
+
if global_dir.is_dir():
|
|
341
|
+
roots.append(global_dir)
|
|
342
|
+
|
|
343
|
+
# Project-local roots
|
|
344
|
+
cwd = Path(os.getcwd())
|
|
345
|
+
for dirname in (".agents", ".aru"):
|
|
346
|
+
local_dir = cwd / dirname
|
|
347
|
+
if local_dir.is_dir():
|
|
348
|
+
roots.append(local_dir)
|
|
349
|
+
|
|
350
|
+
return roots
|