stabilize 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. stabilize/__init__.py +29 -0
  2. stabilize/cli.py +1193 -0
  3. stabilize/context/__init__.py +7 -0
  4. stabilize/context/stage_context.py +170 -0
  5. stabilize/dag/__init__.py +15 -0
  6. stabilize/dag/graph.py +215 -0
  7. stabilize/dag/topological.py +199 -0
  8. stabilize/examples/__init__.py +1 -0
  9. stabilize/examples/docker-example.py +759 -0
  10. stabilize/examples/golden-standard-expected-result.txt +1 -0
  11. stabilize/examples/golden-standard.py +488 -0
  12. stabilize/examples/http-example.py +606 -0
  13. stabilize/examples/llama-example.py +662 -0
  14. stabilize/examples/python-example.py +731 -0
  15. stabilize/examples/shell-example.py +399 -0
  16. stabilize/examples/ssh-example.py +603 -0
  17. stabilize/handlers/__init__.py +53 -0
  18. stabilize/handlers/base.py +226 -0
  19. stabilize/handlers/complete_stage.py +209 -0
  20. stabilize/handlers/complete_task.py +75 -0
  21. stabilize/handlers/complete_workflow.py +150 -0
  22. stabilize/handlers/run_task.py +369 -0
  23. stabilize/handlers/start_stage.py +262 -0
  24. stabilize/handlers/start_task.py +74 -0
  25. stabilize/handlers/start_workflow.py +136 -0
  26. stabilize/launcher.py +307 -0
  27. stabilize/migrations/01KDQ4N9QPJ6Q4MCV3V9GHWPV4_initial_schema.sql +97 -0
  28. stabilize/migrations/01KDRK3TXW4R2GERC1WBCQYJGG_rag_embeddings.sql +25 -0
  29. stabilize/migrations/__init__.py +1 -0
  30. stabilize/models/__init__.py +15 -0
  31. stabilize/models/stage.py +389 -0
  32. stabilize/models/status.py +146 -0
  33. stabilize/models/task.py +125 -0
  34. stabilize/models/workflow.py +317 -0
  35. stabilize/orchestrator.py +113 -0
  36. stabilize/persistence/__init__.py +28 -0
  37. stabilize/persistence/connection.py +185 -0
  38. stabilize/persistence/factory.py +136 -0
  39. stabilize/persistence/memory.py +214 -0
  40. stabilize/persistence/postgres.py +655 -0
  41. stabilize/persistence/sqlite.py +674 -0
  42. stabilize/persistence/store.py +235 -0
  43. stabilize/queue/__init__.py +59 -0
  44. stabilize/queue/messages.py +377 -0
  45. stabilize/queue/processor.py +312 -0
  46. stabilize/queue/queue.py +526 -0
  47. stabilize/queue/sqlite_queue.py +354 -0
  48. stabilize/rag/__init__.py +19 -0
  49. stabilize/rag/assistant.py +459 -0
  50. stabilize/rag/cache.py +294 -0
  51. stabilize/stages/__init__.py +11 -0
  52. stabilize/stages/builder.py +253 -0
  53. stabilize/tasks/__init__.py +19 -0
  54. stabilize/tasks/interface.py +335 -0
  55. stabilize/tasks/registry.py +255 -0
  56. stabilize/tasks/result.py +283 -0
  57. stabilize-0.9.2.dist-info/METADATA +301 -0
  58. stabilize-0.9.2.dist-info/RECORD +61 -0
  59. stabilize-0.9.2.dist-info/WHEEL +4 -0
  60. stabilize-0.9.2.dist-info/entry_points.txt +2 -0
  61. stabilize-0.9.2.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,459 @@
1
+ """StabilizeRAG - RAG assistant for generating Stabilize pipelines."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, TypedDict
8
+
9
+ import numpy as np
10
+
11
+ from .cache import CachedEmbedding, EmbeddingCache
12
+
13
+ if TYPE_CHECKING:
14
+ from numpy.typing import NDArray
15
+
16
+
17
+ class ChunkDict(TypedDict):
18
+ """Type for document chunk dictionary."""
19
+
20
+ doc_id: str
21
+ content: str
22
+ chunk_index: int
23
+
24
+
25
+ # Load .env file if present (ragit does this too, but ensure it's loaded early)
26
+ try:
27
+ from dotenv import load_dotenv
28
+
29
+ _env_path = Path.cwd() / ".env"
30
+ if _env_path.exists():
31
+ load_dotenv(_env_path)
32
+ except ImportError:
33
+ pass # dotenv not required if env vars are set directly
34
+
35
+
36
+ class StabilizeRAG:
37
+ """RAG assistant for generating Stabilize pipelines.
38
+
39
+ Uses ragit for embeddings and LLM generation, with custom caching layer
40
+ to persist embeddings in database.
41
+
42
+ Configuration:
43
+ - LLM generation uses ollama.com cloud (requires OLLAMA_API_KEY)
44
+ - Embeddings use local Ollama (ollama.com doesn't support embeddings)
45
+
46
+ Environment Variables:
47
+ OLLAMA_API_KEY: Required API key for ollama.com cloud
48
+ OLLAMA_BASE_URL: Override LLM URL (default: https://ollama.com)
49
+ OLLAMA_EMBEDDING_URL: Override embedding URL (default: http://localhost:11434)
50
+ """
51
+
52
+ # Default URLs
53
+ DEFAULT_LLM_URL = "https://ollama.com"
54
+ DEFAULT_EMBEDDING_URL = "http://localhost:11434"
55
+
56
+ # Default models
57
+ DEFAULT_EMBEDDING_MODEL = "nomic-embed-text:latest"
58
+ DEFAULT_LLM_MODEL = "qwen3-vl:235b"
59
+
60
+ # Chunking defaults
61
+ DEFAULT_CHUNK_SIZE = 512
62
+ DEFAULT_CHUNK_OVERLAP = 100 # Increased overlap for better context continuity
63
+ DEFAULT_TOP_K = 10 # Retrieve more context chunks for better accuracy
64
+
65
+ def __init__(
66
+ self,
67
+ cache: EmbeddingCache,
68
+ embedding_model: str | None = None,
69
+ llm_model: str | None = None,
70
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
71
+ chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
72
+ ):
73
+ self.cache = cache
74
+ self.embedding_model = embedding_model or self.DEFAULT_EMBEDDING_MODEL
75
+ self.llm_model = llm_model or self.DEFAULT_LLM_MODEL
76
+ self.chunk_size = chunk_size
77
+ self.chunk_overlap = chunk_overlap
78
+
79
+ # Lazily initialized
80
+ self._provider = None
81
+ self._cached_embeddings: list[CachedEmbedding] | None = None
82
+ self._embedding_matrix: NDArray[np.float64] | None = None
83
+
84
+ def _get_provider(self) -> Any:
85
+ """Get or create OllamaProvider with cloud LLM and local embeddings."""
86
+ if self._provider is None:
87
+ try:
88
+ from ragit import OllamaProvider # type: ignore[import-untyped]
89
+ except ImportError as e:
90
+ raise ImportError("RAG support requires: pip install stabilize[rag]") from e
91
+
92
+ # Get configuration from environment or use defaults
93
+ llm_url = os.environ.get("OLLAMA_BASE_URL", self.DEFAULT_LLM_URL)
94
+ embedding_url = os.environ.get("OLLAMA_EMBEDDING_URL", self.DEFAULT_EMBEDDING_URL)
95
+ api_key = os.environ.get("OLLAMA_API_KEY")
96
+
97
+ # Validate API key if using cloud
98
+ if "ollama.com" in llm_url and not api_key:
99
+ raise RuntimeError(
100
+ "OLLAMA_API_KEY environment variable is required for ollama.com.\n"
101
+ "Set it in your .env file or export it:\n"
102
+ " export OLLAMA_API_KEY=your_api_key"
103
+ )
104
+
105
+ self._provider = OllamaProvider(
106
+ base_url=llm_url,
107
+ embedding_url=embedding_url,
108
+ api_key=api_key,
109
+ )
110
+ return self._provider
111
+
112
+ def init(self, force: bool = False, additional_paths: list[str] | None = None) -> int:
113
+ """Initialize embeddings from PROMPT_TEXT + examples/ + additional context.
114
+
115
+ Args:
116
+ force: If True, regenerate even if cache exists.
117
+ additional_paths: Optional list of file/directory paths to include as
118
+ additional training context.
119
+
120
+ Returns:
121
+ Number of embeddings cached.
122
+ """
123
+ if self.cache.is_initialized(self.embedding_model) and not force:
124
+ print(f"Cache already initialized for {self.embedding_model}")
125
+ return 0
126
+
127
+ # Load documents
128
+ documents = self._load_documents(additional_paths)
129
+ if not documents:
130
+ raise RuntimeError("No documents found to index")
131
+
132
+ print(f"Loaded {len(documents)} documents")
133
+
134
+ # Chunk documents
135
+ chunks = self._chunk_documents(documents)
136
+ print(f"Created {len(chunks)} chunks")
137
+
138
+ # Generate embeddings
139
+ print("Generating embeddings...")
140
+ provider = self._get_provider()
141
+ texts = [chunk["content"] for chunk in chunks]
142
+
143
+ try:
144
+ responses = provider.embed_batch(texts, self.embedding_model)
145
+ except ConnectionError as e:
146
+ embedding_url = os.environ.get("OLLAMA_EMBEDDING_URL", self.DEFAULT_EMBEDDING_URL)
147
+ raise RuntimeError(
148
+ f"Cannot connect to Ollama for embeddings at {embedding_url}\n\n"
149
+ "Embeddings require a local Ollama instance (ollama.com doesn't support embeddings).\n\n"
150
+ "To fix this:\n"
151
+ " 1. Install Ollama: https://ollama.com/download\n"
152
+ " 2. Start Ollama: ollama serve\n"
153
+ " 3. Pull embedding model: ollama pull nomic-embed-text\n\n"
154
+ "Or set OLLAMA_EMBEDDING_URL to point to your Ollama instance."
155
+ ) from e
156
+
157
+ # Build cached embeddings
158
+ cached = []
159
+ for i, (chunk, response) in enumerate(zip(chunks, responses)):
160
+ cached.append(
161
+ CachedEmbedding(
162
+ doc_id=chunk["doc_id"],
163
+ content=chunk["content"],
164
+ embedding=list(response.embedding),
165
+ embedding_model=self.embedding_model,
166
+ chunk_index=chunk["chunk_index"],
167
+ )
168
+ )
169
+
170
+ # Store in cache
171
+ self.cache.store(cached)
172
+ print(f"Cached {len(cached)} embeddings")
173
+
174
+ return len(cached)
175
+
176
+ def generate(self, prompt: str, top_k: int | None = None, temperature: float = 0.3) -> str:
177
+ """Generate pipeline code from natural language prompt.
178
+
179
+ Args:
180
+ prompt: Natural language description of desired pipeline.
181
+ top_k: Number of context chunks to retrieve (default: 10).
182
+ temperature: LLM temperature for generation.
183
+
184
+ Returns:
185
+ Generated Python code.
186
+ """
187
+ if top_k is None:
188
+ top_k = self.DEFAULT_TOP_K
189
+
190
+ if not self.cache.is_initialized(self.embedding_model):
191
+ raise RuntimeError("Run 'stabilize rag init' first to initialize embeddings")
192
+
193
+ # Load cached embeddings
194
+ self._load_cache()
195
+
196
+ # Get relevant context
197
+ context = self._get_context(prompt, top_k=top_k)
198
+
199
+ # Generate code
200
+ system_prompt = """You are a Stabilize workflow engine expert.
201
+ Generate ONLY valid Python code that creates a working Stabilize pipeline.
202
+
203
+ CRITICAL: Follow these EXACT patterns - do not invent your own API calls.
204
+
205
+ === IMPORTS (copy exactly) ===
206
+ from stabilize import Workflow, StageExecution, TaskExecution, WorkflowStatus
207
+ from stabilize.persistence.sqlite import SqliteWorkflowStore
208
+ from stabilize.queue.sqlite_queue import SqliteQueue
209
+ from stabilize.queue.processor import QueueProcessor
210
+ from stabilize.orchestrator import Orchestrator
211
+ from stabilize.tasks.interface import Task
212
+ from stabilize.tasks.result import TaskResult
213
+ from stabilize.tasks.registry import TaskRegistry
214
+ from stabilize.handlers.complete_workflow import CompleteWorkflowHandler
215
+ from stabilize.handlers.complete_stage import CompleteStageHandler
216
+ from stabilize.handlers.complete_task import CompleteTaskHandler
217
+ from stabilize.handlers.run_task import RunTaskHandler
218
+ from stabilize.handlers.start_workflow import StartWorkflowHandler
219
+ from stabilize.handlers.start_stage import StartStageHandler
220
+ from stabilize.handlers.start_task import StartTaskHandler
221
+
222
+ === TASK PATTERN (execute takes stage, not context) ===
223
+ class MyTask(Task):
224
+ def execute(self, stage: StageExecution) -> TaskResult:
225
+ value = stage.context.get("key") # Read from stage.context
226
+ return TaskResult.success(outputs={"result": value}) # Use factory methods
227
+
228
+ === TASKRESULT FACTORY METHODS ===
229
+ TaskResult.success(outputs={"key": "value"}) # Success with outputs
230
+ TaskResult.terminal(error="Error message") # Failure, halts pipeline
231
+
232
+ === WORKFLOWSTATUS ENUM VALUES (use exactly) ===
233
+ WorkflowStatus.NOT_STARTED # Initial state
234
+ WorkflowStatus.RUNNING # Currently executing
235
+ WorkflowStatus.SUCCEEDED # Completed successfully (NOT "COMPLETED")
236
+ WorkflowStatus.TERMINAL # Failed/halted
237
+
238
+ === SETUP PATTERN ===
239
+ store = SqliteWorkflowStore("sqlite:///:memory:", create_tables=True)
240
+ queue = SqliteQueue("sqlite:///:memory:", table_name="queue_messages")
241
+ queue._create_table()
242
+
243
+ registry = TaskRegistry()
244
+ registry.register("my_task", MyTask)
245
+
246
+ processor = QueueProcessor(queue)
247
+ handlers = [
248
+ StartWorkflowHandler(queue, store),
249
+ StartStageHandler(queue, store),
250
+ StartTaskHandler(queue, store),
251
+ RunTaskHandler(queue, store, registry),
252
+ CompleteTaskHandler(queue, store),
253
+ CompleteStageHandler(queue, store),
254
+ CompleteWorkflowHandler(queue, store),
255
+ ]
256
+ for h in handlers:
257
+ processor.register_handler(h)
258
+
259
+ orchestrator = Orchestrator(queue) # Only takes queue!
260
+
261
+ === WORKFLOW PATTERN ===
262
+ workflow = Workflow.create(
263
+ application="my-app",
264
+ name="My Pipeline",
265
+ stages=[
266
+ StageExecution(
267
+ ref_id="1",
268
+ type="my_task",
269
+ name="My Stage",
270
+ context={"key": "value"},
271
+ tasks=[
272
+ TaskExecution.create(
273
+ name="Run Task",
274
+ implementing_class="my_task", # Must match registry.register()
275
+ stage_start=True, # REQUIRED for first task
276
+ stage_end=True, # REQUIRED for last task
277
+ ),
278
+ ],
279
+ ),
280
+ ],
281
+ )
282
+
283
+ === EXECUTION ===
284
+ store.store(workflow)
285
+ orchestrator.start(workflow)
286
+ processor.process_all(timeout=30.0)
287
+ result = store.retrieve(workflow.id)
288
+
289
+ Output ONLY valid Python code. No markdown, no explanations."""
290
+
291
+ provider = self._get_provider()
292
+ response = provider.generate(
293
+ prompt=f"""Based on the following reference documentation and examples:
294
+
295
+ {context}
296
+
297
+ Generate a complete, runnable Python script that: {prompt}
298
+
299
+ Remember: Output ONLY valid Python code, no markdown, no explanations.""",
300
+ model=self.llm_model,
301
+ system_prompt=system_prompt,
302
+ temperature=temperature,
303
+ )
304
+
305
+ # Clean up response (remove any markdown if present)
306
+ code: str = response.text.strip()
307
+ if code.startswith("```python"):
308
+ code = code[9:]
309
+ if code.startswith("```"):
310
+ code = code[3:]
311
+ if code.endswith("```"):
312
+ code = code[:-3]
313
+
314
+ return code.strip()
315
+
316
+ def _load_documents(self, additional_paths: list[str] | None = None) -> list[dict[str, str]]:
317
+ """Load PROMPT_TEXT + bundled examples/*.py + additional context as documents.
318
+
319
+ Args:
320
+ additional_paths: Optional list of file/directory paths to include as
321
+ additional training context.
322
+
323
+ Returns:
324
+ List of documents with 'id' and 'content' keys.
325
+ """
326
+ from stabilize.cli import PROMPT_TEXT
327
+
328
+ docs = [{"id": "prompt_reference", "content": PROMPT_TEXT}]
329
+
330
+ # Try bundled examples from package
331
+ try:
332
+ from importlib.resources import files
333
+
334
+ examples_pkg = files("stabilize.examples")
335
+ for item in examples_pkg.iterdir():
336
+ if item.name.endswith(".py") and item.name != "__init__.py":
337
+ content = item.read_text()
338
+ docs.append({"id": item.name[:-3], "content": content})
339
+ except (TypeError, FileNotFoundError, ModuleNotFoundError):
340
+ pass
341
+
342
+ # Fallback: try local examples/ directory (for development)
343
+ examples_dir = Path(__file__).parent.parent.parent.parent / "examples"
344
+ if examples_dir.exists():
345
+ for py_file in examples_dir.glob("*.py"):
346
+ if py_file.name == "__init__.py":
347
+ continue
348
+ # Skip if already loaded from package
349
+ doc_id = py_file.stem
350
+ if any(d["id"] == doc_id for d in docs):
351
+ continue
352
+ content = py_file.read_text()
353
+ docs.append({"id": doc_id, "content": content})
354
+
355
+ # Load additional context from user-provided paths
356
+ if additional_paths:
357
+ for path_str in additional_paths:
358
+ path = Path(path_str)
359
+ if path.is_file():
360
+ # Single file
361
+ try:
362
+ content = path.read_text()
363
+ docs.append({"id": f"additional:{path.name}", "content": content})
364
+ except Exception as e:
365
+ print(f"Warning: Could not read {path}: {e}")
366
+ elif path.is_dir():
367
+ # Directory - load all .py files recursively
368
+ for py_file in path.rglob("*.py"):
369
+ if py_file.name == "__init__.py":
370
+ continue
371
+ try:
372
+ content = py_file.read_text()
373
+ rel_path = py_file.relative_to(path)
374
+ docs.append({"id": f"additional:{rel_path}", "content": content})
375
+ except Exception as e:
376
+ print(f"Warning: Could not read {py_file}: {e}")
377
+ else:
378
+ print(f"Warning: Path does not exist: {path}")
379
+
380
+ return docs
381
+
382
+ def _chunk_documents(self, documents: list[dict[str, str]]) -> list[ChunkDict]:
383
+ """Split documents into overlapping chunks."""
384
+ try:
385
+ from ragit import chunk_text
386
+ except ImportError as e:
387
+ raise ImportError("RAG support requires: pip install stabilize[rag]") from e
388
+
389
+ all_chunks: list[ChunkDict] = []
390
+ for doc in documents:
391
+ chunks = chunk_text(
392
+ doc["content"],
393
+ chunk_size=self.chunk_size,
394
+ chunk_overlap=self.chunk_overlap,
395
+ doc_id=doc["id"],
396
+ )
397
+ for i, chunk in enumerate(chunks):
398
+ all_chunks.append(
399
+ ChunkDict(
400
+ doc_id=doc["id"],
401
+ content=chunk.content,
402
+ chunk_index=i,
403
+ )
404
+ )
405
+
406
+ return all_chunks
407
+
408
+ def _load_cache(self) -> None:
409
+ """Load embeddings from cache and build embedding matrix."""
410
+ if self._cached_embeddings is not None:
411
+ return
412
+
413
+ self._cached_embeddings = self.cache.load(self.embedding_model)
414
+ if not self._cached_embeddings:
415
+ raise RuntimeError(f"No embeddings found for model {self.embedding_model}")
416
+
417
+ # Build normalized embedding matrix for fast similarity search
418
+ embeddings = [e.embedding for e in self._cached_embeddings]
419
+ matrix = np.array(embeddings, dtype=np.float64)
420
+
421
+ # Normalize for cosine similarity via dot product
422
+ norms = np.linalg.norm(matrix, axis=1, keepdims=True)
423
+ norms = np.where(norms == 0, 1, norms) # Avoid division by zero
424
+ self._embedding_matrix = matrix / norms
425
+
426
+ def _get_context(self, query: str, top_k: int = 5) -> str:
427
+ """Retrieve relevant context for a query."""
428
+ if self._cached_embeddings is None or self._embedding_matrix is None:
429
+ raise RuntimeError("Cache not loaded")
430
+
431
+ # Get query embedding
432
+ provider = self._get_provider()
433
+ response = provider.embed(query, self.embedding_model)
434
+ query_embedding = np.array(response.embedding, dtype=np.float64)
435
+
436
+ # Normalize query
437
+ query_norm = np.linalg.norm(query_embedding)
438
+ if query_norm > 0:
439
+ query_embedding = query_embedding / query_norm
440
+
441
+ # Cosine similarity via dot product (embeddings are pre-normalized)
442
+ similarities = self._embedding_matrix @ query_embedding
443
+
444
+ # Get top-k indices
445
+ if top_k >= len(similarities):
446
+ top_indices = np.argsort(similarities)[::-1]
447
+ else:
448
+ # Use argpartition for O(n) partial sort
449
+ top_indices = np.argpartition(similarities, -top_k)[-top_k:]
450
+ top_indices = top_indices[np.argsort(similarities[top_indices])[::-1]]
451
+
452
+ # Build context string
453
+ context_parts = []
454
+ for idx in top_indices:
455
+ emb = self._cached_embeddings[idx]
456
+ score = similarities[idx]
457
+ context_parts.append(f"--- {emb.doc_id} (relevance: {score:.3f}) ---\n{emb.content}")
458
+
459
+ return "\n\n".join(context_parts)