stabilize 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stabilize/__init__.py +29 -0
- stabilize/cli.py +1193 -0
- stabilize/context/__init__.py +7 -0
- stabilize/context/stage_context.py +170 -0
- stabilize/dag/__init__.py +15 -0
- stabilize/dag/graph.py +215 -0
- stabilize/dag/topological.py +199 -0
- stabilize/examples/__init__.py +1 -0
- stabilize/examples/docker-example.py +759 -0
- stabilize/examples/golden-standard-expected-result.txt +1 -0
- stabilize/examples/golden-standard.py +488 -0
- stabilize/examples/http-example.py +606 -0
- stabilize/examples/llama-example.py +662 -0
- stabilize/examples/python-example.py +731 -0
- stabilize/examples/shell-example.py +399 -0
- stabilize/examples/ssh-example.py +603 -0
- stabilize/handlers/__init__.py +53 -0
- stabilize/handlers/base.py +226 -0
- stabilize/handlers/complete_stage.py +209 -0
- stabilize/handlers/complete_task.py +75 -0
- stabilize/handlers/complete_workflow.py +150 -0
- stabilize/handlers/run_task.py +369 -0
- stabilize/handlers/start_stage.py +262 -0
- stabilize/handlers/start_task.py +74 -0
- stabilize/handlers/start_workflow.py +136 -0
- stabilize/launcher.py +307 -0
- stabilize/migrations/01KDQ4N9QPJ6Q4MCV3V9GHWPV4_initial_schema.sql +97 -0
- stabilize/migrations/01KDRK3TXW4R2GERC1WBCQYJGG_rag_embeddings.sql +25 -0
- stabilize/migrations/__init__.py +1 -0
- stabilize/models/__init__.py +15 -0
- stabilize/models/stage.py +389 -0
- stabilize/models/status.py +146 -0
- stabilize/models/task.py +125 -0
- stabilize/models/workflow.py +317 -0
- stabilize/orchestrator.py +113 -0
- stabilize/persistence/__init__.py +28 -0
- stabilize/persistence/connection.py +185 -0
- stabilize/persistence/factory.py +136 -0
- stabilize/persistence/memory.py +214 -0
- stabilize/persistence/postgres.py +655 -0
- stabilize/persistence/sqlite.py +674 -0
- stabilize/persistence/store.py +235 -0
- stabilize/queue/__init__.py +59 -0
- stabilize/queue/messages.py +377 -0
- stabilize/queue/processor.py +312 -0
- stabilize/queue/queue.py +526 -0
- stabilize/queue/sqlite_queue.py +354 -0
- stabilize/rag/__init__.py +19 -0
- stabilize/rag/assistant.py +459 -0
- stabilize/rag/cache.py +294 -0
- stabilize/stages/__init__.py +11 -0
- stabilize/stages/builder.py +253 -0
- stabilize/tasks/__init__.py +19 -0
- stabilize/tasks/interface.py +335 -0
- stabilize/tasks/registry.py +255 -0
- stabilize/tasks/result.py +283 -0
- stabilize-0.9.2.dist-info/METADATA +301 -0
- stabilize-0.9.2.dist-info/RECORD +61 -0
- stabilize-0.9.2.dist-info/WHEEL +4 -0
- stabilize-0.9.2.dist-info/entry_points.txt +2 -0
- stabilize-0.9.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
"""StabilizeRAG - RAG assistant for generating Stabilize pipelines."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from .cache import CachedEmbedding, EmbeddingCache
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from numpy.typing import NDArray
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ChunkDict(TypedDict):
|
|
18
|
+
"""Type for document chunk dictionary."""
|
|
19
|
+
|
|
20
|
+
doc_id: str
|
|
21
|
+
content: str
|
|
22
|
+
chunk_index: int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Load .env file if present (ragit does this too, but ensure it's loaded early)
|
|
26
|
+
try:
|
|
27
|
+
from dotenv import load_dotenv
|
|
28
|
+
|
|
29
|
+
_env_path = Path.cwd() / ".env"
|
|
30
|
+
if _env_path.exists():
|
|
31
|
+
load_dotenv(_env_path)
|
|
32
|
+
except ImportError:
|
|
33
|
+
pass # dotenv not required if env vars are set directly
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class StabilizeRAG:
|
|
37
|
+
"""RAG assistant for generating Stabilize pipelines.
|
|
38
|
+
|
|
39
|
+
Uses ragit for embeddings and LLM generation, with custom caching layer
|
|
40
|
+
to persist embeddings in database.
|
|
41
|
+
|
|
42
|
+
Configuration:
|
|
43
|
+
- LLM generation uses ollama.com cloud (requires OLLAMA_API_KEY)
|
|
44
|
+
- Embeddings use local Ollama (ollama.com doesn't support embeddings)
|
|
45
|
+
|
|
46
|
+
Environment Variables:
|
|
47
|
+
OLLAMA_API_KEY: Required API key for ollama.com cloud
|
|
48
|
+
OLLAMA_BASE_URL: Override LLM URL (default: https://ollama.com)
|
|
49
|
+
OLLAMA_EMBEDDING_URL: Override embedding URL (default: http://localhost:11434)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
# Default URLs
|
|
53
|
+
DEFAULT_LLM_URL = "https://ollama.com"
|
|
54
|
+
DEFAULT_EMBEDDING_URL = "http://localhost:11434"
|
|
55
|
+
|
|
56
|
+
# Default models
|
|
57
|
+
DEFAULT_EMBEDDING_MODEL = "nomic-embed-text:latest"
|
|
58
|
+
DEFAULT_LLM_MODEL = "qwen3-vl:235b"
|
|
59
|
+
|
|
60
|
+
# Chunking defaults
|
|
61
|
+
DEFAULT_CHUNK_SIZE = 512
|
|
62
|
+
DEFAULT_CHUNK_OVERLAP = 100 # Increased overlap for better context continuity
|
|
63
|
+
DEFAULT_TOP_K = 10 # Retrieve more context chunks for better accuracy
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
cache: EmbeddingCache,
|
|
68
|
+
embedding_model: str | None = None,
|
|
69
|
+
llm_model: str | None = None,
|
|
70
|
+
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
|
71
|
+
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
|
72
|
+
):
|
|
73
|
+
self.cache = cache
|
|
74
|
+
self.embedding_model = embedding_model or self.DEFAULT_EMBEDDING_MODEL
|
|
75
|
+
self.llm_model = llm_model or self.DEFAULT_LLM_MODEL
|
|
76
|
+
self.chunk_size = chunk_size
|
|
77
|
+
self.chunk_overlap = chunk_overlap
|
|
78
|
+
|
|
79
|
+
# Lazily initialized
|
|
80
|
+
self._provider = None
|
|
81
|
+
self._cached_embeddings: list[CachedEmbedding] | None = None
|
|
82
|
+
self._embedding_matrix: NDArray[np.float64] | None = None
|
|
83
|
+
|
|
84
|
+
def _get_provider(self) -> Any:
|
|
85
|
+
"""Get or create OllamaProvider with cloud LLM and local embeddings."""
|
|
86
|
+
if self._provider is None:
|
|
87
|
+
try:
|
|
88
|
+
from ragit import OllamaProvider # type: ignore[import-untyped]
|
|
89
|
+
except ImportError as e:
|
|
90
|
+
raise ImportError("RAG support requires: pip install stabilize[rag]") from e
|
|
91
|
+
|
|
92
|
+
# Get configuration from environment or use defaults
|
|
93
|
+
llm_url = os.environ.get("OLLAMA_BASE_URL", self.DEFAULT_LLM_URL)
|
|
94
|
+
embedding_url = os.environ.get("OLLAMA_EMBEDDING_URL", self.DEFAULT_EMBEDDING_URL)
|
|
95
|
+
api_key = os.environ.get("OLLAMA_API_KEY")
|
|
96
|
+
|
|
97
|
+
# Validate API key if using cloud
|
|
98
|
+
if "ollama.com" in llm_url and not api_key:
|
|
99
|
+
raise RuntimeError(
|
|
100
|
+
"OLLAMA_API_KEY environment variable is required for ollama.com.\n"
|
|
101
|
+
"Set it in your .env file or export it:\n"
|
|
102
|
+
" export OLLAMA_API_KEY=your_api_key"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self._provider = OllamaProvider(
|
|
106
|
+
base_url=llm_url,
|
|
107
|
+
embedding_url=embedding_url,
|
|
108
|
+
api_key=api_key,
|
|
109
|
+
)
|
|
110
|
+
return self._provider
|
|
111
|
+
|
|
112
|
+
def init(self, force: bool = False, additional_paths: list[str] | None = None) -> int:
|
|
113
|
+
"""Initialize embeddings from PROMPT_TEXT + examples/ + additional context.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
force: If True, regenerate even if cache exists.
|
|
117
|
+
additional_paths: Optional list of file/directory paths to include as
|
|
118
|
+
additional training context.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Number of embeddings cached.
|
|
122
|
+
"""
|
|
123
|
+
if self.cache.is_initialized(self.embedding_model) and not force:
|
|
124
|
+
print(f"Cache already initialized for {self.embedding_model}")
|
|
125
|
+
return 0
|
|
126
|
+
|
|
127
|
+
# Load documents
|
|
128
|
+
documents = self._load_documents(additional_paths)
|
|
129
|
+
if not documents:
|
|
130
|
+
raise RuntimeError("No documents found to index")
|
|
131
|
+
|
|
132
|
+
print(f"Loaded {len(documents)} documents")
|
|
133
|
+
|
|
134
|
+
# Chunk documents
|
|
135
|
+
chunks = self._chunk_documents(documents)
|
|
136
|
+
print(f"Created {len(chunks)} chunks")
|
|
137
|
+
|
|
138
|
+
# Generate embeddings
|
|
139
|
+
print("Generating embeddings...")
|
|
140
|
+
provider = self._get_provider()
|
|
141
|
+
texts = [chunk["content"] for chunk in chunks]
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
responses = provider.embed_batch(texts, self.embedding_model)
|
|
145
|
+
except ConnectionError as e:
|
|
146
|
+
embedding_url = os.environ.get("OLLAMA_EMBEDDING_URL", self.DEFAULT_EMBEDDING_URL)
|
|
147
|
+
raise RuntimeError(
|
|
148
|
+
f"Cannot connect to Ollama for embeddings at {embedding_url}\n\n"
|
|
149
|
+
"Embeddings require a local Ollama instance (ollama.com doesn't support embeddings).\n\n"
|
|
150
|
+
"To fix this:\n"
|
|
151
|
+
" 1. Install Ollama: https://ollama.com/download\n"
|
|
152
|
+
" 2. Start Ollama: ollama serve\n"
|
|
153
|
+
" 3. Pull embedding model: ollama pull nomic-embed-text\n\n"
|
|
154
|
+
"Or set OLLAMA_EMBEDDING_URL to point to your Ollama instance."
|
|
155
|
+
) from e
|
|
156
|
+
|
|
157
|
+
# Build cached embeddings
|
|
158
|
+
cached = []
|
|
159
|
+
for i, (chunk, response) in enumerate(zip(chunks, responses)):
|
|
160
|
+
cached.append(
|
|
161
|
+
CachedEmbedding(
|
|
162
|
+
doc_id=chunk["doc_id"],
|
|
163
|
+
content=chunk["content"],
|
|
164
|
+
embedding=list(response.embedding),
|
|
165
|
+
embedding_model=self.embedding_model,
|
|
166
|
+
chunk_index=chunk["chunk_index"],
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Store in cache
|
|
171
|
+
self.cache.store(cached)
|
|
172
|
+
print(f"Cached {len(cached)} embeddings")
|
|
173
|
+
|
|
174
|
+
return len(cached)
|
|
175
|
+
|
|
176
|
+
def generate(self, prompt: str, top_k: int | None = None, temperature: float = 0.3) -> str:
|
|
177
|
+
"""Generate pipeline code from natural language prompt.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
prompt: Natural language description of desired pipeline.
|
|
181
|
+
top_k: Number of context chunks to retrieve (default: 10).
|
|
182
|
+
temperature: LLM temperature for generation.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Generated Python code.
|
|
186
|
+
"""
|
|
187
|
+
if top_k is None:
|
|
188
|
+
top_k = self.DEFAULT_TOP_K
|
|
189
|
+
|
|
190
|
+
if not self.cache.is_initialized(self.embedding_model):
|
|
191
|
+
raise RuntimeError("Run 'stabilize rag init' first to initialize embeddings")
|
|
192
|
+
|
|
193
|
+
# Load cached embeddings
|
|
194
|
+
self._load_cache()
|
|
195
|
+
|
|
196
|
+
# Get relevant context
|
|
197
|
+
context = self._get_context(prompt, top_k=top_k)
|
|
198
|
+
|
|
199
|
+
# Generate code
|
|
200
|
+
system_prompt = """You are a Stabilize workflow engine expert.
|
|
201
|
+
Generate ONLY valid Python code that creates a working Stabilize pipeline.
|
|
202
|
+
|
|
203
|
+
CRITICAL: Follow these EXACT patterns - do not invent your own API calls.
|
|
204
|
+
|
|
205
|
+
=== IMPORTS (copy exactly) ===
|
|
206
|
+
from stabilize import Workflow, StageExecution, TaskExecution, WorkflowStatus
|
|
207
|
+
from stabilize.persistence.sqlite import SqliteWorkflowStore
|
|
208
|
+
from stabilize.queue.sqlite_queue import SqliteQueue
|
|
209
|
+
from stabilize.queue.processor import QueueProcessor
|
|
210
|
+
from stabilize.orchestrator import Orchestrator
|
|
211
|
+
from stabilize.tasks.interface import Task
|
|
212
|
+
from stabilize.tasks.result import TaskResult
|
|
213
|
+
from stabilize.tasks.registry import TaskRegistry
|
|
214
|
+
from stabilize.handlers.complete_workflow import CompleteWorkflowHandler
|
|
215
|
+
from stabilize.handlers.complete_stage import CompleteStageHandler
|
|
216
|
+
from stabilize.handlers.complete_task import CompleteTaskHandler
|
|
217
|
+
from stabilize.handlers.run_task import RunTaskHandler
|
|
218
|
+
from stabilize.handlers.start_workflow import StartWorkflowHandler
|
|
219
|
+
from stabilize.handlers.start_stage import StartStageHandler
|
|
220
|
+
from stabilize.handlers.start_task import StartTaskHandler
|
|
221
|
+
|
|
222
|
+
=== TASK PATTERN (execute takes stage, not context) ===
|
|
223
|
+
class MyTask(Task):
|
|
224
|
+
def execute(self, stage: StageExecution) -> TaskResult:
|
|
225
|
+
value = stage.context.get("key") # Read from stage.context
|
|
226
|
+
return TaskResult.success(outputs={"result": value}) # Use factory methods
|
|
227
|
+
|
|
228
|
+
=== TASKRESULT FACTORY METHODS ===
|
|
229
|
+
TaskResult.success(outputs={"key": "value"}) # Success with outputs
|
|
230
|
+
TaskResult.terminal(error="Error message") # Failure, halts pipeline
|
|
231
|
+
|
|
232
|
+
=== WORKFLOWSTATUS ENUM VALUES (use exactly) ===
|
|
233
|
+
WorkflowStatus.NOT_STARTED # Initial state
|
|
234
|
+
WorkflowStatus.RUNNING # Currently executing
|
|
235
|
+
WorkflowStatus.SUCCEEDED # Completed successfully (NOT "COMPLETED")
|
|
236
|
+
WorkflowStatus.TERMINAL # Failed/halted
|
|
237
|
+
|
|
238
|
+
=== SETUP PATTERN ===
|
|
239
|
+
store = SqliteWorkflowStore("sqlite:///:memory:", create_tables=True)
|
|
240
|
+
queue = SqliteQueue("sqlite:///:memory:", table_name="queue_messages")
|
|
241
|
+
queue._create_table()
|
|
242
|
+
|
|
243
|
+
registry = TaskRegistry()
|
|
244
|
+
registry.register("my_task", MyTask)
|
|
245
|
+
|
|
246
|
+
processor = QueueProcessor(queue)
|
|
247
|
+
handlers = [
|
|
248
|
+
StartWorkflowHandler(queue, store),
|
|
249
|
+
StartStageHandler(queue, store),
|
|
250
|
+
StartTaskHandler(queue, store),
|
|
251
|
+
RunTaskHandler(queue, store, registry),
|
|
252
|
+
CompleteTaskHandler(queue, store),
|
|
253
|
+
CompleteStageHandler(queue, store),
|
|
254
|
+
CompleteWorkflowHandler(queue, store),
|
|
255
|
+
]
|
|
256
|
+
for h in handlers:
|
|
257
|
+
processor.register_handler(h)
|
|
258
|
+
|
|
259
|
+
orchestrator = Orchestrator(queue) # Only takes queue!
|
|
260
|
+
|
|
261
|
+
=== WORKFLOW PATTERN ===
|
|
262
|
+
workflow = Workflow.create(
|
|
263
|
+
application="my-app",
|
|
264
|
+
name="My Pipeline",
|
|
265
|
+
stages=[
|
|
266
|
+
StageExecution(
|
|
267
|
+
ref_id="1",
|
|
268
|
+
type="my_task",
|
|
269
|
+
name="My Stage",
|
|
270
|
+
context={"key": "value"},
|
|
271
|
+
tasks=[
|
|
272
|
+
TaskExecution.create(
|
|
273
|
+
name="Run Task",
|
|
274
|
+
implementing_class="my_task", # Must match registry.register()
|
|
275
|
+
stage_start=True, # REQUIRED for first task
|
|
276
|
+
stage_end=True, # REQUIRED for last task
|
|
277
|
+
),
|
|
278
|
+
],
|
|
279
|
+
),
|
|
280
|
+
],
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
=== EXECUTION ===
|
|
284
|
+
store.store(workflow)
|
|
285
|
+
orchestrator.start(workflow)
|
|
286
|
+
processor.process_all(timeout=30.0)
|
|
287
|
+
result = store.retrieve(workflow.id)
|
|
288
|
+
|
|
289
|
+
Output ONLY valid Python code. No markdown, no explanations."""
|
|
290
|
+
|
|
291
|
+
provider = self._get_provider()
|
|
292
|
+
response = provider.generate(
|
|
293
|
+
prompt=f"""Based on the following reference documentation and examples:
|
|
294
|
+
|
|
295
|
+
{context}
|
|
296
|
+
|
|
297
|
+
Generate a complete, runnable Python script that: {prompt}
|
|
298
|
+
|
|
299
|
+
Remember: Output ONLY valid Python code, no markdown, no explanations.""",
|
|
300
|
+
model=self.llm_model,
|
|
301
|
+
system_prompt=system_prompt,
|
|
302
|
+
temperature=temperature,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Clean up response (remove any markdown if present)
|
|
306
|
+
code: str = response.text.strip()
|
|
307
|
+
if code.startswith("```python"):
|
|
308
|
+
code = code[9:]
|
|
309
|
+
if code.startswith("```"):
|
|
310
|
+
code = code[3:]
|
|
311
|
+
if code.endswith("```"):
|
|
312
|
+
code = code[:-3]
|
|
313
|
+
|
|
314
|
+
return code.strip()
|
|
315
|
+
|
|
316
|
+
def _load_documents(self, additional_paths: list[str] | None = None) -> list[dict[str, str]]:
|
|
317
|
+
"""Load PROMPT_TEXT + bundled examples/*.py + additional context as documents.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
additional_paths: Optional list of file/directory paths to include as
|
|
321
|
+
additional training context.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
List of documents with 'id' and 'content' keys.
|
|
325
|
+
"""
|
|
326
|
+
from stabilize.cli import PROMPT_TEXT
|
|
327
|
+
|
|
328
|
+
docs = [{"id": "prompt_reference", "content": PROMPT_TEXT}]
|
|
329
|
+
|
|
330
|
+
# Try bundled examples from package
|
|
331
|
+
try:
|
|
332
|
+
from importlib.resources import files
|
|
333
|
+
|
|
334
|
+
examples_pkg = files("stabilize.examples")
|
|
335
|
+
for item in examples_pkg.iterdir():
|
|
336
|
+
if item.name.endswith(".py") and item.name != "__init__.py":
|
|
337
|
+
content = item.read_text()
|
|
338
|
+
docs.append({"id": item.name[:-3], "content": content})
|
|
339
|
+
except (TypeError, FileNotFoundError, ModuleNotFoundError):
|
|
340
|
+
pass
|
|
341
|
+
|
|
342
|
+
# Fallback: try local examples/ directory (for development)
|
|
343
|
+
examples_dir = Path(__file__).parent.parent.parent.parent / "examples"
|
|
344
|
+
if examples_dir.exists():
|
|
345
|
+
for py_file in examples_dir.glob("*.py"):
|
|
346
|
+
if py_file.name == "__init__.py":
|
|
347
|
+
continue
|
|
348
|
+
# Skip if already loaded from package
|
|
349
|
+
doc_id = py_file.stem
|
|
350
|
+
if any(d["id"] == doc_id for d in docs):
|
|
351
|
+
continue
|
|
352
|
+
content = py_file.read_text()
|
|
353
|
+
docs.append({"id": doc_id, "content": content})
|
|
354
|
+
|
|
355
|
+
# Load additional context from user-provided paths
|
|
356
|
+
if additional_paths:
|
|
357
|
+
for path_str in additional_paths:
|
|
358
|
+
path = Path(path_str)
|
|
359
|
+
if path.is_file():
|
|
360
|
+
# Single file
|
|
361
|
+
try:
|
|
362
|
+
content = path.read_text()
|
|
363
|
+
docs.append({"id": f"additional:{path.name}", "content": content})
|
|
364
|
+
except Exception as e:
|
|
365
|
+
print(f"Warning: Could not read {path}: {e}")
|
|
366
|
+
elif path.is_dir():
|
|
367
|
+
# Directory - load all .py files recursively
|
|
368
|
+
for py_file in path.rglob("*.py"):
|
|
369
|
+
if py_file.name == "__init__.py":
|
|
370
|
+
continue
|
|
371
|
+
try:
|
|
372
|
+
content = py_file.read_text()
|
|
373
|
+
rel_path = py_file.relative_to(path)
|
|
374
|
+
docs.append({"id": f"additional:{rel_path}", "content": content})
|
|
375
|
+
except Exception as e:
|
|
376
|
+
print(f"Warning: Could not read {py_file}: {e}")
|
|
377
|
+
else:
|
|
378
|
+
print(f"Warning: Path does not exist: {path}")
|
|
379
|
+
|
|
380
|
+
return docs
|
|
381
|
+
|
|
382
|
+
def _chunk_documents(self, documents: list[dict[str, str]]) -> list[ChunkDict]:
|
|
383
|
+
"""Split documents into overlapping chunks."""
|
|
384
|
+
try:
|
|
385
|
+
from ragit import chunk_text
|
|
386
|
+
except ImportError as e:
|
|
387
|
+
raise ImportError("RAG support requires: pip install stabilize[rag]") from e
|
|
388
|
+
|
|
389
|
+
all_chunks: list[ChunkDict] = []
|
|
390
|
+
for doc in documents:
|
|
391
|
+
chunks = chunk_text(
|
|
392
|
+
doc["content"],
|
|
393
|
+
chunk_size=self.chunk_size,
|
|
394
|
+
chunk_overlap=self.chunk_overlap,
|
|
395
|
+
doc_id=doc["id"],
|
|
396
|
+
)
|
|
397
|
+
for i, chunk in enumerate(chunks):
|
|
398
|
+
all_chunks.append(
|
|
399
|
+
ChunkDict(
|
|
400
|
+
doc_id=doc["id"],
|
|
401
|
+
content=chunk.content,
|
|
402
|
+
chunk_index=i,
|
|
403
|
+
)
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
return all_chunks
|
|
407
|
+
|
|
408
|
+
def _load_cache(self) -> None:
|
|
409
|
+
"""Load embeddings from cache and build embedding matrix."""
|
|
410
|
+
if self._cached_embeddings is not None:
|
|
411
|
+
return
|
|
412
|
+
|
|
413
|
+
self._cached_embeddings = self.cache.load(self.embedding_model)
|
|
414
|
+
if not self._cached_embeddings:
|
|
415
|
+
raise RuntimeError(f"No embeddings found for model {self.embedding_model}")
|
|
416
|
+
|
|
417
|
+
# Build normalized embedding matrix for fast similarity search
|
|
418
|
+
embeddings = [e.embedding for e in self._cached_embeddings]
|
|
419
|
+
matrix = np.array(embeddings, dtype=np.float64)
|
|
420
|
+
|
|
421
|
+
# Normalize for cosine similarity via dot product
|
|
422
|
+
norms = np.linalg.norm(matrix, axis=1, keepdims=True)
|
|
423
|
+
norms = np.where(norms == 0, 1, norms) # Avoid division by zero
|
|
424
|
+
self._embedding_matrix = matrix / norms
|
|
425
|
+
|
|
426
|
+
def _get_context(self, query: str, top_k: int = 5) -> str:
|
|
427
|
+
"""Retrieve relevant context for a query."""
|
|
428
|
+
if self._cached_embeddings is None or self._embedding_matrix is None:
|
|
429
|
+
raise RuntimeError("Cache not loaded")
|
|
430
|
+
|
|
431
|
+
# Get query embedding
|
|
432
|
+
provider = self._get_provider()
|
|
433
|
+
response = provider.embed(query, self.embedding_model)
|
|
434
|
+
query_embedding = np.array(response.embedding, dtype=np.float64)
|
|
435
|
+
|
|
436
|
+
# Normalize query
|
|
437
|
+
query_norm = np.linalg.norm(query_embedding)
|
|
438
|
+
if query_norm > 0:
|
|
439
|
+
query_embedding = query_embedding / query_norm
|
|
440
|
+
|
|
441
|
+
# Cosine similarity via dot product (embeddings are pre-normalized)
|
|
442
|
+
similarities = self._embedding_matrix @ query_embedding
|
|
443
|
+
|
|
444
|
+
# Get top-k indices
|
|
445
|
+
if top_k >= len(similarities):
|
|
446
|
+
top_indices = np.argsort(similarities)[::-1]
|
|
447
|
+
else:
|
|
448
|
+
# Use argpartition for O(n) partial sort
|
|
449
|
+
top_indices = np.argpartition(similarities, -top_k)[-top_k:]
|
|
450
|
+
top_indices = top_indices[np.argsort(similarities[top_indices])[::-1]]
|
|
451
|
+
|
|
452
|
+
# Build context string
|
|
453
|
+
context_parts = []
|
|
454
|
+
for idx in top_indices:
|
|
455
|
+
emb = self._cached_embeddings[idx]
|
|
456
|
+
score = similarities[idx]
|
|
457
|
+
context_parts.append(f"--- {emb.doc_id} (relevance: {score:.3f}) ---\n{emb.content}")
|
|
458
|
+
|
|
459
|
+
return "\n\n".join(context_parts)
|