piragi 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ragi/embeddings.py ADDED
@@ -0,0 +1,150 @@
1
+ """Embedding generation using local or remote models."""
2
+
3
+ import os
4
+ from typing import List, Optional
5
+
6
+ from .types import Chunk
7
+
8
+
9
+ class EmbeddingGenerator:
10
+ """Generate embeddings using local sentence-transformers or remote API."""
11
+
12
+ def __init__(
13
+ self,
14
+ model: str = "nvidia/llama-embed-nemotron-8b",
15
+ device: str | None = None,
16
+ base_url: str | None = None,
17
+ api_key: str | None = None,
18
+ ) -> None:
19
+ """
20
+ Initialize the embedding generator.
21
+
22
+ Args:
23
+ model: Embedding model to use (default: nvidia/llama-embed-nemotron-8b)
24
+ device: Device to run on ('cuda', 'cpu', or None for auto-detect) - only for local models
25
+ base_url: Optional API base URL for remote embeddings (e.g., https://api.openai.com/v1)
26
+ api_key: Optional API key for remote embeddings
27
+ """
28
+ self.model_name = model
29
+ self.base_url = base_url
30
+ self.api_key = api_key
31
+ self.use_remote = base_url is not None
32
+
33
+ if self.use_remote:
34
+ # Use OpenAI-compatible API client
35
+ from openai import OpenAI
36
+
37
+ if self.api_key is None:
38
+ self.api_key = os.getenv("EMBEDDING_API_KEY", "not-needed")
39
+
40
+ self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
41
+ self.model = None
42
+ else:
43
+ # Use local sentence-transformers
44
+ from sentence_transformers import SentenceTransformer
45
+
46
+ self.model = SentenceTransformer(
47
+ model,
48
+ trust_remote_code=True,
49
+ device=device,
50
+ )
51
+ self.client = None
52
+
53
+ def embed_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
54
+ """
55
+ Generate embeddings for a list of chunks.
56
+
57
+ Args:
58
+ chunks: List of chunks to embed
59
+
60
+ Returns:
61
+ Chunks with embeddings added
62
+ """
63
+ if not chunks:
64
+ return chunks
65
+
66
+ # Extract texts
67
+ texts = [chunk.text for chunk in chunks]
68
+
69
+ # Generate embeddings in batch (documents don't need instruction prefix)
70
+ embeddings = self._generate_embeddings(texts)
71
+
72
+ # Add embeddings to chunks
73
+ for chunk, embedding in zip(chunks, embeddings):
74
+ chunk.embedding = embedding.tolist()
75
+
76
+ return chunks
77
+
78
+ def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
79
+ """
80
+ Generate embeddings for a list of texts.
81
+
82
+ Args:
83
+ texts: List of text strings
84
+
85
+ Returns:
86
+ List of embedding vectors
87
+ """
88
+ try:
89
+ if self.use_remote:
90
+ # Use OpenAI-compatible API
91
+ response = self.client.embeddings.create(
92
+ input=texts,
93
+ model=self.model_name,
94
+ )
95
+ return [item.embedding for item in response.data]
96
+ else:
97
+ # Use local sentence-transformers
98
+ # Use encode_document for document chunks if available
99
+ if hasattr(self.model, "encode_document"):
100
+ embeddings = self.model.encode_document(texts)
101
+ else:
102
+ embeddings = self.model.encode(texts)
103
+ return embeddings
104
+
105
+ except Exception as e:
106
+ raise RuntimeError(f"Failed to generate embeddings: {e}")
107
+
108
+ def embed_query(self, query: str, task_instruction: str | None = None) -> List[float]:
109
+ """
110
+ Generate embedding for a single query.
111
+
112
+ Args:
113
+ query: Query text
114
+ task_instruction: Optional task instruction for query
115
+ (e.g., "Retrieve relevant documents for this question")
116
+
117
+ Returns:
118
+ Embedding vector
119
+ """
120
+ try:
121
+ if self.use_remote:
122
+ # Use OpenAI-compatible API
123
+ query_text = query
124
+ if task_instruction:
125
+ query_text = f"{task_instruction}\n{query}"
126
+
127
+ response = self.client.embeddings.create(
128
+ input=query_text,
129
+ model=self.model_name,
130
+ )
131
+ return response.data[0].embedding
132
+ else:
133
+ # Use local sentence-transformers
134
+ # Use encode_query for search queries if available
135
+ if hasattr(self.model, "encode_query"):
136
+ if task_instruction:
137
+ query_with_instruction = f"{task_instruction}\n{query}"
138
+ embedding = self.model.encode_query(query_with_instruction)
139
+ else:
140
+ embedding = self.model.encode_query(query)
141
+ else:
142
+ query_text = query
143
+ if task_instruction:
144
+ query_text = f"{task_instruction}\n{query}"
145
+ embedding = self.model.encode(query_text)
146
+
147
+ return embedding.tolist()
148
+
149
+ except Exception as e:
150
+ raise RuntimeError(f"Failed to generate query embedding: {e}")
ragi/loader.py ADDED
@@ -0,0 +1,131 @@
1
+ """Document loading using markitdown."""
2
+
3
+ import glob
4
+ import os
5
+ from pathlib import Path
6
+ from typing import List, Union
7
+ from urllib.parse import urlparse
8
+
9
+ from markitdown import MarkItDown
10
+
11
+ from .types import Document
12
+
13
+
14
+ class DocumentLoader:
15
+ """Load documents from various sources using markitdown."""
16
+
17
+ def __init__(self) -> None:
18
+ """Initialize the document loader."""
19
+ self.converter = MarkItDown()
20
+
21
+ def load(self, source: Union[str, List[str]]) -> List[Document]:
22
+ """
23
+ Load documents from file paths, URLs, or glob patterns.
24
+
25
+ Args:
26
+ source: Single path/URL, list of paths/URLs, or glob pattern
27
+
28
+ Returns:
29
+ List of loaded documents
30
+ """
31
+ if isinstance(source, str):
32
+ sources = [source]
33
+ else:
34
+ sources = source
35
+
36
+ documents = []
37
+ for src in sources:
38
+ documents.extend(self._load_single(src))
39
+
40
+ return documents
41
+
42
+ def _load_single(self, source: str) -> List[Document]:
43
+ """Load from a single source (file, URL, or glob pattern)."""
44
+ # Check if it's a URL
45
+ if self._is_url(source):
46
+ return [self._load_url(source)]
47
+
48
+ # Check if it's a glob pattern
49
+ if any(char in source for char in ["*", "?", "[", "]"]):
50
+ return self._load_glob(source)
51
+
52
+ # Single file
53
+ if os.path.isfile(source):
54
+ return [self._load_file(source)]
55
+
56
+ # Directory - load all files
57
+ if os.path.isdir(source):
58
+ return self._load_directory(source)
59
+
60
+ raise ValueError(f"Invalid source: {source}")
61
+
62
+ def _is_url(self, source: str) -> bool:
63
+ """Check if source is a URL."""
64
+ try:
65
+ result = urlparse(source)
66
+ return all([result.scheme, result.netloc])
67
+ except Exception:
68
+ return False
69
+
70
+ def _load_file(self, file_path: str) -> Document:
71
+ """Load a single file."""
72
+ try:
73
+ result = self.converter.convert(file_path)
74
+ content = result.text_content
75
+
76
+ # Extract metadata
77
+ metadata = {
78
+ "filename": os.path.basename(file_path),
79
+ "file_type": Path(file_path).suffix.lstrip("."),
80
+ "file_path": os.path.abspath(file_path),
81
+ }
82
+
83
+ return Document(content=content, source=file_path, metadata=metadata)
84
+
85
+ except Exception as e:
86
+ raise RuntimeError(f"Failed to load file {file_path}: {e}")
87
+
88
+ def _load_url(self, url: str) -> Document:
89
+ """Load content from a URL."""
90
+ try:
91
+ result = self.converter.convert(url)
92
+ content = result.text_content
93
+
94
+ metadata = {
95
+ "url": url,
96
+ "source_type": "url",
97
+ }
98
+
99
+ return Document(content=content, source=url, metadata=metadata)
100
+
101
+ except Exception as e:
102
+ raise RuntimeError(f"Failed to load URL {url}: {e}")
103
+
104
+ def _load_glob(self, pattern: str) -> List[Document]:
105
+ """Load files matching a glob pattern."""
106
+ files = glob.glob(pattern, recursive=True)
107
+ files = [f for f in files if os.path.isfile(f)]
108
+
109
+ if not files:
110
+ raise ValueError(f"No files found matching pattern: {pattern}")
111
+
112
+ return [self._load_file(f) for f in files]
113
+
114
+ def _load_directory(self, directory: str) -> List[Document]:
115
+ """Load all files from a directory recursively."""
116
+ pattern = os.path.join(directory, "**", "*")
117
+ files = glob.glob(pattern, recursive=True)
118
+ files = [f for f in files if os.path.isfile(f)]
119
+
120
+ if not files:
121
+ raise ValueError(f"No files found in directory: {directory}")
122
+
123
+ documents = []
124
+ for f in files:
125
+ try:
126
+ documents.append(self._load_file(f))
127
+ except Exception:
128
+ # Skip files that can't be processed
129
+ continue
130
+
131
+ return documents
ragi/retrieval.py ADDED
@@ -0,0 +1,125 @@
1
+ """Retrieval and answer generation using OpenAI-compatible APIs."""
2
+
3
+ import os
4
+ from typing import List, Optional
5
+
6
+ from openai import OpenAI
7
+
8
+ from .types import Answer, Citation
9
+
10
+
11
+ class Retriever:
12
+ """Generate answers from retrieved chunks using OpenAI-compatible LLM APIs."""
13
+
14
+ def __init__(
15
+ self,
16
+ model: str = "llama3.2",
17
+ api_key: str | None = None,
18
+ base_url: str | None = None,
19
+ ) -> None:
20
+ """
21
+ Initialize the retriever.
22
+
23
+ Args:
24
+ model: Model name to use (default: llama3.2 for Ollama)
25
+ api_key: API key (optional for local models like Ollama)
26
+ base_url: Base URL for OpenAI-compatible API (e.g., http://localhost:11434/v1 for Ollama)
27
+ """
28
+ self.model = model
29
+
30
+ # Default to Ollama if no base_url provided
31
+ if base_url is None:
32
+ base_url = os.getenv("LLM_BASE_URL", "http://localhost:11434/v1")
33
+
34
+ # API key is optional for local models
35
+ if api_key is None:
36
+ api_key = os.getenv("LLM_API_KEY", "not-needed")
37
+
38
+ self.client = OpenAI(
39
+ api_key=api_key,
40
+ base_url=base_url,
41
+ )
42
+
43
+ def generate_answer(
44
+ self,
45
+ query: str,
46
+ citations: List[Citation],
47
+ system_prompt: Optional[str] = None,
48
+ ) -> Answer:
49
+ """
50
+ Generate an answer from retrieved citations.
51
+
52
+ Args:
53
+ query: User's question
54
+ citations: Retrieved citations
55
+ system_prompt: Optional custom system prompt
56
+
57
+ Returns:
58
+ Answer with citations
59
+ """
60
+ if not citations:
61
+ return Answer(
62
+ text="I couldn't find any relevant information to answer your question.",
63
+ citations=[],
64
+ query=query,
65
+ )
66
+
67
+ # Build context from citations
68
+ context = self._build_context(citations)
69
+
70
+ # Generate answer
71
+ answer_text = self._generate_with_llm(query, context, system_prompt)
72
+
73
+ return Answer(
74
+ text=answer_text,
75
+ citations=citations,
76
+ query=query,
77
+ )
78
+
79
+ def _build_context(self, citations: List[Citation]) -> str:
80
+ """Build context string from citations."""
81
+ context_parts = []
82
+
83
+ for i, citation in enumerate(citations, 1):
84
+ source_info = f"Source {i} ({citation.source}):"
85
+ context_parts.append(f"{source_info}\n{citation.chunk}\n")
86
+
87
+ return "\n".join(context_parts)
88
+
89
+ def _generate_with_llm(
90
+ self,
91
+ query: str,
92
+ context: str,
93
+ system_prompt: Optional[str] = None,
94
+ ) -> str:
95
+ """Generate answer using OpenAI-compatible API."""
96
+ if system_prompt is None:
97
+ system_prompt = (
98
+ "You are a helpful assistant that answers questions based on the provided context. "
99
+ "Always cite your sources by mentioning which source number you're referring to. "
100
+ "If the context doesn't contain enough information to answer the question, say so. "
101
+ "Be concise and accurate."
102
+ )
103
+
104
+ user_prompt = f"""Context from documents:
105
+
106
+ {context}
107
+
108
+ Question: {query}
109
+
110
+ Please answer the question based on the context provided above. Cite your sources."""
111
+
112
+ try:
113
+ response = self.client.chat.completions.create(
114
+ model=self.model,
115
+ messages=[
116
+ {"role": "system", "content": system_prompt},
117
+ {"role": "user", "content": user_prompt},
118
+ ],
119
+ temperature=0.3,
120
+ )
121
+
122
+ return response.choices[0].message.content or ""
123
+
124
+ except Exception as e:
125
+ raise RuntimeError(f"Failed to generate answer: {e}")
ragi/store.py ADDED
@@ -0,0 +1,177 @@
1
+ """Vector store using LanceDB."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import lancedb
8
+ from lancedb.pydantic import LanceModel, Vector
9
+
10
+ from .types import Chunk, Citation
11
+
12
+
13
+ class ChunkModel(LanceModel):
14
+ """LanceDB schema for chunks."""
15
+
16
+ text: str
17
+ source: str
18
+ chunk_index: int
19
+ metadata: Dict[str, Any]
20
+ vector: Vector(4096) # nvidia/llama-embed-nemotron-8b dimension
21
+
22
+
23
+ class SourceMetadata(LanceModel):
24
+ """Metadata for tracking source changes."""
25
+
26
+ source: str # File path or URL
27
+ last_checked: float # Unix timestamp
28
+ content_hash: str # SHA256 hash of content
29
+ mtime: Optional[float] = None # File modification time (for files)
30
+ etag: Optional[str] = None # HTTP ETag (for URLs)
31
+ last_modified: Optional[str] = None # HTTP Last-Modified (for URLs)
32
+ check_interval: float = 300.0 # Seconds between checks (default: 5 min)
33
+
34
+
35
+ class VectorStore:
36
+ """Vector store using LanceDB."""
37
+
38
+ def __init__(self, persist_dir: str = ".ragi") -> None:
39
+ """
40
+ Initialize the vector store.
41
+
42
+ Args:
43
+ persist_dir: Directory to persist the vector database
44
+ """
45
+ self.persist_dir = persist_dir
46
+ Path(persist_dir).mkdir(parents=True, exist_ok=True)
47
+
48
+ self.db = lancedb.connect(persist_dir)
49
+ self.table_name = "chunks"
50
+ self.metadata_table_name = "source_metadata"
51
+ self.table: Optional[Any] = None
52
+ self.metadata_table: Optional[Any] = None
53
+
54
+ # Initialize tables if they exist
55
+ if self.table_name in self.db.table_names():
56
+ self.table = self.db.open_table(self.table_name)
57
+
58
+ if self.metadata_table_name in self.db.table_names():
59
+ self.metadata_table = self.db.open_table(self.metadata_table_name)
60
+
61
+ def add_chunks(self, chunks: List[Chunk]) -> None:
62
+ """
63
+ Add chunks to the vector store.
64
+
65
+ Args:
66
+ chunks: List of chunks with embeddings
67
+ """
68
+ if not chunks:
69
+ return
70
+
71
+ # Validate embeddings
72
+ for chunk in chunks:
73
+ if chunk.embedding is None:
74
+ raise ValueError("All chunks must have embeddings before adding to store")
75
+
76
+ # Convert chunks to LanceDB format
77
+ data = [
78
+ {
79
+ "text": chunk.text,
80
+ "source": chunk.source,
81
+ "chunk_index": chunk.chunk_index,
82
+ "metadata": chunk.metadata,
83
+ "vector": chunk.embedding,
84
+ }
85
+ for chunk in chunks
86
+ ]
87
+
88
+ # Create or update table
89
+ if self.table is None:
90
+ self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")
91
+ else:
92
+ self.table.add(data)
93
+
94
+ def search(
95
+ self,
96
+ query_embedding: List[float],
97
+ top_k: int = 5,
98
+ filters: Optional[Dict[str, Any]] = None,
99
+ ) -> List[Citation]:
100
+ """
101
+ Search for similar chunks.
102
+
103
+ Args:
104
+ query_embedding: Query vector
105
+ top_k: Number of results to return
106
+ filters: Optional metadata filters
107
+
108
+ Returns:
109
+ List of citations
110
+ """
111
+ if self.table is None:
112
+ return []
113
+
114
+ # Build query
115
+ query = self.table.search(query_embedding).limit(top_k)
116
+
117
+ # Apply filters if provided
118
+ if filters:
119
+ filter_conditions = []
120
+ for key, value in filters.items():
121
+ # Handle nested metadata filters
122
+ filter_conditions.append(f"metadata['{key}'] = '{value}'")
123
+
124
+ if filter_conditions:
125
+ query = query.where(" AND ".join(filter_conditions))
126
+
127
+ # Execute search
128
+ results = query.to_list()
129
+
130
+ # Convert to citations
131
+ citations = []
132
+ for result in results:
133
+ citation = Citation(
134
+ source=result["source"],
135
+ chunk=result["text"],
136
+ score=1.0 - result["_distance"], # Convert distance to similarity score
137
+ metadata=result["metadata"],
138
+ )
139
+ citations.append(citation)
140
+
141
+ return citations
142
+
143
+ def count(self) -> int:
144
+ """Return the number of chunks in the store."""
145
+ if self.table is None:
146
+ return 0
147
+ return self.table.count_rows()
148
+
149
+ def delete_by_source(self, source: str) -> int:
150
+ """
151
+ Delete all chunks from a specific source.
152
+
153
+ Args:
154
+ source: Source file path or URL to delete
155
+
156
+ Returns:
157
+ Number of chunks deleted
158
+ """
159
+ if self.table is None:
160
+ return 0
161
+
162
+ # Count chunks before deletion
163
+ count_before = self.table.count_rows()
164
+
165
+ # Delete chunks matching the source
166
+ self.table.delete(f"source = '{source}'")
167
+
168
+ # Count chunks after deletion
169
+ count_after = self.table.count_rows()
170
+
171
+ return count_before - count_after
172
+
173
+ def clear(self) -> None:
174
+ """Clear all data from the store."""
175
+ if self.table_name in self.db.table_names():
176
+ self.db.drop_table(self.table_name)
177
+ self.table = None
ragi/types.py ADDED
@@ -0,0 +1,54 @@
1
+ """Type definitions for Ragi."""
2
+
3
+ from typing import Any, Dict, List, Optional
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class Citation(BaseModel):
8
+ """A single citation with source information."""
9
+
10
+ source: str = Field(description="Source file path or URL")
11
+ chunk: str = Field(description="The actual text chunk")
12
+ score: float = Field(description="Relevance score (0-1)")
13
+ metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
14
+
15
+ @property
16
+ def preview(self) -> str:
17
+ """Return a preview of the chunk (first 100 chars)."""
18
+ return self.chunk[:100] + "..." if len(self.chunk) > 100 else self.chunk
19
+
20
+
21
+ class Answer(BaseModel):
22
+ """Answer with citations from the RAG system."""
23
+
24
+ text: str = Field(description="The generated answer")
25
+ citations: List[Citation] = Field(default_factory=list, description="Source citations")
26
+ query: str = Field(description="Original query")
27
+
28
+ def __str__(self) -> str:
29
+ """Return the answer text."""
30
+ return self.text
31
+
32
+ def __repr__(self) -> str:
33
+ """Return detailed representation."""
34
+ return f"Answer(text='{self.text[:50]}...', citations={len(self.citations)})"
35
+
36
+
37
+ class Document(BaseModel):
38
+ """Internal document representation."""
39
+
40
+ content: str = Field(description="Document content in markdown")
41
+ source: str = Field(description="Source file path or URL")
42
+ metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
43
+ content_hash: Optional[str] = Field(default=None, description="Hash of content for change detection")
44
+ last_modified: Optional[float] = Field(default=None, description="Last modification timestamp")
45
+
46
+
47
+ class Chunk(BaseModel):
48
+ """A chunk of a document with metadata."""
49
+
50
+ text: str = Field(description="Chunk text")
51
+ source: str = Field(description="Source document")
52
+ chunk_index: int = Field(description="Index of chunk in document")
53
+ metadata: Dict[str, Any] = Field(default_factory=dict, description="Chunk metadata")
54
+ embedding: Optional[List[float]] = Field(default=None, description="Vector embedding")