ata-coder 2.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. ata_coder/__init__.py +1 -0
  2. ata_coder/agent.py +874 -0
  3. ata_coder/agent_compact.py +190 -0
  4. ata_coder/agent_controller.py +218 -0
  5. ata_coder/agent_extension.py +69 -0
  6. ata_coder/agent_routing.py +105 -0
  7. ata_coder/agent_subsystems.py +72 -0
  8. ata_coder/agent_tools.py +318 -0
  9. ata_coder/agent_undo.py +63 -0
  10. ata_coder/anthropic_client.py +465 -0
  11. ata_coder/change_tracker.py +368 -0
  12. ata_coder/clawd_integration.py +574 -0
  13. ata_coder/commands/__init__.py +128 -0
  14. ata_coder/commands/_core.py +184 -0
  15. ata_coder/commands/_safety.py +95 -0
  16. ata_coder/commands/_settings.py +241 -0
  17. ata_coder/commands/_workflow.py +451 -0
  18. ata_coder/commands.py +974 -0
  19. ata_coder/config.py +257 -0
  20. ata_coder/core/__init__.py +35 -0
  21. ata_coder/core/events.py +73 -0
  22. ata_coder/core/queue.py +85 -0
  23. ata_coder/core/state.py +17 -0
  24. ata_coder/event_queue.py +5 -0
  25. ata_coder/extension.py +654 -0
  26. ata_coder/extensions/__init__.py +1 -0
  27. ata_coder/extensions/hello_skill.py +47 -0
  28. ata_coder/fool_proof.py +295 -0
  29. ata_coder/git_workflow.py +371 -0
  30. ata_coder/gui.py +511 -0
  31. ata_coder/llm_client.py +543 -0
  32. ata_coder/main.py +814 -0
  33. ata_coder/mcp_client.py +1095 -0
  34. ata_coder/memory.py +539 -0
  35. ata_coder/model_registry.py +134 -0
  36. ata_coder/model_router.py +105 -0
  37. ata_coder/permissions.py +274 -0
  38. ata_coder/privilege.py +464 -0
  39. ata_coder/project.py +273 -0
  40. ata_coder/prompt_template.py +423 -0
  41. ata_coder/prompts/auto-mode.md +7 -0
  42. ata_coder/prompts/coding-rules.md +40 -0
  43. ata_coder/prompts/execution-guardrails.md +14 -0
  44. ata_coder/prompts/memory-system.md +24 -0
  45. ata_coder/prompts/output-style.md +23 -0
  46. ata_coder/prompts/safety.md +17 -0
  47. ata_coder/prompts/slash-commands.md +24 -0
  48. ata_coder/prompts/sub-agents.md +38 -0
  49. ata_coder/prompts/system-reminders.md +17 -0
  50. ata_coder/prompts/system.md +105 -0
  51. ata_coder/prompts/tool-policy.md +46 -0
  52. ata_coder/repl_theme.py +99 -0
  53. ata_coder/repl_tracker.py +89 -0
  54. ata_coder/repl_ui.py +1214 -0
  55. ata_coder/safety_guard.py +434 -0
  56. ata_coder/self_correct.py +346 -0
  57. ata_coder/server.py +882 -0
  58. ata_coder/server_session.py +159 -0
  59. ata_coder/server_shell.py +129 -0
  60. ata_coder/session.py +431 -0
  61. ata_coder/settings.py +439 -0
  62. ata_coder/setup_wizard.py +136 -0
  63. ata_coder/skill_extension.py +92 -0
  64. ata_coder/skills/architect/SKILL.md +42 -0
  65. ata_coder/skills/code-reviewer/SKILL.md +37 -0
  66. ata_coder/skills/codecraft/SKILL.md +452 -0
  67. ata_coder/skills/debugger/SKILL.md +45 -0
  68. ata_coder/skills/doc-writer/SKILL.md +36 -0
  69. ata_coder/skills/general-coder/SKILL.md +76 -0
  70. ata_coder/skills/math-calculator/README.md +40 -0
  71. ata_coder/skills/math-calculator/SKILL.md +59 -0
  72. ata_coder/skills/math-calculator/handler.py +103 -0
  73. ata_coder/skills/math-calculator/prompts/system.md +8 -0
  74. ata_coder/skills/math-calculator/requirements.txt +2 -0
  75. ata_coder/skills/math-calculator/resources/constants.json +8 -0
  76. ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
  77. ata_coder/skills/security-auditor/SKILL.md +40 -0
  78. ata_coder/skills/test-writer/SKILL.md +36 -0
  79. ata_coder/skills/weather-skill/README.md +45 -0
  80. ata_coder/skills/weather-skill/handler.py +76 -0
  81. ata_coder/skills/weather-skill/manifest.json +48 -0
  82. ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
  83. ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
  84. ata_coder/skills/weather-skill/requirements.txt +1 -0
  85. ata_coder/skills/weather-skill/resources/city_list.json +17 -0
  86. ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
  87. ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
  88. ata_coder/skills/weather-skill/weather_utils.py +50 -0
  89. ata_coder/skills.py +1014 -0
  90. ata_coder/sub_agent.py +273 -0
  91. ata_coder/sub_agent_manager.py +203 -0
  92. ata_coder/system_prompt_builder.py +146 -0
  93. ata_coder/task_planner.py +391 -0
  94. ata_coder/terminal.py +318 -0
  95. ata_coder/test_runner.py +219 -0
  96. ata_coder/thread_supervisor.py +195 -0
  97. ata_coder/tool_defs.py +335 -0
  98. ata_coder/tools/__init__.py +11 -0
  99. ata_coder/tools/definitions.py +335 -0
  100. ata_coder/tools/executor.py +1036 -0
  101. ata_coder/tools/result.py +26 -0
  102. ata_coder/tools/subagent.py +332 -0
  103. ata_coder/tools/web.py +361 -0
  104. ata_coder/tools.py +1576 -0
  105. ata_coder/types.py +92 -0
  106. ata_coder/utils.py +113 -0
  107. ata_coder/web/css/style.css +180 -0
  108. ata_coder/web/index.html +84 -0
  109. ata_coder/web/js/app.js +489 -0
  110. ata_coder/web/package-lock.json +25 -0
  111. ata_coder/web/package.json +10 -0
  112. ata_coder/web/tsconfig.json +13 -0
  113. ata_coder-2.4.2.dist-info/METADATA +799 -0
  114. ata_coder-2.4.2.dist-info/RECORD +118 -0
  115. ata_coder-2.4.2.dist-info/WHEEL +5 -0
  116. ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
  117. ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
  118. ata_coder-2.4.2.dist-info/top_level.txt +1 -0
ata_coder/memory.py ADDED
@@ -0,0 +1,539 @@
1
+ """
2
+ Persistent memory system for ATA Coder.
3
+
4
+ Stores facts, user preferences, feedback, and project context across sessions.
5
+ Uses a file-based approach:
6
+ - memory/MEMORY.md — index of all memories (loaded on startup)
7
+ - memory/<slug>.md — individual memory files with YAML frontmatter
8
+
9
+ Memory types:
10
+ - user: who the user is, their preferences, expertise
11
+ - feedback: user guidance on how the agent should work
12
+ - project: ongoing goals, constraints, architecture decisions
13
+ - reference: pointers to external resources (URLs, docs, etc.)
14
+ """
15
+
16
+ import json
17
+ import logging
18
+ import os
19
+ import re
20
+ import threading
21
+ from dataclasses import dataclass, field
22
+ from datetime import datetime, timezone
23
+ from pathlib import Path
24
+ from typing import Any
25
+
26
+ from .utils import try_import_yaml
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ _yaml_mod, HAS_YAML = try_import_yaml()
31
+
32
+
33
+ # ── Memory data model ────────────────────────────────────────────────────────
34
+
35
+ @dataclass
36
+ class Memory:
37
+ """A single memory entry."""
38
+
39
+ name: str # kebab-case slug, used as filename
40
+ description: str # one-line summary (used for relevance)
41
+ content: str # the memory body
42
+ metadata: dict[str, Any] = field(default_factory=dict) # type, tags, etc.
43
+ created: str = "" # ISO timestamp
44
+ updated: str = "" # ISO timestamp
45
+
46
+ @property
47
+ def memory_type(self) -> str:
48
+ return self.metadata.get("type", "reference")
49
+
50
+ @property
51
+ def file_path(self) -> str:
52
+ return f"{self.name}.md"
53
+
54
+ def to_frontmatter(self) -> str:
55
+ """Serialize to a markdown file with YAML frontmatter."""
56
+ now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
57
+ meta = {
58
+ "name": self.name, "description": self.description,
59
+ "metadata": self.metadata, "created": self.created or now, "updated": now,
60
+ }
61
+ if _yaml_mod is not None:
62
+ yaml_str = _yaml_mod.dump(meta, default_flow_style=False, allow_unicode=True, sort_keys=False)
63
+ else:
64
+ yaml_str = json.dumps(meta, indent=2, ensure_ascii=False)
65
+ return f"---\n{yaml_str}---\n\n{self.content}"
66
+
67
+ @classmethod
68
+ def from_frontmatter(cls, raw: str) -> "Memory | None":
69
+ """Parse a markdown file with YAML frontmatter into a Memory."""
70
+ match = re.match(r"^---\s*\n(.*?)\n---\s*\n(.*)", raw, re.DOTALL)
71
+ if not match:
72
+ return None
73
+ front_str, content = match.group(1), match.group(2).strip()
74
+ try:
75
+ try:
76
+ if _yaml_mod is not None:
77
+ meta = _yaml_mod.safe_load(front_str)
78
+ else:
79
+ meta = json.loads(front_str)
80
+ except ImportError:
81
+ meta = json.loads(front_str)
82
+ except Exception as e:
83
+ logger.warning("Failed to parse frontmatter: %s", e)
84
+ return None
85
+ if not isinstance(meta, dict):
86
+ return None
87
+ return cls(
88
+ name=meta.get("name", "unknown"), description=meta.get("description", ""),
89
+ content=content, metadata=meta.get("metadata", {}),
90
+ created=meta.get("created", ""), updated=meta.get("updated", ""),
91
+ )
92
+
93
+
94
+ # ── Memory store ─────────────────────────────────────────────────────────────
95
+
96
+ class MemoryStore:
97
+ """
98
+ Persistent file-based memory store.
99
+
100
+ On initialization, reads MEMORY.md for the index, then loads individual
101
+ memory files on demand or all at once.
102
+ """
103
+
104
+ def __init__(self, memory_dir: str | Path | None = None):
105
+ if memory_dir is None:
106
+ try:
107
+ from .settings import get_settings
108
+ memory_dir = get_settings().memory_dir
109
+ except Exception:
110
+ memory_dir = Path.home() / ".ata_coder" / "memory"
111
+ self.memory_dir = Path(memory_dir)
112
+ self.memory_dir.mkdir(parents=True, exist_ok=True)
113
+
114
+ self._index_path = self.memory_dir / "MEMORY.md"
115
+ self._memories: dict[str, Memory] = {}
116
+ self._index_entries: list[str] = [] # lines from MEMORY.md
117
+ # IDF cache — invalidated on add/delete
118
+ self._idf_cache: dict[str, float] | None = None
119
+ self._idf_doc_count: int = 0
120
+ self._lock = threading.RLock() # protect concurrent read/write
121
+
122
+ self._load_index()
123
+ self._load_all()
124
+
125
+ # ── Loading ───────────────────────────────────────────────────────────
126
+
127
+ def _load_index(self) -> None:
128
+ """Load the MEMORY.md index file."""
129
+ if self._index_path.exists():
130
+ try:
131
+ with open(self._index_path, "r", encoding="utf-8") as f:
132
+ self._index_entries = [
133
+ line.strip() for line in f.readlines() if line.strip()
134
+ ]
135
+ logger.debug(
136
+ "Loaded MEMORY.md: %d entries", len(self._index_entries)
137
+ )
138
+ except Exception as e:
139
+ logger.warning("Failed to load MEMORY.md: %s", e)
140
+ self._index_entries = []
141
+ else:
142
+ # Create empty index
143
+ self._write_index()
144
+
145
+ def _load_all(self) -> None:
146
+ """Load all memory files from the directory."""
147
+ if not self.memory_dir.exists():
148
+ return
149
+
150
+ for file_path in self.memory_dir.glob("*.md"):
151
+ if file_path.name == "MEMORY.md":
152
+ continue
153
+ try:
154
+ with open(file_path, "r", encoding="utf-8") as f:
155
+ raw = f.read()
156
+ memory = Memory.from_frontmatter(raw)
157
+ if memory:
158
+ self._memories[memory.name] = memory
159
+ else:
160
+ logger.warning("Failed to parse memory file: %s", file_path.name)
161
+ except Exception as e:
162
+ logger.warning("Failed to read memory file %s: %s", file_path.name, e)
163
+
164
+ logger.debug("Loaded %d memories from disk", len(self._memories))
165
+
166
+ def _write_index(self) -> None:
167
+ """Write the index file atomically (write-then-rename)."""
168
+ tmp = self._index_path.with_suffix(".tmp")
169
+ try:
170
+ with open(tmp, "w", encoding="utf-8") as f:
171
+ for entry in self._index_entries:
172
+ f.write(entry + "\n")
173
+ # os.replace is atomic cross-platform; Path.replace raises
174
+ # FileExistsError on Windows for existing targets.
175
+ os.replace(tmp, self._index_path)
176
+ except Exception as e:
177
+ logger.warning("Failed to write MEMORY.md: %s", e)
178
+
179
+ # ── CRUD operations ──────────────────────────────────────────────────
180
+
181
+ def add(self, memory: Memory) -> Memory:
182
+ """
183
+ Add or update a memory. If one with the same name exists, update it.
184
+ """
185
+ with self._lock:
186
+ self._idf_cache = None # invalidate IDF cache
187
+ existing = self._memories.get(memory.name)
188
+
189
+ now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
190
+ if existing:
191
+ memory.created = existing.created
192
+ else:
193
+ memory.created = memory.created or now
194
+ memory.updated = now
195
+
196
+ # Write memory file
197
+ file_path = self.memory_dir / memory.file_path
198
+ try:
199
+ with open(file_path, "w", encoding="utf-8") as f:
200
+ f.write(memory.to_frontmatter())
201
+ except Exception as e:
202
+ logger.error("Failed to write memory file %s: %s", file_path, e)
203
+ raise
204
+
205
+ self._memories[memory.name] = memory
206
+
207
+ # Update index
208
+ entry = f"- [{memory.description}]({memory.file_path})"
209
+ replaced = False
210
+ for i, line in enumerate(self._index_entries):
211
+ if f"]({memory.file_path})" in line:
212
+ self._index_entries[i] = entry
213
+ replaced = True
214
+ break
215
+ if not replaced:
216
+ self._index_entries.append(entry)
217
+
218
+ self._write_index()
219
+ logger.info("Saved memory: %s", memory.name)
220
+ return memory
221
+
222
+ def save_batch(self, memories: list[Memory]) -> list[Memory]:
223
+ """Save multiple memories efficiently — writes index only once."""
224
+ with self._lock:
225
+ self._idf_cache = None
226
+ for memory in memories:
227
+ file_path = self.memory_dir / memory.file_path
228
+ try:
229
+ with open(file_path, "w", encoding="utf-8") as f:
230
+ f.write(memory.to_frontmatter())
231
+ except Exception as e:
232
+ logger.error("Failed to write memory file %s: %s", file_path, e)
233
+ continue
234
+ self._memories[memory.name] = memory
235
+ # Rebuild and write index once
236
+ self._rebuild_index()
237
+ logger.info("Batch saved %d memories", len(memories))
238
+ return memories
239
+
240
+ def _rebuild_index(self) -> None:
241
+ """Rebuild the index from all loaded memories (batch-safe)."""
242
+ self._index_entries = [
243
+ f"- [{m.description}]({m.file_path})"
244
+ for m in self._memories.values()
245
+ ]
246
+ self._write_index()
247
+
248
+ def flush(self) -> None:
249
+ """Force-write the index to disk (call before shutdown)."""
250
+ self._rebuild_index()
251
+
252
+ def get(self, name: str) -> Memory | None:
253
+ """Get a memory by name (slug)."""
254
+ return self._memories.get(name)
255
+
256
+ def delete(self, name: str) -> bool:
257
+ """Delete a memory by name."""
258
+ with self._lock:
259
+ self._idf_cache = None
260
+ memory = self._memories.pop(name, None)
261
+ if memory is None:
262
+ return False
263
+
264
+ file_path = self.memory_dir / memory.file_path
265
+ if file_path.exists():
266
+ try:
267
+ file_path.unlink()
268
+ except Exception as e:
269
+ logger.warning("Failed to delete memory file: %s", e)
270
+
271
+ self._index_entries = [
272
+ line for line in self._index_entries
273
+ if f"]({memory.file_path})" not in line
274
+ ]
275
+ self._write_index()
276
+ logger.info("Deleted memory: %s", name)
277
+ return True
278
+
279
+ def list_all(self, memory_type: str | None = None) -> list[Memory]:
280
+ """List all memories, optionally filtered by type."""
281
+ memories = list(self._memories.values())
282
+ if memory_type:
283
+ memories = [m for m in memories if m.memory_type == memory_type]
284
+ # Sort by updated (handle both string and datetime types)
285
+ def sort_key(m: Memory) -> str:
286
+ return str(m.updated or "")
287
+ return sorted(memories, key=sort_key, reverse=True)
288
+
289
+ def search(self, query: str) -> list[Memory]:
290
+ """Search memories by TF-IDF-weighted token overlap.
291
+ Returns memories sorted by relevance score (descending).
292
+ """
293
+ scored = self._search_scored(query)
294
+ return [m for _, m in scored]
295
+
296
+ def _search_scored(self, query: str) -> list[tuple[float, Memory]]:
297
+ """
298
+ Score every memory against *query* with TF-IDF-weighted token
299
+ overlap plus phrase bonuses and recency boost.
300
+
301
+ Returns (score, memory) pairs sorted by score descending.
302
+ """
303
+ if not self._memories:
304
+ return []
305
+
306
+ query_lower = query.lower()
307
+ query_tokens = set(query_lower.split())
308
+
309
+ # ── Pre-compute document frequencies for IDF weighting ──────────
310
+ # Use cached IDF when available; rebuild only when memories change.
311
+ import math as _math
312
+ doc_count = len(self._memories)
313
+ if self._idf_cache is None or self._idf_doc_count != doc_count:
314
+ token_df: dict[str, int] = {}
315
+ for m in self._memories.values():
316
+ text = f"{m.name} {m.description} {m.content}".lower()
317
+ seen: set[str] = set()
318
+ for word in text.split():
319
+ if word not in seen:
320
+ token_df[word] = token_df.get(word, 0) + 1
321
+ seen.add(word)
322
+ # Pre-compute IDF for every token
323
+ self._idf_cache = {
324
+ t: _math.log((doc_count + 1) / (df + 1)) + 1.0
325
+ for t, df in token_df.items()
326
+ }
327
+ self._idf_doc_count = doc_count
328
+
329
+ idf_map = self._idf_cache
330
+
331
+ def idf(token: str) -> float:
332
+ return idf_map.get(token, 1.0) # unseen tokens get neutral weight
333
+
334
+ # ── Score each memory ──────────────────────────────────────────
335
+ results: list[tuple[float, Memory]] = []
336
+ for memory in self._memories.values():
337
+ score = 0.0
338
+ name_lower = memory.name.lower()
339
+ desc_lower = memory.description.lower()
340
+ content_lower = memory.content.lower()
341
+
342
+ # Phrase bonus: full query appears as substring
343
+ if query_lower in name_lower:
344
+ score += 15.0
345
+ if query_lower in desc_lower:
346
+ score += 8.0
347
+ if query_lower in content_lower:
348
+ score += 4.0
349
+
350
+ # Token-level IDF-weighted match
351
+ for token in query_tokens:
352
+ w = idf(token)
353
+ if token in name_lower.replace("-", " ").split():
354
+ score += 6.0 * w # name match — highest signal
355
+ if token in set(desc_lower.split()):
356
+ score += 3.0 * w # description match — medium signal
357
+ if token in set(content_lower.split()):
358
+ score += 1.5 * w # content match — lower signal
359
+
360
+ # Recency boost: memories touched in the last hour get +2
361
+ try:
362
+ from datetime import datetime, timezone, timedelta
363
+ updated = memory.updated or ""
364
+ if updated:
365
+ dt = datetime.fromisoformat(updated.replace("Z", "+00:00"))
366
+ if dt > datetime.now(timezone.utc) - timedelta(hours=1):
367
+ score += 2.0
368
+ except (ValueError, TypeError):
369
+ pass
370
+
371
+ if score > 0:
372
+ results.append((score, memory))
373
+
374
+ results.sort(key=lambda x: x[0], reverse=True)
375
+ return results
376
+
377
+ # ── Recall for context ───────────────────────────────────────────────
378
+
379
+ def recall_context(self, user_input: str, max_memories: int = 5,
380
+ min_score: float = 3.0) -> str:
381
+ """
382
+ Recall memories relevant to *user_input* for inclusion in the system
383
+ prompt. Only returns memories whose relevance score exceeds
384
+ *min_score* so the prompt doesn't get polluted with noise.
385
+ """
386
+ if not self._memories:
387
+ return ""
388
+
389
+ # Re-use the scored search
390
+ scored = self._search_scored(user_input)
391
+ relevant = [m for score, m in scored if score >= min_score][:max_memories]
392
+ if not relevant:
393
+ return ""
394
+
395
+ # Bump access tracking
396
+ now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
397
+ for m in relevant:
398
+ m.metadata["last_accessed"] = now
399
+ m.metadata["access_count"] = m.metadata.get("access_count", 0) + 1
400
+
401
+ lines = ["\n## Relevant Memories"]
402
+ for memory in relevant:
403
+ lines.append(f"\n### {memory.description}")
404
+ content = memory.content
405
+ if len(content) > 500:
406
+ content = content[:500] + "..."
407
+ lines.append(content)
408
+ refs = self._extract_links(memory.content)
409
+ if refs:
410
+ lines.append(f"Related: {', '.join(refs)}")
411
+ return "\n".join(lines)
412
+
413
+ def _extract_links(self, content: str) -> list[str]:
414
+ """Extract [[wiki-style]] links from content."""
415
+ return re.findall(r"\[\[([^\]]+)\]\]", content)
416
+
417
+ def get_memory_context(self, max_total: int = 8) -> str:
418
+ """
419
+ Return a compact summary of recently-updated memories for the
420
+ system prompt. Capped at *max_total* entries so the prompt
421
+ doesn't bloat when the user has dozens of memories.
422
+ """
423
+ if not self._memories:
424
+ return ""
425
+
426
+ def _sort_key(m: Memory) -> str:
427
+ return str(m.updated or "")
428
+
429
+ recent = sorted(self._memories.values(), key=_sort_key, reverse=True)[:max_total]
430
+ if not recent:
431
+ return ""
432
+
433
+ lines = ["\n## Persistent Memory"]
434
+ by_type: dict[str, list[Memory]] = {}
435
+ for m in recent:
436
+ by_type.setdefault(m.memory_type, []).append(m)
437
+
438
+ for mtype in ["user", "project", "feedback", "reference"]:
439
+ entries = by_type.get(mtype, [])
440
+ if entries:
441
+ lines.append(f"\n### {mtype.title()}")
442
+ for m in entries[:3]:
443
+ lines.append(f"- {m.description}")
444
+ return "\n".join(lines)
445
+
446
+ # ── Auto-suggest from conversation ──────────────────────────────────
447
+
448
+ def suggest_from_conversation(self, user_messages: list[str],
449
+ file_ops: list[str] | None = None,
450
+ tool_errors: list[str] | None = None) -> list[str]:
451
+ """Analyse recent messages for facts worth saving as memories.
452
+
453
+ Returns a list of human-readable suggestions like
454
+ ``"user prefers YAML over JSON for config"`` that the agent can
455
+ surface to the user with a quick save prompt.
456
+ """
457
+ suggestions: list[str] = []
458
+
459
+ # Heuristic 1: explicit "remember …" or "save …" directives
460
+ for msg in user_messages:
461
+ lower = msg.lower()
462
+ if any(kw in lower for kw in ("remember", "save this", "don't forget",
463
+ "记", "记住", "备忘")):
464
+ suggestions.append(f"User asked to remember: {msg[:120]}")
465
+
466
+ # Heuristic 2: project-specific paths or toolchains mentioned
467
+ toolchain_keywords = ["idf.py", "esp-idf", "esptool", "cmake", "platformio",
468
+ "arduino", "stm32", "nrf", "zephyr"]
469
+ for msg in user_messages:
470
+ for kw in toolchain_keywords:
471
+ if kw.lower() in msg.lower():
472
+ suggestions.append(
473
+ f"Project uses {kw}: {msg[:120]}"
474
+ )
475
+ break
476
+
477
+ # Heuristic 3: device ports / serial config
478
+ import re as _re
479
+ for msg in user_messages:
480
+ port_match = _re.search(r'COM\d+|/dev/tty\w+', msg)
481
+ if port_match:
482
+ suggestions.append(
483
+ f"Device port {port_match.group()}: {msg[:120]}"
484
+ )
485
+
486
+ # Heuristic 4: operational learnings — detect "X failed → Y worked" patterns
487
+ if tool_errors:
488
+ for err in tool_errors:
489
+ lower = err.lower()
490
+ if "not in the allowed list" in lower:
491
+ suggestions.append(
492
+ "ops: Some shell commands are blocked by the whitelist. "
493
+ "Use python -c \"import subprocess; subprocess.run([...], cwd='...')\" "
494
+ "as a workaround for tools not on PATH."
495
+ )
496
+ break
497
+ if "command not found" in lower or "not recognized" in lower:
498
+ # Extract the command name
499
+ m = _re.search(r"'(\w+)'", err)
500
+ cmd = m.group(1) if m else "?"
501
+ suggestions.append(
502
+ f"ops: Command '{cmd}' not found — use full path or "
503
+ f"python subprocess wrapper."
504
+ )
505
+
506
+ return suggestions[:5] # cap to avoid overwhelming the user
507
+
508
+
509
+ # ── Convenience functions ────────────────────────────────────────────────────
510
+
511
+ def create_memory(
512
+ name: str,
513
+ description: str,
514
+ content: str,
515
+ memory_type: str = "reference",
516
+ store: MemoryStore | None = None,
517
+ ) -> Memory:
518
+ """Create a memory with the given fields."""
519
+ if store is None:
520
+ store = get_memory_store()
521
+ memory = Memory(
522
+ name=name,
523
+ description=description,
524
+ content=content,
525
+ metadata={"type": memory_type},
526
+ )
527
+ return store.add(memory)
528
+
529
+
530
+ # ── Global instance ──────────────────────────────────────────────────────────
531
+
532
+ _memory_store: MemoryStore | None = None
533
+
534
+
535
+ def get_memory_store(memory_dir: str | None = None) -> MemoryStore:
536
+ global _memory_store
537
+ if _memory_store is None:
538
+ _memory_store = MemoryStore(memory_dir)
539
+ return _memory_store
@@ -0,0 +1,134 @@
1
+ """
2
+ Shared model metadata — single source of truth for pricing, URL building,
3
+ and model info. Eliminates the duplicated price tables and URL construction
4
+ that were scattered across commands.py, repl_ui.py, main.py, and llm_client.py.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class ModelInfo:
12
+ """Immutable metadata for a known model."""
13
+ model_id: str
14
+ input_price_per_1m: float # USD per 1M input tokens
15
+ output_price_per_1m: float # USD per 1M output tokens
16
+ provider: str = "" # "openai" | "deepseek" | "anthropic" | "local"
17
+
18
+
19
+ # ── Registry ─────────────────────────────────────────────────────────────────
20
+
21
+ MODEL_REGISTRY: dict[str, ModelInfo] = {
22
+ "gpt-4o": ModelInfo("gpt-4o", 2.50, 10.00, "openai"),
23
+ "gpt-4o-mini": ModelInfo("gpt-4o-mini", 0.15, 0.60, "openai"),
24
+ "gpt-4-turbo": ModelInfo("gpt-4-turbo", 10.00, 30.00, "openai"),
25
+ "gpt-4": ModelInfo("gpt-4", 30.00, 60.00, "openai"),
26
+ "deepseek-chat": ModelInfo("deepseek-chat", 0.14, 0.28, "deepseek"),
27
+ "deepseek-coder": ModelInfo("deepseek-coder", 0.14, 0.28, "deepseek"),
28
+ "deepseek-v4-pro": ModelInfo("deepseek-v4-pro", 0.14, 0.28, "deepseek"),
29
+ "deepseek-v4-flash": ModelInfo("deepseek-v4-flash", 0.14, 0.28, "deepseek"),
30
+ "claude-sonnet-4-6": ModelInfo("claude-sonnet-4-6", 3.00, 15.00, "anthropic"),
31
+ "claude-opus-4-8": ModelInfo("claude-opus-4-8", 15.00, 75.00, "anthropic"),
32
+ "qwen2.5-coder-14b": ModelInfo("qwen2.5-coder-14b", 0.00, 0.00, "local"),
33
+ }
34
+
35
+ # Fallback prices when model is not in the registry
36
+ _FALLBACK_INPUT_PRICE = 1.00
37
+ _FALLBACK_OUTPUT_PRICE = 5.00
38
+
39
+
40
+ def get_model_info(model_id: str) -> ModelInfo:
41
+ """Look up a model in the registry. Returns a fallback for unknown models.
42
+
43
+ Resolution order:
44
+ 1. Exact match (e.g. "gpt-4o" → openai)
45
+ 2. Strip bracket suffixes like ``[1m]`` / ``[context]``, then exact match
46
+ 3. Substring match — longest known key found inside *model_id* wins
47
+ (e.g. "some-prefix-deepseek-chat-v2" → deepseek)
48
+ """
49
+ # Exact match first
50
+ if model_id in MODEL_REGISTRY:
51
+ return MODEL_REGISTRY[model_id]
52
+ # Strip common suffixes that providers append: "[1m]", "[context]", etc.
53
+ import re
54
+ clean = re.sub(r'\[.*\]', '', model_id).strip()
55
+ if clean in MODEL_REGISTRY:
56
+ return MODEL_REGISTRY[clean]
57
+ # Substring match — longest key wins (prevents "gpt-4" matching before "gpt-4o")
58
+ for key in sorted(MODEL_REGISTRY, key=len, reverse=True):
59
+ if key in model_id:
60
+ return MODEL_REGISTRY[key]
61
+ return ModelInfo(model_id, _FALLBACK_INPUT_PRICE, _FALLBACK_OUTPUT_PRICE, "unknown")
62
+
63
+
64
+ def get_model_cost(model_id: str) -> tuple[float, float]:
65
+ """Return (input_price_per_1m, output_price_per_1m) for a model."""
66
+ info = get_model_info(model_id)
67
+ return info.input_price_per_1m, info.output_price_per_1m
68
+
69
+
70
+ def estimate_cost(token_count: int, model_id: str,
71
+ input_ratio: float = 0.7) -> float:
72
+ """
73
+ Estimate USD cost from a total token count.
74
+ Assumes *input_ratio* fraction of tokens are input (default 70%).
75
+ """
76
+ inp_price, out_price = get_model_cost(model_id)
77
+ input_tokens = int(token_count * input_ratio)
78
+ output_tokens = token_count - input_tokens
79
+ return (input_tokens / 1_000_000) * inp_price + (output_tokens / 1_000_000) * out_price
80
+
81
+
82
+ # ── URL building ─────────────────────────────────────────────────────────────
83
+
84
+ def build_api_url(base_url: str, endpoint: str = "chat/completions") -> str:
85
+ """
86
+ Build a complete OpenAI-compatible API URL from a base URL and endpoint.
87
+
88
+ Normalizes the base URL:
89
+ https://api.openai.com → https://api.openai.com/v1/chat/completions
90
+ https://api.deepseek.com/v1 → https://api.deepseek.com/v1/chat/completions
91
+ https://api.deepseek.com/v2 → https://api.deepseek.com/v2/chat/completions
92
+
93
+ Use endpoint="" to get just the versioned base, e.g. for /models listing.
94
+ """
95
+ import re
96
+ base = base_url.rstrip("/")
97
+ if not re.search(r'/v\d+', base):
98
+ base += "/v1"
99
+ if endpoint:
100
+ return f"{base}/{endpoint.lstrip('/')}"
101
+ return base
102
+
103
+
104
+ def build_models_url(base_url: str) -> str:
105
+ """Build the /models endpoint URL from a base URL."""
106
+ base = base_url.rstrip("/")
107
+ # Some providers expose /models at root, others at /v1/models
108
+ if "/v1" in base or "/v2" in base:
109
+ return f"{base}/models"
110
+ return f"{base}/v1/models"
111
+
112
+
113
+ # ── Model list from API ──────────────────────────────────────────────────────
114
+
115
+ def fetch_available_models(base_url: str, api_key: str, timeout: float = 10.0) -> list[str]:
116
+ """
117
+ Fetch the available model list from the API's /models endpoint.
118
+ Returns model IDs, or an empty list on failure.
119
+
120
+ .. note::
121
+ This uses synchronous ``httpx.get()``. Callers inside an async
122
+ event loop should use ``asyncio.to_thread(fetch_available_models, ...)``
123
+ to avoid blocking the loop.
124
+ """
125
+ import httpx
126
+ url = build_models_url(base_url)
127
+ headers = {"Authorization": f"Bearer {api_key}"}
128
+ try:
129
+ resp = httpx.get(url, headers=headers, timeout=timeout)
130
+ resp.raise_for_status()
131
+ data = resp.json()
132
+ return [m.get("id", "") for m in data.get("data", [])]
133
+ except Exception:
134
+ return []