agent-cli 0.70.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_cli/__init__.py +5 -0
- agent_cli/__main__.py +6 -0
- agent_cli/_extras.json +14 -0
- agent_cli/_requirements/.gitkeep +0 -0
- agent_cli/_requirements/audio.txt +79 -0
- agent_cli/_requirements/faster-whisper.txt +215 -0
- agent_cli/_requirements/kokoro.txt +425 -0
- agent_cli/_requirements/llm.txt +183 -0
- agent_cli/_requirements/memory.txt +355 -0
- agent_cli/_requirements/mlx-whisper.txt +222 -0
- agent_cli/_requirements/piper.txt +176 -0
- agent_cli/_requirements/rag.txt +402 -0
- agent_cli/_requirements/server.txt +154 -0
- agent_cli/_requirements/speed.txt +77 -0
- agent_cli/_requirements/vad.txt +155 -0
- agent_cli/_requirements/wyoming.txt +71 -0
- agent_cli/_tools.py +368 -0
- agent_cli/agents/__init__.py +23 -0
- agent_cli/agents/_voice_agent_common.py +136 -0
- agent_cli/agents/assistant.py +383 -0
- agent_cli/agents/autocorrect.py +284 -0
- agent_cli/agents/chat.py +496 -0
- agent_cli/agents/memory/__init__.py +31 -0
- agent_cli/agents/memory/add.py +190 -0
- agent_cli/agents/memory/proxy.py +160 -0
- agent_cli/agents/rag_proxy.py +128 -0
- agent_cli/agents/speak.py +209 -0
- agent_cli/agents/transcribe.py +671 -0
- agent_cli/agents/transcribe_daemon.py +499 -0
- agent_cli/agents/voice_edit.py +291 -0
- agent_cli/api.py +22 -0
- agent_cli/cli.py +106 -0
- agent_cli/config.py +503 -0
- agent_cli/config_cmd.py +307 -0
- agent_cli/constants.py +27 -0
- agent_cli/core/__init__.py +1 -0
- agent_cli/core/audio.py +461 -0
- agent_cli/core/audio_format.py +299 -0
- agent_cli/core/chroma.py +88 -0
- agent_cli/core/deps.py +191 -0
- agent_cli/core/openai_proxy.py +139 -0
- agent_cli/core/process.py +195 -0
- agent_cli/core/reranker.py +120 -0
- agent_cli/core/sse.py +87 -0
- agent_cli/core/transcription_logger.py +70 -0
- agent_cli/core/utils.py +526 -0
- agent_cli/core/vad.py +175 -0
- agent_cli/core/watch.py +65 -0
- agent_cli/dev/__init__.py +14 -0
- agent_cli/dev/cli.py +1588 -0
- agent_cli/dev/coding_agents/__init__.py +19 -0
- agent_cli/dev/coding_agents/aider.py +24 -0
- agent_cli/dev/coding_agents/base.py +167 -0
- agent_cli/dev/coding_agents/claude.py +39 -0
- agent_cli/dev/coding_agents/codex.py +24 -0
- agent_cli/dev/coding_agents/continue_dev.py +15 -0
- agent_cli/dev/coding_agents/copilot.py +24 -0
- agent_cli/dev/coding_agents/cursor_agent.py +48 -0
- agent_cli/dev/coding_agents/gemini.py +28 -0
- agent_cli/dev/coding_agents/opencode.py +15 -0
- agent_cli/dev/coding_agents/registry.py +49 -0
- agent_cli/dev/editors/__init__.py +19 -0
- agent_cli/dev/editors/base.py +89 -0
- agent_cli/dev/editors/cursor.py +15 -0
- agent_cli/dev/editors/emacs.py +46 -0
- agent_cli/dev/editors/jetbrains.py +56 -0
- agent_cli/dev/editors/nano.py +31 -0
- agent_cli/dev/editors/neovim.py +33 -0
- agent_cli/dev/editors/registry.py +59 -0
- agent_cli/dev/editors/sublime.py +20 -0
- agent_cli/dev/editors/vim.py +42 -0
- agent_cli/dev/editors/vscode.py +15 -0
- agent_cli/dev/editors/zed.py +20 -0
- agent_cli/dev/project.py +568 -0
- agent_cli/dev/registry.py +52 -0
- agent_cli/dev/skill/SKILL.md +141 -0
- agent_cli/dev/skill/examples.md +571 -0
- agent_cli/dev/terminals/__init__.py +19 -0
- agent_cli/dev/terminals/apple_terminal.py +82 -0
- agent_cli/dev/terminals/base.py +56 -0
- agent_cli/dev/terminals/gnome.py +51 -0
- agent_cli/dev/terminals/iterm2.py +84 -0
- agent_cli/dev/terminals/kitty.py +77 -0
- agent_cli/dev/terminals/registry.py +48 -0
- agent_cli/dev/terminals/tmux.py +58 -0
- agent_cli/dev/terminals/warp.py +132 -0
- agent_cli/dev/terminals/zellij.py +78 -0
- agent_cli/dev/worktree.py +856 -0
- agent_cli/docs_gen.py +417 -0
- agent_cli/example-config.toml +185 -0
- agent_cli/install/__init__.py +5 -0
- agent_cli/install/common.py +89 -0
- agent_cli/install/extras.py +174 -0
- agent_cli/install/hotkeys.py +48 -0
- agent_cli/install/services.py +87 -0
- agent_cli/memory/__init__.py +7 -0
- agent_cli/memory/_files.py +250 -0
- agent_cli/memory/_filters.py +63 -0
- agent_cli/memory/_git.py +157 -0
- agent_cli/memory/_indexer.py +142 -0
- agent_cli/memory/_ingest.py +408 -0
- agent_cli/memory/_persistence.py +182 -0
- agent_cli/memory/_prompt.py +91 -0
- agent_cli/memory/_retrieval.py +294 -0
- agent_cli/memory/_store.py +169 -0
- agent_cli/memory/_streaming.py +44 -0
- agent_cli/memory/_tasks.py +48 -0
- agent_cli/memory/api.py +113 -0
- agent_cli/memory/client.py +272 -0
- agent_cli/memory/engine.py +361 -0
- agent_cli/memory/entities.py +43 -0
- agent_cli/memory/models.py +112 -0
- agent_cli/opts.py +433 -0
- agent_cli/py.typed +0 -0
- agent_cli/rag/__init__.py +3 -0
- agent_cli/rag/_indexer.py +67 -0
- agent_cli/rag/_indexing.py +226 -0
- agent_cli/rag/_prompt.py +30 -0
- agent_cli/rag/_retriever.py +156 -0
- agent_cli/rag/_store.py +48 -0
- agent_cli/rag/_utils.py +218 -0
- agent_cli/rag/api.py +175 -0
- agent_cli/rag/client.py +299 -0
- agent_cli/rag/engine.py +302 -0
- agent_cli/rag/models.py +55 -0
- agent_cli/scripts/.runtime/.gitkeep +0 -0
- agent_cli/scripts/__init__.py +1 -0
- agent_cli/scripts/check_plugin_skill_sync.py +50 -0
- agent_cli/scripts/linux-hotkeys/README.md +63 -0
- agent_cli/scripts/linux-hotkeys/toggle-autocorrect.sh +45 -0
- agent_cli/scripts/linux-hotkeys/toggle-transcription.sh +58 -0
- agent_cli/scripts/linux-hotkeys/toggle-voice-edit.sh +58 -0
- agent_cli/scripts/macos-hotkeys/README.md +45 -0
- agent_cli/scripts/macos-hotkeys/skhd-config-example +5 -0
- agent_cli/scripts/macos-hotkeys/toggle-autocorrect.sh +12 -0
- agent_cli/scripts/macos-hotkeys/toggle-transcription.sh +37 -0
- agent_cli/scripts/macos-hotkeys/toggle-voice-edit.sh +37 -0
- agent_cli/scripts/nvidia-asr-server/README.md +99 -0
- agent_cli/scripts/nvidia-asr-server/pyproject.toml +27 -0
- agent_cli/scripts/nvidia-asr-server/server.py +255 -0
- agent_cli/scripts/nvidia-asr-server/shell.nix +32 -0
- agent_cli/scripts/nvidia-asr-server/uv.lock +4654 -0
- agent_cli/scripts/run-openwakeword.sh +11 -0
- agent_cli/scripts/run-piper-windows.ps1 +30 -0
- agent_cli/scripts/run-piper.sh +24 -0
- agent_cli/scripts/run-whisper-linux.sh +40 -0
- agent_cli/scripts/run-whisper-macos.sh +6 -0
- agent_cli/scripts/run-whisper-windows.ps1 +51 -0
- agent_cli/scripts/run-whisper.sh +9 -0
- agent_cli/scripts/run_faster_whisper_server.py +136 -0
- agent_cli/scripts/setup-linux-hotkeys.sh +72 -0
- agent_cli/scripts/setup-linux.sh +108 -0
- agent_cli/scripts/setup-macos-hotkeys.sh +61 -0
- agent_cli/scripts/setup-macos.sh +76 -0
- agent_cli/scripts/setup-windows.ps1 +63 -0
- agent_cli/scripts/start-all-services-windows.ps1 +53 -0
- agent_cli/scripts/start-all-services.sh +178 -0
- agent_cli/scripts/sync_extras.py +138 -0
- agent_cli/server/__init__.py +3 -0
- agent_cli/server/cli.py +721 -0
- agent_cli/server/common.py +222 -0
- agent_cli/server/model_manager.py +288 -0
- agent_cli/server/model_registry.py +225 -0
- agent_cli/server/proxy/__init__.py +3 -0
- agent_cli/server/proxy/api.py +444 -0
- agent_cli/server/streaming.py +67 -0
- agent_cli/server/tts/__init__.py +3 -0
- agent_cli/server/tts/api.py +335 -0
- agent_cli/server/tts/backends/__init__.py +82 -0
- agent_cli/server/tts/backends/base.py +139 -0
- agent_cli/server/tts/backends/kokoro.py +403 -0
- agent_cli/server/tts/backends/piper.py +253 -0
- agent_cli/server/tts/model_manager.py +201 -0
- agent_cli/server/tts/model_registry.py +28 -0
- agent_cli/server/tts/wyoming_handler.py +249 -0
- agent_cli/server/whisper/__init__.py +3 -0
- agent_cli/server/whisper/api.py +413 -0
- agent_cli/server/whisper/backends/__init__.py +89 -0
- agent_cli/server/whisper/backends/base.py +97 -0
- agent_cli/server/whisper/backends/faster_whisper.py +225 -0
- agent_cli/server/whisper/backends/mlx.py +270 -0
- agent_cli/server/whisper/languages.py +116 -0
- agent_cli/server/whisper/model_manager.py +157 -0
- agent_cli/server/whisper/model_registry.py +28 -0
- agent_cli/server/whisper/wyoming_handler.py +203 -0
- agent_cli/services/__init__.py +343 -0
- agent_cli/services/_wyoming_utils.py +64 -0
- agent_cli/services/asr.py +506 -0
- agent_cli/services/llm.py +228 -0
- agent_cli/services/tts.py +450 -0
- agent_cli/services/wake_word.py +142 -0
- agent_cli-0.70.5.dist-info/METADATA +2118 -0
- agent_cli-0.70.5.dist-info/RECORD +196 -0
- agent_cli-0.70.5.dist-info/WHEEL +4 -0
- agent_cli-0.70.5.dist-info/entry_points.txt +4 -0
- agent_cli-0.70.5.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""ChromaDB helpers for memory storage."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from agent_cli.constants import DEFAULT_OPENAI_EMBEDDING_MODEL
|
|
8
|
+
from agent_cli.core.chroma import delete as delete_docs
|
|
9
|
+
from agent_cli.core.chroma import init_collection, upsert
|
|
10
|
+
from agent_cli.memory._filters import to_chroma_where
|
|
11
|
+
from agent_cli.memory.models import MemoryMetadata, StoredMemory
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Sequence
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
from chromadb import Collection
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def init_memory_collection(
|
|
21
|
+
persistence_path: Path,
|
|
22
|
+
*,
|
|
23
|
+
collection_name: str = "memory",
|
|
24
|
+
embedding_model: str = DEFAULT_OPENAI_EMBEDDING_MODEL,
|
|
25
|
+
openai_base_url: str | None = None,
|
|
26
|
+
openai_api_key: str | None = None,
|
|
27
|
+
) -> Collection:
|
|
28
|
+
"""Initialize or create the memory collection."""
|
|
29
|
+
return init_collection(
|
|
30
|
+
persistence_path,
|
|
31
|
+
name=collection_name,
|
|
32
|
+
embedding_model=embedding_model,
|
|
33
|
+
openai_base_url=openai_base_url,
|
|
34
|
+
openai_api_key=openai_api_key,
|
|
35
|
+
subdir="chroma",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def upsert_memories(
|
|
40
|
+
collection: Collection,
|
|
41
|
+
ids: list[str],
|
|
42
|
+
contents: list[str],
|
|
43
|
+
metadatas: Sequence[MemoryMetadata],
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Persist memory entries."""
|
|
46
|
+
upsert(collection, ids=ids, documents=contents, metadatas=metadatas)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def query_memories(
|
|
50
|
+
collection: Collection,
|
|
51
|
+
*,
|
|
52
|
+
conversation_id: str,
|
|
53
|
+
text: str,
|
|
54
|
+
n_results: int,
|
|
55
|
+
filters: dict[str, Any] | None = None,
|
|
56
|
+
) -> list[StoredMemory]:
|
|
57
|
+
"""Query for relevant memory entries and return structured results."""
|
|
58
|
+
base_filters: list[dict[str, Any]] = [
|
|
59
|
+
{"conversation_id": conversation_id},
|
|
60
|
+
{"role": {"$ne": "summary"}},
|
|
61
|
+
]
|
|
62
|
+
if filters:
|
|
63
|
+
chroma_filters = to_chroma_where(filters)
|
|
64
|
+
if chroma_filters:
|
|
65
|
+
base_filters.append(chroma_filters)
|
|
66
|
+
raw = collection.query(
|
|
67
|
+
query_texts=[text],
|
|
68
|
+
n_results=n_results,
|
|
69
|
+
where={"$and": base_filters},
|
|
70
|
+
include=["documents", "metadatas", "distances", "embeddings"],
|
|
71
|
+
)
|
|
72
|
+
docs_list = raw.get("documents")
|
|
73
|
+
docs = docs_list[0] if docs_list else []
|
|
74
|
+
|
|
75
|
+
metas_list = raw.get("metadatas")
|
|
76
|
+
metas = metas_list[0] if metas_list else []
|
|
77
|
+
|
|
78
|
+
ids_list = raw.get("ids")
|
|
79
|
+
ids = ids_list[0] if ids_list else []
|
|
80
|
+
|
|
81
|
+
dists_list = raw.get("distances")
|
|
82
|
+
distances = dists_list[0] if dists_list else []
|
|
83
|
+
|
|
84
|
+
raw_embeddings = raw.get("embeddings")
|
|
85
|
+
embeddings: list[Any] = []
|
|
86
|
+
if raw_embeddings and len(raw_embeddings) > 0 and raw_embeddings[0] is not None:
|
|
87
|
+
embeddings = raw_embeddings[0]
|
|
88
|
+
|
|
89
|
+
if len(embeddings) != len(docs):
|
|
90
|
+
msg = f"Chroma returned embeddings of unexpected length: {len(embeddings)} vs {len(docs)}"
|
|
91
|
+
raise ValueError(msg)
|
|
92
|
+
records: list[StoredMemory] = []
|
|
93
|
+
for doc, meta, doc_id, dist, emb in zip(
|
|
94
|
+
docs,
|
|
95
|
+
metas,
|
|
96
|
+
ids,
|
|
97
|
+
distances,
|
|
98
|
+
embeddings,
|
|
99
|
+
strict=False,
|
|
100
|
+
):
|
|
101
|
+
assert doc_id is not None
|
|
102
|
+
records.append(
|
|
103
|
+
StoredMemory(
|
|
104
|
+
id=doc_id,
|
|
105
|
+
content=doc,
|
|
106
|
+
metadata=MemoryMetadata(**dict(meta)),
|
|
107
|
+
distance=float(dist) if dist is not None else None,
|
|
108
|
+
embedding=[float(x) for x in emb] if emb is not None else None,
|
|
109
|
+
),
|
|
110
|
+
)
|
|
111
|
+
return records
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def get_summary_entry(
|
|
115
|
+
collection: Collection,
|
|
116
|
+
conversation_id: str,
|
|
117
|
+
*,
|
|
118
|
+
role: str = "summary",
|
|
119
|
+
) -> StoredMemory | None:
|
|
120
|
+
"""Return the latest summary entry for a conversation, if present."""
|
|
121
|
+
result = collection.get(
|
|
122
|
+
where={"$and": [{"conversation_id": conversation_id}, {"role": role}]},
|
|
123
|
+
)
|
|
124
|
+
docs = result.get("documents") or []
|
|
125
|
+
metas = result.get("metadatas") or []
|
|
126
|
+
ids = result.get("ids") or []
|
|
127
|
+
|
|
128
|
+
if not docs or not metas or not ids:
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
return StoredMemory(
|
|
132
|
+
id=ids[0],
|
|
133
|
+
content=docs[0],
|
|
134
|
+
metadata=MemoryMetadata(**dict(metas[0])),
|
|
135
|
+
distance=None,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def list_conversation_entries(
|
|
140
|
+
collection: Collection,
|
|
141
|
+
conversation_id: str,
|
|
142
|
+
*,
|
|
143
|
+
include_summary: bool = False,
|
|
144
|
+
) -> list[StoredMemory]:
|
|
145
|
+
"""List all entries for a conversation (optionally excluding summary)."""
|
|
146
|
+
filters: list[dict[str, Any]] = [{"conversation_id": conversation_id}]
|
|
147
|
+
if not include_summary:
|
|
148
|
+
filters.append({"role": {"$ne": "summary"}})
|
|
149
|
+
result = collection.get(where={"$and": filters} if len(filters) > 1 else filters[0])
|
|
150
|
+
docs = result.get("documents") or []
|
|
151
|
+
metas = result.get("metadatas") or []
|
|
152
|
+
ids = result.get("ids") or []
|
|
153
|
+
|
|
154
|
+
records: list[StoredMemory] = []
|
|
155
|
+
for doc, meta, entry_id in zip(docs, metas, ids, strict=False):
|
|
156
|
+
records.append(
|
|
157
|
+
StoredMemory(
|
|
158
|
+
id=entry_id,
|
|
159
|
+
content=doc,
|
|
160
|
+
metadata=MemoryMetadata(**dict(meta)),
|
|
161
|
+
distance=None,
|
|
162
|
+
),
|
|
163
|
+
)
|
|
164
|
+
return records
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def delete_entries(collection: Collection, ids: list[str]) -> None:
|
|
168
|
+
"""Delete entries by ID."""
|
|
169
|
+
delete_docs(collection, ids)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Streaming helpers for chat completions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from agent_cli.core.sse import extract_content_from_chunk, parse_chunk
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from collections.abc import AsyncGenerator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
async def stream_chat_sse(
|
|
14
|
+
*,
|
|
15
|
+
openai_base_url: str,
|
|
16
|
+
payload: dict[str, Any],
|
|
17
|
+
headers: dict[str, str] | None = None,
|
|
18
|
+
request_timeout: float = 120.0,
|
|
19
|
+
) -> AsyncGenerator[str, None]:
|
|
20
|
+
"""Stream Server-Sent Events from an OpenAI-compatible chat completion endpoint."""
|
|
21
|
+
import httpx # noqa: PLC0415
|
|
22
|
+
|
|
23
|
+
url = f"{openai_base_url.rstrip('/')}/chat/completions"
|
|
24
|
+
async with (
|
|
25
|
+
httpx.AsyncClient(timeout=request_timeout) as client,
|
|
26
|
+
client.stream("POST", url, json=payload, headers=headers) as response,
|
|
27
|
+
):
|
|
28
|
+
if response.status_code != 200: # noqa: PLR2004
|
|
29
|
+
error_text = await response.aread()
|
|
30
|
+
yield f"data: {error_text.decode(errors='ignore')}\n\n"
|
|
31
|
+
return
|
|
32
|
+
async for line in response.aiter_lines():
|
|
33
|
+
if line:
|
|
34
|
+
yield line
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def accumulate_assistant_text(line: str, buffer: list[str]) -> None:
|
|
38
|
+
"""Parse SSE line and append any assistant text delta into buffer."""
|
|
39
|
+
chunk = parse_chunk(line)
|
|
40
|
+
if chunk is None:
|
|
41
|
+
return
|
|
42
|
+
piece = extract_content_from_chunk(chunk)
|
|
43
|
+
if piece:
|
|
44
|
+
buffer.append(piece)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Utilities for tracking background tasks in the memory proxy."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from collections.abc import Coroutine
|
|
11
|
+
|
|
12
|
+
LOGGER = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
_BACKGROUND_TASKS: set[asyncio.Task[Any]] = set()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _track_background(task: asyncio.Task[Any], label: str) -> asyncio.Task[Any]:
|
|
18
|
+
"""Track background tasks and surface failures."""
|
|
19
|
+
_BACKGROUND_TASKS.add(task)
|
|
20
|
+
|
|
21
|
+
def _done_callback(done: asyncio.Task[Any]) -> None:
|
|
22
|
+
_BACKGROUND_TASKS.discard(done)
|
|
23
|
+
if done.cancelled():
|
|
24
|
+
LOGGER.debug("Background task %s cancelled", label)
|
|
25
|
+
return
|
|
26
|
+
exc = done.exception()
|
|
27
|
+
if exc:
|
|
28
|
+
LOGGER.exception("Background task %s failed", label, exc_info=exc)
|
|
29
|
+
|
|
30
|
+
task.add_done_callback(_done_callback)
|
|
31
|
+
return task
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def run_in_background(
|
|
35
|
+
coro: asyncio.Task[Any] | Coroutine[Any, Any, Any],
|
|
36
|
+
label: str,
|
|
37
|
+
) -> asyncio.Task[Any]:
|
|
38
|
+
"""Create and track a background asyncio task."""
|
|
39
|
+
task = coro if isinstance(coro, asyncio.Task) else asyncio.create_task(coro)
|
|
40
|
+
task.set_name(f"memory-{label}")
|
|
41
|
+
return _track_background(task, label)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
async def wait_for_background_tasks() -> None:
|
|
45
|
+
"""Await any in-flight background tasks (useful in tests)."""
|
|
46
|
+
while _BACKGROUND_TASKS:
|
|
47
|
+
tasks = list(_BACKGROUND_TASKS)
|
|
48
|
+
await asyncio.gather(*tasks, return_exceptions=False)
|
agent_cli/memory/api.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""FastAPI application factory for memory proxy."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from fastapi import FastAPI, Request
|
|
9
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
10
|
+
|
|
11
|
+
from agent_cli.constants import DEFAULT_OPENAI_EMBEDDING_MODEL
|
|
12
|
+
from agent_cli.core.openai_proxy import proxy_request_to_upstream
|
|
13
|
+
from agent_cli.memory.client import MemoryClient
|
|
14
|
+
from agent_cli.memory.models import ChatRequest # noqa: TC001
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
LOGGER = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def create_app(
|
|
23
|
+
memory_path: Path,
|
|
24
|
+
openai_base_url: str,
|
|
25
|
+
embedding_model: str = DEFAULT_OPENAI_EMBEDDING_MODEL,
|
|
26
|
+
embedding_api_key: str | None = None,
|
|
27
|
+
chat_api_key: str | None = None,
|
|
28
|
+
default_top_k: int = 5,
|
|
29
|
+
enable_summarization: bool = True,
|
|
30
|
+
max_entries: int = 500,
|
|
31
|
+
mmr_lambda: float = 0.7,
|
|
32
|
+
recency_weight: float = 0.2,
|
|
33
|
+
score_threshold: float | None = None,
|
|
34
|
+
enable_git_versioning: bool = True,
|
|
35
|
+
) -> FastAPI:
|
|
36
|
+
"""Create the FastAPI app for memory-backed chat."""
|
|
37
|
+
LOGGER.info("Initializing memory client...")
|
|
38
|
+
|
|
39
|
+
client = MemoryClient(
|
|
40
|
+
memory_path=memory_path,
|
|
41
|
+
openai_base_url=openai_base_url,
|
|
42
|
+
embedding_model=embedding_model,
|
|
43
|
+
embedding_api_key=embedding_api_key,
|
|
44
|
+
chat_api_key=chat_api_key,
|
|
45
|
+
default_top_k=default_top_k,
|
|
46
|
+
enable_summarization=enable_summarization,
|
|
47
|
+
max_entries=max_entries,
|
|
48
|
+
mmr_lambda=mmr_lambda,
|
|
49
|
+
recency_weight=recency_weight,
|
|
50
|
+
score_threshold=score_threshold,
|
|
51
|
+
start_watcher=False, # We control start/stop via app events
|
|
52
|
+
enable_git_versioning=enable_git_versioning,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
app = FastAPI(title="Memory Proxy")
|
|
56
|
+
|
|
57
|
+
app.add_middleware(
|
|
58
|
+
CORSMiddleware,
|
|
59
|
+
allow_origins=["*"],
|
|
60
|
+
allow_credentials=True,
|
|
61
|
+
allow_methods=["*"],
|
|
62
|
+
allow_headers=["*"],
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@app.post("/v1/chat/completions")
|
|
66
|
+
async def chat_completions(request: Request, chat_request: ChatRequest) -> Any:
|
|
67
|
+
auth_header = request.headers.get("Authorization")
|
|
68
|
+
api_key = None
|
|
69
|
+
if auth_header and auth_header.startswith("Bearer "):
|
|
70
|
+
api_key = auth_header.split(" ")[1]
|
|
71
|
+
|
|
72
|
+
return await client.chat(
|
|
73
|
+
messages=chat_request.messages,
|
|
74
|
+
conversation_id=chat_request.memory_id or "default",
|
|
75
|
+
model=chat_request.model,
|
|
76
|
+
stream=chat_request.stream or False,
|
|
77
|
+
api_key=api_key,
|
|
78
|
+
memory_top_k=chat_request.memory_top_k,
|
|
79
|
+
recency_weight=chat_request.memory_recency_weight,
|
|
80
|
+
score_threshold=chat_request.memory_score_threshold,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
@app.on_event("startup")
|
|
84
|
+
async def start_watch() -> None:
|
|
85
|
+
client.start()
|
|
86
|
+
|
|
87
|
+
@app.on_event("shutdown")
|
|
88
|
+
async def stop_watch() -> None:
|
|
89
|
+
await client.stop()
|
|
90
|
+
|
|
91
|
+
@app.get("/health")
|
|
92
|
+
def health() -> dict[str, str]:
|
|
93
|
+
return {
|
|
94
|
+
"status": "ok",
|
|
95
|
+
"memory_store": str(client.memory_path),
|
|
96
|
+
"openai_base_url": client.openai_base_url,
|
|
97
|
+
"default_top_k": str(client.default_top_k),
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
@app.api_route(
|
|
101
|
+
"/{path:path}",
|
|
102
|
+
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
|
103
|
+
)
|
|
104
|
+
async def proxy_catch_all(request: Request, path: str) -> Any:
|
|
105
|
+
"""Forward any other request to the upstream provider."""
|
|
106
|
+
return await proxy_request_to_upstream(
|
|
107
|
+
request,
|
|
108
|
+
path,
|
|
109
|
+
client.openai_base_url,
|
|
110
|
+
client.chat_api_key,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return app
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""High-level client for interacting with the memory system."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from contextlib import suppress
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
9
|
+
|
|
10
|
+
from agent_cli.constants import DEFAULT_OPENAI_EMBEDDING_MODEL, DEFAULT_OPENAI_MODEL
|
|
11
|
+
from agent_cli.core.reranker import get_reranker_model
|
|
12
|
+
from agent_cli.memory._files import ensure_store_dirs
|
|
13
|
+
from agent_cli.memory._git import init_repo
|
|
14
|
+
from agent_cli.memory._indexer import MemoryIndex, initial_index, watch_memory_store
|
|
15
|
+
from agent_cli.memory._ingest import extract_and_store_facts_and_summaries
|
|
16
|
+
from agent_cli.memory._persistence import evict_if_needed
|
|
17
|
+
from agent_cli.memory._retrieval import augment_chat_request
|
|
18
|
+
from agent_cli.memory._store import init_memory_collection, list_conversation_entries
|
|
19
|
+
from agent_cli.memory.engine import process_chat_request
|
|
20
|
+
from agent_cli.memory.models import ChatRequest, MemoryRetrieval, Message
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
from chromadb import Collection
|
|
26
|
+
|
|
27
|
+
from agent_cli.core.reranker import OnnxCrossEncoder
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger("agent_cli.memory.client")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MemoryClient:
|
|
34
|
+
"""A client for interacting with the memory system (add, search, chat).
|
|
35
|
+
|
|
36
|
+
This class decouples the memory logic from the HTTP server, allowing
|
|
37
|
+
direct usage in other applications or scripts.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
memory_path: Path,
|
|
43
|
+
openai_base_url: str,
|
|
44
|
+
embedding_model: str = DEFAULT_OPENAI_EMBEDDING_MODEL,
|
|
45
|
+
embedding_api_key: str | None = None,
|
|
46
|
+
chat_api_key: str | None = None,
|
|
47
|
+
enable_summarization: bool = True,
|
|
48
|
+
default_top_k: int = 5,
|
|
49
|
+
max_entries: int = 500,
|
|
50
|
+
mmr_lambda: float = 0.7,
|
|
51
|
+
recency_weight: float = 0.2,
|
|
52
|
+
score_threshold: float | None = None,
|
|
53
|
+
start_watcher: bool = False,
|
|
54
|
+
enable_git_versioning: bool = True,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Initialize the memory client."""
|
|
57
|
+
self.memory_path = memory_path.resolve()
|
|
58
|
+
self.openai_base_url = openai_base_url.rstrip("/")
|
|
59
|
+
self.chat_api_key = chat_api_key
|
|
60
|
+
self.enable_summarization = enable_summarization
|
|
61
|
+
self.default_top_k = default_top_k
|
|
62
|
+
self.max_entries = max_entries
|
|
63
|
+
self.mmr_lambda = mmr_lambda
|
|
64
|
+
self.recency_weight = recency_weight
|
|
65
|
+
self.score_threshold = score_threshold
|
|
66
|
+
self.enable_git_versioning = enable_git_versioning
|
|
67
|
+
|
|
68
|
+
_, snapshot_path = ensure_store_dirs(self.memory_path)
|
|
69
|
+
|
|
70
|
+
if self.enable_git_versioning:
|
|
71
|
+
init_repo(self.memory_path)
|
|
72
|
+
|
|
73
|
+
logger.info("Initializing memory collection...")
|
|
74
|
+
self.collection: Collection = init_memory_collection(
|
|
75
|
+
self.memory_path,
|
|
76
|
+
embedding_model=embedding_model,
|
|
77
|
+
openai_base_url=self.openai_base_url,
|
|
78
|
+
openai_api_key=embedding_api_key,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.index = MemoryIndex.from_snapshot(snapshot_path)
|
|
82
|
+
initial_index(self.collection, self.memory_path, index=self.index)
|
|
83
|
+
|
|
84
|
+
logger.info("Loading reranker model...")
|
|
85
|
+
self.reranker_model: OnnxCrossEncoder = get_reranker_model()
|
|
86
|
+
|
|
87
|
+
self._watch_task: asyncio.Task | None = None
|
|
88
|
+
if start_watcher:
|
|
89
|
+
self.start()
|
|
90
|
+
|
|
91
|
+
def start(self) -> None:
|
|
92
|
+
"""Start the background file watcher."""
|
|
93
|
+
if self._watch_task is None:
|
|
94
|
+
self._watch_task = asyncio.create_task(
|
|
95
|
+
watch_memory_store(self.collection, self.memory_path, index=self.index),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
async def stop(self) -> None:
|
|
99
|
+
"""Stop the background file watcher."""
|
|
100
|
+
if self._watch_task:
|
|
101
|
+
self._watch_task.cancel()
|
|
102
|
+
with suppress(asyncio.CancelledError):
|
|
103
|
+
await self._watch_task
|
|
104
|
+
self._watch_task = None
|
|
105
|
+
|
|
106
|
+
async def __aenter__(self) -> Self:
|
|
107
|
+
"""Start the client context."""
|
|
108
|
+
self.start()
|
|
109
|
+
return self
|
|
110
|
+
|
|
111
|
+
async def __aexit__(self, *args: object) -> None:
|
|
112
|
+
"""Stop the client context."""
|
|
113
|
+
await self.stop()
|
|
114
|
+
|
|
115
|
+
async def add(
|
|
116
|
+
self,
|
|
117
|
+
text: str,
|
|
118
|
+
conversation_id: str = "default",
|
|
119
|
+
model: str = DEFAULT_OPENAI_MODEL,
|
|
120
|
+
) -> None:
|
|
121
|
+
"""Add a memory by extracting facts from text and reconciling them.
|
|
122
|
+
|
|
123
|
+
This mimics the 'mem0.add' behavior but uses our advanced reconciliation
|
|
124
|
+
pipeline (Add/Update/Delete) and updates the conversation summary.
|
|
125
|
+
"""
|
|
126
|
+
await extract_and_store_facts_and_summaries(
|
|
127
|
+
collection=self.collection,
|
|
128
|
+
memory_root=self.memory_path,
|
|
129
|
+
conversation_id=conversation_id,
|
|
130
|
+
user_message=text,
|
|
131
|
+
assistant_message=None,
|
|
132
|
+
openai_base_url=self.openai_base_url,
|
|
133
|
+
api_key=self.chat_api_key,
|
|
134
|
+
model=model,
|
|
135
|
+
enable_git_versioning=self.enable_git_versioning,
|
|
136
|
+
enable_summarization=self.enable_summarization,
|
|
137
|
+
)
|
|
138
|
+
evict_if_needed(self.collection, self.memory_path, conversation_id, self.max_entries)
|
|
139
|
+
|
|
140
|
+
async def search(
|
|
141
|
+
self,
|
|
142
|
+
query: str,
|
|
143
|
+
conversation_id: str = "default",
|
|
144
|
+
top_k: int | None = None,
|
|
145
|
+
model: str = DEFAULT_OPENAI_MODEL,
|
|
146
|
+
recency_weight: float | None = None,
|
|
147
|
+
score_threshold: float | None = None,
|
|
148
|
+
filters: dict[str, Any] | None = None,
|
|
149
|
+
) -> MemoryRetrieval:
|
|
150
|
+
"""Search for memories relevant to a query.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
query: The search query text.
|
|
154
|
+
conversation_id: Conversation scope for the search.
|
|
155
|
+
top_k: Number of results to return.
|
|
156
|
+
model: Model for any LLM operations.
|
|
157
|
+
recency_weight: Weight for recency scoring (0-1).
|
|
158
|
+
score_threshold: Minimum relevance score threshold.
|
|
159
|
+
filters: Optional metadata filters. Examples:
|
|
160
|
+
- {"role": "user"} - exact match
|
|
161
|
+
- {"created_at": {"gte": "2024-01-01"}} - comparison
|
|
162
|
+
- {"$or": [{"role": "user"}, {"role": "assistant"}]} - logical OR
|
|
163
|
+
Operators: eq, ne, gt, gte, lt, lte, in, nin
|
|
164
|
+
|
|
165
|
+
"""
|
|
166
|
+
dummy_request = ChatRequest(
|
|
167
|
+
messages=[Message(role="user", content=query)],
|
|
168
|
+
model=model,
|
|
169
|
+
memory_id=conversation_id,
|
|
170
|
+
memory_top_k=top_k or self.default_top_k,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
_, retrieval, _, _ = await augment_chat_request(
|
|
174
|
+
dummy_request,
|
|
175
|
+
self.collection,
|
|
176
|
+
reranker_model=self.reranker_model,
|
|
177
|
+
default_top_k=top_k or self.default_top_k,
|
|
178
|
+
include_global=True,
|
|
179
|
+
mmr_lambda=self.mmr_lambda,
|
|
180
|
+
recency_weight=recency_weight if recency_weight is not None else self.recency_weight,
|
|
181
|
+
score_threshold=score_threshold
|
|
182
|
+
if score_threshold is not None
|
|
183
|
+
else self.score_threshold,
|
|
184
|
+
filters=filters,
|
|
185
|
+
)
|
|
186
|
+
return retrieval or MemoryRetrieval(entries=[])
|
|
187
|
+
|
|
188
|
+
def list_all(
|
|
189
|
+
self,
|
|
190
|
+
conversation_id: str = "default",
|
|
191
|
+
include_summary: bool = False,
|
|
192
|
+
) -> list[dict[str, Any]]:
|
|
193
|
+
"""List all stored memories for a conversation.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
conversation_id: Conversation scope.
|
|
197
|
+
include_summary: Whether to include summary entries.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
List of memory entries with id, content, and metadata.
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
entries = list_conversation_entries(
|
|
204
|
+
self.collection,
|
|
205
|
+
conversation_id,
|
|
206
|
+
include_summary=include_summary,
|
|
207
|
+
)
|
|
208
|
+
return [
|
|
209
|
+
{
|
|
210
|
+
"id": e.id,
|
|
211
|
+
"content": e.content,
|
|
212
|
+
"role": e.metadata.role,
|
|
213
|
+
"created_at": e.metadata.created_at,
|
|
214
|
+
}
|
|
215
|
+
for e in entries
|
|
216
|
+
]
|
|
217
|
+
|
|
218
|
+
async def chat(
|
|
219
|
+
self,
|
|
220
|
+
messages: list[dict[str, str]] | list[Any],
|
|
221
|
+
conversation_id: str = "default",
|
|
222
|
+
model: str = DEFAULT_OPENAI_MODEL,
|
|
223
|
+
stream: bool = False,
|
|
224
|
+
api_key: str | None = None,
|
|
225
|
+
memory_top_k: int | None = None,
|
|
226
|
+
recency_weight: float | None = None,
|
|
227
|
+
score_threshold: float | None = None,
|
|
228
|
+
filters: dict[str, Any] | None = None,
|
|
229
|
+
) -> Any:
|
|
230
|
+
"""Process a chat request (Augment -> LLM -> Update Memory).
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
messages: Chat messages.
|
|
234
|
+
conversation_id: Conversation scope.
|
|
235
|
+
model: LLM model to use.
|
|
236
|
+
stream: Whether to stream the response.
|
|
237
|
+
api_key: Optional API key override.
|
|
238
|
+
memory_top_k: Number of memories to retrieve.
|
|
239
|
+
recency_weight: Weight for recency scoring (0-1).
|
|
240
|
+
score_threshold: Minimum relevance score threshold.
|
|
241
|
+
filters: Optional metadata filters for memory retrieval.
|
|
242
|
+
|
|
243
|
+
"""
|
|
244
|
+
req = ChatRequest(
|
|
245
|
+
messages=messages, # type: ignore[arg-type]
|
|
246
|
+
model=model,
|
|
247
|
+
memory_id=conversation_id,
|
|
248
|
+
stream=stream,
|
|
249
|
+
memory_top_k=memory_top_k if memory_top_k is not None else self.default_top_k,
|
|
250
|
+
memory_recency_weight=recency_weight,
|
|
251
|
+
memory_score_threshold=score_threshold,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
return await process_chat_request(
|
|
255
|
+
req,
|
|
256
|
+
collection=self.collection,
|
|
257
|
+
memory_root=self.memory_path,
|
|
258
|
+
openai_base_url=self.openai_base_url,
|
|
259
|
+
reranker_model=self.reranker_model,
|
|
260
|
+
default_top_k=self.default_top_k,
|
|
261
|
+
api_key=api_key or self.chat_api_key,
|
|
262
|
+
enable_summarization=self.enable_summarization,
|
|
263
|
+
max_entries=self.max_entries,
|
|
264
|
+
mmr_lambda=self.mmr_lambda,
|
|
265
|
+
recency_weight=recency_weight if recency_weight is not None else self.recency_weight,
|
|
266
|
+
score_threshold=score_threshold
|
|
267
|
+
if score_threshold is not None
|
|
268
|
+
else self.score_threshold,
|
|
269
|
+
postprocess_in_background=True,
|
|
270
|
+
enable_git_versioning=self.enable_git_versioning,
|
|
271
|
+
filters=filters,
|
|
272
|
+
)
|