piragi 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- piragi-0.1.0.dist-info/METADATA +149 -0
- piragi-0.1.0.dist-info/RECORD +14 -0
- piragi-0.1.0.dist-info/WHEEL +4 -0
- piragi-0.1.0.dist-info/licenses/LICENSE +21 -0
- ragi/__init__.py +28 -0
- ragi/async_updater.py +345 -0
- ragi/change_detection.py +211 -0
- ragi/chunking.py +150 -0
- ragi/core.py +318 -0
- ragi/embeddings.py +150 -0
- ragi/loader.py +131 -0
- ragi/retrieval.py +125 -0
- ragi/store.py +177 -0
- ragi/types.py +54 -0
ragi/embeddings.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Embedding generation using local or remote models."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from .types import Chunk
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EmbeddingGenerator:
|
|
10
|
+
"""Generate embeddings using local sentence-transformers or remote API."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
model: str = "nvidia/llama-embed-nemotron-8b",
|
|
15
|
+
device: str | None = None,
|
|
16
|
+
base_url: str | None = None,
|
|
17
|
+
api_key: str | None = None,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""
|
|
20
|
+
Initialize the embedding generator.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model: Embedding model to use (default: nvidia/llama-embed-nemotron-8b)
|
|
24
|
+
device: Device to run on ('cuda', 'cpu', or None for auto-detect) - only for local models
|
|
25
|
+
base_url: Optional API base URL for remote embeddings (e.g., https://api.openai.com/v1)
|
|
26
|
+
api_key: Optional API key for remote embeddings
|
|
27
|
+
"""
|
|
28
|
+
self.model_name = model
|
|
29
|
+
self.base_url = base_url
|
|
30
|
+
self.api_key = api_key
|
|
31
|
+
self.use_remote = base_url is not None
|
|
32
|
+
|
|
33
|
+
if self.use_remote:
|
|
34
|
+
# Use OpenAI-compatible API client
|
|
35
|
+
from openai import OpenAI
|
|
36
|
+
|
|
37
|
+
if self.api_key is None:
|
|
38
|
+
self.api_key = os.getenv("EMBEDDING_API_KEY", "not-needed")
|
|
39
|
+
|
|
40
|
+
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
41
|
+
self.model = None
|
|
42
|
+
else:
|
|
43
|
+
# Use local sentence-transformers
|
|
44
|
+
from sentence_transformers import SentenceTransformer
|
|
45
|
+
|
|
46
|
+
self.model = SentenceTransformer(
|
|
47
|
+
model,
|
|
48
|
+
trust_remote_code=True,
|
|
49
|
+
device=device,
|
|
50
|
+
)
|
|
51
|
+
self.client = None
|
|
52
|
+
|
|
53
|
+
def embed_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
|
|
54
|
+
"""
|
|
55
|
+
Generate embeddings for a list of chunks.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
chunks: List of chunks to embed
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Chunks with embeddings added
|
|
62
|
+
"""
|
|
63
|
+
if not chunks:
|
|
64
|
+
return chunks
|
|
65
|
+
|
|
66
|
+
# Extract texts
|
|
67
|
+
texts = [chunk.text for chunk in chunks]
|
|
68
|
+
|
|
69
|
+
# Generate embeddings in batch (documents don't need instruction prefix)
|
|
70
|
+
embeddings = self._generate_embeddings(texts)
|
|
71
|
+
|
|
72
|
+
# Add embeddings to chunks
|
|
73
|
+
for chunk, embedding in zip(chunks, embeddings):
|
|
74
|
+
chunk.embedding = embedding.tolist()
|
|
75
|
+
|
|
76
|
+
return chunks
|
|
77
|
+
|
|
78
|
+
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
79
|
+
"""
|
|
80
|
+
Generate embeddings for a list of texts.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
texts: List of text strings
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of embedding vectors
|
|
87
|
+
"""
|
|
88
|
+
try:
|
|
89
|
+
if self.use_remote:
|
|
90
|
+
# Use OpenAI-compatible API
|
|
91
|
+
response = self.client.embeddings.create(
|
|
92
|
+
input=texts,
|
|
93
|
+
model=self.model_name,
|
|
94
|
+
)
|
|
95
|
+
return [item.embedding for item in response.data]
|
|
96
|
+
else:
|
|
97
|
+
# Use local sentence-transformers
|
|
98
|
+
# Use encode_document for document chunks if available
|
|
99
|
+
if hasattr(self.model, "encode_document"):
|
|
100
|
+
embeddings = self.model.encode_document(texts)
|
|
101
|
+
else:
|
|
102
|
+
embeddings = self.model.encode(texts)
|
|
103
|
+
return embeddings
|
|
104
|
+
|
|
105
|
+
except Exception as e:
|
|
106
|
+
raise RuntimeError(f"Failed to generate embeddings: {e}")
|
|
107
|
+
|
|
108
|
+
def embed_query(self, query: str, task_instruction: str | None = None) -> List[float]:
|
|
109
|
+
"""
|
|
110
|
+
Generate embedding for a single query.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
query: Query text
|
|
114
|
+
task_instruction: Optional task instruction for query
|
|
115
|
+
(e.g., "Retrieve relevant documents for this question")
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Embedding vector
|
|
119
|
+
"""
|
|
120
|
+
try:
|
|
121
|
+
if self.use_remote:
|
|
122
|
+
# Use OpenAI-compatible API
|
|
123
|
+
query_text = query
|
|
124
|
+
if task_instruction:
|
|
125
|
+
query_text = f"{task_instruction}\n{query}"
|
|
126
|
+
|
|
127
|
+
response = self.client.embeddings.create(
|
|
128
|
+
input=query_text,
|
|
129
|
+
model=self.model_name,
|
|
130
|
+
)
|
|
131
|
+
return response.data[0].embedding
|
|
132
|
+
else:
|
|
133
|
+
# Use local sentence-transformers
|
|
134
|
+
# Use encode_query for search queries if available
|
|
135
|
+
if hasattr(self.model, "encode_query"):
|
|
136
|
+
if task_instruction:
|
|
137
|
+
query_with_instruction = f"{task_instruction}\n{query}"
|
|
138
|
+
embedding = self.model.encode_query(query_with_instruction)
|
|
139
|
+
else:
|
|
140
|
+
embedding = self.model.encode_query(query)
|
|
141
|
+
else:
|
|
142
|
+
query_text = query
|
|
143
|
+
if task_instruction:
|
|
144
|
+
query_text = f"{task_instruction}\n{query}"
|
|
145
|
+
embedding = self.model.encode(query_text)
|
|
146
|
+
|
|
147
|
+
return embedding.tolist()
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
raise RuntimeError(f"Failed to generate query embedding: {e}")
|
ragi/loader.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Document loading using markitdown."""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import List, Union
|
|
7
|
+
from urllib.parse import urlparse
|
|
8
|
+
|
|
9
|
+
from markitdown import MarkItDown
|
|
10
|
+
|
|
11
|
+
from .types import Document
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DocumentLoader:
|
|
15
|
+
"""Load documents from various sources using markitdown."""
|
|
16
|
+
|
|
17
|
+
def __init__(self) -> None:
|
|
18
|
+
"""Initialize the document loader."""
|
|
19
|
+
self.converter = MarkItDown()
|
|
20
|
+
|
|
21
|
+
def load(self, source: Union[str, List[str]]) -> List[Document]:
|
|
22
|
+
"""
|
|
23
|
+
Load documents from file paths, URLs, or glob patterns.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
source: Single path/URL, list of paths/URLs, or glob pattern
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
List of loaded documents
|
|
30
|
+
"""
|
|
31
|
+
if isinstance(source, str):
|
|
32
|
+
sources = [source]
|
|
33
|
+
else:
|
|
34
|
+
sources = source
|
|
35
|
+
|
|
36
|
+
documents = []
|
|
37
|
+
for src in sources:
|
|
38
|
+
documents.extend(self._load_single(src))
|
|
39
|
+
|
|
40
|
+
return documents
|
|
41
|
+
|
|
42
|
+
def _load_single(self, source: str) -> List[Document]:
|
|
43
|
+
"""Load from a single source (file, URL, or glob pattern)."""
|
|
44
|
+
# Check if it's a URL
|
|
45
|
+
if self._is_url(source):
|
|
46
|
+
return [self._load_url(source)]
|
|
47
|
+
|
|
48
|
+
# Check if it's a glob pattern
|
|
49
|
+
if any(char in source for char in ["*", "?", "[", "]"]):
|
|
50
|
+
return self._load_glob(source)
|
|
51
|
+
|
|
52
|
+
# Single file
|
|
53
|
+
if os.path.isfile(source):
|
|
54
|
+
return [self._load_file(source)]
|
|
55
|
+
|
|
56
|
+
# Directory - load all files
|
|
57
|
+
if os.path.isdir(source):
|
|
58
|
+
return self._load_directory(source)
|
|
59
|
+
|
|
60
|
+
raise ValueError(f"Invalid source: {source}")
|
|
61
|
+
|
|
62
|
+
def _is_url(self, source: str) -> bool:
|
|
63
|
+
"""Check if source is a URL."""
|
|
64
|
+
try:
|
|
65
|
+
result = urlparse(source)
|
|
66
|
+
return all([result.scheme, result.netloc])
|
|
67
|
+
except Exception:
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
def _load_file(self, file_path: str) -> Document:
|
|
71
|
+
"""Load a single file."""
|
|
72
|
+
try:
|
|
73
|
+
result = self.converter.convert(file_path)
|
|
74
|
+
content = result.text_content
|
|
75
|
+
|
|
76
|
+
# Extract metadata
|
|
77
|
+
metadata = {
|
|
78
|
+
"filename": os.path.basename(file_path),
|
|
79
|
+
"file_type": Path(file_path).suffix.lstrip("."),
|
|
80
|
+
"file_path": os.path.abspath(file_path),
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
return Document(content=content, source=file_path, metadata=metadata)
|
|
84
|
+
|
|
85
|
+
except Exception as e:
|
|
86
|
+
raise RuntimeError(f"Failed to load file {file_path}: {e}")
|
|
87
|
+
|
|
88
|
+
def _load_url(self, url: str) -> Document:
|
|
89
|
+
"""Load content from a URL."""
|
|
90
|
+
try:
|
|
91
|
+
result = self.converter.convert(url)
|
|
92
|
+
content = result.text_content
|
|
93
|
+
|
|
94
|
+
metadata = {
|
|
95
|
+
"url": url,
|
|
96
|
+
"source_type": "url",
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return Document(content=content, source=url, metadata=metadata)
|
|
100
|
+
|
|
101
|
+
except Exception as e:
|
|
102
|
+
raise RuntimeError(f"Failed to load URL {url}: {e}")
|
|
103
|
+
|
|
104
|
+
def _load_glob(self, pattern: str) -> List[Document]:
|
|
105
|
+
"""Load files matching a glob pattern."""
|
|
106
|
+
files = glob.glob(pattern, recursive=True)
|
|
107
|
+
files = [f for f in files if os.path.isfile(f)]
|
|
108
|
+
|
|
109
|
+
if not files:
|
|
110
|
+
raise ValueError(f"No files found matching pattern: {pattern}")
|
|
111
|
+
|
|
112
|
+
return [self._load_file(f) for f in files]
|
|
113
|
+
|
|
114
|
+
def _load_directory(self, directory: str) -> List[Document]:
|
|
115
|
+
"""Load all files from a directory recursively."""
|
|
116
|
+
pattern = os.path.join(directory, "**", "*")
|
|
117
|
+
files = glob.glob(pattern, recursive=True)
|
|
118
|
+
files = [f for f in files if os.path.isfile(f)]
|
|
119
|
+
|
|
120
|
+
if not files:
|
|
121
|
+
raise ValueError(f"No files found in directory: {directory}")
|
|
122
|
+
|
|
123
|
+
documents = []
|
|
124
|
+
for f in files:
|
|
125
|
+
try:
|
|
126
|
+
documents.append(self._load_file(f))
|
|
127
|
+
except Exception:
|
|
128
|
+
# Skip files that can't be processed
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
return documents
|
ragi/retrieval.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""Retrieval and answer generation using OpenAI-compatible APIs."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from openai import OpenAI
|
|
7
|
+
|
|
8
|
+
from .types import Answer, Citation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Retriever:
|
|
12
|
+
"""Generate answers from retrieved chunks using OpenAI-compatible LLM APIs."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
model: str = "llama3.2",
|
|
17
|
+
api_key: str | None = None,
|
|
18
|
+
base_url: str | None = None,
|
|
19
|
+
) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Initialize the retriever.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
model: Model name to use (default: llama3.2 for Ollama)
|
|
25
|
+
api_key: API key (optional for local models like Ollama)
|
|
26
|
+
base_url: Base URL for OpenAI-compatible API (e.g., http://localhost:11434/v1 for Ollama)
|
|
27
|
+
"""
|
|
28
|
+
self.model = model
|
|
29
|
+
|
|
30
|
+
# Default to Ollama if no base_url provided
|
|
31
|
+
if base_url is None:
|
|
32
|
+
base_url = os.getenv("LLM_BASE_URL", "http://localhost:11434/v1")
|
|
33
|
+
|
|
34
|
+
# API key is optional for local models
|
|
35
|
+
if api_key is None:
|
|
36
|
+
api_key = os.getenv("LLM_API_KEY", "not-needed")
|
|
37
|
+
|
|
38
|
+
self.client = OpenAI(
|
|
39
|
+
api_key=api_key,
|
|
40
|
+
base_url=base_url,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def generate_answer(
|
|
44
|
+
self,
|
|
45
|
+
query: str,
|
|
46
|
+
citations: List[Citation],
|
|
47
|
+
system_prompt: Optional[str] = None,
|
|
48
|
+
) -> Answer:
|
|
49
|
+
"""
|
|
50
|
+
Generate an answer from retrieved citations.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
query: User's question
|
|
54
|
+
citations: Retrieved citations
|
|
55
|
+
system_prompt: Optional custom system prompt
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Answer with citations
|
|
59
|
+
"""
|
|
60
|
+
if not citations:
|
|
61
|
+
return Answer(
|
|
62
|
+
text="I couldn't find any relevant information to answer your question.",
|
|
63
|
+
citations=[],
|
|
64
|
+
query=query,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Build context from citations
|
|
68
|
+
context = self._build_context(citations)
|
|
69
|
+
|
|
70
|
+
# Generate answer
|
|
71
|
+
answer_text = self._generate_with_llm(query, context, system_prompt)
|
|
72
|
+
|
|
73
|
+
return Answer(
|
|
74
|
+
text=answer_text,
|
|
75
|
+
citations=citations,
|
|
76
|
+
query=query,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _build_context(self, citations: List[Citation]) -> str:
|
|
80
|
+
"""Build context string from citations."""
|
|
81
|
+
context_parts = []
|
|
82
|
+
|
|
83
|
+
for i, citation in enumerate(citations, 1):
|
|
84
|
+
source_info = f"Source {i} ({citation.source}):"
|
|
85
|
+
context_parts.append(f"{source_info}\n{citation.chunk}\n")
|
|
86
|
+
|
|
87
|
+
return "\n".join(context_parts)
|
|
88
|
+
|
|
89
|
+
def _generate_with_llm(
|
|
90
|
+
self,
|
|
91
|
+
query: str,
|
|
92
|
+
context: str,
|
|
93
|
+
system_prompt: Optional[str] = None,
|
|
94
|
+
) -> str:
|
|
95
|
+
"""Generate answer using OpenAI-compatible API."""
|
|
96
|
+
if system_prompt is None:
|
|
97
|
+
system_prompt = (
|
|
98
|
+
"You are a helpful assistant that answers questions based on the provided context. "
|
|
99
|
+
"Always cite your sources by mentioning which source number you're referring to. "
|
|
100
|
+
"If the context doesn't contain enough information to answer the question, say so. "
|
|
101
|
+
"Be concise and accurate."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
user_prompt = f"""Context from documents:
|
|
105
|
+
|
|
106
|
+
{context}
|
|
107
|
+
|
|
108
|
+
Question: {query}
|
|
109
|
+
|
|
110
|
+
Please answer the question based on the context provided above. Cite your sources."""
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
response = self.client.chat.completions.create(
|
|
114
|
+
model=self.model,
|
|
115
|
+
messages=[
|
|
116
|
+
{"role": "system", "content": system_prompt},
|
|
117
|
+
{"role": "user", "content": user_prompt},
|
|
118
|
+
],
|
|
119
|
+
temperature=0.3,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return response.choices[0].message.content or ""
|
|
123
|
+
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise RuntimeError(f"Failed to generate answer: {e}")
|
ragi/store.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Vector store using LanceDB."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import lancedb
|
|
8
|
+
from lancedb.pydantic import LanceModel, Vector
|
|
9
|
+
|
|
10
|
+
from .types import Chunk, Citation
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ChunkModel(LanceModel):
|
|
14
|
+
"""LanceDB schema for chunks."""
|
|
15
|
+
|
|
16
|
+
text: str
|
|
17
|
+
source: str
|
|
18
|
+
chunk_index: int
|
|
19
|
+
metadata: Dict[str, Any]
|
|
20
|
+
vector: Vector(4096) # nvidia/llama-embed-nemotron-8b dimension
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SourceMetadata(LanceModel):
|
|
24
|
+
"""Metadata for tracking source changes."""
|
|
25
|
+
|
|
26
|
+
source: str # File path or URL
|
|
27
|
+
last_checked: float # Unix timestamp
|
|
28
|
+
content_hash: str # SHA256 hash of content
|
|
29
|
+
mtime: Optional[float] = None # File modification time (for files)
|
|
30
|
+
etag: Optional[str] = None # HTTP ETag (for URLs)
|
|
31
|
+
last_modified: Optional[str] = None # HTTP Last-Modified (for URLs)
|
|
32
|
+
check_interval: float = 300.0 # Seconds between checks (default: 5 min)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class VectorStore:
|
|
36
|
+
"""Vector store using LanceDB."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, persist_dir: str = ".ragi") -> None:
|
|
39
|
+
"""
|
|
40
|
+
Initialize the vector store.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
persist_dir: Directory to persist the vector database
|
|
44
|
+
"""
|
|
45
|
+
self.persist_dir = persist_dir
|
|
46
|
+
Path(persist_dir).mkdir(parents=True, exist_ok=True)
|
|
47
|
+
|
|
48
|
+
self.db = lancedb.connect(persist_dir)
|
|
49
|
+
self.table_name = "chunks"
|
|
50
|
+
self.metadata_table_name = "source_metadata"
|
|
51
|
+
self.table: Optional[Any] = None
|
|
52
|
+
self.metadata_table: Optional[Any] = None
|
|
53
|
+
|
|
54
|
+
# Initialize tables if they exist
|
|
55
|
+
if self.table_name in self.db.table_names():
|
|
56
|
+
self.table = self.db.open_table(self.table_name)
|
|
57
|
+
|
|
58
|
+
if self.metadata_table_name in self.db.table_names():
|
|
59
|
+
self.metadata_table = self.db.open_table(self.metadata_table_name)
|
|
60
|
+
|
|
61
|
+
def add_chunks(self, chunks: List[Chunk]) -> None:
|
|
62
|
+
"""
|
|
63
|
+
Add chunks to the vector store.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
chunks: List of chunks with embeddings
|
|
67
|
+
"""
|
|
68
|
+
if not chunks:
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
# Validate embeddings
|
|
72
|
+
for chunk in chunks:
|
|
73
|
+
if chunk.embedding is None:
|
|
74
|
+
raise ValueError("All chunks must have embeddings before adding to store")
|
|
75
|
+
|
|
76
|
+
# Convert chunks to LanceDB format
|
|
77
|
+
data = [
|
|
78
|
+
{
|
|
79
|
+
"text": chunk.text,
|
|
80
|
+
"source": chunk.source,
|
|
81
|
+
"chunk_index": chunk.chunk_index,
|
|
82
|
+
"metadata": chunk.metadata,
|
|
83
|
+
"vector": chunk.embedding,
|
|
84
|
+
}
|
|
85
|
+
for chunk in chunks
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
# Create or update table
|
|
89
|
+
if self.table is None:
|
|
90
|
+
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")
|
|
91
|
+
else:
|
|
92
|
+
self.table.add(data)
|
|
93
|
+
|
|
94
|
+
def search(
|
|
95
|
+
self,
|
|
96
|
+
query_embedding: List[float],
|
|
97
|
+
top_k: int = 5,
|
|
98
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
99
|
+
) -> List[Citation]:
|
|
100
|
+
"""
|
|
101
|
+
Search for similar chunks.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
query_embedding: Query vector
|
|
105
|
+
top_k: Number of results to return
|
|
106
|
+
filters: Optional metadata filters
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List of citations
|
|
110
|
+
"""
|
|
111
|
+
if self.table is None:
|
|
112
|
+
return []
|
|
113
|
+
|
|
114
|
+
# Build query
|
|
115
|
+
query = self.table.search(query_embedding).limit(top_k)
|
|
116
|
+
|
|
117
|
+
# Apply filters if provided
|
|
118
|
+
if filters:
|
|
119
|
+
filter_conditions = []
|
|
120
|
+
for key, value in filters.items():
|
|
121
|
+
# Handle nested metadata filters
|
|
122
|
+
filter_conditions.append(f"metadata['{key}'] = '{value}'")
|
|
123
|
+
|
|
124
|
+
if filter_conditions:
|
|
125
|
+
query = query.where(" AND ".join(filter_conditions))
|
|
126
|
+
|
|
127
|
+
# Execute search
|
|
128
|
+
results = query.to_list()
|
|
129
|
+
|
|
130
|
+
# Convert to citations
|
|
131
|
+
citations = []
|
|
132
|
+
for result in results:
|
|
133
|
+
citation = Citation(
|
|
134
|
+
source=result["source"],
|
|
135
|
+
chunk=result["text"],
|
|
136
|
+
score=1.0 - result["_distance"], # Convert distance to similarity score
|
|
137
|
+
metadata=result["metadata"],
|
|
138
|
+
)
|
|
139
|
+
citations.append(citation)
|
|
140
|
+
|
|
141
|
+
return citations
|
|
142
|
+
|
|
143
|
+
def count(self) -> int:
|
|
144
|
+
"""Return the number of chunks in the store."""
|
|
145
|
+
if self.table is None:
|
|
146
|
+
return 0
|
|
147
|
+
return self.table.count_rows()
|
|
148
|
+
|
|
149
|
+
def delete_by_source(self, source: str) -> int:
|
|
150
|
+
"""
|
|
151
|
+
Delete all chunks from a specific source.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
source: Source file path or URL to delete
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Number of chunks deleted
|
|
158
|
+
"""
|
|
159
|
+
if self.table is None:
|
|
160
|
+
return 0
|
|
161
|
+
|
|
162
|
+
# Count chunks before deletion
|
|
163
|
+
count_before = self.table.count_rows()
|
|
164
|
+
|
|
165
|
+
# Delete chunks matching the source
|
|
166
|
+
self.table.delete(f"source = '{source}'")
|
|
167
|
+
|
|
168
|
+
# Count chunks after deletion
|
|
169
|
+
count_after = self.table.count_rows()
|
|
170
|
+
|
|
171
|
+
return count_before - count_after
|
|
172
|
+
|
|
173
|
+
def clear(self) -> None:
|
|
174
|
+
"""Clear all data from the store."""
|
|
175
|
+
if self.table_name in self.db.table_names():
|
|
176
|
+
self.db.drop_table(self.table_name)
|
|
177
|
+
self.table = None
|
ragi/types.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Type definitions for Ragi."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Citation(BaseModel):
|
|
8
|
+
"""A single citation with source information."""
|
|
9
|
+
|
|
10
|
+
source: str = Field(description="Source file path or URL")
|
|
11
|
+
chunk: str = Field(description="The actual text chunk")
|
|
12
|
+
score: float = Field(description="Relevance score (0-1)")
|
|
13
|
+
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def preview(self) -> str:
|
|
17
|
+
"""Return a preview of the chunk (first 100 chars)."""
|
|
18
|
+
return self.chunk[:100] + "..." if len(self.chunk) > 100 else self.chunk
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Answer(BaseModel):
|
|
22
|
+
"""Answer with citations from the RAG system."""
|
|
23
|
+
|
|
24
|
+
text: str = Field(description="The generated answer")
|
|
25
|
+
citations: List[Citation] = Field(default_factory=list, description="Source citations")
|
|
26
|
+
query: str = Field(description="Original query")
|
|
27
|
+
|
|
28
|
+
def __str__(self) -> str:
|
|
29
|
+
"""Return the answer text."""
|
|
30
|
+
return self.text
|
|
31
|
+
|
|
32
|
+
def __repr__(self) -> str:
|
|
33
|
+
"""Return detailed representation."""
|
|
34
|
+
return f"Answer(text='{self.text[:50]}...', citations={len(self.citations)})"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Document(BaseModel):
|
|
38
|
+
"""Internal document representation."""
|
|
39
|
+
|
|
40
|
+
content: str = Field(description="Document content in markdown")
|
|
41
|
+
source: str = Field(description="Source file path or URL")
|
|
42
|
+
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
|
|
43
|
+
content_hash: Optional[str] = Field(default=None, description="Hash of content for change detection")
|
|
44
|
+
last_modified: Optional[float] = Field(default=None, description="Last modification timestamp")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Chunk(BaseModel):
|
|
48
|
+
"""A chunk of a document with metadata."""
|
|
49
|
+
|
|
50
|
+
text: str = Field(description="Chunk text")
|
|
51
|
+
source: str = Field(description="Source document")
|
|
52
|
+
chunk_index: int = Field(description="Index of chunk in document")
|
|
53
|
+
metadata: Dict[str, Any] = Field(default_factory=dict, description="Chunk metadata")
|
|
54
|
+
embedding: Optional[List[float]] = Field(default=None, description="Vector embedding")
|