ata-coder 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ata_coder/__init__.py +1 -0
- ata_coder/agent.py +874 -0
- ata_coder/agent_compact.py +190 -0
- ata_coder/agent_controller.py +218 -0
- ata_coder/agent_extension.py +69 -0
- ata_coder/agent_routing.py +105 -0
- ata_coder/agent_subsystems.py +72 -0
- ata_coder/agent_tools.py +318 -0
- ata_coder/agent_undo.py +63 -0
- ata_coder/anthropic_client.py +465 -0
- ata_coder/change_tracker.py +368 -0
- ata_coder/clawd_integration.py +574 -0
- ata_coder/commands/__init__.py +128 -0
- ata_coder/commands/_core.py +184 -0
- ata_coder/commands/_safety.py +95 -0
- ata_coder/commands/_settings.py +241 -0
- ata_coder/commands/_workflow.py +451 -0
- ata_coder/commands.py +974 -0
- ata_coder/config.py +257 -0
- ata_coder/core/__init__.py +35 -0
- ata_coder/core/events.py +73 -0
- ata_coder/core/queue.py +85 -0
- ata_coder/core/state.py +17 -0
- ata_coder/event_queue.py +5 -0
- ata_coder/extension.py +654 -0
- ata_coder/extensions/__init__.py +1 -0
- ata_coder/extensions/hello_skill.py +47 -0
- ata_coder/fool_proof.py +295 -0
- ata_coder/git_workflow.py +371 -0
- ata_coder/gui.py +511 -0
- ata_coder/llm_client.py +543 -0
- ata_coder/main.py +814 -0
- ata_coder/mcp_client.py +1095 -0
- ata_coder/memory.py +539 -0
- ata_coder/model_registry.py +134 -0
- ata_coder/model_router.py +105 -0
- ata_coder/permissions.py +274 -0
- ata_coder/privilege.py +464 -0
- ata_coder/project.py +273 -0
- ata_coder/prompt_template.py +423 -0
- ata_coder/prompts/auto-mode.md +7 -0
- ata_coder/prompts/coding-rules.md +40 -0
- ata_coder/prompts/execution-guardrails.md +14 -0
- ata_coder/prompts/memory-system.md +24 -0
- ata_coder/prompts/output-style.md +23 -0
- ata_coder/prompts/safety.md +17 -0
- ata_coder/prompts/slash-commands.md +24 -0
- ata_coder/prompts/sub-agents.md +38 -0
- ata_coder/prompts/system-reminders.md +17 -0
- ata_coder/prompts/system.md +105 -0
- ata_coder/prompts/tool-policy.md +46 -0
- ata_coder/repl_theme.py +99 -0
- ata_coder/repl_tracker.py +89 -0
- ata_coder/repl_ui.py +1214 -0
- ata_coder/safety_guard.py +434 -0
- ata_coder/self_correct.py +346 -0
- ata_coder/server.py +882 -0
- ata_coder/server_session.py +159 -0
- ata_coder/server_shell.py +129 -0
- ata_coder/session.py +431 -0
- ata_coder/settings.py +439 -0
- ata_coder/setup_wizard.py +136 -0
- ata_coder/skill_extension.py +92 -0
- ata_coder/skills/architect/SKILL.md +42 -0
- ata_coder/skills/code-reviewer/SKILL.md +37 -0
- ata_coder/skills/codecraft/SKILL.md +452 -0
- ata_coder/skills/debugger/SKILL.md +45 -0
- ata_coder/skills/doc-writer/SKILL.md +36 -0
- ata_coder/skills/general-coder/SKILL.md +76 -0
- ata_coder/skills/math-calculator/README.md +40 -0
- ata_coder/skills/math-calculator/SKILL.md +59 -0
- ata_coder/skills/math-calculator/handler.py +103 -0
- ata_coder/skills/math-calculator/prompts/system.md +8 -0
- ata_coder/skills/math-calculator/requirements.txt +2 -0
- ata_coder/skills/math-calculator/resources/constants.json +8 -0
- ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
- ata_coder/skills/security-auditor/SKILL.md +40 -0
- ata_coder/skills/test-writer/SKILL.md +36 -0
- ata_coder/skills/weather-skill/README.md +45 -0
- ata_coder/skills/weather-skill/handler.py +76 -0
- ata_coder/skills/weather-skill/manifest.json +48 -0
- ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
- ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
- ata_coder/skills/weather-skill/requirements.txt +1 -0
- ata_coder/skills/weather-skill/resources/city_list.json +17 -0
- ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
- ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
- ata_coder/skills/weather-skill/weather_utils.py +50 -0
- ata_coder/skills.py +1014 -0
- ata_coder/sub_agent.py +273 -0
- ata_coder/sub_agent_manager.py +203 -0
- ata_coder/system_prompt_builder.py +146 -0
- ata_coder/task_planner.py +391 -0
- ata_coder/terminal.py +318 -0
- ata_coder/test_runner.py +219 -0
- ata_coder/thread_supervisor.py +195 -0
- ata_coder/tool_defs.py +335 -0
- ata_coder/tools/__init__.py +11 -0
- ata_coder/tools/definitions.py +335 -0
- ata_coder/tools/executor.py +1036 -0
- ata_coder/tools/result.py +26 -0
- ata_coder/tools/subagent.py +332 -0
- ata_coder/tools/web.py +361 -0
- ata_coder/tools.py +1576 -0
- ata_coder/types.py +92 -0
- ata_coder/utils.py +113 -0
- ata_coder/web/css/style.css +180 -0
- ata_coder/web/index.html +84 -0
- ata_coder/web/js/app.js +489 -0
- ata_coder/web/package-lock.json +25 -0
- ata_coder/web/package.json +10 -0
- ata_coder/web/tsconfig.json +13 -0
- ata_coder-2.4.2.dist-info/METADATA +799 -0
- ata_coder-2.4.2.dist-info/RECORD +118 -0
- ata_coder-2.4.2.dist-info/WHEEL +5 -0
- ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
- ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
- ata_coder-2.4.2.dist-info/top_level.txt +1 -0
ata_coder/memory.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Persistent memory system for ATA Coder.
|
|
3
|
+
|
|
4
|
+
Stores facts, user preferences, feedback, and project context across sessions.
|
|
5
|
+
Uses a file-based approach:
|
|
6
|
+
- memory/MEMORY.md — index of all memories (loaded on startup)
|
|
7
|
+
- memory/<slug>.md — individual memory files with YAML frontmatter
|
|
8
|
+
|
|
9
|
+
Memory types:
|
|
10
|
+
- user: who the user is, their preferences, expertise
|
|
11
|
+
- feedback: user guidance on how the agent should work
|
|
12
|
+
- project: ongoing goals, constraints, architecture decisions
|
|
13
|
+
- reference: pointers to external resources (URLs, docs, etc.)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import re
|
|
20
|
+
import threading
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from datetime import datetime, timezone
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
from .utils import try_import_yaml
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
_yaml_mod, HAS_YAML = try_import_yaml()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ── Memory data model ────────────────────────────────────────────────────────
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class Memory:
|
|
37
|
+
"""A single memory entry."""
|
|
38
|
+
|
|
39
|
+
name: str # kebab-case slug, used as filename
|
|
40
|
+
description: str # one-line summary (used for relevance)
|
|
41
|
+
content: str # the memory body
|
|
42
|
+
metadata: dict[str, Any] = field(default_factory=dict) # type, tags, etc.
|
|
43
|
+
created: str = "" # ISO timestamp
|
|
44
|
+
updated: str = "" # ISO timestamp
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def memory_type(self) -> str:
|
|
48
|
+
return self.metadata.get("type", "reference")
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def file_path(self) -> str:
|
|
52
|
+
return f"{self.name}.md"
|
|
53
|
+
|
|
54
|
+
def to_frontmatter(self) -> str:
|
|
55
|
+
"""Serialize to a markdown file with YAML frontmatter."""
|
|
56
|
+
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
57
|
+
meta = {
|
|
58
|
+
"name": self.name, "description": self.description,
|
|
59
|
+
"metadata": self.metadata, "created": self.created or now, "updated": now,
|
|
60
|
+
}
|
|
61
|
+
if _yaml_mod is not None:
|
|
62
|
+
yaml_str = _yaml_mod.dump(meta, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
|
63
|
+
else:
|
|
64
|
+
yaml_str = json.dumps(meta, indent=2, ensure_ascii=False)
|
|
65
|
+
return f"---\n{yaml_str}---\n\n{self.content}"
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_frontmatter(cls, raw: str) -> "Memory | None":
|
|
69
|
+
"""Parse a markdown file with YAML frontmatter into a Memory."""
|
|
70
|
+
match = re.match(r"^---\s*\n(.*?)\n---\s*\n(.*)", raw, re.DOTALL)
|
|
71
|
+
if not match:
|
|
72
|
+
return None
|
|
73
|
+
front_str, content = match.group(1), match.group(2).strip()
|
|
74
|
+
try:
|
|
75
|
+
try:
|
|
76
|
+
if _yaml_mod is not None:
|
|
77
|
+
meta = _yaml_mod.safe_load(front_str)
|
|
78
|
+
else:
|
|
79
|
+
meta = json.loads(front_str)
|
|
80
|
+
except ImportError:
|
|
81
|
+
meta = json.loads(front_str)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.warning("Failed to parse frontmatter: %s", e)
|
|
84
|
+
return None
|
|
85
|
+
if not isinstance(meta, dict):
|
|
86
|
+
return None
|
|
87
|
+
return cls(
|
|
88
|
+
name=meta.get("name", "unknown"), description=meta.get("description", ""),
|
|
89
|
+
content=content, metadata=meta.get("metadata", {}),
|
|
90
|
+
created=meta.get("created", ""), updated=meta.get("updated", ""),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# ── Memory store ─────────────────────────────────────────────────────────────
|
|
95
|
+
|
|
96
|
+
class MemoryStore:
|
|
97
|
+
"""
|
|
98
|
+
Persistent file-based memory store.
|
|
99
|
+
|
|
100
|
+
On initialization, reads MEMORY.md for the index, then loads individual
|
|
101
|
+
memory files on demand or all at once.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, memory_dir: str | Path | None = None):
|
|
105
|
+
if memory_dir is None:
|
|
106
|
+
try:
|
|
107
|
+
from .settings import get_settings
|
|
108
|
+
memory_dir = get_settings().memory_dir
|
|
109
|
+
except Exception:
|
|
110
|
+
memory_dir = Path.home() / ".ata_coder" / "memory"
|
|
111
|
+
self.memory_dir = Path(memory_dir)
|
|
112
|
+
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
|
113
|
+
|
|
114
|
+
self._index_path = self.memory_dir / "MEMORY.md"
|
|
115
|
+
self._memories: dict[str, Memory] = {}
|
|
116
|
+
self._index_entries: list[str] = [] # lines from MEMORY.md
|
|
117
|
+
# IDF cache — invalidated on add/delete
|
|
118
|
+
self._idf_cache: dict[str, float] | None = None
|
|
119
|
+
self._idf_doc_count: int = 0
|
|
120
|
+
self._lock = threading.RLock() # protect concurrent read/write
|
|
121
|
+
|
|
122
|
+
self._load_index()
|
|
123
|
+
self._load_all()
|
|
124
|
+
|
|
125
|
+
# ── Loading ───────────────────────────────────────────────────────────
|
|
126
|
+
|
|
127
|
+
def _load_index(self) -> None:
|
|
128
|
+
"""Load the MEMORY.md index file."""
|
|
129
|
+
if self._index_path.exists():
|
|
130
|
+
try:
|
|
131
|
+
with open(self._index_path, "r", encoding="utf-8") as f:
|
|
132
|
+
self._index_entries = [
|
|
133
|
+
line.strip() for line in f.readlines() if line.strip()
|
|
134
|
+
]
|
|
135
|
+
logger.debug(
|
|
136
|
+
"Loaded MEMORY.md: %d entries", len(self._index_entries)
|
|
137
|
+
)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
logger.warning("Failed to load MEMORY.md: %s", e)
|
|
140
|
+
self._index_entries = []
|
|
141
|
+
else:
|
|
142
|
+
# Create empty index
|
|
143
|
+
self._write_index()
|
|
144
|
+
|
|
145
|
+
def _load_all(self) -> None:
|
|
146
|
+
"""Load all memory files from the directory."""
|
|
147
|
+
if not self.memory_dir.exists():
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
for file_path in self.memory_dir.glob("*.md"):
|
|
151
|
+
if file_path.name == "MEMORY.md":
|
|
152
|
+
continue
|
|
153
|
+
try:
|
|
154
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
155
|
+
raw = f.read()
|
|
156
|
+
memory = Memory.from_frontmatter(raw)
|
|
157
|
+
if memory:
|
|
158
|
+
self._memories[memory.name] = memory
|
|
159
|
+
else:
|
|
160
|
+
logger.warning("Failed to parse memory file: %s", file_path.name)
|
|
161
|
+
except Exception as e:
|
|
162
|
+
logger.warning("Failed to read memory file %s: %s", file_path.name, e)
|
|
163
|
+
|
|
164
|
+
logger.debug("Loaded %d memories from disk", len(self._memories))
|
|
165
|
+
|
|
166
|
+
def _write_index(self) -> None:
|
|
167
|
+
"""Write the index file atomically (write-then-rename)."""
|
|
168
|
+
tmp = self._index_path.with_suffix(".tmp")
|
|
169
|
+
try:
|
|
170
|
+
with open(tmp, "w", encoding="utf-8") as f:
|
|
171
|
+
for entry in self._index_entries:
|
|
172
|
+
f.write(entry + "\n")
|
|
173
|
+
# os.replace is atomic cross-platform; Path.replace raises
|
|
174
|
+
# FileExistsError on Windows for existing targets.
|
|
175
|
+
os.replace(tmp, self._index_path)
|
|
176
|
+
except Exception as e:
|
|
177
|
+
logger.warning("Failed to write MEMORY.md: %s", e)
|
|
178
|
+
|
|
179
|
+
# ── CRUD operations ──────────────────────────────────────────────────
|
|
180
|
+
|
|
181
|
+
def add(self, memory: Memory) -> Memory:
|
|
182
|
+
"""
|
|
183
|
+
Add or update a memory. If one with the same name exists, update it.
|
|
184
|
+
"""
|
|
185
|
+
with self._lock:
|
|
186
|
+
self._idf_cache = None # invalidate IDF cache
|
|
187
|
+
existing = self._memories.get(memory.name)
|
|
188
|
+
|
|
189
|
+
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
190
|
+
if existing:
|
|
191
|
+
memory.created = existing.created
|
|
192
|
+
else:
|
|
193
|
+
memory.created = memory.created or now
|
|
194
|
+
memory.updated = now
|
|
195
|
+
|
|
196
|
+
# Write memory file
|
|
197
|
+
file_path = self.memory_dir / memory.file_path
|
|
198
|
+
try:
|
|
199
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
200
|
+
f.write(memory.to_frontmatter())
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error("Failed to write memory file %s: %s", file_path, e)
|
|
203
|
+
raise
|
|
204
|
+
|
|
205
|
+
self._memories[memory.name] = memory
|
|
206
|
+
|
|
207
|
+
# Update index
|
|
208
|
+
entry = f"- [{memory.description}]({memory.file_path})"
|
|
209
|
+
replaced = False
|
|
210
|
+
for i, line in enumerate(self._index_entries):
|
|
211
|
+
if f"]({memory.file_path})" in line:
|
|
212
|
+
self._index_entries[i] = entry
|
|
213
|
+
replaced = True
|
|
214
|
+
break
|
|
215
|
+
if not replaced:
|
|
216
|
+
self._index_entries.append(entry)
|
|
217
|
+
|
|
218
|
+
self._write_index()
|
|
219
|
+
logger.info("Saved memory: %s", memory.name)
|
|
220
|
+
return memory
|
|
221
|
+
|
|
222
|
+
def save_batch(self, memories: list[Memory]) -> list[Memory]:
|
|
223
|
+
"""Save multiple memories efficiently — writes index only once."""
|
|
224
|
+
with self._lock:
|
|
225
|
+
self._idf_cache = None
|
|
226
|
+
for memory in memories:
|
|
227
|
+
file_path = self.memory_dir / memory.file_path
|
|
228
|
+
try:
|
|
229
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
230
|
+
f.write(memory.to_frontmatter())
|
|
231
|
+
except Exception as e:
|
|
232
|
+
logger.error("Failed to write memory file %s: %s", file_path, e)
|
|
233
|
+
continue
|
|
234
|
+
self._memories[memory.name] = memory
|
|
235
|
+
# Rebuild and write index once
|
|
236
|
+
self._rebuild_index()
|
|
237
|
+
logger.info("Batch saved %d memories", len(memories))
|
|
238
|
+
return memories
|
|
239
|
+
|
|
240
|
+
def _rebuild_index(self) -> None:
|
|
241
|
+
"""Rebuild the index from all loaded memories (batch-safe)."""
|
|
242
|
+
self._index_entries = [
|
|
243
|
+
f"- [{m.description}]({m.file_path})"
|
|
244
|
+
for m in self._memories.values()
|
|
245
|
+
]
|
|
246
|
+
self._write_index()
|
|
247
|
+
|
|
248
|
+
def flush(self) -> None:
|
|
249
|
+
"""Force-write the index to disk (call before shutdown)."""
|
|
250
|
+
self._rebuild_index()
|
|
251
|
+
|
|
252
|
+
def get(self, name: str) -> Memory | None:
|
|
253
|
+
"""Get a memory by name (slug)."""
|
|
254
|
+
return self._memories.get(name)
|
|
255
|
+
|
|
256
|
+
def delete(self, name: str) -> bool:
|
|
257
|
+
"""Delete a memory by name."""
|
|
258
|
+
with self._lock:
|
|
259
|
+
self._idf_cache = None
|
|
260
|
+
memory = self._memories.pop(name, None)
|
|
261
|
+
if memory is None:
|
|
262
|
+
return False
|
|
263
|
+
|
|
264
|
+
file_path = self.memory_dir / memory.file_path
|
|
265
|
+
if file_path.exists():
|
|
266
|
+
try:
|
|
267
|
+
file_path.unlink()
|
|
268
|
+
except Exception as e:
|
|
269
|
+
logger.warning("Failed to delete memory file: %s", e)
|
|
270
|
+
|
|
271
|
+
self._index_entries = [
|
|
272
|
+
line for line in self._index_entries
|
|
273
|
+
if f"]({memory.file_path})" not in line
|
|
274
|
+
]
|
|
275
|
+
self._write_index()
|
|
276
|
+
logger.info("Deleted memory: %s", name)
|
|
277
|
+
return True
|
|
278
|
+
|
|
279
|
+
def list_all(self, memory_type: str | None = None) -> list[Memory]:
|
|
280
|
+
"""List all memories, optionally filtered by type."""
|
|
281
|
+
memories = list(self._memories.values())
|
|
282
|
+
if memory_type:
|
|
283
|
+
memories = [m for m in memories if m.memory_type == memory_type]
|
|
284
|
+
# Sort by updated (handle both string and datetime types)
|
|
285
|
+
def sort_key(m: Memory) -> str:
|
|
286
|
+
return str(m.updated or "")
|
|
287
|
+
return sorted(memories, key=sort_key, reverse=True)
|
|
288
|
+
|
|
289
|
+
def search(self, query: str) -> list[Memory]:
|
|
290
|
+
"""Search memories by TF-IDF-weighted token overlap.
|
|
291
|
+
Returns memories sorted by relevance score (descending).
|
|
292
|
+
"""
|
|
293
|
+
scored = self._search_scored(query)
|
|
294
|
+
return [m for _, m in scored]
|
|
295
|
+
|
|
296
|
+
def _search_scored(self, query: str) -> list[tuple[float, Memory]]:
|
|
297
|
+
"""
|
|
298
|
+
Score every memory against *query* with TF-IDF-weighted token
|
|
299
|
+
overlap plus phrase bonuses and recency boost.
|
|
300
|
+
|
|
301
|
+
Returns (score, memory) pairs sorted by score descending.
|
|
302
|
+
"""
|
|
303
|
+
if not self._memories:
|
|
304
|
+
return []
|
|
305
|
+
|
|
306
|
+
query_lower = query.lower()
|
|
307
|
+
query_tokens = set(query_lower.split())
|
|
308
|
+
|
|
309
|
+
# ── Pre-compute document frequencies for IDF weighting ──────────
|
|
310
|
+
# Use cached IDF when available; rebuild only when memories change.
|
|
311
|
+
import math as _math
|
|
312
|
+
doc_count = len(self._memories)
|
|
313
|
+
if self._idf_cache is None or self._idf_doc_count != doc_count:
|
|
314
|
+
token_df: dict[str, int] = {}
|
|
315
|
+
for m in self._memories.values():
|
|
316
|
+
text = f"{m.name} {m.description} {m.content}".lower()
|
|
317
|
+
seen: set[str] = set()
|
|
318
|
+
for word in text.split():
|
|
319
|
+
if word not in seen:
|
|
320
|
+
token_df[word] = token_df.get(word, 0) + 1
|
|
321
|
+
seen.add(word)
|
|
322
|
+
# Pre-compute IDF for every token
|
|
323
|
+
self._idf_cache = {
|
|
324
|
+
t: _math.log((doc_count + 1) / (df + 1)) + 1.0
|
|
325
|
+
for t, df in token_df.items()
|
|
326
|
+
}
|
|
327
|
+
self._idf_doc_count = doc_count
|
|
328
|
+
|
|
329
|
+
idf_map = self._idf_cache
|
|
330
|
+
|
|
331
|
+
def idf(token: str) -> float:
|
|
332
|
+
return idf_map.get(token, 1.0) # unseen tokens get neutral weight
|
|
333
|
+
|
|
334
|
+
# ── Score each memory ──────────────────────────────────────────
|
|
335
|
+
results: list[tuple[float, Memory]] = []
|
|
336
|
+
for memory in self._memories.values():
|
|
337
|
+
score = 0.0
|
|
338
|
+
name_lower = memory.name.lower()
|
|
339
|
+
desc_lower = memory.description.lower()
|
|
340
|
+
content_lower = memory.content.lower()
|
|
341
|
+
|
|
342
|
+
# Phrase bonus: full query appears as substring
|
|
343
|
+
if query_lower in name_lower:
|
|
344
|
+
score += 15.0
|
|
345
|
+
if query_lower in desc_lower:
|
|
346
|
+
score += 8.0
|
|
347
|
+
if query_lower in content_lower:
|
|
348
|
+
score += 4.0
|
|
349
|
+
|
|
350
|
+
# Token-level IDF-weighted match
|
|
351
|
+
for token in query_tokens:
|
|
352
|
+
w = idf(token)
|
|
353
|
+
if token in name_lower.replace("-", " ").split():
|
|
354
|
+
score += 6.0 * w # name match — highest signal
|
|
355
|
+
if token in set(desc_lower.split()):
|
|
356
|
+
score += 3.0 * w # description match — medium signal
|
|
357
|
+
if token in set(content_lower.split()):
|
|
358
|
+
score += 1.5 * w # content match — lower signal
|
|
359
|
+
|
|
360
|
+
# Recency boost: memories touched in the last hour get +2
|
|
361
|
+
try:
|
|
362
|
+
from datetime import datetime, timezone, timedelta
|
|
363
|
+
updated = memory.updated or ""
|
|
364
|
+
if updated:
|
|
365
|
+
dt = datetime.fromisoformat(updated.replace("Z", "+00:00"))
|
|
366
|
+
if dt > datetime.now(timezone.utc) - timedelta(hours=1):
|
|
367
|
+
score += 2.0
|
|
368
|
+
except (ValueError, TypeError):
|
|
369
|
+
pass
|
|
370
|
+
|
|
371
|
+
if score > 0:
|
|
372
|
+
results.append((score, memory))
|
|
373
|
+
|
|
374
|
+
results.sort(key=lambda x: x[0], reverse=True)
|
|
375
|
+
return results
|
|
376
|
+
|
|
377
|
+
# ── Recall for context ───────────────────────────────────────────────
|
|
378
|
+
|
|
379
|
+
def recall_context(self, user_input: str, max_memories: int = 5,
|
|
380
|
+
min_score: float = 3.0) -> str:
|
|
381
|
+
"""
|
|
382
|
+
Recall memories relevant to *user_input* for inclusion in the system
|
|
383
|
+
prompt. Only returns memories whose relevance score exceeds
|
|
384
|
+
*min_score* so the prompt doesn't get polluted with noise.
|
|
385
|
+
"""
|
|
386
|
+
if not self._memories:
|
|
387
|
+
return ""
|
|
388
|
+
|
|
389
|
+
# Re-use the scored search
|
|
390
|
+
scored = self._search_scored(user_input)
|
|
391
|
+
relevant = [m for score, m in scored if score >= min_score][:max_memories]
|
|
392
|
+
if not relevant:
|
|
393
|
+
return ""
|
|
394
|
+
|
|
395
|
+
# Bump access tracking
|
|
396
|
+
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
397
|
+
for m in relevant:
|
|
398
|
+
m.metadata["last_accessed"] = now
|
|
399
|
+
m.metadata["access_count"] = m.metadata.get("access_count", 0) + 1
|
|
400
|
+
|
|
401
|
+
lines = ["\n## Relevant Memories"]
|
|
402
|
+
for memory in relevant:
|
|
403
|
+
lines.append(f"\n### {memory.description}")
|
|
404
|
+
content = memory.content
|
|
405
|
+
if len(content) > 500:
|
|
406
|
+
content = content[:500] + "..."
|
|
407
|
+
lines.append(content)
|
|
408
|
+
refs = self._extract_links(memory.content)
|
|
409
|
+
if refs:
|
|
410
|
+
lines.append(f"Related: {', '.join(refs)}")
|
|
411
|
+
return "\n".join(lines)
|
|
412
|
+
|
|
413
|
+
def _extract_links(self, content: str) -> list[str]:
|
|
414
|
+
"""Extract [[wiki-style]] links from content."""
|
|
415
|
+
return re.findall(r"\[\[([^\]]+)\]\]", content)
|
|
416
|
+
|
|
417
|
+
def get_memory_context(self, max_total: int = 8) -> str:
|
|
418
|
+
"""
|
|
419
|
+
Return a compact summary of recently-updated memories for the
|
|
420
|
+
system prompt. Capped at *max_total* entries so the prompt
|
|
421
|
+
doesn't bloat when the user has dozens of memories.
|
|
422
|
+
"""
|
|
423
|
+
if not self._memories:
|
|
424
|
+
return ""
|
|
425
|
+
|
|
426
|
+
def _sort_key(m: Memory) -> str:
|
|
427
|
+
return str(m.updated or "")
|
|
428
|
+
|
|
429
|
+
recent = sorted(self._memories.values(), key=_sort_key, reverse=True)[:max_total]
|
|
430
|
+
if not recent:
|
|
431
|
+
return ""
|
|
432
|
+
|
|
433
|
+
lines = ["\n## Persistent Memory"]
|
|
434
|
+
by_type: dict[str, list[Memory]] = {}
|
|
435
|
+
for m in recent:
|
|
436
|
+
by_type.setdefault(m.memory_type, []).append(m)
|
|
437
|
+
|
|
438
|
+
for mtype in ["user", "project", "feedback", "reference"]:
|
|
439
|
+
entries = by_type.get(mtype, [])
|
|
440
|
+
if entries:
|
|
441
|
+
lines.append(f"\n### {mtype.title()}")
|
|
442
|
+
for m in entries[:3]:
|
|
443
|
+
lines.append(f"- {m.description}")
|
|
444
|
+
return "\n".join(lines)
|
|
445
|
+
|
|
446
|
+
# ── Auto-suggest from conversation ──────────────────────────────────
|
|
447
|
+
|
|
448
|
+
def suggest_from_conversation(self, user_messages: list[str],
|
|
449
|
+
file_ops: list[str] | None = None,
|
|
450
|
+
tool_errors: list[str] | None = None) -> list[str]:
|
|
451
|
+
"""Analyse recent messages for facts worth saving as memories.
|
|
452
|
+
|
|
453
|
+
Returns a list of human-readable suggestions like
|
|
454
|
+
``"user prefers YAML over JSON for config"`` that the agent can
|
|
455
|
+
surface to the user with a quick save prompt.
|
|
456
|
+
"""
|
|
457
|
+
suggestions: list[str] = []
|
|
458
|
+
|
|
459
|
+
# Heuristic 1: explicit "remember …" or "save …" directives
|
|
460
|
+
for msg in user_messages:
|
|
461
|
+
lower = msg.lower()
|
|
462
|
+
if any(kw in lower for kw in ("remember", "save this", "don't forget",
|
|
463
|
+
"记", "记住", "备忘")):
|
|
464
|
+
suggestions.append(f"User asked to remember: {msg[:120]}")
|
|
465
|
+
|
|
466
|
+
# Heuristic 2: project-specific paths or toolchains mentioned
|
|
467
|
+
toolchain_keywords = ["idf.py", "esp-idf", "esptool", "cmake", "platformio",
|
|
468
|
+
"arduino", "stm32", "nrf", "zephyr"]
|
|
469
|
+
for msg in user_messages:
|
|
470
|
+
for kw in toolchain_keywords:
|
|
471
|
+
if kw.lower() in msg.lower():
|
|
472
|
+
suggestions.append(
|
|
473
|
+
f"Project uses {kw}: {msg[:120]}"
|
|
474
|
+
)
|
|
475
|
+
break
|
|
476
|
+
|
|
477
|
+
# Heuristic 3: device ports / serial config
|
|
478
|
+
import re as _re
|
|
479
|
+
for msg in user_messages:
|
|
480
|
+
port_match = _re.search(r'COM\d+|/dev/tty\w+', msg)
|
|
481
|
+
if port_match:
|
|
482
|
+
suggestions.append(
|
|
483
|
+
f"Device port {port_match.group()}: {msg[:120]}"
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Heuristic 4: operational learnings — detect "X failed → Y worked" patterns
|
|
487
|
+
if tool_errors:
|
|
488
|
+
for err in tool_errors:
|
|
489
|
+
lower = err.lower()
|
|
490
|
+
if "not in the allowed list" in lower:
|
|
491
|
+
suggestions.append(
|
|
492
|
+
"ops: Some shell commands are blocked by the whitelist. "
|
|
493
|
+
"Use python -c \"import subprocess; subprocess.run([...], cwd='...')\" "
|
|
494
|
+
"as a workaround for tools not on PATH."
|
|
495
|
+
)
|
|
496
|
+
break
|
|
497
|
+
if "command not found" in lower or "not recognized" in lower:
|
|
498
|
+
# Extract the command name
|
|
499
|
+
m = _re.search(r"'(\w+)'", err)
|
|
500
|
+
cmd = m.group(1) if m else "?"
|
|
501
|
+
suggestions.append(
|
|
502
|
+
f"ops: Command '{cmd}' not found — use full path or "
|
|
503
|
+
f"python subprocess wrapper."
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
return suggestions[:5] # cap to avoid overwhelming the user
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
# ── Convenience functions ────────────────────────────────────────────────────
|
|
510
|
+
|
|
511
|
+
def create_memory(
|
|
512
|
+
name: str,
|
|
513
|
+
description: str,
|
|
514
|
+
content: str,
|
|
515
|
+
memory_type: str = "reference",
|
|
516
|
+
store: MemoryStore | None = None,
|
|
517
|
+
) -> Memory:
|
|
518
|
+
"""Create a memory with the given fields."""
|
|
519
|
+
if store is None:
|
|
520
|
+
store = get_memory_store()
|
|
521
|
+
memory = Memory(
|
|
522
|
+
name=name,
|
|
523
|
+
description=description,
|
|
524
|
+
content=content,
|
|
525
|
+
metadata={"type": memory_type},
|
|
526
|
+
)
|
|
527
|
+
return store.add(memory)
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
# ── Global instance ──────────────────────────────────────────────────────────
|
|
531
|
+
|
|
532
|
+
_memory_store: MemoryStore | None = None
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def get_memory_store(memory_dir: str | None = None) -> MemoryStore:
|
|
536
|
+
global _memory_store
|
|
537
|
+
if _memory_store is None:
|
|
538
|
+
_memory_store = MemoryStore(memory_dir)
|
|
539
|
+
return _memory_store
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared model metadata — single source of truth for pricing, URL building,
|
|
3
|
+
and model info. Eliminates the duplicated price tables and URL construction
|
|
4
|
+
that were scattered across commands.py, repl_ui.py, main.py, and llm_client.py.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class ModelInfo:
|
|
12
|
+
"""Immutable metadata for a known model."""
|
|
13
|
+
model_id: str
|
|
14
|
+
input_price_per_1m: float # USD per 1M input tokens
|
|
15
|
+
output_price_per_1m: float # USD per 1M output tokens
|
|
16
|
+
provider: str = "" # "openai" | "deepseek" | "anthropic" | "local"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# ── Registry ─────────────────────────────────────────────────────────────────
|
|
20
|
+
|
|
21
|
+
MODEL_REGISTRY: dict[str, ModelInfo] = {
|
|
22
|
+
"gpt-4o": ModelInfo("gpt-4o", 2.50, 10.00, "openai"),
|
|
23
|
+
"gpt-4o-mini": ModelInfo("gpt-4o-mini", 0.15, 0.60, "openai"),
|
|
24
|
+
"gpt-4-turbo": ModelInfo("gpt-4-turbo", 10.00, 30.00, "openai"),
|
|
25
|
+
"gpt-4": ModelInfo("gpt-4", 30.00, 60.00, "openai"),
|
|
26
|
+
"deepseek-chat": ModelInfo("deepseek-chat", 0.14, 0.28, "deepseek"),
|
|
27
|
+
"deepseek-coder": ModelInfo("deepseek-coder", 0.14, 0.28, "deepseek"),
|
|
28
|
+
"deepseek-v4-pro": ModelInfo("deepseek-v4-pro", 0.14, 0.28, "deepseek"),
|
|
29
|
+
"deepseek-v4-flash": ModelInfo("deepseek-v4-flash", 0.14, 0.28, "deepseek"),
|
|
30
|
+
"claude-sonnet-4-6": ModelInfo("claude-sonnet-4-6", 3.00, 15.00, "anthropic"),
|
|
31
|
+
"claude-opus-4-8": ModelInfo("claude-opus-4-8", 15.00, 75.00, "anthropic"),
|
|
32
|
+
"qwen2.5-coder-14b": ModelInfo("qwen2.5-coder-14b", 0.00, 0.00, "local"),
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
# Fallback prices when model is not in the registry
|
|
36
|
+
_FALLBACK_INPUT_PRICE = 1.00
|
|
37
|
+
_FALLBACK_OUTPUT_PRICE = 5.00
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_model_info(model_id: str) -> ModelInfo:
|
|
41
|
+
"""Look up a model in the registry. Returns a fallback for unknown models.
|
|
42
|
+
|
|
43
|
+
Resolution order:
|
|
44
|
+
1. Exact match (e.g. "gpt-4o" → openai)
|
|
45
|
+
2. Strip bracket suffixes like ``[1m]`` / ``[context]``, then exact match
|
|
46
|
+
3. Substring match — longest known key found inside *model_id* wins
|
|
47
|
+
(e.g. "some-prefix-deepseek-chat-v2" → deepseek)
|
|
48
|
+
"""
|
|
49
|
+
# Exact match first
|
|
50
|
+
if model_id in MODEL_REGISTRY:
|
|
51
|
+
return MODEL_REGISTRY[model_id]
|
|
52
|
+
# Strip common suffixes that providers append: "[1m]", "[context]", etc.
|
|
53
|
+
import re
|
|
54
|
+
clean = re.sub(r'\[.*\]', '', model_id).strip()
|
|
55
|
+
if clean in MODEL_REGISTRY:
|
|
56
|
+
return MODEL_REGISTRY[clean]
|
|
57
|
+
# Substring match — longest key wins (prevents "gpt-4" matching before "gpt-4o")
|
|
58
|
+
for key in sorted(MODEL_REGISTRY, key=len, reverse=True):
|
|
59
|
+
if key in model_id:
|
|
60
|
+
return MODEL_REGISTRY[key]
|
|
61
|
+
return ModelInfo(model_id, _FALLBACK_INPUT_PRICE, _FALLBACK_OUTPUT_PRICE, "unknown")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_model_cost(model_id: str) -> tuple[float, float]:
|
|
65
|
+
"""Return (input_price_per_1m, output_price_per_1m) for a model."""
|
|
66
|
+
info = get_model_info(model_id)
|
|
67
|
+
return info.input_price_per_1m, info.output_price_per_1m
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def estimate_cost(token_count: int, model_id: str,
|
|
71
|
+
input_ratio: float = 0.7) -> float:
|
|
72
|
+
"""
|
|
73
|
+
Estimate USD cost from a total token count.
|
|
74
|
+
Assumes *input_ratio* fraction of tokens are input (default 70%).
|
|
75
|
+
"""
|
|
76
|
+
inp_price, out_price = get_model_cost(model_id)
|
|
77
|
+
input_tokens = int(token_count * input_ratio)
|
|
78
|
+
output_tokens = token_count - input_tokens
|
|
79
|
+
return (input_tokens / 1_000_000) * inp_price + (output_tokens / 1_000_000) * out_price
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# ── URL building ─────────────────────────────────────────────────────────────
|
|
83
|
+
|
|
84
|
+
def build_api_url(base_url: str, endpoint: str = "chat/completions") -> str:
|
|
85
|
+
"""
|
|
86
|
+
Build a complete OpenAI-compatible API URL from a base URL and endpoint.
|
|
87
|
+
|
|
88
|
+
Normalizes the base URL:
|
|
89
|
+
https://api.openai.com → https://api.openai.com/v1/chat/completions
|
|
90
|
+
https://api.deepseek.com/v1 → https://api.deepseek.com/v1/chat/completions
|
|
91
|
+
https://api.deepseek.com/v2 → https://api.deepseek.com/v2/chat/completions
|
|
92
|
+
|
|
93
|
+
Use endpoint="" to get just the versioned base, e.g. for /models listing.
|
|
94
|
+
"""
|
|
95
|
+
import re
|
|
96
|
+
base = base_url.rstrip("/")
|
|
97
|
+
if not re.search(r'/v\d+', base):
|
|
98
|
+
base += "/v1"
|
|
99
|
+
if endpoint:
|
|
100
|
+
return f"{base}/{endpoint.lstrip('/')}"
|
|
101
|
+
return base
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def build_models_url(base_url: str) -> str:
|
|
105
|
+
"""Build the /models endpoint URL from a base URL."""
|
|
106
|
+
base = base_url.rstrip("/")
|
|
107
|
+
# Some providers expose /models at root, others at /v1/models
|
|
108
|
+
if "/v1" in base or "/v2" in base:
|
|
109
|
+
return f"{base}/models"
|
|
110
|
+
return f"{base}/v1/models"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# ── Model list from API ──────────────────────────────────────────────────────
|
|
114
|
+
|
|
115
|
+
def fetch_available_models(base_url: str, api_key: str, timeout: float = 10.0) -> list[str]:
|
|
116
|
+
"""
|
|
117
|
+
Fetch the available model list from the API's /models endpoint.
|
|
118
|
+
Returns model IDs, or an empty list on failure.
|
|
119
|
+
|
|
120
|
+
.. note::
|
|
121
|
+
This uses synchronous ``httpx.get()``. Callers inside an async
|
|
122
|
+
event loop should use ``asyncio.to_thread(fetch_available_models, ...)``
|
|
123
|
+
to avoid blocking the loop.
|
|
124
|
+
"""
|
|
125
|
+
import httpx
|
|
126
|
+
url = build_models_url(base_url)
|
|
127
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
128
|
+
try:
|
|
129
|
+
resp = httpx.get(url, headers=headers, timeout=timeout)
|
|
130
|
+
resp.raise_for_status()
|
|
131
|
+
data = resp.json()
|
|
132
|
+
return [m.get("id", "") for m in data.get("data", [])]
|
|
133
|
+
except Exception:
|
|
134
|
+
return []
|