pdo-agent 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pdo/tools/git.py ADDED
@@ -0,0 +1,44 @@
1
+ """Git tool: run git commands in the working directory."""
2
+ from __future__ import annotations
3
+
4
+ import shlex
5
+ import subprocess
6
+ from typing import Any
7
+
8
+ from .base import Tool, truncate
9
+ from .registry import register_tool
10
+
11
+
12
+ @register_tool
13
+ class GitTool(Tool):
14
+ name = "git"
15
+ description = (
16
+ "Run a git command in the working directory and return its output "
17
+ "(e.g. 'status --short', 'diff', 'log --oneline -10', "
18
+ "'commit -m \"message\"'). Network commands like push use your "
19
+ "existing credentials."
20
+ )
21
+ parameters = {
22
+ "type": "object",
23
+ "properties": {
24
+ "args": {
25
+ "type": "string",
26
+ "description": "Arguments passed to git, e.g. 'status --short'.",
27
+ }
28
+ },
29
+ "required": ["args"],
30
+ }
31
+
32
+ def run(self, args: str, **_: Any) -> str:
33
+ try:
34
+ argv = ["git", *shlex.split(args)]
35
+ except ValueError as exc:
36
+ return f"Error parsing git arguments: {exc}"
37
+ try:
38
+ completed = subprocess.run(argv, capture_output=True, text=True, timeout=60)
39
+ except FileNotFoundError:
40
+ return "Error: git is not installed or not on PATH."
41
+ except subprocess.TimeoutExpired:
42
+ return "Error: git command timed out."
43
+ output = ((completed.stdout or "") + (completed.stderr or "")).strip()
44
+ return f"[exit {completed.returncode}]\n{truncate(output or '(no output)')}"
pdo/tools/memory.py ADDED
@@ -0,0 +1,70 @@
1
+ """Memory tools: save, search and delete facts in the local JSON store.
2
+
3
+ These thin wrappers expose :class:`pdo.agent.memory.MemoryStore` to the model so
4
+ it can remember useful facts and preferences across turns and sessions.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from typing import Any
9
+
10
+ from ..agent.memory import get_memory_store
11
+ from .base import Tool
12
+ from .registry import register_tool
13
+
14
+
15
+ @register_tool
16
+ class MemorySaveTool(Tool):
17
+ name = "memory_save"
18
+ description = "Save a useful fact or user preference to long-term memory."
19
+ parameters = {
20
+ "type": "object",
21
+ "properties": {
22
+ "text": {"type": "string", "description": "The fact to remember."},
23
+ "tags": {
24
+ "type": "array",
25
+ "items": {"type": "string"},
26
+ "description": "Optional tags to aid later search.",
27
+ },
28
+ },
29
+ "required": ["text"],
30
+ }
31
+
32
+ def run(self, text: str, tags: list[str] | None = None, **_: Any) -> str:
33
+ fact_id = get_memory_store().save_fact(text, tags or [])
34
+ return f"Saved memory {fact_id}."
35
+
36
+
37
+ @register_tool
38
+ class MemorySearchTool(Tool):
39
+ name = "memory_search"
40
+ description = "Search long-term memory for facts matching a keyword query."
41
+ parameters = {
42
+ "type": "object",
43
+ "properties": {
44
+ "query": {"type": "string", "description": "Keyword(s) to search for."}
45
+ },
46
+ "required": ["query"],
47
+ }
48
+
49
+ def run(self, query: str, **_: Any) -> str:
50
+ hits = get_memory_store().search_facts(query)
51
+ if not hits:
52
+ return "No memories matched that query."
53
+ return "\n".join(f"[{hit['id']}] {hit['text']}" for hit in hits)
54
+
55
+
56
+ @register_tool
57
+ class MemoryDeleteTool(Tool):
58
+ name = "memory_delete"
59
+ description = "Delete a memory by its id."
60
+ parameters = {
61
+ "type": "object",
62
+ "properties": {
63
+ "id": {"type": "string", "description": "The id of the memory to delete."}
64
+ },
65
+ "required": ["id"],
66
+ }
67
+
68
+ def run(self, id: str, **_: Any) -> str: # noqa: A002 — matches the schema field name
69
+ deleted = get_memory_store().delete_fact(id)
70
+ return "Deleted." if deleted else "No memory with that id."
pdo/tools/rag.py ADDED
@@ -0,0 +1,60 @@
1
+ """Codebase retrieval tool backed by the BM25 index in :mod:`pdo.rag`."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from ..rag import build_index, load_index, search
8
+ from .base import Tool, truncate
9
+ from .registry import register_tool
10
+
11
+ # Snippet size cap per result so a handful of hits can't flood the context.
12
+ _SNIPPET_CHARS = 1200
13
+
14
+
15
+ @register_tool
16
+ class CodebaseSearchTool(Tool):
17
+ name = "codebase_search"
18
+ description = (
19
+ "Semantic-ish search over the indexed codebase in the current directory: "
20
+ "returns the most relevant code/document chunks for a natural-language "
21
+ "query, with path:line references. Builds the index on first use; the "
22
+ "user can refresh it with /index."
23
+ )
24
+ parameters = {
25
+ "type": "object",
26
+ "properties": {
27
+ "query": {
28
+ "type": "string",
29
+ "description": "What to look for (identifiers, phrases, concepts).",
30
+ },
31
+ "top_k": {
32
+ "type": "integer",
33
+ "description": "How many chunks to return (default 5).",
34
+ },
35
+ },
36
+ "required": ["query"],
37
+ }
38
+
39
+ def run(self, query: str, top_k: int = 5, **_: Any) -> str:
40
+ root = Path.cwd()
41
+ index = load_index(root)
42
+ if index is None or index.root != str(root.resolve()):
43
+ index = build_index(root)
44
+ if not index.chunks:
45
+ return "The index is empty — no indexable files found here."
46
+
47
+ results = search(index, query, top_k=max(1, min(top_k, 20)))
48
+ if not results:
49
+ return "No relevant chunks found for that query."
50
+
51
+ blocks = []
52
+ for result in results:
53
+ chunk = result.chunk
54
+ snippet = chunk.text
55
+ if len(snippet) > _SNIPPET_CHARS:
56
+ snippet = snippet[:_SNIPPET_CHARS] + "\n… [truncated]"
57
+ blocks.append(
58
+ f"[{chunk.path}:{chunk.start}-{chunk.end}] (score {result.score:.1f})\n{snippet}"
59
+ )
60
+ return truncate("\n\n---\n\n".join(blocks), 8000)
pdo/tools/registry.py ADDED
@@ -0,0 +1,203 @@
1
+ """The single tool registry.
2
+
3
+ Tools auto-register here via the :func:`register_tool` class decorator when
4
+ their module is imported. :func:`get_registry` lazily imports the built-in tool
5
+ modules the first time it is called, so the rest of the app only ever sees a
6
+ fully populated registry and never imports individual tools.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import importlib.util
11
+ import inspect
12
+ import logging
13
+ import sys
14
+ from pathlib import Path
15
+ from threading import Lock
16
+
17
+ from .base import Tool
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Entry-point group third-party packages can register tools under.
22
+ PLUGIN_ENTRYPOINT_GROUP = "pdo.plugins"
23
+
24
+ # Names of tools that were added by plugins (for display/introspection).
25
+ _plugin_tool_names: list[str] = []
26
+
27
+
28
+ class ToolRegistry:
29
+ """An ordered collection of tools keyed by name."""
30
+
31
+ def __init__(self) -> None:
32
+ self._tools: dict[str, Tool] = {}
33
+
34
+ def register(self, tool: Tool) -> None:
35
+ if not tool.name:
36
+ raise ValueError(f"{type(tool).__name__} must define a non-empty 'name'")
37
+ if tool.name in self._tools:
38
+ logger.warning("Overwriting already-registered tool %r", tool.name)
39
+ self._tools[tool.name] = tool
40
+
41
+ def get(self, name: str) -> Tool:
42
+ if name not in self._tools:
43
+ raise KeyError(name)
44
+ return self._tools[name]
45
+
46
+ def has(self, name: str) -> bool:
47
+ return name in self._tools
48
+
49
+ def all(self) -> list[Tool]:
50
+ return list(self._tools.values())
51
+
52
+ def names(self) -> list[str]:
53
+ return sorted(self._tools)
54
+
55
+ def schemas(self) -> list[dict]:
56
+ """Return every tool's JSON schema for the model's ``tools`` parameter."""
57
+ return [tool.to_openai_schema() for tool in self._tools.values()]
58
+
59
+
60
+ # Module-level singleton. Tools register into this instance at import time.
61
+ _registry = ToolRegistry()
62
+ _loaded = False
63
+ _load_lock = Lock()
64
+
65
+
66
+ def register_tool(cls: type[Tool]) -> type[Tool]:
67
+ """Class decorator: instantiate a ``Tool`` subclass and register it.
68
+
69
+ Subclasses must be constructible with no required arguments (inject
70
+ dependencies via keyword arguments with sensible defaults).
71
+ """
72
+ _registry.register(cls())
73
+ return cls
74
+
75
+
76
+ def get_registry() -> ToolRegistry:
77
+ """Return the shared registry, importing built-in and plugin tools on first use."""
78
+ global _loaded
79
+ if not _loaded:
80
+ with _load_lock:
81
+ if not _loaded:
82
+ # Importing these modules triggers their @register_tool decorators.
83
+ from . import ( # noqa: F401
84
+ code,
85
+ data,
86
+ edit,
87
+ filesystem,
88
+ git,
89
+ memory,
90
+ rag,
91
+ search,
92
+ shell,
93
+ web,
94
+ )
95
+
96
+ load_plugins(_registry)
97
+ _loaded = True
98
+ return _registry
99
+
100
+
101
+ def plugin_tool_names() -> list[str]:
102
+ """Return the names of tools contributed by plugins."""
103
+ return list(_plugin_tool_names)
104
+
105
+
106
+ def load_plugins(registry: ToolRegistry) -> None:
107
+ """Discover and register external tool plugins (directory + entry points)."""
108
+ from ..config import get_plugins_dir
109
+
110
+ try:
111
+ discover_directory_plugins(registry, get_plugins_dir())
112
+ except Exception: # noqa: BLE001 — discovery must never crash startup
113
+ logger.exception("Directory plugin discovery failed")
114
+ try:
115
+ discover_entrypoint_plugins(registry)
116
+ except Exception: # noqa: BLE001
117
+ logger.exception("Entry-point plugin discovery failed")
118
+
119
+
120
+ def discover_directory_plugins(registry: ToolRegistry, plugins_dir: Path) -> None:
121
+ """Import every ``*.py`` file in ``plugins_dir`` and register its tools.
122
+
123
+ Each file is loaded in isolation; a broken plugin is logged and skipped so
124
+ it can't take down the rest of PDO.
125
+ """
126
+ if not plugins_dir.exists():
127
+ return
128
+ for path in sorted(plugins_dir.glob("*.py")):
129
+ if path.name.startswith("_"):
130
+ continue
131
+ try:
132
+ module = _import_file(path)
133
+ _register_module_tools(registry, module)
134
+ except Exception: # noqa: BLE001
135
+ logger.exception("Failed to load plugin file %s", path)
136
+
137
+
138
+ def discover_entrypoint_plugins(registry: ToolRegistry) -> None:
139
+ """Load tools advertised by installed packages under the plugin entry-point group.
140
+
141
+ An entry point may resolve to a ``Tool`` subclass (registered directly) or a
142
+ callable ``register(registry)`` that adds its own tools.
143
+ """
144
+ try:
145
+ from importlib.metadata import entry_points
146
+
147
+ eps = entry_points()
148
+ group = (
149
+ eps.select(group=PLUGIN_ENTRYPOINT_GROUP)
150
+ if hasattr(eps, "select")
151
+ else eps.get(PLUGIN_ENTRYPOINT_GROUP, []) # Python <3.10 mapping API
152
+ )
153
+ except Exception: # noqa: BLE001
154
+ logger.exception("Could not read plugin entry points")
155
+ return
156
+
157
+ for entry_point in group:
158
+ try:
159
+ obj = entry_point.load()
160
+ except Exception: # noqa: BLE001
161
+ logger.exception("Failed to load plugin entry point %s", entry_point.name)
162
+ continue
163
+ _register_plugin_object(registry, obj)
164
+
165
+
166
+ def _import_file(path: Path):
167
+ """Import a standalone .py file under a unique module name and return it."""
168
+ name = f"pdo_plugin_{path.stem}"
169
+ spec = importlib.util.spec_from_file_location(name, path)
170
+ if spec is None or spec.loader is None:
171
+ raise ImportError(f"cannot load plugin spec for {path}")
172
+ module = importlib.util.module_from_spec(spec)
173
+ sys.modules[name] = module
174
+ spec.loader.exec_module(module)
175
+ return module
176
+
177
+
178
+ def _register_module_tools(registry: ToolRegistry, module) -> None:
179
+ """Register every concrete ``Tool`` subclass *defined in* ``module``."""
180
+ for _, obj in inspect.getmembers(module, inspect.isclass):
181
+ if (
182
+ issubclass(obj, Tool)
183
+ and obj is not Tool
184
+ and not inspect.isabstract(obj)
185
+ and obj.__module__ == module.__name__ # ignore imported base/others
186
+ ):
187
+ _register_plugin_object(registry, obj)
188
+
189
+
190
+ def _register_plugin_object(registry: ToolRegistry, obj) -> None:
191
+ """Register a plugin-provided Tool subclass or a register(registry) callable."""
192
+ if inspect.isclass(obj) and issubclass(obj, Tool) and obj is not Tool:
193
+ try:
194
+ tool = obj()
195
+ except Exception: # noqa: BLE001
196
+ logger.exception("Could not instantiate plugin tool %s", obj)
197
+ return
198
+ if tool.name and not registry.has(tool.name):
199
+ registry.register(tool)
200
+ _plugin_tool_names.append(tool.name)
201
+ logger.info("Loaded plugin tool %r", tool.name)
202
+ elif callable(obj):
203
+ obj(registry)
pdo/tools/search.py ADDED
@@ -0,0 +1,83 @@
1
+ """Code-search tools: glob for files and grep file contents."""
2
+ from __future__ import annotations
3
+
4
+ import re
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from .base import Tool, truncate
9
+ from .registry import register_tool
10
+
11
+ # Directories that are noise for code search; skipped by grep.
12
+ _SKIP_DIRS = {".git", ".venv", "venv", "node_modules", "__pycache__", ".ruff_cache", ".pytest_cache"}
13
+ _MAX_HITS = 200
14
+
15
+
16
+ def _skipped(path: Path) -> bool:
17
+ return any(part in _SKIP_DIRS for part in path.parts)
18
+
19
+
20
+ @register_tool
21
+ class GlobTool(Tool):
22
+ name = "glob_files"
23
+ description = "Find files matching a glob pattern (e.g. '**/*.py') under a directory."
24
+ parameters = {
25
+ "type": "object",
26
+ "properties": {
27
+ "pattern": {"type": "string", "description": "Glob pattern, e.g. '**/*.md'."},
28
+ "path": {"type": "string", "description": "Base directory (default current)."},
29
+ },
30
+ "required": ["pattern"],
31
+ }
32
+
33
+ def run(self, pattern: str, path: str = ".", **_: Any) -> str:
34
+ base = Path(path).expanduser()
35
+ if not base.exists():
36
+ return f"Error: path not found: {base}"
37
+ matches = [
38
+ str(p) for p in sorted(base.glob(pattern)) if p.is_file() and not _skipped(p)
39
+ ][:500]
40
+ return truncate("\n".join(matches)) if matches else "No files matched."
41
+
42
+
43
+ @register_tool
44
+ class GrepTool(Tool):
45
+ name = "search_files"
46
+ description = (
47
+ "Search file contents for a regular expression and return matching "
48
+ "'path:line: text'. Skips common noise directories (.git, .venv, …)."
49
+ )
50
+ parameters = {
51
+ "type": "object",
52
+ "properties": {
53
+ "pattern": {"type": "string", "description": "Regular expression to search for."},
54
+ "path": {"type": "string", "description": "Base directory (default current)."},
55
+ "glob": {
56
+ "type": "string",
57
+ "description": "Limit to files matching this glob (default '**/*').",
58
+ },
59
+ },
60
+ "required": ["pattern"],
61
+ }
62
+
63
+ def run(self, pattern: str, path: str = ".", glob: str = "**/*", **_: Any) -> str:
64
+ base = Path(path).expanduser()
65
+ try:
66
+ regex = re.compile(pattern)
67
+ except re.error as exc:
68
+ return f"Invalid regular expression: {exc}"
69
+
70
+ hits: list[str] = []
71
+ for file in base.glob(glob):
72
+ if not file.is_file() or _skipped(file):
73
+ continue
74
+ try:
75
+ content = file.read_text("utf-8", "ignore")
76
+ except OSError:
77
+ continue
78
+ for lineno, line in enumerate(content.splitlines(), 1):
79
+ if regex.search(line):
80
+ hits.append(f"{file}:{lineno}: {line.strip()[:200]}")
81
+ if len(hits) >= _MAX_HITS:
82
+ return truncate("\n".join(hits) + "\n… [more matches omitted]")
83
+ return truncate("\n".join(hits)) if hits else "No matches found."
pdo/tools/shell.py ADDED
@@ -0,0 +1,125 @@
1
+ """Shell execution tool with a configurable dangerous-command detector.
2
+
3
+ The detector is a pure function (:func:`is_dangerous`) so it can be unit tested
4
+ in isolation. Anything it flags requires an explicit, typed confirmation before
5
+ the command runs.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import re
11
+ import subprocess
12
+ from collections.abc import Callable, Sequence
13
+ from typing import Any
14
+
15
+ from .base import Tool, default_confirm
16
+ from .registry import register_tool
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ DEFAULT_TIMEOUT = 60
21
+
22
+ # Exact substrings that are unambiguously destructive. Kept separate from the
23
+ # regex patterns so operators can extend either independently.
24
+ DEFAULT_DENYLIST: tuple[str, ...] = (
25
+ "rm -rf /",
26
+ "rm -rf /*",
27
+ ":(){:|:&};:",
28
+ "mkfs",
29
+ "dd if=",
30
+ "> /dev/sda",
31
+ )
32
+
33
+ # (pattern, human-readable reason) pairs describing dangerous command shapes.
34
+ DANGEROUS_PATTERNS: tuple[tuple[str, str], ...] = (
35
+ (r"\brm\b\s+(?:-\S+\s+)*-\S*r", "recursive delete (rm -r)"),
36
+ (r"\brm\b.*\*", "wildcard delete (rm ... *)"),
37
+ (r"\bsudo\b", "elevated privileges (sudo)"),
38
+ (r"\bshutdown\b", "system shutdown"),
39
+ (r"\breboot\b", "system reboot"),
40
+ (r"\bhalt\b", "system halt"),
41
+ (r"\bmkfs\b", "disk format (mkfs)"),
42
+ (r"\bfdisk\b", "disk partitioning (fdisk)"),
43
+ (r"\bmkswap\b", "swap creation (mkswap)"),
44
+ (r"\bdd\b\s+if=", "raw disk write (dd)"),
45
+ (r":\(\)\s*\{.*\}\s*;\s*:", "fork bomb"),
46
+ (r">\s*/dev/sd[a-z]", "writing to a disk device"),
47
+ (r"\bchmod\b\s+-R\s+0?00", "recursive permission wipe (chmod -R 000)"),
48
+ )
49
+
50
+
51
+ def is_dangerous(
52
+ command: str, denylist: Sequence[str] | None = None
53
+ ) -> tuple[bool, str | None]:
54
+ """Classify ``command`` as dangerous or not.
55
+
56
+ Returns a ``(dangerous, reason)`` tuple. The reason is a short, human-
57
+ readable explanation suitable for a confirmation prompt.
58
+ """
59
+ text = command.strip()
60
+ lowered = text.lower()
61
+
62
+ for token in (denylist if denylist is not None else DEFAULT_DENYLIST):
63
+ if token.lower() in lowered:
64
+ return True, f"matches denylist entry {token!r}"
65
+
66
+ for pattern, reason in DANGEROUS_PATTERNS:
67
+ if re.search(pattern, text):
68
+ return True, reason
69
+
70
+ return False, None
71
+
72
+
73
+ @register_tool
74
+ class ShellTool(Tool):
75
+ name = "run_shell"
76
+ description = (
77
+ "Run a shell command and return its combined stdout/stderr and exit "
78
+ "code. Commands detected as dangerous require explicit user "
79
+ "confirmation before they run."
80
+ )
81
+ parameters = {
82
+ "type": "object",
83
+ "properties": {
84
+ "command": {"type": "string", "description": "The shell command to execute."},
85
+ "timeout": {
86
+ "type": "integer",
87
+ "description": f"Maximum seconds to wait (default {DEFAULT_TIMEOUT}).",
88
+ },
89
+ },
90
+ "required": ["command"],
91
+ }
92
+
93
+ def __init__(
94
+ self,
95
+ confirm: Callable[[str], bool] = default_confirm,
96
+ denylist: Sequence[str] | None = None,
97
+ ) -> None:
98
+ self._confirm = confirm
99
+ self._denylist = denylist
100
+
101
+ def run(self, command: str, timeout: int = DEFAULT_TIMEOUT, **_: Any) -> str:
102
+ dangerous, reason = is_dangerous(command, self._denylist)
103
+ if dangerous:
104
+ logger.warning("Dangerous command requested (%s): %s", reason, command)
105
+ if not self._confirm(
106
+ f"DANGEROUS command detected — {reason}:\n {command}\nProceed?"
107
+ ):
108
+ return "Cancelled: dangerous command was not confirmed by the user."
109
+
110
+ try:
111
+ completed = subprocess.run(
112
+ command,
113
+ shell=True, # noqa: S602 — running shell commands is this tool's purpose
114
+ capture_output=True,
115
+ text=True,
116
+ timeout=timeout,
117
+ )
118
+ except subprocess.TimeoutExpired:
119
+ return f"Error: command timed out after {timeout}s."
120
+ except Exception as exc: # noqa: BLE001 — never crash the agent loop
121
+ logger.exception("Shell command failed to start")
122
+ return f"Error: could not run command: {exc}"
123
+
124
+ output = ((completed.stdout or "") + (completed.stderr or "")).strip()
125
+ return f"[exit {completed.returncode}]\n{output or '(no output)'}"