pdo-agent 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pdo/__init__.py +21 -0
- pdo/agent/__init__.py +6 -0
- pdo/agent/core.py +275 -0
- pdo/agent/delegate.py +56 -0
- pdo/agent/executor.py +87 -0
- pdo/agent/memory.py +191 -0
- pdo/agent/messages.py +87 -0
- pdo/agent/planner.py +38 -0
- pdo/agent/reviewer.py +25 -0
- pdo/agent/router.py +37 -0
- pdo/api.py +65 -0
- pdo/banner.py +53 -0
- pdo/config.py +151 -0
- pdo/llm.py +211 -0
- pdo/logging_setup.py +34 -0
- pdo/main.py +961 -0
- pdo/mcp.py +264 -0
- pdo/prompts/system.md +46 -0
- pdo/providers.py +86 -0
- pdo/rag.py +191 -0
- pdo/serve.py +124 -0
- pdo/skills.py +59 -0
- pdo/theme.py +47 -0
- pdo/tools/__init__.py +6 -0
- pdo/tools/base.py +89 -0
- pdo/tools/code.py +48 -0
- pdo/tools/data.py +57 -0
- pdo/tools/edit.py +55 -0
- pdo/tools/filesystem.py +175 -0
- pdo/tools/git.py +44 -0
- pdo/tools/memory.py +70 -0
- pdo/tools/rag.py +60 -0
- pdo/tools/registry.py +203 -0
- pdo/tools/search.py +83 -0
- pdo/tools/shell.py +125 -0
- pdo/tools/web.py +163 -0
- pdo_agent-2.0.0.dist-info/METADATA +456 -0
- pdo_agent-2.0.0.dist-info/RECORD +42 -0
- pdo_agent-2.0.0.dist-info/WHEEL +5 -0
- pdo_agent-2.0.0.dist-info/entry_points.txt +2 -0
- pdo_agent-2.0.0.dist-info/licenses/LICENSE +21 -0
- pdo_agent-2.0.0.dist-info/top_level.txt +1 -0
pdo/mcp.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""Minimal Model Context Protocol (MCP) client.
|
|
2
|
+
|
|
3
|
+
Connects to MCP servers over the stdio transport (newline-delimited JSON-RPC
|
|
4
|
+
2.0), lists their tools, and exposes each as a PDO :class:`~pdo.tools.base.Tool`
|
|
5
|
+
so the agent can use them transparently. Implemented synchronously with the
|
|
6
|
+
standard library only — no extra dependencies, consistent with PDO's v1 design.
|
|
7
|
+
|
|
8
|
+
Servers are declared in ``<PDO_HOME>/mcp.json`` using the widely-used format::
|
|
9
|
+
|
|
10
|
+
{
|
|
11
|
+
"mcpServers": {
|
|
12
|
+
"filesystem": {
|
|
13
|
+
"command": "npx",
|
|
14
|
+
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path"]
|
|
15
|
+
}
|
|
16
|
+
}
|
|
17
|
+
}
|
|
18
|
+
"""
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import json
|
|
22
|
+
import logging
|
|
23
|
+
import os
|
|
24
|
+
import re
|
|
25
|
+
import select
|
|
26
|
+
import subprocess
|
|
27
|
+
import time
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import Any
|
|
30
|
+
|
|
31
|
+
from . import __version__
|
|
32
|
+
from .tools.base import Tool, truncate
|
|
33
|
+
from .tools.registry import ToolRegistry
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
PROTOCOL_VERSION = "2024-11-05"
|
|
38
|
+
_REQUEST_TIMEOUT = 30
|
|
39
|
+
_CALL_TIMEOUT = 120
|
|
40
|
+
|
|
41
|
+
# Servers started this process, for introspection (e.g. the /mcp command).
|
|
42
|
+
_ACTIVE_SERVERS: list[MCPServer] = []
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class MCPError(RuntimeError):
|
|
46
|
+
"""Raised when an MCP server errors or can't be reached."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _sanitize(name: str) -> str:
|
|
50
|
+
"""Make a tool name safe for the model API (^[A-Za-z0-9_-]{1,64}$)."""
|
|
51
|
+
return re.sub(r"[^A-Za-z0-9_-]", "_", name)[:64]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MCPServer:
|
|
55
|
+
"""A single MCP server subprocess, spoken to over stdio JSON-RPC."""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self, name: str, command: str, args: list[str], env: dict[str, str] | None = None
|
|
59
|
+
) -> None:
|
|
60
|
+
self.name = name
|
|
61
|
+
self.command = command
|
|
62
|
+
self.args = list(args)
|
|
63
|
+
self.env = dict(env or {})
|
|
64
|
+
self._proc: subprocess.Popen | None = None
|
|
65
|
+
self._id = 0
|
|
66
|
+
self._tools: list[dict[str, Any]] = []
|
|
67
|
+
|
|
68
|
+
# --- lifecycle ---------------------------------------------------------- #
|
|
69
|
+
def start(self) -> None:
|
|
70
|
+
"""Spawn the server, perform the MCP handshake, and cache its tools."""
|
|
71
|
+
self._proc = subprocess.Popen( # noqa: S603 — launching a user-configured server
|
|
72
|
+
[self.command, *self.args],
|
|
73
|
+
stdin=subprocess.PIPE,
|
|
74
|
+
stdout=subprocess.PIPE,
|
|
75
|
+
stderr=subprocess.DEVNULL,
|
|
76
|
+
text=True,
|
|
77
|
+
bufsize=1,
|
|
78
|
+
env={**os.environ, **self.env},
|
|
79
|
+
)
|
|
80
|
+
self._request(
|
|
81
|
+
"initialize",
|
|
82
|
+
{
|
|
83
|
+
"protocolVersion": PROTOCOL_VERSION,
|
|
84
|
+
"capabilities": {},
|
|
85
|
+
"clientInfo": {"name": "pdo", "version": __version__},
|
|
86
|
+
},
|
|
87
|
+
)
|
|
88
|
+
self._notify("notifications/initialized")
|
|
89
|
+
self._tools = self._request("tools/list").get("tools", [])
|
|
90
|
+
|
|
91
|
+
def stop(self) -> None:
|
|
92
|
+
if self._proc is None:
|
|
93
|
+
return
|
|
94
|
+
try:
|
|
95
|
+
self._proc.terminate()
|
|
96
|
+
self._proc.wait(timeout=5)
|
|
97
|
+
except Exception: # noqa: BLE001 — force-kill if it won't exit cleanly
|
|
98
|
+
self._proc.kill()
|
|
99
|
+
finally:
|
|
100
|
+
self._proc = None
|
|
101
|
+
|
|
102
|
+
# --- public API --------------------------------------------------------- #
|
|
103
|
+
def list_tools(self) -> list[dict[str, Any]]:
|
|
104
|
+
return list(self._tools)
|
|
105
|
+
|
|
106
|
+
def call_tool(self, name: str, arguments: dict[str, Any]) -> str:
|
|
107
|
+
result = self._request(
|
|
108
|
+
"tools/call", {"name": name, "arguments": arguments}, timeout=_CALL_TIMEOUT
|
|
109
|
+
)
|
|
110
|
+
parts = []
|
|
111
|
+
for item in result.get("content", []):
|
|
112
|
+
if item.get("type") == "text":
|
|
113
|
+
parts.append(item.get("text", ""))
|
|
114
|
+
else:
|
|
115
|
+
parts.append(json.dumps(item))
|
|
116
|
+
text = "\n".join(parts) if parts else "(no content)"
|
|
117
|
+
if result.get("isError"):
|
|
118
|
+
return f"Error from MCP tool: {truncate(text)}"
|
|
119
|
+
return truncate(text)
|
|
120
|
+
|
|
121
|
+
# --- JSON-RPC plumbing -------------------------------------------------- #
|
|
122
|
+
def _send(self, message: dict[str, Any]) -> None:
|
|
123
|
+
if self._proc is None or self._proc.stdin is None:
|
|
124
|
+
raise MCPError(f"MCP server {self.name!r} is not running")
|
|
125
|
+
self._proc.stdin.write(json.dumps(message) + "\n")
|
|
126
|
+
self._proc.stdin.flush()
|
|
127
|
+
|
|
128
|
+
def _notify(self, method: str, params: dict[str, Any] | None = None) -> None:
|
|
129
|
+
self._send({"jsonrpc": "2.0", "method": method, "params": params or {}})
|
|
130
|
+
|
|
131
|
+
def _request(
|
|
132
|
+
self, method: str, params: dict[str, Any] | None = None, timeout: int = _REQUEST_TIMEOUT
|
|
133
|
+
) -> dict[str, Any]:
|
|
134
|
+
self._id += 1
|
|
135
|
+
request_id = self._id
|
|
136
|
+
self._send({"jsonrpc": "2.0", "id": request_id, "method": method, "params": params or {}})
|
|
137
|
+
|
|
138
|
+
deadline = time.time() + timeout
|
|
139
|
+
while time.time() < deadline:
|
|
140
|
+
line = self._read_line(deadline - time.time())
|
|
141
|
+
if not line:
|
|
142
|
+
continue
|
|
143
|
+
try:
|
|
144
|
+
message = json.loads(line)
|
|
145
|
+
except json.JSONDecodeError:
|
|
146
|
+
continue
|
|
147
|
+
if message.get("id") == request_id:
|
|
148
|
+
if "error" in message:
|
|
149
|
+
raise MCPError(message["error"].get("message", "MCP error"))
|
|
150
|
+
return message.get("result", {})
|
|
151
|
+
# Otherwise it's a notification or an unrelated message: ignore it.
|
|
152
|
+
raise MCPError(f"timed out waiting for {method!r} from {self.name!r}")
|
|
153
|
+
|
|
154
|
+
def _read_line(self, timeout: float) -> str | None:
|
|
155
|
+
if self._proc is None or self._proc.stdout is None:
|
|
156
|
+
return None
|
|
157
|
+
try:
|
|
158
|
+
ready, _, _ = select.select([self._proc.stdout], [], [], max(0.0, timeout))
|
|
159
|
+
except (OSError, ValueError):
|
|
160
|
+
# select may not work on this platform's pipes; fall back to blocking.
|
|
161
|
+
return self._proc.stdout.readline() or None
|
|
162
|
+
if not ready:
|
|
163
|
+
return None
|
|
164
|
+
return self._proc.stdout.readline() or None
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class MCPTool(Tool):
|
|
168
|
+
"""Adapts a remote MCP tool to PDO's Tool interface."""
|
|
169
|
+
|
|
170
|
+
def __init__(self, server: MCPServer, mcp_name: str, description: str, schema: dict | None):
|
|
171
|
+
self.name = _sanitize(f"mcp__{server.name}__{mcp_name}")
|
|
172
|
+
self.description = description or f"MCP tool {mcp_name} from {server.name}."
|
|
173
|
+
self.parameters = schema or {"type": "object", "properties": {}}
|
|
174
|
+
self._server = server
|
|
175
|
+
self._mcp_name = mcp_name
|
|
176
|
+
|
|
177
|
+
def run(self, **kwargs: Any) -> str:
|
|
178
|
+
try:
|
|
179
|
+
return self._server.call_tool(self._mcp_name, self._clean_args(kwargs))
|
|
180
|
+
except MCPError as exc:
|
|
181
|
+
return f"Error: {exc}"
|
|
182
|
+
|
|
183
|
+
def _clean_args(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
184
|
+
"""Drop blank optional arguments.
|
|
185
|
+
|
|
186
|
+
Smaller models often pad optional parameters with junk like ``" "``,
|
|
187
|
+
``"/"`` or ``[]``, which some MCP servers reject (e.g. Canva fails on a
|
|
188
|
+
meaningless ``brand_kit_id``). For any argument not required by the tool's
|
|
189
|
+
schema, we drop None, empty collections, and strings with no alphanumeric
|
|
190
|
+
characters (a real id/value always contains letters or digits).
|
|
191
|
+
"""
|
|
192
|
+
required = set(self.parameters.get("required", []) or [])
|
|
193
|
+
cleaned: dict[str, Any] = {}
|
|
194
|
+
for key, value in kwargs.items():
|
|
195
|
+
if key in required:
|
|
196
|
+
cleaned[key] = value
|
|
197
|
+
continue
|
|
198
|
+
if value is None:
|
|
199
|
+
continue
|
|
200
|
+
if isinstance(value, str) and not any(ch.isalnum() for ch in value):
|
|
201
|
+
continue
|
|
202
|
+
if isinstance(value, (list, dict)) and len(value) == 0:
|
|
203
|
+
continue
|
|
204
|
+
cleaned[key] = value
|
|
205
|
+
return cleaned
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def load_mcp_config(path: Path) -> dict[str, dict]:
|
|
209
|
+
"""Read the ``mcpServers`` mapping from ``path`` (empty if missing/invalid)."""
|
|
210
|
+
if not path.exists():
|
|
211
|
+
return {}
|
|
212
|
+
try:
|
|
213
|
+
data = json.loads(path.read_text(encoding="utf-8"))
|
|
214
|
+
except (json.JSONDecodeError, OSError) as exc:
|
|
215
|
+
logger.warning("Could not read MCP config %s: %s", path, exc)
|
|
216
|
+
return {}
|
|
217
|
+
servers = data.get("mcpServers", data) if isinstance(data, dict) else {}
|
|
218
|
+
return servers if isinstance(servers, dict) else {}
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def start_servers(
|
|
222
|
+
registry: ToolRegistry, servers_config: dict[str, dict]
|
|
223
|
+
) -> tuple[list[MCPServer], list[tuple[str, int, str | None]]]:
|
|
224
|
+
"""Start each configured server and register its tools.
|
|
225
|
+
|
|
226
|
+
Returns the started servers and a per-server ``(name, tool_count, error)``
|
|
227
|
+
summary. A server that fails to start is skipped, not fatal.
|
|
228
|
+
"""
|
|
229
|
+
started: list[MCPServer] = []
|
|
230
|
+
summary: list[tuple[str, int, str | None]] = []
|
|
231
|
+
|
|
232
|
+
for name, spec in servers_config.items():
|
|
233
|
+
command = spec.get("command")
|
|
234
|
+
if not command:
|
|
235
|
+
summary.append((name, 0, "no 'command' specified"))
|
|
236
|
+
continue
|
|
237
|
+
server = MCPServer(name, command, spec.get("args", []), spec.get("env", {}))
|
|
238
|
+
try:
|
|
239
|
+
server.start()
|
|
240
|
+
except Exception as exc: # noqa: BLE001 — a bad server must not crash PDO
|
|
241
|
+
logger.warning("MCP server %r failed to start: %s", name, exc)
|
|
242
|
+
summary.append((name, 0, str(exc)))
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
count = 0
|
|
246
|
+
for spec_tool in server.list_tools():
|
|
247
|
+
tool = MCPTool(
|
|
248
|
+
server,
|
|
249
|
+
spec_tool["name"],
|
|
250
|
+
spec_tool.get("description", ""),
|
|
251
|
+
spec_tool.get("inputSchema"),
|
|
252
|
+
)
|
|
253
|
+
if not registry.has(tool.name):
|
|
254
|
+
registry.register(tool)
|
|
255
|
+
count += 1
|
|
256
|
+
started.append(server)
|
|
257
|
+
_ACTIVE_SERVERS.append(server)
|
|
258
|
+
summary.append((name, count, None))
|
|
259
|
+
|
|
260
|
+
return started, summary
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def active_servers() -> list[MCPServer]:
|
|
264
|
+
return list(_ACTIVE_SERVERS)
|
pdo/prompts/system.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
You are **PDO** (Python Do) — a terminal-first AI agent. Your motto is **Think. Plan. Do.**
|
|
2
|
+
|
|
3
|
+
You are not a chatbot. You help the user accomplish real work on their machine by
|
|
4
|
+
reasoning about the goal, planning the steps, deciding whether tools are needed,
|
|
5
|
+
executing them safely, reviewing the result, and replying clearly.
|
|
6
|
+
|
|
7
|
+
## How you work
|
|
8
|
+
|
|
9
|
+
1. **Understand** the user's goal before acting. Ask a brief clarifying question
|
|
10
|
+
only when genuinely blocked — otherwise proceed with sensible defaults.
|
|
11
|
+
2. **Think and plan.** For multi-step tasks, work through the steps in order.
|
|
12
|
+
3. **Decide if tools are needed.** Use a tool only when it makes the answer
|
|
13
|
+
better or is required to complete the task. Plain conversation, explanations,
|
|
14
|
+
and questions you can answer directly should stay a plain reply with **no
|
|
15
|
+
tool calls**.
|
|
16
|
+
4. **Execute safely.** Prefer the smallest action that accomplishes the step.
|
|
17
|
+
Inspect before you change: read a file before editing it, list a directory
|
|
18
|
+
before assuming its contents.
|
|
19
|
+
5. **Review** what came back. If a tool returned an error, explain it and adjust
|
|
20
|
+
rather than blindly retrying.
|
|
21
|
+
6. **Respond clearly** in concise Markdown. Show commands and code in fenced
|
|
22
|
+
blocks. Summarise what you did and what (if anything) the user should do next.
|
|
23
|
+
|
|
24
|
+
## Tools
|
|
25
|
+
|
|
26
|
+
You have native function/tool calling. Available tools include reading, writing,
|
|
27
|
+
and appending files; listing and creating directories; running shell commands;
|
|
28
|
+
and saving, searching, and deleting long-term memories. Call tools by name with
|
|
29
|
+
JSON arguments — never describe a tool call in prose instead of making it.
|
|
30
|
+
|
|
31
|
+
## Safety
|
|
32
|
+
|
|
33
|
+
- Destructive or privileged commands (`rm`, `sudo`, `shutdown`, `reboot`, disk
|
|
34
|
+
operations, recursive deletes) require explicit user confirmation, which the
|
|
35
|
+
tool layer enforces. Do not try to bypass it.
|
|
36
|
+
- Writes are sandboxed to the working directory by default; writing elsewhere or
|
|
37
|
+
overwriting a file will prompt the user.
|
|
38
|
+
- When you are about to do something irreversible, say so plainly first.
|
|
39
|
+
|
|
40
|
+
## Style
|
|
41
|
+
|
|
42
|
+
- Be direct and practical. Lead with the answer or the result.
|
|
43
|
+
- Don't narrate your internal reasoning at length; show the outcome.
|
|
44
|
+
- Use the user's working directory as the default location for new files.
|
|
45
|
+
- When you remember something durable about the user or project, use the memory
|
|
46
|
+
tools so it persists across sessions.
|
pdo/providers.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Known LLM providers for the ``/models`` picker.
|
|
2
|
+
|
|
3
|
+
All three providers expose an OpenAI-compatible chat API, so PDO reaches them
|
|
4
|
+
through the same :class:`pdo.llm.OpenAIClient` — only the ``base_url`` and the
|
|
5
|
+
API key differ. Adding a provider is just another entry in ``PROVIDERS``.
|
|
6
|
+
|
|
7
|
+
The model lists are a curated starting point, not an exhaustive catalogue; the
|
|
8
|
+
picker always offers a "custom model id" option so any model the provider
|
|
9
|
+
supports can be used.
|
|
10
|
+
"""
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class Provider:
|
|
18
|
+
"""Connection details and suggested models for one provider."""
|
|
19
|
+
|
|
20
|
+
key: str # short internal id, e.g. "openai"
|
|
21
|
+
label: str # human label shown in the menu
|
|
22
|
+
env_key: str # environment variable holding the API key
|
|
23
|
+
base_url: str | None # OpenAI-compatible endpoint (None = OpenAI default)
|
|
24
|
+
models: list[str] = field(default_factory=list)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
PROVIDERS: dict[str, Provider] = {
|
|
28
|
+
"openai": Provider(
|
|
29
|
+
key="openai",
|
|
30
|
+
label="OpenAI",
|
|
31
|
+
env_key="OPENAI_API_KEY",
|
|
32
|
+
base_url=None,
|
|
33
|
+
models=[
|
|
34
|
+
"gpt-4.1-mini",
|
|
35
|
+
"gpt-4.1",
|
|
36
|
+
"gpt-4.1-nano",
|
|
37
|
+
"gpt-4o",
|
|
38
|
+
"gpt-4o-mini",
|
|
39
|
+
"o4-mini",
|
|
40
|
+
],
|
|
41
|
+
),
|
|
42
|
+
"anthropic": Provider(
|
|
43
|
+
key="anthropic",
|
|
44
|
+
label="Anthropic (Claude)",
|
|
45
|
+
env_key="ANTHROPIC_API_KEY",
|
|
46
|
+
# Anthropic's OpenAI-compatibility endpoint.
|
|
47
|
+
base_url="https://api.anthropic.com/v1/",
|
|
48
|
+
models=[
|
|
49
|
+
"claude-sonnet-4-5",
|
|
50
|
+
"claude-opus-4-1",
|
|
51
|
+
"claude-3-7-sonnet-latest",
|
|
52
|
+
"claude-3-5-sonnet-latest",
|
|
53
|
+
"claude-3-5-haiku-latest",
|
|
54
|
+
],
|
|
55
|
+
),
|
|
56
|
+
"openrouter": Provider(
|
|
57
|
+
key="openrouter",
|
|
58
|
+
label="OpenRouter",
|
|
59
|
+
env_key="OPENROUTER_API_KEY",
|
|
60
|
+
base_url="https://openrouter.ai/api/v1",
|
|
61
|
+
models=[
|
|
62
|
+
"openai/gpt-4.1-mini",
|
|
63
|
+
"anthropic/claude-3.7-sonnet",
|
|
64
|
+
"anthropic/claude-3.5-sonnet",
|
|
65
|
+
"google/gemini-flash-1.5",
|
|
66
|
+
"meta-llama/llama-3.3-70b-instruct",
|
|
67
|
+
"deepseek/deepseek-chat",
|
|
68
|
+
],
|
|
69
|
+
),
|
|
70
|
+
"ollama": Provider(
|
|
71
|
+
key="ollama",
|
|
72
|
+
label="Ollama (local)",
|
|
73
|
+
env_key="OLLAMA_API_KEY", # unused; a local server needs no key
|
|
74
|
+
# OpenAI-compatible local endpoint. Override with OLLAMA_BASE_URL.
|
|
75
|
+
base_url="http://localhost:11434/v1",
|
|
76
|
+
models=[
|
|
77
|
+
"llama3.2",
|
|
78
|
+
"llama3.1",
|
|
79
|
+
"qwen2.5",
|
|
80
|
+
"mistral",
|
|
81
|
+
"gemma2",
|
|
82
|
+
"phi4",
|
|
83
|
+
"deepseek-r1",
|
|
84
|
+
],
|
|
85
|
+
),
|
|
86
|
+
}
|
pdo/rag.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Lexical codebase retrieval (BM25).
|
|
2
|
+
|
|
3
|
+
``/index`` builds a chunk index of the working directory; the
|
|
4
|
+
``codebase_search`` tool ranks chunks against a query with BM25 and returns the
|
|
5
|
+
best snippets with ``path:line`` references.
|
|
6
|
+
|
|
7
|
+
We use lexical retrieval rather than embeddings on purpose: it needs no
|
|
8
|
+
embeddings endpoint (OpenRouter doesn't offer one), no API key, and no new
|
|
9
|
+
dependencies — and BM25 over code identifiers is strong in practice. Embedding
|
|
10
|
+
support can be layered on later without changing the tool interface.
|
|
11
|
+
"""
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import hashlib
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import math
|
|
18
|
+
import re
|
|
19
|
+
import time
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
from .config import get_home_dir
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
# Files worth indexing (source + docs + config).
|
|
28
|
+
_INDEX_EXTS = {
|
|
29
|
+
".py", ".md", ".txt", ".rst", ".toml", ".ini", ".cfg", ".json", ".yml",
|
|
30
|
+
".yaml", ".js", ".jsx", ".ts", ".tsx", ".html", ".css", ".sh", ".zsh",
|
|
31
|
+
".swift", ".rs", ".go", ".java", ".kt", ".c", ".h", ".cpp", ".hpp", ".sql",
|
|
32
|
+
}
|
|
33
|
+
_SKIP_DIRS = {
|
|
34
|
+
".git", ".venv", "venv", "node_modules", "__pycache__", ".ruff_cache",
|
|
35
|
+
".pytest_cache", "dist", "build", ".mypy_cache", ".idea", ".vscode",
|
|
36
|
+
}
|
|
37
|
+
_MAX_FILE_BYTES = 200_000
|
|
38
|
+
_CHUNK_LINES = 40
|
|
39
|
+
_CHUNK_OVERLAP = 10
|
|
40
|
+
_MAX_CHUNKS = 20_000
|
|
41
|
+
|
|
42
|
+
# BM25 constants (standard defaults).
|
|
43
|
+
_K1 = 1.5
|
|
44
|
+
_B = 0.75
|
|
45
|
+
|
|
46
|
+
_WORD = re.compile(r"[A-Za-z0-9_]+")
|
|
47
|
+
# Split camelCase and acronym boundaries: fooBar -> foo|Bar, HTTPServer -> HTTP|Server.
|
|
48
|
+
_CAMEL = re.compile(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def tokenize(text: str) -> list[str]:
|
|
52
|
+
"""Split text into lowercase tokens, including snake/camelCase sub-words."""
|
|
53
|
+
tokens: list[str] = []
|
|
54
|
+
for word in _WORD.findall(text):
|
|
55
|
+
lowered = word.lower()
|
|
56
|
+
tokens.append(lowered)
|
|
57
|
+
parts = [p for chunk in _CAMEL.split(word) for p in chunk.split("_") if p]
|
|
58
|
+
if len(parts) > 1:
|
|
59
|
+
tokens.extend(p.lower() for p in parts)
|
|
60
|
+
return tokens
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class Chunk:
|
|
65
|
+
path: str # relative to the index root
|
|
66
|
+
start: int # 1-based first line
|
|
67
|
+
end: int # 1-based last line
|
|
68
|
+
text: str
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class Index:
|
|
73
|
+
root: str
|
|
74
|
+
built: float
|
|
75
|
+
chunks: list[Chunk] = field(default_factory=list)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _index_path(root: Path) -> Path:
|
|
79
|
+
digest = hashlib.sha1(str(root.resolve()).encode()).hexdigest()[:16]
|
|
80
|
+
directory = get_home_dir() / "index"
|
|
81
|
+
directory.mkdir(parents=True, exist_ok=True)
|
|
82
|
+
return directory / f"{digest}.json"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _iter_files(root: Path):
|
|
86
|
+
for path in sorted(root.rglob("*")):
|
|
87
|
+
if not path.is_file() or path.suffix.lower() not in _INDEX_EXTS:
|
|
88
|
+
continue
|
|
89
|
+
if any(part in _SKIP_DIRS for part in path.parts):
|
|
90
|
+
continue
|
|
91
|
+
try:
|
|
92
|
+
if path.stat().st_size > _MAX_FILE_BYTES:
|
|
93
|
+
continue
|
|
94
|
+
except OSError:
|
|
95
|
+
continue
|
|
96
|
+
yield path
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def build_index(root: Path) -> Index:
|
|
100
|
+
"""Chunk every indexable file under ``root`` and persist the index."""
|
|
101
|
+
root = root.resolve()
|
|
102
|
+
index = Index(root=str(root), built=time.time())
|
|
103
|
+
for file in _iter_files(root):
|
|
104
|
+
try:
|
|
105
|
+
lines = file.read_text("utf-8", "ignore").splitlines()
|
|
106
|
+
except OSError:
|
|
107
|
+
continue
|
|
108
|
+
rel = str(file.relative_to(root))
|
|
109
|
+
step = _CHUNK_LINES - _CHUNK_OVERLAP
|
|
110
|
+
for start in range(0, max(len(lines), 1), step):
|
|
111
|
+
block = lines[start : start + _CHUNK_LINES]
|
|
112
|
+
if not any(line.strip() for line in block):
|
|
113
|
+
continue
|
|
114
|
+
index.chunks.append(
|
|
115
|
+
Chunk(path=rel, start=start + 1, end=start + len(block), text="\n".join(block))
|
|
116
|
+
)
|
|
117
|
+
if len(index.chunks) >= _MAX_CHUNKS:
|
|
118
|
+
logger.warning("Index chunk cap reached (%d); stopping", _MAX_CHUNKS)
|
|
119
|
+
save_index(index)
|
|
120
|
+
return index
|
|
121
|
+
if start + _CHUNK_LINES >= len(lines):
|
|
122
|
+
break
|
|
123
|
+
save_index(index)
|
|
124
|
+
return index
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def save_index(index: Index) -> None:
|
|
128
|
+
payload = {
|
|
129
|
+
"root": index.root,
|
|
130
|
+
"built": index.built,
|
|
131
|
+
"chunks": [vars(chunk) for chunk in index.chunks],
|
|
132
|
+
}
|
|
133
|
+
_index_path(Path(index.root)).write_text(json.dumps(payload), encoding="utf-8")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def load_index(root: Path) -> Index | None:
|
|
137
|
+
path = _index_path(root)
|
|
138
|
+
if not path.exists():
|
|
139
|
+
return None
|
|
140
|
+
try:
|
|
141
|
+
data = json.loads(path.read_text(encoding="utf-8"))
|
|
142
|
+
return Index(
|
|
143
|
+
root=data["root"],
|
|
144
|
+
built=data["built"],
|
|
145
|
+
chunks=[Chunk(**chunk) for chunk in data["chunks"]],
|
|
146
|
+
)
|
|
147
|
+
except (json.JSONDecodeError, KeyError, TypeError, OSError) as exc:
|
|
148
|
+
logger.warning("Could not load index %s: %s", path, exc)
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@dataclass
|
|
153
|
+
class SearchResult:
|
|
154
|
+
chunk: Chunk
|
|
155
|
+
score: float
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def search(index: Index, query: str, top_k: int = 5) -> list[SearchResult]:
|
|
159
|
+
"""Rank chunks against ``query`` with BM25 and return the top ``top_k``."""
|
|
160
|
+
query_tokens = set(tokenize(query))
|
|
161
|
+
if not query_tokens or not index.chunks:
|
|
162
|
+
return []
|
|
163
|
+
|
|
164
|
+
token_lists = [tokenize(chunk.text) for chunk in index.chunks]
|
|
165
|
+
n = len(token_lists)
|
|
166
|
+
avg_len = sum(len(t) for t in token_lists) / n
|
|
167
|
+
|
|
168
|
+
# Document frequency per query token.
|
|
169
|
+
df = dict.fromkeys(query_tokens, 0)
|
|
170
|
+
for tokens in token_lists:
|
|
171
|
+
present = query_tokens.intersection(tokens)
|
|
172
|
+
for token in present:
|
|
173
|
+
df[token] += 1
|
|
174
|
+
|
|
175
|
+
results: list[SearchResult] = []
|
|
176
|
+
for chunk, tokens in zip(index.chunks, token_lists, strict=True):
|
|
177
|
+
if not tokens:
|
|
178
|
+
continue
|
|
179
|
+
score = 0.0
|
|
180
|
+
length_norm = _K1 * (1 - _B + _B * len(tokens) / avg_len)
|
|
181
|
+
for token in query_tokens:
|
|
182
|
+
tf = tokens.count(token)
|
|
183
|
+
if tf == 0 or df[token] == 0:
|
|
184
|
+
continue
|
|
185
|
+
idf = math.log(1 + (n - df[token] + 0.5) / (df[token] + 0.5))
|
|
186
|
+
score += idf * (tf * (_K1 + 1)) / (tf + length_norm)
|
|
187
|
+
if score > 0:
|
|
188
|
+
results.append(SearchResult(chunk=chunk, score=score))
|
|
189
|
+
|
|
190
|
+
results.sort(key=lambda r: r.score, reverse=True)
|
|
191
|
+
return results[:top_k]
|