agent-cli 0.70.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_cli/__init__.py +5 -0
- agent_cli/__main__.py +6 -0
- agent_cli/_extras.json +14 -0
- agent_cli/_requirements/.gitkeep +0 -0
- agent_cli/_requirements/audio.txt +79 -0
- agent_cli/_requirements/faster-whisper.txt +215 -0
- agent_cli/_requirements/kokoro.txt +425 -0
- agent_cli/_requirements/llm.txt +183 -0
- agent_cli/_requirements/memory.txt +355 -0
- agent_cli/_requirements/mlx-whisper.txt +222 -0
- agent_cli/_requirements/piper.txt +176 -0
- agent_cli/_requirements/rag.txt +402 -0
- agent_cli/_requirements/server.txt +154 -0
- agent_cli/_requirements/speed.txt +77 -0
- agent_cli/_requirements/vad.txt +155 -0
- agent_cli/_requirements/wyoming.txt +71 -0
- agent_cli/_tools.py +368 -0
- agent_cli/agents/__init__.py +23 -0
- agent_cli/agents/_voice_agent_common.py +136 -0
- agent_cli/agents/assistant.py +383 -0
- agent_cli/agents/autocorrect.py +284 -0
- agent_cli/agents/chat.py +496 -0
- agent_cli/agents/memory/__init__.py +31 -0
- agent_cli/agents/memory/add.py +190 -0
- agent_cli/agents/memory/proxy.py +160 -0
- agent_cli/agents/rag_proxy.py +128 -0
- agent_cli/agents/speak.py +209 -0
- agent_cli/agents/transcribe.py +671 -0
- agent_cli/agents/transcribe_daemon.py +499 -0
- agent_cli/agents/voice_edit.py +291 -0
- agent_cli/api.py +22 -0
- agent_cli/cli.py +106 -0
- agent_cli/config.py +503 -0
- agent_cli/config_cmd.py +307 -0
- agent_cli/constants.py +27 -0
- agent_cli/core/__init__.py +1 -0
- agent_cli/core/audio.py +461 -0
- agent_cli/core/audio_format.py +299 -0
- agent_cli/core/chroma.py +88 -0
- agent_cli/core/deps.py +191 -0
- agent_cli/core/openai_proxy.py +139 -0
- agent_cli/core/process.py +195 -0
- agent_cli/core/reranker.py +120 -0
- agent_cli/core/sse.py +87 -0
- agent_cli/core/transcription_logger.py +70 -0
- agent_cli/core/utils.py +526 -0
- agent_cli/core/vad.py +175 -0
- agent_cli/core/watch.py +65 -0
- agent_cli/dev/__init__.py +14 -0
- agent_cli/dev/cli.py +1588 -0
- agent_cli/dev/coding_agents/__init__.py +19 -0
- agent_cli/dev/coding_agents/aider.py +24 -0
- agent_cli/dev/coding_agents/base.py +167 -0
- agent_cli/dev/coding_agents/claude.py +39 -0
- agent_cli/dev/coding_agents/codex.py +24 -0
- agent_cli/dev/coding_agents/continue_dev.py +15 -0
- agent_cli/dev/coding_agents/copilot.py +24 -0
- agent_cli/dev/coding_agents/cursor_agent.py +48 -0
- agent_cli/dev/coding_agents/gemini.py +28 -0
- agent_cli/dev/coding_agents/opencode.py +15 -0
- agent_cli/dev/coding_agents/registry.py +49 -0
- agent_cli/dev/editors/__init__.py +19 -0
- agent_cli/dev/editors/base.py +89 -0
- agent_cli/dev/editors/cursor.py +15 -0
- agent_cli/dev/editors/emacs.py +46 -0
- agent_cli/dev/editors/jetbrains.py +56 -0
- agent_cli/dev/editors/nano.py +31 -0
- agent_cli/dev/editors/neovim.py +33 -0
- agent_cli/dev/editors/registry.py +59 -0
- agent_cli/dev/editors/sublime.py +20 -0
- agent_cli/dev/editors/vim.py +42 -0
- agent_cli/dev/editors/vscode.py +15 -0
- agent_cli/dev/editors/zed.py +20 -0
- agent_cli/dev/project.py +568 -0
- agent_cli/dev/registry.py +52 -0
- agent_cli/dev/skill/SKILL.md +141 -0
- agent_cli/dev/skill/examples.md +571 -0
- agent_cli/dev/terminals/__init__.py +19 -0
- agent_cli/dev/terminals/apple_terminal.py +82 -0
- agent_cli/dev/terminals/base.py +56 -0
- agent_cli/dev/terminals/gnome.py +51 -0
- agent_cli/dev/terminals/iterm2.py +84 -0
- agent_cli/dev/terminals/kitty.py +77 -0
- agent_cli/dev/terminals/registry.py +48 -0
- agent_cli/dev/terminals/tmux.py +58 -0
- agent_cli/dev/terminals/warp.py +132 -0
- agent_cli/dev/terminals/zellij.py +78 -0
- agent_cli/dev/worktree.py +856 -0
- agent_cli/docs_gen.py +417 -0
- agent_cli/example-config.toml +185 -0
- agent_cli/install/__init__.py +5 -0
- agent_cli/install/common.py +89 -0
- agent_cli/install/extras.py +174 -0
- agent_cli/install/hotkeys.py +48 -0
- agent_cli/install/services.py +87 -0
- agent_cli/memory/__init__.py +7 -0
- agent_cli/memory/_files.py +250 -0
- agent_cli/memory/_filters.py +63 -0
- agent_cli/memory/_git.py +157 -0
- agent_cli/memory/_indexer.py +142 -0
- agent_cli/memory/_ingest.py +408 -0
- agent_cli/memory/_persistence.py +182 -0
- agent_cli/memory/_prompt.py +91 -0
- agent_cli/memory/_retrieval.py +294 -0
- agent_cli/memory/_store.py +169 -0
- agent_cli/memory/_streaming.py +44 -0
- agent_cli/memory/_tasks.py +48 -0
- agent_cli/memory/api.py +113 -0
- agent_cli/memory/client.py +272 -0
- agent_cli/memory/engine.py +361 -0
- agent_cli/memory/entities.py +43 -0
- agent_cli/memory/models.py +112 -0
- agent_cli/opts.py +433 -0
- agent_cli/py.typed +0 -0
- agent_cli/rag/__init__.py +3 -0
- agent_cli/rag/_indexer.py +67 -0
- agent_cli/rag/_indexing.py +226 -0
- agent_cli/rag/_prompt.py +30 -0
- agent_cli/rag/_retriever.py +156 -0
- agent_cli/rag/_store.py +48 -0
- agent_cli/rag/_utils.py +218 -0
- agent_cli/rag/api.py +175 -0
- agent_cli/rag/client.py +299 -0
- agent_cli/rag/engine.py +302 -0
- agent_cli/rag/models.py +55 -0
- agent_cli/scripts/.runtime/.gitkeep +0 -0
- agent_cli/scripts/__init__.py +1 -0
- agent_cli/scripts/check_plugin_skill_sync.py +50 -0
- agent_cli/scripts/linux-hotkeys/README.md +63 -0
- agent_cli/scripts/linux-hotkeys/toggle-autocorrect.sh +45 -0
- agent_cli/scripts/linux-hotkeys/toggle-transcription.sh +58 -0
- agent_cli/scripts/linux-hotkeys/toggle-voice-edit.sh +58 -0
- agent_cli/scripts/macos-hotkeys/README.md +45 -0
- agent_cli/scripts/macos-hotkeys/skhd-config-example +5 -0
- agent_cli/scripts/macos-hotkeys/toggle-autocorrect.sh +12 -0
- agent_cli/scripts/macos-hotkeys/toggle-transcription.sh +37 -0
- agent_cli/scripts/macos-hotkeys/toggle-voice-edit.sh +37 -0
- agent_cli/scripts/nvidia-asr-server/README.md +99 -0
- agent_cli/scripts/nvidia-asr-server/pyproject.toml +27 -0
- agent_cli/scripts/nvidia-asr-server/server.py +255 -0
- agent_cli/scripts/nvidia-asr-server/shell.nix +32 -0
- agent_cli/scripts/nvidia-asr-server/uv.lock +4654 -0
- agent_cli/scripts/run-openwakeword.sh +11 -0
- agent_cli/scripts/run-piper-windows.ps1 +30 -0
- agent_cli/scripts/run-piper.sh +24 -0
- agent_cli/scripts/run-whisper-linux.sh +40 -0
- agent_cli/scripts/run-whisper-macos.sh +6 -0
- agent_cli/scripts/run-whisper-windows.ps1 +51 -0
- agent_cli/scripts/run-whisper.sh +9 -0
- agent_cli/scripts/run_faster_whisper_server.py +136 -0
- agent_cli/scripts/setup-linux-hotkeys.sh +72 -0
- agent_cli/scripts/setup-linux.sh +108 -0
- agent_cli/scripts/setup-macos-hotkeys.sh +61 -0
- agent_cli/scripts/setup-macos.sh +76 -0
- agent_cli/scripts/setup-windows.ps1 +63 -0
- agent_cli/scripts/start-all-services-windows.ps1 +53 -0
- agent_cli/scripts/start-all-services.sh +178 -0
- agent_cli/scripts/sync_extras.py +138 -0
- agent_cli/server/__init__.py +3 -0
- agent_cli/server/cli.py +721 -0
- agent_cli/server/common.py +222 -0
- agent_cli/server/model_manager.py +288 -0
- agent_cli/server/model_registry.py +225 -0
- agent_cli/server/proxy/__init__.py +3 -0
- agent_cli/server/proxy/api.py +444 -0
- agent_cli/server/streaming.py +67 -0
- agent_cli/server/tts/__init__.py +3 -0
- agent_cli/server/tts/api.py +335 -0
- agent_cli/server/tts/backends/__init__.py +82 -0
- agent_cli/server/tts/backends/base.py +139 -0
- agent_cli/server/tts/backends/kokoro.py +403 -0
- agent_cli/server/tts/backends/piper.py +253 -0
- agent_cli/server/tts/model_manager.py +201 -0
- agent_cli/server/tts/model_registry.py +28 -0
- agent_cli/server/tts/wyoming_handler.py +249 -0
- agent_cli/server/whisper/__init__.py +3 -0
- agent_cli/server/whisper/api.py +413 -0
- agent_cli/server/whisper/backends/__init__.py +89 -0
- agent_cli/server/whisper/backends/base.py +97 -0
- agent_cli/server/whisper/backends/faster_whisper.py +225 -0
- agent_cli/server/whisper/backends/mlx.py +270 -0
- agent_cli/server/whisper/languages.py +116 -0
- agent_cli/server/whisper/model_manager.py +157 -0
- agent_cli/server/whisper/model_registry.py +28 -0
- agent_cli/server/whisper/wyoming_handler.py +203 -0
- agent_cli/services/__init__.py +343 -0
- agent_cli/services/_wyoming_utils.py +64 -0
- agent_cli/services/asr.py +506 -0
- agent_cli/services/llm.py +228 -0
- agent_cli/services/tts.py +450 -0
- agent_cli/services/wake_word.py +142 -0
- agent_cli-0.70.5.dist-info/METADATA +2118 -0
- agent_cli-0.70.5.dist-info/RECORD +196 -0
- agent_cli-0.70.5.dist-info/WHEEL +4 -0
- agent_cli-0.70.5.dist-info/entry_points.txt +4 -0
- agent_cli-0.70.5.dist-info/licenses/LICENSE +21 -0
agent_cli/rag/api.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""FastAPI application factory for RAG."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import threading
|
|
8
|
+
from contextlib import asynccontextmanager, suppress
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
from fastapi import FastAPI, Request
|
|
12
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
13
|
+
|
|
14
|
+
from agent_cli.constants import DEFAULT_OPENAI_EMBEDDING_MODEL
|
|
15
|
+
from agent_cli.core.chroma import init_collection
|
|
16
|
+
from agent_cli.core.openai_proxy import proxy_request_to_upstream
|
|
17
|
+
from agent_cli.core.reranker import get_reranker_model
|
|
18
|
+
from agent_cli.rag._indexer import watch_docs
|
|
19
|
+
from agent_cli.rag._indexing import initial_index, load_hashes_from_metadata
|
|
20
|
+
from agent_cli.rag._store import get_all_metadata
|
|
21
|
+
from agent_cli.rag.engine import process_chat_request
|
|
22
|
+
from agent_cli.rag.models import ChatRequest # noqa: TC001
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
LOGGER = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_app(
|
|
32
|
+
docs_folder: Path,
|
|
33
|
+
chroma_path: Path,
|
|
34
|
+
openai_base_url: str,
|
|
35
|
+
embedding_model: str = DEFAULT_OPENAI_EMBEDDING_MODEL,
|
|
36
|
+
embedding_api_key: str | None = None,
|
|
37
|
+
chat_api_key: str | None = None,
|
|
38
|
+
limit: int = 3,
|
|
39
|
+
enable_rag_tools: bool = True,
|
|
40
|
+
) -> FastAPI:
|
|
41
|
+
"""Create the FastAPI app."""
|
|
42
|
+
# Initialize State
|
|
43
|
+
LOGGER.info("Initializing RAG components...")
|
|
44
|
+
|
|
45
|
+
LOGGER.info("Loading vector database (ChromaDB)...")
|
|
46
|
+
collection = init_collection(
|
|
47
|
+
chroma_path,
|
|
48
|
+
name="docs",
|
|
49
|
+
embedding_model=embedding_model,
|
|
50
|
+
openai_base_url=openai_base_url,
|
|
51
|
+
openai_api_key=embedding_api_key,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
LOGGER.info("Loading reranker model (CrossEncoder)...")
|
|
55
|
+
reranker_model = get_reranker_model()
|
|
56
|
+
|
|
57
|
+
LOGGER.info("Loading existing file index...")
|
|
58
|
+
file_hashes, file_mtimes = load_hashes_from_metadata(collection)
|
|
59
|
+
LOGGER.info("Loaded %d files from index.", len(file_hashes))
|
|
60
|
+
|
|
61
|
+
docs_folder.mkdir(exist_ok=True, parents=True)
|
|
62
|
+
|
|
63
|
+
@asynccontextmanager
|
|
64
|
+
async def lifespan(_app: FastAPI): # noqa: ANN202
|
|
65
|
+
LOGGER.info("Starting file watcher...")
|
|
66
|
+
# Background Tasks
|
|
67
|
+
background_tasks = set()
|
|
68
|
+
watcher_task = asyncio.create_task(
|
|
69
|
+
watch_docs(collection, docs_folder, file_hashes, file_mtimes),
|
|
70
|
+
)
|
|
71
|
+
background_tasks.add(watcher_task)
|
|
72
|
+
watcher_task.add_done_callback(background_tasks.discard)
|
|
73
|
+
|
|
74
|
+
LOGGER.info("Starting initial index scan...")
|
|
75
|
+
threading.Thread(
|
|
76
|
+
target=initial_index,
|
|
77
|
+
args=(collection, docs_folder, file_hashes, file_mtimes),
|
|
78
|
+
daemon=True,
|
|
79
|
+
).start()
|
|
80
|
+
yield
|
|
81
|
+
# Cleanup if needed
|
|
82
|
+
watcher_task.cancel()
|
|
83
|
+
with suppress(asyncio.CancelledError):
|
|
84
|
+
await watcher_task
|
|
85
|
+
|
|
86
|
+
app = FastAPI(title="RAG Proxy", lifespan=lifespan)
|
|
87
|
+
|
|
88
|
+
app.add_middleware(
|
|
89
|
+
CORSMiddleware,
|
|
90
|
+
allow_origins=["*"],
|
|
91
|
+
allow_credentials=True,
|
|
92
|
+
allow_methods=["*"],
|
|
93
|
+
allow_headers=["*"],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
@app.post("/v1/chat/completions")
|
|
97
|
+
async def chat_completions(request: Request, chat_request: ChatRequest) -> Any:
|
|
98
|
+
# Extract API Key from Authorization header if present
|
|
99
|
+
auth_header = request.headers.get("Authorization")
|
|
100
|
+
api_key = None
|
|
101
|
+
if auth_header and auth_header.startswith("Bearer "):
|
|
102
|
+
api_key = auth_header.split(" ")[1]
|
|
103
|
+
|
|
104
|
+
# Fallback to server-configured key
|
|
105
|
+
if not api_key:
|
|
106
|
+
api_key = chat_api_key
|
|
107
|
+
|
|
108
|
+
return await process_chat_request(
|
|
109
|
+
chat_request,
|
|
110
|
+
collection,
|
|
111
|
+
reranker_model,
|
|
112
|
+
openai_base_url.rstrip("/"),
|
|
113
|
+
docs_folder,
|
|
114
|
+
default_top_k=limit,
|
|
115
|
+
api_key=api_key,
|
|
116
|
+
enable_rag_tools=enable_rag_tools,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@app.post("/reindex")
|
|
120
|
+
def reindex_all() -> dict[str, Any]:
|
|
121
|
+
"""Manually reindex all files."""
|
|
122
|
+
LOGGER.info("Manual reindex requested.")
|
|
123
|
+
threading.Thread(
|
|
124
|
+
target=initial_index,
|
|
125
|
+
args=(collection, docs_folder, file_hashes, file_mtimes),
|
|
126
|
+
daemon=True,
|
|
127
|
+
).start()
|
|
128
|
+
return {"status": "started reindexing", "total_chunks": collection.count()}
|
|
129
|
+
|
|
130
|
+
@app.get("/files")
|
|
131
|
+
def list_files() -> dict[str, Any]:
|
|
132
|
+
"""List all indexed files."""
|
|
133
|
+
metadatas = get_all_metadata(collection)
|
|
134
|
+
|
|
135
|
+
files = {}
|
|
136
|
+
for meta in metadatas:
|
|
137
|
+
if not meta:
|
|
138
|
+
continue
|
|
139
|
+
fp = meta["file_path"]
|
|
140
|
+
if fp not in files:
|
|
141
|
+
files[fp] = {
|
|
142
|
+
"name": meta["source"],
|
|
143
|
+
"path": fp,
|
|
144
|
+
"type": meta["file_type"],
|
|
145
|
+
"chunks": 0,
|
|
146
|
+
"indexed_at": meta["indexed_at"],
|
|
147
|
+
}
|
|
148
|
+
files[fp]["chunks"] += 1
|
|
149
|
+
|
|
150
|
+
return {"files": list(files.values()), "total": len(files)}
|
|
151
|
+
|
|
152
|
+
@app.get("/health")
|
|
153
|
+
def health() -> dict[str, str]:
|
|
154
|
+
return {
|
|
155
|
+
"status": "ok",
|
|
156
|
+
"rag_docs": str(docs_folder),
|
|
157
|
+
"openai_base_url": openai_base_url,
|
|
158
|
+
"embedding_model": embedding_model,
|
|
159
|
+
"limit": str(limit),
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
@app.api_route(
|
|
163
|
+
"/{path:path}",
|
|
164
|
+
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
|
165
|
+
)
|
|
166
|
+
async def proxy_catch_all(request: Request, path: str) -> Any:
|
|
167
|
+
"""Forward any other request to the upstream provider."""
|
|
168
|
+
return await proxy_request_to_upstream(
|
|
169
|
+
request,
|
|
170
|
+
path,
|
|
171
|
+
openai_base_url,
|
|
172
|
+
chat_api_key,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return app
|
agent_cli/rag/client.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""RagClient - Composable RAG abstraction for indexing and search."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import uuid
|
|
7
|
+
from datetime import UTC, datetime
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from agent_cli.constants import (
|
|
11
|
+
DEFAULT_OPENAI_BASE_URL,
|
|
12
|
+
DEFAULT_OPENAI_EMBEDDING_MODEL,
|
|
13
|
+
)
|
|
14
|
+
from agent_cli.core.chroma import init_collection
|
|
15
|
+
from agent_cli.core.reranker import get_reranker_model
|
|
16
|
+
from agent_cli.rag._retriever import format_context, rerank_and_filter
|
|
17
|
+
from agent_cli.rag._utils import chunk_text, load_document_text
|
|
18
|
+
from agent_cli.rag.models import RagSource, RetrievalResult
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
from chromadb import Collection
|
|
24
|
+
|
|
25
|
+
from agent_cli.core.reranker import OnnxCrossEncoder
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("agent_cli.rag.client")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RagClient:
|
|
31
|
+
"""A composable RAG index for adding documents and searching.
|
|
32
|
+
|
|
33
|
+
Designed for building personal knowledge systems. Supports:
|
|
34
|
+
- Adding raw text with metadata (for chat ingestion)
|
|
35
|
+
- Adding files (auto-chunked)
|
|
36
|
+
- Search with metadata filtering
|
|
37
|
+
- Delete by ID or metadata filter
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
index = RagClient(chroma_path=Path("./my_index"))
|
|
41
|
+
index.add("User asked about Python", metadata={"source": "chatgpt"})
|
|
42
|
+
results = index.search("Python", filters={"source": "chatgpt"})
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
chroma_path: Path,
|
|
49
|
+
embedding_model: str = DEFAULT_OPENAI_EMBEDDING_MODEL,
|
|
50
|
+
openai_base_url: str = DEFAULT_OPENAI_BASE_URL,
|
|
51
|
+
openai_api_key: str | None = None,
|
|
52
|
+
collection_name: str = "rag_index",
|
|
53
|
+
chunk_size: int = 1200,
|
|
54
|
+
chunk_overlap: int = 200,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Initialize the RAG index.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
chroma_path: Path for ChromaDB persistence.
|
|
60
|
+
embedding_model: OpenAI embedding model name.
|
|
61
|
+
openai_base_url: Base URL for embedding API.
|
|
62
|
+
openai_api_key: API key for embeddings.
|
|
63
|
+
collection_name: Name of the ChromaDB collection.
|
|
64
|
+
chunk_size: Maximum chunk size in characters.
|
|
65
|
+
chunk_overlap: Overlap between chunks in characters.
|
|
66
|
+
|
|
67
|
+
"""
|
|
68
|
+
self.chroma_path = chroma_path
|
|
69
|
+
self._chunk_size = chunk_size
|
|
70
|
+
self._chunk_overlap = chunk_overlap
|
|
71
|
+
|
|
72
|
+
logger.info("Initializing RAG index at %s", chroma_path)
|
|
73
|
+
self.collection: Collection = init_collection(
|
|
74
|
+
chroma_path,
|
|
75
|
+
name=collection_name,
|
|
76
|
+
embedding_model=embedding_model,
|
|
77
|
+
openai_base_url=openai_base_url,
|
|
78
|
+
openai_api_key=openai_api_key,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
logger.info("Loading reranker model...")
|
|
82
|
+
self.reranker: OnnxCrossEncoder = get_reranker_model()
|
|
83
|
+
|
|
84
|
+
def add(
|
|
85
|
+
self,
|
|
86
|
+
text: str,
|
|
87
|
+
metadata: dict[str, Any] | None = None,
|
|
88
|
+
doc_id: str | None = None,
|
|
89
|
+
) -> str:
|
|
90
|
+
"""Add text to the index.
|
|
91
|
+
|
|
92
|
+
Text is automatically chunked if it exceeds chunk_size.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
text: The text content to add.
|
|
96
|
+
metadata: Optional metadata dict (e.g., {"source": "chatgpt"}).
|
|
97
|
+
doc_id: Optional document ID. Auto-generated if not provided.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
The document ID (useful for deletion).
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
doc_id = doc_id or str(uuid.uuid4())
|
|
104
|
+
metadata = metadata or {}
|
|
105
|
+
|
|
106
|
+
# Add indexing timestamp
|
|
107
|
+
metadata["indexed_at"] = datetime.now(UTC).isoformat()
|
|
108
|
+
metadata["doc_id"] = doc_id
|
|
109
|
+
|
|
110
|
+
# Chunk the text
|
|
111
|
+
chunks = chunk_text(text, self._chunk_size, self._chunk_overlap)
|
|
112
|
+
|
|
113
|
+
if not chunks:
|
|
114
|
+
logger.warning("No chunks generated for doc_id=%s", doc_id)
|
|
115
|
+
return doc_id
|
|
116
|
+
|
|
117
|
+
# Generate chunk IDs and metadata
|
|
118
|
+
ids = [f"{doc_id}:{i}" for i in range(len(chunks))]
|
|
119
|
+
metadatas = [
|
|
120
|
+
{
|
|
121
|
+
**metadata,
|
|
122
|
+
"chunk_id": i,
|
|
123
|
+
"total_chunks": len(chunks),
|
|
124
|
+
}
|
|
125
|
+
for i in range(len(chunks))
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
# Upsert to collection
|
|
129
|
+
self.collection.upsert(ids=ids, documents=chunks, metadatas=metadatas)
|
|
130
|
+
logger.info("Added doc_id=%s with %d chunks", doc_id, len(chunks))
|
|
131
|
+
|
|
132
|
+
return doc_id
|
|
133
|
+
|
|
134
|
+
def add_file(
|
|
135
|
+
self,
|
|
136
|
+
file_path: Path,
|
|
137
|
+
metadata: dict[str, Any] | None = None,
|
|
138
|
+
) -> str:
|
|
139
|
+
"""Add a file to the index.
|
|
140
|
+
|
|
141
|
+
File is read, chunked, and indexed with file metadata.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
file_path: Path to the file to add.
|
|
145
|
+
metadata: Optional additional metadata.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
The document ID.
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: If file cannot be read.
|
|
152
|
+
|
|
153
|
+
"""
|
|
154
|
+
text = load_document_text(file_path)
|
|
155
|
+
if text is None:
|
|
156
|
+
msg = f"Could not read file: {file_path}"
|
|
157
|
+
raise ValueError(msg)
|
|
158
|
+
|
|
159
|
+
file_metadata = {
|
|
160
|
+
"source": file_path.name,
|
|
161
|
+
"file_path": str(file_path),
|
|
162
|
+
"file_type": file_path.suffix.lstrip("."),
|
|
163
|
+
**(metadata or {}),
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
return self.add(text, file_metadata)
|
|
167
|
+
|
|
168
|
+
def search(
|
|
169
|
+
self,
|
|
170
|
+
query: str,
|
|
171
|
+
top_k: int = 5,
|
|
172
|
+
filters: dict[str, Any] | None = None,
|
|
173
|
+
min_score: float = 0.2,
|
|
174
|
+
) -> RetrievalResult:
|
|
175
|
+
"""Search the index with optional metadata filtering.
|
|
176
|
+
|
|
177
|
+
Uses bi-encoder for initial retrieval, then cross-encoder for reranking.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
query: The search query.
|
|
181
|
+
top_k: Number of results to return.
|
|
182
|
+
filters: ChromaDB where clause (e.g., {"source": "chatgpt"}).
|
|
183
|
+
min_score: Minimum relevance score threshold. Results below this are filtered out.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
RetrievalResult with context string and sources.
|
|
187
|
+
|
|
188
|
+
"""
|
|
189
|
+
# Over-fetch for reranking
|
|
190
|
+
n_candidates = top_k * 3
|
|
191
|
+
|
|
192
|
+
# Query with optional filter
|
|
193
|
+
results = self.collection.query(
|
|
194
|
+
query_texts=[query],
|
|
195
|
+
n_results=n_candidates,
|
|
196
|
+
where=filters,
|
|
197
|
+
include=["documents", "metadatas", "distances"],
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
docs = results.get("documents", [[]])[0]
|
|
201
|
+
metas = results.get("metadatas", [[]])[0]
|
|
202
|
+
|
|
203
|
+
if not docs:
|
|
204
|
+
return RetrievalResult(context="", sources=[])
|
|
205
|
+
|
|
206
|
+
# Rerank and filter
|
|
207
|
+
ranked = rerank_and_filter(self.reranker, query, docs, metas, top_k, min_score)
|
|
208
|
+
|
|
209
|
+
if not ranked:
|
|
210
|
+
return RetrievalResult(context="", sources=[])
|
|
211
|
+
|
|
212
|
+
# Build context and sources
|
|
213
|
+
context = format_context(ranked)
|
|
214
|
+
sources = [
|
|
215
|
+
RagSource(
|
|
216
|
+
source=meta.get("source", "unknown"),
|
|
217
|
+
path=meta.get("file_path", meta.get("doc_id", "unknown")),
|
|
218
|
+
chunk_id=meta.get("chunk_id", 0),
|
|
219
|
+
score=float(score),
|
|
220
|
+
)
|
|
221
|
+
for _doc, meta, score in ranked
|
|
222
|
+
]
|
|
223
|
+
|
|
224
|
+
return RetrievalResult(context=context, sources=sources)
|
|
225
|
+
|
|
226
|
+
def delete(self, doc_id: str) -> int:
|
|
227
|
+
"""Delete all chunks for a document ID.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
doc_id: The document ID to delete.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Number of chunks deleted.
|
|
234
|
+
|
|
235
|
+
"""
|
|
236
|
+
# Get all chunk IDs for this doc_id
|
|
237
|
+
results = self.collection.get(
|
|
238
|
+
where={"doc_id": doc_id},
|
|
239
|
+
include=[],
|
|
240
|
+
)
|
|
241
|
+
ids = results.get("ids", [])
|
|
242
|
+
|
|
243
|
+
if ids:
|
|
244
|
+
self.collection.delete(ids=ids)
|
|
245
|
+
logger.info("Deleted %d chunks for doc_id=%s", len(ids), doc_id)
|
|
246
|
+
|
|
247
|
+
return len(ids)
|
|
248
|
+
|
|
249
|
+
def delete_by_metadata(self, filters: dict[str, Any]) -> int:
|
|
250
|
+
"""Delete all documents matching a metadata filter.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
filters: ChromaDB where clause.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Number of chunks deleted.
|
|
257
|
+
|
|
258
|
+
"""
|
|
259
|
+
results = self.collection.get(
|
|
260
|
+
where=filters,
|
|
261
|
+
include=[],
|
|
262
|
+
)
|
|
263
|
+
ids = results.get("ids", [])
|
|
264
|
+
|
|
265
|
+
if ids:
|
|
266
|
+
self.collection.delete(ids=ids)
|
|
267
|
+
logger.info("Deleted %d chunks matching filters=%s", len(ids), filters)
|
|
268
|
+
|
|
269
|
+
return len(ids)
|
|
270
|
+
|
|
271
|
+
def count(self, filters: dict[str, Any] | None = None) -> int:
|
|
272
|
+
"""Count documents in the index.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
filters: Optional ChromaDB where clause.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Number of chunks (not documents).
|
|
279
|
+
|
|
280
|
+
"""
|
|
281
|
+
if filters is None:
|
|
282
|
+
return self.collection.count()
|
|
283
|
+
|
|
284
|
+
results = self.collection.get(where=filters, include=[])
|
|
285
|
+
return len(results.get("ids", []))
|
|
286
|
+
|
|
287
|
+
def list_sources(self) -> list[str]:
|
|
288
|
+
"""List unique source values in the index.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
Sorted list of unique source values.
|
|
292
|
+
|
|
293
|
+
"""
|
|
294
|
+
results = self.collection.get(include=["metadatas"])
|
|
295
|
+
metadatas = results.get("metadatas", []) or []
|
|
296
|
+
|
|
297
|
+
sources = {meta.get("source") for meta in metadatas if meta and meta.get("source")}
|
|
298
|
+
|
|
299
|
+
return sorted(sources)
|