stabilize 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. stabilize/__init__.py +29 -0
  2. stabilize/cli.py +1193 -0
  3. stabilize/context/__init__.py +7 -0
  4. stabilize/context/stage_context.py +170 -0
  5. stabilize/dag/__init__.py +15 -0
  6. stabilize/dag/graph.py +215 -0
  7. stabilize/dag/topological.py +199 -0
  8. stabilize/examples/__init__.py +1 -0
  9. stabilize/examples/docker-example.py +759 -0
  10. stabilize/examples/golden-standard-expected-result.txt +1 -0
  11. stabilize/examples/golden-standard.py +488 -0
  12. stabilize/examples/http-example.py +606 -0
  13. stabilize/examples/llama-example.py +662 -0
  14. stabilize/examples/python-example.py +731 -0
  15. stabilize/examples/shell-example.py +399 -0
  16. stabilize/examples/ssh-example.py +603 -0
  17. stabilize/handlers/__init__.py +53 -0
  18. stabilize/handlers/base.py +226 -0
  19. stabilize/handlers/complete_stage.py +209 -0
  20. stabilize/handlers/complete_task.py +75 -0
  21. stabilize/handlers/complete_workflow.py +150 -0
  22. stabilize/handlers/run_task.py +369 -0
  23. stabilize/handlers/start_stage.py +262 -0
  24. stabilize/handlers/start_task.py +74 -0
  25. stabilize/handlers/start_workflow.py +136 -0
  26. stabilize/launcher.py +307 -0
  27. stabilize/migrations/01KDQ4N9QPJ6Q4MCV3V9GHWPV4_initial_schema.sql +97 -0
  28. stabilize/migrations/01KDRK3TXW4R2GERC1WBCQYJGG_rag_embeddings.sql +25 -0
  29. stabilize/migrations/__init__.py +1 -0
  30. stabilize/models/__init__.py +15 -0
  31. stabilize/models/stage.py +389 -0
  32. stabilize/models/status.py +146 -0
  33. stabilize/models/task.py +125 -0
  34. stabilize/models/workflow.py +317 -0
  35. stabilize/orchestrator.py +113 -0
  36. stabilize/persistence/__init__.py +28 -0
  37. stabilize/persistence/connection.py +185 -0
  38. stabilize/persistence/factory.py +136 -0
  39. stabilize/persistence/memory.py +214 -0
  40. stabilize/persistence/postgres.py +655 -0
  41. stabilize/persistence/sqlite.py +674 -0
  42. stabilize/persistence/store.py +235 -0
  43. stabilize/queue/__init__.py +59 -0
  44. stabilize/queue/messages.py +377 -0
  45. stabilize/queue/processor.py +312 -0
  46. stabilize/queue/queue.py +526 -0
  47. stabilize/queue/sqlite_queue.py +354 -0
  48. stabilize/rag/__init__.py +19 -0
  49. stabilize/rag/assistant.py +459 -0
  50. stabilize/rag/cache.py +294 -0
  51. stabilize/stages/__init__.py +11 -0
  52. stabilize/stages/builder.py +253 -0
  53. stabilize/tasks/__init__.py +19 -0
  54. stabilize/tasks/interface.py +335 -0
  55. stabilize/tasks/registry.py +255 -0
  56. stabilize/tasks/result.py +283 -0
  57. stabilize-0.9.2.dist-info/METADATA +301 -0
  58. stabilize-0.9.2.dist-info/RECORD +61 -0
  59. stabilize-0.9.2.dist-info/WHEEL +4 -0
  60. stabilize-0.9.2.dist-info/entry_points.txt +2 -0
  61. stabilize-0.9.2.dist-info/licenses/LICENSE +201 -0
stabilize/cli.py ADDED
@@ -0,0 +1,1193 @@
1
+ """Stabilize CLI for database migrations and developer tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import hashlib
7
+ import os
8
+ import re
9
+ import sys
10
+ from importlib.resources import files
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING
13
+
14
+ if TYPE_CHECKING:
15
+ from typing import Any
16
+
17
+
18
+ # =============================================================================
19
+ # RAG PROMPT - Comprehensive documentation for AI coding agents
20
+ # =============================================================================
21
+
22
+ PROMPT_TEXT = r'''
23
+ ===============================================================================
24
+ STABILIZE WORKFLOW ENGINE - COMPLETE REFERENCE FOR CODE GENERATION
25
+ ===============================================================================
26
+
27
+ Stabilize is a Python DAG-based workflow orchestration engine. Workflows consist
28
+ of Stages (nodes in the DAG) containing Tasks (atomic work units). Stages can
29
+ run sequentially or in parallel based on their dependencies.
30
+
31
+ ===============================================================================
32
+ 1. COMPLETE WORKING EXAMPLE - COPY THIS AS YOUR STARTING TEMPLATE
33
+ ===============================================================================
34
+
35
+ #!/usr/bin/env python3
36
+ """Minimal working Stabilize workflow example."""
37
+
38
+ from stabilize import Workflow, StageExecution, TaskExecution, WorkflowStatus
39
+ from stabilize.persistence.sqlite import SqliteWorkflowStore
40
+ from stabilize.queue.sqlite_queue import SqliteQueue
41
+ from stabilize.queue.processor import QueueProcessor
42
+ from stabilize.queue.queue import Queue
43
+ from stabilize.persistence.store import WorkflowStore
44
+ from stabilize.orchestrator import Orchestrator
45
+ from stabilize.tasks.interface import Task
46
+ from stabilize.tasks.result import TaskResult
47
+ from stabilize.tasks.registry import TaskRegistry
48
+ from stabilize.handlers.complete_workflow import CompleteWorkflowHandler
49
+ from stabilize.handlers.complete_stage import CompleteStageHandler
50
+ from stabilize.handlers.complete_task import CompleteTaskHandler
51
+ from stabilize.handlers.run_task import RunTaskHandler
52
+ from stabilize.handlers.start_workflow import StartWorkflowHandler
53
+ from stabilize.handlers.start_stage import StartStageHandler
54
+ from stabilize.handlers.start_task import StartTaskHandler
55
+
56
+
57
+ # Step 1: Define your custom Task
58
+ class MyTask(Task):
59
+ """Custom task implementation."""
60
+
61
+ def execute(self, stage: StageExecution) -> TaskResult:
62
+ # Read inputs from stage.context
63
+ input_value = stage.context.get("my_input", "default")
64
+
65
+ # Do your work here
66
+ result = f"Processed: {input_value}"
67
+
68
+ # Return success with outputs for downstream stages
69
+ return TaskResult.success(outputs={"result": result})
70
+
71
+
72
+ # Step 2: Setup infrastructure
73
+ def setup_pipeline_runner(store: WorkflowStore, queue: Queue) -> tuple[QueueProcessor, Orchestrator]:
74
+ """Create processor and orchestrator with task registered."""
75
+ task_registry = TaskRegistry()
76
+ task_registry.register("my_task", MyTask) # Register task with name
77
+
78
+ processor = QueueProcessor(queue)
79
+
80
+ # Register all handlers in order
81
+ handlers = [
82
+ StartWorkflowHandler(queue, store),
83
+ StartStageHandler(queue, store),
84
+ StartTaskHandler(queue, store),
85
+ RunTaskHandler(queue, store, task_registry),
86
+ CompleteTaskHandler(queue, store),
87
+ CompleteStageHandler(queue, store),
88
+ CompleteWorkflowHandler(queue, store),
89
+ ]
90
+ for handler in handlers:
91
+ processor.register_handler(handler)
92
+
93
+ orchestrator = Orchestrator(queue)
94
+ return processor, orchestrator
95
+
96
+
97
+ # Step 3: Create and run workflow
98
+ def main():
99
+ # Initialize storage (in-memory SQLite for development)
100
+ store = SqliteWorkflowStore("sqlite:///:memory:", create_tables=True)
101
+ queue = SqliteQueue("sqlite:///:memory:", table_name="queue_messages")
102
+ queue._create_table()
103
+ processor, orchestrator = setup_pipeline_runner(store, queue)
104
+
105
+ # Create workflow with stages
106
+ workflow = Workflow.create(
107
+ application="my-app",
108
+ name="My Pipeline",
109
+ stages=[
110
+ StageExecution(
111
+ ref_id="1",
112
+ type="my_task",
113
+ name="First Stage",
114
+ context={"my_input": "hello world"},
115
+ tasks=[
116
+ TaskExecution.create(
117
+ name="Run MyTask",
118
+ implementing_class="my_task", # Must match registered name
119
+ stage_start=True, # REQUIRED for first task
120
+ stage_end=True, # REQUIRED for last task
121
+ ),
122
+ ],
123
+ ),
124
+ ],
125
+ )
126
+
127
+ # Execute workflow
128
+ store.store(workflow)
129
+ orchestrator.start(workflow)
130
+ processor.process_all(timeout=30.0)
131
+
132
+ # Check result
133
+ result = store.retrieve(workflow.id)
134
+ print(f"Status: {result.status}")
135
+ print(f"Output: {result.stages[0].outputs}")
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
140
+
141
+ ===============================================================================
142
+ 2. CORE CLASSES API
143
+ ===============================================================================
144
+
145
+ 2.1 Workflow
146
+ -------------
147
+ Factory: Workflow.create(application, name, stages, trigger=None, pipeline_config_id=None)
148
+
149
+ Fields:
150
+ id: str - Unique ULID identifier (auto-generated)
151
+ status: WorkflowStatus - Current execution status
152
+ stages: list[StageExecution] - All stages in the workflow
153
+ application: str - Application name
154
+ name: str - Pipeline name
155
+
156
+ Methods:
157
+ stage_by_id(stage_id) -> StageExecution - Get stage by internal ID
158
+ stage_by_ref_id(ref_id) -> StageExecution - Get stage by reference ID
159
+ get_context() -> dict - Get merged outputs from all stages
160
+
161
+
162
+ 2.2 StageExecution
163
+ -------------------
164
+ Constructor: StageExecution(ref_id, type, name, context, tasks, requisite_stage_ref_ids=set())
165
+
166
+ Fields:
167
+ ref_id: str - UNIQUE reference ID for DAG (e.g., "1", "deploy", "build")
168
+ type: str - Stage type (usually matches task name)
169
+ name: str - Human-readable name
170
+ context: dict[str, Any] - INPUT parameters for this stage
171
+ outputs: dict[str, Any] - OUTPUT values for downstream stages (populated by tasks)
172
+ tasks: list[TaskExecution] - Tasks to execute (sequentially)
173
+ requisite_stage_ref_ids: set[str] - Dependencies (ref_ids of upstream stages)
174
+ status: WorkflowStatus - Current status
175
+
176
+ DAG Dependencies:
177
+ - Empty set: Stage runs immediately (initial stage)
178
+ - {"A"}: Stage runs after stage with ref_id="A" completes
179
+ - {"A", "B"}: Stage waits for BOTH A and B to complete (join point)
180
+
181
+
182
+ 2.3 TaskExecution
183
+ ------------------
184
+ Factory: TaskExecution.create(name, implementing_class, stage_start=False, stage_end=False)
185
+
186
+ Fields:
187
+ name: str - Human-readable task name
188
+ implementing_class: str - MUST match the name used in TaskRegistry.register()
189
+ stage_start: bool - MUST be True for first task in stage
190
+ stage_end: bool - MUST be True for last task in stage
191
+ status: WorkflowStatus - Current status
192
+
193
+ CRITICAL: If a stage has only one task, set BOTH stage_start=True AND stage_end=True
194
+
195
+
196
+ 2.4 WorkflowStatus
197
+ -------------------
198
+ All status values:
199
+ NOT_STARTED - Not yet started
200
+ RUNNING - Currently executing
201
+ PAUSED - Paused, can be resumed
202
+ SUSPENDED - Waiting for external trigger
203
+ SUCCEEDED - Completed successfully
204
+ FAILED_CONTINUE - Failed but pipeline continues
205
+ TERMINAL - Failed, pipeline halts
206
+ CANCELED - Execution was canceled
207
+ STOPPED - Execution was stopped
208
+ SKIPPED - Stage/task was skipped
209
+ REDIRECT - Decision branch redirect
210
+ BUFFERED - Buffered, waiting
211
+
212
+ Properties:
213
+ .is_complete: bool - Has finished executing
214
+ .is_halt: bool - Blocks downstream stages
215
+ .is_successful: bool - SUCCEEDED, STOPPED, or SKIPPED
216
+ .is_failure: bool - TERMINAL, STOPPED, or FAILED_CONTINUE
217
+
218
+ ===============================================================================
219
+ 3. TASK IMPLEMENTATION
220
+ ===============================================================================
221
+
222
+ 3.1 Task Interface (Abstract Base Class)
223
+ -----------------------------------------
224
+ from stabilize.tasks.interface import Task
225
+
226
+ class MyTask(Task):
227
+ def execute(self, stage: StageExecution) -> TaskResult:
228
+ # Read from stage.context (includes upstream outputs)
229
+ value = stage.context.get("key")
230
+
231
+ # Return TaskResult
232
+ return TaskResult.success(outputs={"output_key": "value"})
233
+
234
+ # Optional: Handle timeout (for RetryableTask)
235
+ def on_timeout(self, stage: StageExecution) -> TaskResult | None:
236
+ return TaskResult.terminal(error="Task timed out")
237
+
238
+ # Optional: Handle cancellation
239
+ def on_cancel(self, stage: StageExecution) -> TaskResult | None:
240
+ return TaskResult.canceled()
241
+
242
+
243
+ 3.2 TaskResult Factory Methods - CRITICAL REFERENCE
244
+ ----------------------------------------------------
245
+ from stabilize.tasks.result import TaskResult
246
+
247
+ SUCCESS - Task completed successfully, pipeline continues:
248
+ TaskResult.success(outputs=None, context=None)
249
+ Parameters:
250
+ outputs: dict - Values available to downstream stages
251
+ context: dict - Values stored in stage.context (stage-scoped)
252
+
253
+ RUNNING - Task needs to poll again (for RetryableTask):
254
+ TaskResult.running(context=None)
255
+ Parameters:
256
+ context: dict - Updated state for next poll iteration
257
+
258
+ TERMINAL - Task failed, pipeline HALTS:
259
+ TaskResult.terminal(error, context=None)
260
+ Parameters:
261
+ error: str - Error message (REQUIRED)
262
+ context: dict - Additional context data
263
+ WARNING: Does NOT accept 'outputs' parameter!
264
+
265
+ FAILED_CONTINUE - Task failed but pipeline continues:
266
+ TaskResult.failed_continue(error, outputs=None, context=None)
267
+ Parameters:
268
+ error: str - Error message (REQUIRED)
269
+ outputs: dict - Values still available downstream
270
+ context: dict - Additional context data
271
+
272
+ SKIPPED - Task was skipped:
273
+ TaskResult.skipped()
274
+
275
+ CANCELED - Task was canceled:
276
+ TaskResult.canceled(outputs=None)
277
+
278
+ STOPPED - Task was stopped:
279
+ TaskResult.stopped(outputs=None)
280
+
281
+
282
+ 3.3 RetryableTask - For Polling Operations
283
+ -------------------------------------------
284
+ from datetime import timedelta
285
+ from stabilize.tasks.interface import RetryableTask
286
+
287
+ class PollTask(RetryableTask):
288
+ def get_timeout(self) -> timedelta:
289
+ """Maximum time before task times out."""
290
+ return timedelta(minutes=30)
291
+
292
+ def get_backoff_period(self, stage: StageExecution, duration: timedelta) -> timedelta:
293
+ """Time to wait between poll attempts."""
294
+ return timedelta(seconds=10)
295
+
296
+ def execute(self, stage: StageExecution) -> TaskResult:
297
+ status = check_external_system()
298
+
299
+ if status == "complete":
300
+ return TaskResult.success(outputs={"status": "done"})
301
+ elif status == "failed":
302
+ return TaskResult.terminal(error="External system failed")
303
+ else:
304
+ # Keep polling - will be called again after backoff
305
+ return TaskResult.running(context={"last_check": time.time()})
306
+
307
+
308
+ 3.4 SkippableTask - Conditional Execution
309
+ ------------------------------------------
310
+ from stabilize.tasks.interface import SkippableTask
311
+
312
+ class ConditionalTask(SkippableTask):
313
+ def is_enabled(self, stage: StageExecution) -> bool:
314
+ """Return False to skip this task."""
315
+ return stage.context.get("should_run", True)
316
+
317
+ def do_execute(self, stage: StageExecution) -> TaskResult:
318
+ """Actual task logic (only called if is_enabled returns True)."""
319
+ return TaskResult.success()
320
+
321
+ ===============================================================================
322
+ 4. TASK REGISTRY
323
+ ===============================================================================
324
+
325
+ from stabilize.tasks.registry import TaskRegistry
326
+
327
+ registry = TaskRegistry()
328
+
329
+ # Register a task class
330
+ registry.register("my_task", MyTask)
331
+
332
+ # Register with aliases
333
+ registry.register("http", HTTPTask, aliases=["http_request", "web_request"])
334
+
335
+ # The implementing_class in TaskExecution MUST match the registered name:
336
+ TaskExecution.create(
337
+ name="Do something",
338
+ implementing_class="my_task", # Must match registry.register() name
339
+ stage_start=True,
340
+ stage_end=True,
341
+ )
342
+
343
+ ===============================================================================
344
+ 5. DAG PATTERNS
345
+ ===============================================================================
346
+
347
+ 5.1 Sequential Stages (A -> B -> C)
348
+ ------------------------------------
349
+ stages=[
350
+ StageExecution(ref_id="A", ..., requisite_stage_ref_ids=set()), # Initial
351
+ StageExecution(ref_id="B", ..., requisite_stage_ref_ids={"A"}), # After A
352
+ StageExecution(ref_id="C", ..., requisite_stage_ref_ids={"B"}), # After B
353
+ ]
354
+
355
+
356
+ 5.2 Parallel Stages
357
+ --------------------
358
+ A
359
+ / \
360
+ B C <- B and C run in parallel after A
361
+ \ /
362
+ D
363
+
364
+ stages=[
365
+ StageExecution(ref_id="A", ..., requisite_stage_ref_ids=set()),
366
+ StageExecution(ref_id="B", ..., requisite_stage_ref_ids={"A"}), # Parallel
367
+ StageExecution(ref_id="C", ..., requisite_stage_ref_ids={"A"}), # Parallel
368
+ StageExecution(ref_id="D", ..., requisite_stage_ref_ids={"B", "C"}), # Join
369
+ ]
370
+
371
+
372
+ 5.3 Complex DAG
373
+ ----------------
374
+ A
375
+ /|\
376
+ B C D <- All parallel after A
377
+ |/ \|
378
+ E F <- E waits for B,C; F waits for C,D
379
+ \ /
380
+ G <- G waits for E and F
381
+
382
+ stages=[
383
+ StageExecution(ref_id="A", ..., requisite_stage_ref_ids=set()),
384
+ StageExecution(ref_id="B", ..., requisite_stage_ref_ids={"A"}),
385
+ StageExecution(ref_id="C", ..., requisite_stage_ref_ids={"A"}),
386
+ StageExecution(ref_id="D", ..., requisite_stage_ref_ids={"A"}),
387
+ StageExecution(ref_id="E", ..., requisite_stage_ref_ids={"B", "C"}),
388
+ StageExecution(ref_id="F", ..., requisite_stage_ref_ids={"C", "D"}),
389
+ StageExecution(ref_id="G", ..., requisite_stage_ref_ids={"E", "F"}),
390
+ ]
391
+
392
+ ===============================================================================
393
+ 6. CONTEXT AND OUTPUTS DATA FLOW
394
+ ===============================================================================
395
+
396
+ stage.context - INPUT: Parameters passed when creating the stage
397
+ Also includes outputs from upstream stages (automatic lookup)
398
+
399
+ stage.outputs - OUTPUT: Values produced by tasks for downstream stages
400
+ Set via TaskResult.success(outputs={...})
401
+
402
+ Example flow:
403
+ Stage A context: {"input": "hello"}
404
+ Stage A task returns: TaskResult.success(outputs={"result": "processed"})
405
+ Stage B context: {"input": "hello", "result": "processed"} <- Includes A's output
406
+
407
+ Accessing in tasks:
408
+ def execute(self, stage):
409
+ # Read from context (includes upstream outputs)
410
+ upstream_result = stage.context.get("result") # From upstream stage
411
+
412
+ # Write to outputs (available downstream)
413
+ return TaskResult.success(outputs={"my_output": "value"})
414
+
415
+ ===============================================================================
416
+ 7. COMMON MISTAKES AND HOW TO FIX THEM
417
+ ===============================================================================
418
+
419
+ MISTAKE 1: Using 'outputs' parameter with TaskResult.terminal()
420
+ ---------------------------------------------------------------
421
+ WRONG:
422
+ return TaskResult.terminal(error="Failed", outputs={"data": value})
423
+
424
+ RIGHT:
425
+ return TaskResult.terminal(error="Failed", context={"data": value})
426
+
427
+ terminal() only accepts: error (required), context (optional)
428
+
429
+
430
+ MISTAKE 2: Forgetting stage_start and stage_end on tasks
431
+ ---------------------------------------------------------
432
+ WRONG:
433
+ TaskExecution.create(name="X", implementing_class="y")
434
+
435
+ RIGHT:
436
+ TaskExecution.create(name="X", implementing_class="y", stage_start=True, stage_end=True)
437
+
438
+
439
+ MISTAKE 3: implementing_class doesn't match registered name
440
+ ------------------------------------------------------------
441
+ WRONG:
442
+ registry.register("http_task", HTTPTask)
443
+ TaskExecution.create(..., implementing_class="HTTPTask") # Class name, not registered name
444
+
445
+ RIGHT:
446
+ registry.register("http_task", HTTPTask)
447
+ TaskExecution.create(..., implementing_class="http_task") # Matches registered name
448
+
449
+
450
+ MISTAKE 4: Duplicate ref_id values
451
+ -----------------------------------
452
+ WRONG:
453
+ StageExecution(ref_id="1", name="Stage A", ...)
454
+ StageExecution(ref_id="1", name="Stage B", ...) # Same ref_id!
455
+
456
+ RIGHT:
457
+ StageExecution(ref_id="1", name="Stage A", ...)
458
+ StageExecution(ref_id="2", name="Stage B", ...) # Unique ref_ids
459
+
460
+
461
+ MISTAKE 5: Missing handlers
462
+ ----------------------------
463
+ All 7 handlers are REQUIRED for the engine to work:
464
+ StartWorkflowHandler, StartStageHandler, StartTaskHandler,
465
+ RunTaskHandler, CompleteTaskHandler, CompleteStageHandler, CompleteWorkflowHandler
466
+
467
+ ===============================================================================
468
+ 8. COMPLETE EXAMPLE: SEQUENTIAL PIPELINE WITH ERROR HANDLING
469
+ ===============================================================================
470
+
471
+ #!/usr/bin/env python3
472
+ from stabilize import Workflow, StageExecution, TaskExecution, WorkflowStatus
473
+ from stabilize.persistence.sqlite import SqliteWorkflowStore
474
+ from stabilize.queue.sqlite_queue import SqliteQueue
475
+ from stabilize.queue.processor import QueueProcessor
476
+ from stabilize.orchestrator import Orchestrator
477
+ from stabilize.tasks.interface import Task
478
+ from stabilize.tasks.result import TaskResult
479
+ from stabilize.tasks.registry import TaskRegistry
480
+ from stabilize.handlers.complete_workflow import CompleteWorkflowHandler
481
+ from stabilize.handlers.complete_stage import CompleteStageHandler
482
+ from stabilize.handlers.complete_task import CompleteTaskHandler
483
+ from stabilize.handlers.run_task import RunTaskHandler
484
+ from stabilize.handlers.start_workflow import StartWorkflowHandler
485
+ from stabilize.handlers.start_stage import StartStageHandler
486
+ from stabilize.handlers.start_task import StartTaskHandler
487
+
488
+
489
+ class ValidateTask(Task):
490
+ def execute(self, stage: StageExecution) -> TaskResult:
491
+ data = stage.context.get("data")
492
+ if not data:
493
+ return TaskResult.terminal(error="No data provided")
494
+ return TaskResult.success(outputs={"validated": True, "data": data})
495
+
496
+
497
+ class ProcessTask(Task):
498
+ def execute(self, stage: StageExecution) -> TaskResult:
499
+ data = stage.context.get("data")
500
+ validated = stage.context.get("validated")
501
+ if not validated:
502
+ return TaskResult.terminal(error="Data not validated")
503
+ result = data.upper()
504
+ return TaskResult.success(outputs={"processed_data": result})
505
+
506
+
507
+ class NotifyTask(Task):
508
+ def execute(self, stage: StageExecution) -> TaskResult:
509
+ processed = stage.context.get("processed_data")
510
+ # Even if notification fails, we don't want to fail the pipeline
511
+ try:
512
+ send_notification(processed)
513
+ return TaskResult.success(outputs={"notified": True})
514
+ except Exception as e:
515
+ # Use failed_continue to not halt the pipeline
516
+ return TaskResult.failed_continue(
517
+ error=f"Notification failed: {e}",
518
+ outputs={"notified": False}
519
+ )
520
+
521
+
522
+ def setup_pipeline_runner(store, queue):
523
+ registry = TaskRegistry()
524
+ registry.register("validate", ValidateTask)
525
+ registry.register("process", ProcessTask)
526
+ registry.register("notify", NotifyTask)
527
+
528
+ processor = QueueProcessor(queue)
529
+ handlers = [
530
+ StartWorkflowHandler(queue, store),
531
+ StartStageHandler(queue, store),
532
+ StartTaskHandler(queue, store),
533
+ RunTaskHandler(queue, store, registry),
534
+ CompleteTaskHandler(queue, store),
535
+ CompleteStageHandler(queue, store),
536
+ CompleteWorkflowHandler(queue, store),
537
+ ]
538
+ for h in handlers:
539
+ processor.register_handler(h)
540
+
541
+ return processor, Orchestrator(queue)
542
+
543
+
544
+ def main():
545
+ store = SqliteWorkflowStore("sqlite:///:memory:", create_tables=True)
546
+ queue = SqliteQueue("sqlite:///:memory:", table_name="queue_messages")
547
+ queue._create_table()
548
+ processor, orchestrator = setup_pipeline_runner(store, queue)
549
+
550
+ workflow = Workflow.create(
551
+ application="data-pipeline",
552
+ name="Process Data",
553
+ stages=[
554
+ StageExecution(
555
+ ref_id="validate",
556
+ type="validate",
557
+ name="Validate Input",
558
+ context={"data": "hello world"},
559
+ tasks=[TaskExecution.create("Validate", "validate", stage_start=True, stage_end=True)],
560
+ ),
561
+ StageExecution(
562
+ ref_id="process",
563
+ type="process",
564
+ name="Process Data",
565
+ requisite_stage_ref_ids={"validate"},
566
+ context={}, # Will receive 'data' from upstream
567
+ tasks=[TaskExecution.create("Process", "process", stage_start=True, stage_end=True)],
568
+ ),
569
+ StageExecution(
570
+ ref_id="notify",
571
+ type="notify",
572
+ name="Send Notification",
573
+ requisite_stage_ref_ids={"process"},
574
+ context={},
575
+ tasks=[TaskExecution.create("Notify", "notify", stage_start=True, stage_end=True)],
576
+ ),
577
+ ],
578
+ )
579
+
580
+ store.store(workflow)
581
+ orchestrator.start(workflow)
582
+ processor.process_all(timeout=30.0)
583
+
584
+ result = store.retrieve(workflow.id)
585
+ print(f"Final status: {result.status}")
586
+ for stage in result.stages:
587
+ print(f" {stage.name}: {stage.status} - {stage.outputs}")
588
+
589
+
590
+ if __name__ == "__main__":
591
+ main()
592
+
593
+ ===============================================================================
594
+ 9. COMPLETE EXAMPLE: PARALLEL STAGES WITH JOIN
595
+ ===============================================================================
596
+
597
+ #!/usr/bin/env python3
598
+ from stabilize import Workflow, StageExecution, TaskExecution
599
+ from stabilize.persistence.sqlite import SqliteWorkflowStore
600
+ from stabilize.queue.sqlite_queue import SqliteQueue
601
+ from stabilize.queue.processor import QueueProcessor
602
+ from stabilize.orchestrator import Orchestrator
603
+ from stabilize.tasks.interface import Task
604
+ from stabilize.tasks.result import TaskResult
605
+ from stabilize.tasks.registry import TaskRegistry
606
+ from stabilize.handlers.complete_workflow import CompleteWorkflowHandler
607
+ from stabilize.handlers.complete_stage import CompleteStageHandler
608
+ from stabilize.handlers.complete_task import CompleteTaskHandler
609
+ from stabilize.handlers.run_task import RunTaskHandler
610
+ from stabilize.handlers.start_workflow import StartWorkflowHandler
611
+ from stabilize.handlers.start_stage import StartStageHandler
612
+ from stabilize.handlers.start_task import StartTaskHandler
613
+
614
+
615
+ class FetchDataTask(Task):
616
+ def execute(self, stage: StageExecution) -> TaskResult:
617
+ source = stage.context.get("source")
618
+ # Simulate fetching data from different sources
619
+ data = f"data_from_{source}"
620
+ return TaskResult.success(outputs={f"{source}_data": data})
621
+
622
+
623
+ class AggregateTask(Task):
624
+ def execute(self, stage: StageExecution) -> TaskResult:
625
+ # Collect data from all upstream parallel stages
626
+ api_data = stage.context.get("api_data")
627
+ db_data = stage.context.get("db_data")
628
+ cache_data = stage.context.get("cache_data")
629
+ combined = f"{api_data} + {db_data} + {cache_data}"
630
+ return TaskResult.success(outputs={"combined_data": combined})
631
+
632
+
633
+ def setup_pipeline_runner(store, queue):
634
+ registry = TaskRegistry()
635
+ registry.register("fetch", FetchDataTask)
636
+ registry.register("aggregate", AggregateTask)
637
+
638
+ processor = QueueProcessor(queue)
639
+ for h in [
640
+ StartWorkflowHandler(queue, store),
641
+ StartStageHandler(queue, store),
642
+ StartTaskHandler(queue, store),
643
+ RunTaskHandler(queue, store, registry),
644
+ CompleteTaskHandler(queue, store),
645
+ CompleteStageHandler(queue, store),
646
+ CompleteWorkflowHandler(queue, store),
647
+ ]:
648
+ processor.register_handler(h)
649
+
650
+ return processor, Orchestrator(queue)
651
+
652
+
653
+ def main():
654
+ store = SqliteWorkflowStore("sqlite:///:memory:", create_tables=True)
655
+ queue = SqliteQueue("sqlite:///:memory:", table_name="queue_messages")
656
+ queue._create_table()
657
+ processor, orchestrator = setup_pipeline_runner(store, queue)
658
+
659
+ # Start
660
+ # / | \
661
+ # API DB Cache <- Run in parallel
662
+ # \ | /
663
+ # Aggregate <- Join point
664
+
665
+ workflow = Workflow.create(
666
+ application="parallel-fetch",
667
+ name="Parallel Data Fetch",
668
+ stages=[
669
+ StageExecution(
670
+ ref_id="api",
671
+ type="fetch",
672
+ name="Fetch from API",
673
+ context={"source": "api"},
674
+ tasks=[TaskExecution.create("Fetch API", "fetch", stage_start=True, stage_end=True)],
675
+ ),
676
+ StageExecution(
677
+ ref_id="db",
678
+ type="fetch",
679
+ name="Fetch from Database",
680
+ context={"source": "db"},
681
+ tasks=[TaskExecution.create("Fetch DB", "fetch", stage_start=True, stage_end=True)],
682
+ ),
683
+ StageExecution(
684
+ ref_id="cache",
685
+ type="fetch",
686
+ name="Fetch from Cache",
687
+ context={"source": "cache"},
688
+ tasks=[TaskExecution.create("Fetch Cache", "fetch", stage_start=True, stage_end=True)],
689
+ ),
690
+ StageExecution(
691
+ ref_id="aggregate",
692
+ type="aggregate",
693
+ name="Aggregate Results",
694
+ requisite_stage_ref_ids={"api", "db", "cache"}, # Wait for ALL three
695
+ context={},
696
+ tasks=[TaskExecution.create("Aggregate", "aggregate", stage_start=True, stage_end=True)],
697
+ ),
698
+ ],
699
+ )
700
+
701
+ store.store(workflow)
702
+ orchestrator.start(workflow)
703
+ processor.process_all(timeout=30.0)
704
+
705
+ result = store.retrieve(workflow.id)
706
+ print(f"Final status: {result.status}")
707
+ print(f"Combined data: {result.stages[-1].outputs.get('combined_data')}")
708
+
709
+
710
+ if __name__ == "__main__":
711
+ main()
712
+
713
+ ===============================================================================
714
+ 10. COMPLETE IMPORTS REFERENCE
715
+ ===============================================================================
716
+
717
+ # Core models
718
+ from stabilize import Workflow, StageExecution, TaskExecution, WorkflowStatus
719
+
720
+ # Persistence
721
+ from stabilize.persistence.sqlite import SqliteWorkflowStore
722
+ from stabilize.persistence.store import WorkflowStore
723
+
724
+ # Queue
725
+ from stabilize.queue.sqlite_queue import SqliteQueue
726
+ from stabilize.queue.queue import Queue
727
+ from stabilize.queue.processor import QueueProcessor
728
+
729
+ # Orchestration
730
+ from stabilize.orchestrator import Orchestrator
731
+
732
+ # Tasks
733
+ from stabilize.tasks.interface import Task, RetryableTask, SkippableTask
734
+ from stabilize.tasks.result import TaskResult
735
+ from stabilize.tasks.registry import TaskRegistry
736
+
737
+ # Handlers (all 7 required)
738
+ from stabilize.handlers.start_workflow import StartWorkflowHandler
739
+ from stabilize.handlers.start_stage import StartStageHandler
740
+ from stabilize.handlers.start_task import StartTaskHandler
741
+ from stabilize.handlers.run_task import RunTaskHandler
742
+ from stabilize.handlers.complete_task import CompleteTaskHandler
743
+ from stabilize.handlers.complete_stage import CompleteStageHandler
744
+ from stabilize.handlers.complete_workflow import CompleteWorkflowHandler
745
+
746
+ ===============================================================================
747
+ END OF REFERENCE
748
+ ===============================================================================
749
+ '''
750
+
751
+ # Migration tracking table
752
+ MIGRATION_TABLE = "stabilize_migrations"
753
+
754
+
755
+ def load_config() -> dict[str, Any]:
756
+ """Load database config from mg.yaml or environment."""
757
+ db_url = os.environ.get("MG_DATABASE_URL")
758
+ if db_url:
759
+ return parse_db_url(db_url)
760
+
761
+ # Try to load mg.yaml
762
+ mg_yaml = Path("mg.yaml")
763
+ if mg_yaml.exists():
764
+ try:
765
+ import yaml # type: ignore[import-untyped]
766
+
767
+ with open(mg_yaml) as f:
768
+ config = yaml.safe_load(f)
769
+ db_config: dict[str, Any] = config.get("database", {}) if config else {}
770
+ return db_config
771
+ except ImportError:
772
+ print("Warning: PyYAML not installed, cannot read mg.yaml")
773
+ print("Set MG_DATABASE_URL environment variable instead")
774
+ sys.exit(1)
775
+
776
+ print("Error: No database configuration found")
777
+ print("Either create mg.yaml or set MG_DATABASE_URL environment variable")
778
+ sys.exit(1)
779
+
780
+
781
+ def parse_db_url(url: str) -> dict[str, Any]:
782
+ """Parse a database URL into connection parameters."""
783
+ # postgres://user:pass@host:port/dbname
784
+ pattern = r"postgres(?:ql)?://(?:(?P<user>[^:]+)(?::(?P<password>[^@]+))?@)?(?P<host>[^:/]+)(?::(?P<port>\d+))?/(?P<dbname>[^?]+)"
785
+ match = re.match(pattern, url)
786
+ if not match:
787
+ print(f"Error: Invalid database URL: {url}")
788
+ sys.exit(1)
789
+
790
+ return {
791
+ "host": match.group("host"),
792
+ "port": int(match.group("port") or 5432),
793
+ "user": match.group("user") or "postgres",
794
+ "password": match.group("password") or "",
795
+ "dbname": match.group("dbname"),
796
+ }
797
+
798
+
799
+ def get_migrations() -> list[tuple[str, str]]:
800
+ """Get all migration files from the package."""
801
+ migrations_pkg = files("stabilize.migrations")
802
+ migrations = []
803
+
804
+ for item in migrations_pkg.iterdir():
805
+ if item.name.endswith(".sql"):
806
+ content = item.read_text()
807
+ migrations.append((item.name, content))
808
+
809
+ # Sort by filename (ULID prefix ensures chronological order)
810
+ migrations.sort(key=lambda x: x[0])
811
+ return migrations
812
+
813
+
814
+ def extract_up_migration(content: str) -> str:
815
+ """Extract the UP migration from SQL content."""
816
+ # Find content between "-- migrate: up" and "-- migrate: down"
817
+ up_match = re.search(
818
+ r"--\s*migrate:\s*up\s*\n(.*?)(?:--\s*migrate:\s*down|$)",
819
+ content,
820
+ re.DOTALL | re.IGNORECASE,
821
+ )
822
+ if up_match:
823
+ return up_match.group(1).strip()
824
+ return content
825
+
826
+
827
+ def compute_checksum(content: str) -> str:
828
+ """Compute MD5 checksum of migration content."""
829
+ return hashlib.md5(content.encode()).hexdigest()
830
+
831
+
832
+ def mg_up(db_url: str | None = None) -> None:
833
+ """Apply pending migrations to PostgreSQL database."""
834
+ try:
835
+ import psycopg
836
+ except ImportError:
837
+ print("Error: psycopg not installed")
838
+ print("Install with: pip install stabilize[postgres]")
839
+ sys.exit(1)
840
+
841
+ # Load config
842
+ if db_url:
843
+ config = parse_db_url(db_url)
844
+ else:
845
+ config = load_config()
846
+
847
+ # Connect to database
848
+ conninfo = (
849
+ f"host={config['host']} port={config.get('port', 5432)} "
850
+ f"user={config.get('user', 'postgres')} password={config.get('password', '')} "
851
+ f"dbname={config['dbname']}"
852
+ )
853
+
854
+ try:
855
+ with psycopg.connect(conninfo) as conn:
856
+ with conn.cursor() as cur:
857
+ # Ensure migration tracking table exists
858
+ cur.execute(
859
+ f"""
860
+ CREATE TABLE IF NOT EXISTS {MIGRATION_TABLE} (
861
+ id SERIAL PRIMARY KEY,
862
+ name VARCHAR(255) NOT NULL UNIQUE,
863
+ checksum VARCHAR(32) NOT NULL,
864
+ applied_at TIMESTAMP DEFAULT NOW()
865
+ )
866
+ """
867
+ )
868
+ conn.commit()
869
+
870
+ # Get applied migrations
871
+ cur.execute(f"SELECT name, checksum FROM {MIGRATION_TABLE}")
872
+ applied = {row[0]: row[1] for row in cur.fetchall()}
873
+
874
+ # Get available migrations
875
+ migrations = get_migrations()
876
+
877
+ if not migrations:
878
+ print("No migrations found in package")
879
+ return
880
+
881
+ # Apply pending migrations
882
+ pending = 0
883
+ for name, content in migrations:
884
+ if name in applied:
885
+ # Verify checksum
886
+ expected = compute_checksum(content)
887
+ if applied[name] != expected:
888
+ print(f"Warning: Checksum mismatch for {name}")
889
+ continue
890
+
891
+ pending += 1
892
+ print(f"Applying: {name}")
893
+
894
+ up_sql = extract_up_migration(content)
895
+ cur.execute(up_sql)
896
+
897
+ checksum = compute_checksum(content)
898
+ cur.execute(
899
+ f"INSERT INTO {MIGRATION_TABLE} (name, checksum) VALUES (%s, %s)",
900
+ (name, checksum),
901
+ )
902
+ conn.commit()
903
+
904
+ if pending == 0:
905
+ print("All migrations already applied")
906
+ else:
907
+ print(f"Applied {pending} migration(s)")
908
+
909
+ except psycopg.Error as e:
910
+ print(f"Database error: {e}")
911
+ sys.exit(1)
912
+
913
+
914
+ def prompt() -> None:
915
+ """Output comprehensive documentation for RAG systems and coding agents."""
916
+ print(PROMPT_TEXT)
917
+
918
+
919
+ def rag_init(
920
+ db_url: str | None = None,
921
+ force: bool = False,
922
+ additional_context: list[str] | None = None,
923
+ ) -> None:
924
+ """Initialize RAG embeddings from examples, documentation, and additional context."""
925
+ try:
926
+ from stabilize.rag import StabilizeRAG, get_cache
927
+ except ImportError:
928
+ print("Error: RAG support requires: pip install stabilize[rag]")
929
+ sys.exit(1)
930
+
931
+ cache = get_cache(db_url)
932
+ rag = StabilizeRAG(cache)
933
+
934
+ print("Initializing embeddings...")
935
+ count = rag.init(force=force, additional_paths=additional_context)
936
+ if count > 0:
937
+ print(f"Cached {count} embeddings")
938
+ else:
939
+ print("Embeddings already initialized (use --force to regenerate)")
940
+
941
+
942
+ def rag_clear(db_url: str | None = None) -> None:
943
+ """Clear all cached embeddings."""
944
+ try:
945
+ from stabilize.rag import get_cache
946
+ except ImportError:
947
+ print("Error: RAG support requires: pip install stabilize[rag]")
948
+ sys.exit(1)
949
+
950
+ cache = get_cache(db_url)
951
+ cache.clear()
952
+ print("Embedding cache cleared")
953
+
954
+
955
+ def rag_generate(
956
+ prompt_text: str,
957
+ db_url: str | None = None,
958
+ execute: bool = False,
959
+ top_k: int = 5,
960
+ temperature: float = 0.3,
961
+ llm_model: str | None = None,
962
+ ) -> None:
963
+ """Generate pipeline code from natural language prompt."""
964
+ try:
965
+ from stabilize.rag import StabilizeRAG, get_cache
966
+ except ImportError:
967
+ print("Error: RAG support requires: pip install stabilize[rag]")
968
+ sys.exit(1)
969
+
970
+ cache = get_cache(db_url)
971
+ rag = StabilizeRAG(cache, llm_model=llm_model)
972
+
973
+ try:
974
+ code = rag.generate(prompt_text, top_k=top_k, temperature=temperature)
975
+ except RuntimeError as e:
976
+ print(f"Error: {e}")
977
+ sys.exit(1)
978
+
979
+ print(code)
980
+
981
+ if execute:
982
+ print("\n--- Executing generated code ---\n")
983
+ try:
984
+ exec(code, {"__name__": "__main__"})
985
+ except ImportError as e:
986
+ print("\n--- Execution failed: Import error ---")
987
+ print(f"Error: {e}")
988
+ print("\nThe generated code has incorrect imports.")
989
+ print("Review the imports above and compare with examples/shell-example.py")
990
+ sys.exit(1)
991
+ except Exception as e:
992
+ print("\n--- Execution failed ---")
993
+ print(f"Error: {type(e).__name__}: {e}")
994
+ print("\nThe generated code may need manual adjustments.")
995
+ sys.exit(1)
996
+
997
+
998
+ def mg_status(db_url: str | None = None) -> None:
999
+ """Show migration status."""
1000
+ try:
1001
+ import psycopg
1002
+ except ImportError:
1003
+ print("Error: psycopg not installed")
1004
+ print("Install with: pip install stabilize[postgres]")
1005
+ sys.exit(1)
1006
+
1007
+ # Load config
1008
+ if db_url:
1009
+ config = parse_db_url(db_url)
1010
+ else:
1011
+ config = load_config()
1012
+
1013
+ conninfo = (
1014
+ f"host={config['host']} port={config.get('port', 5432)} "
1015
+ f"user={config.get('user', 'postgres')} password={config.get('password', '')} "
1016
+ f"dbname={config['dbname']}"
1017
+ )
1018
+
1019
+ try:
1020
+ with psycopg.connect(conninfo) as conn:
1021
+ with conn.cursor() as cur:
1022
+ # Check if tracking table exists
1023
+ cur.execute(
1024
+ """
1025
+ SELECT EXISTS (
1026
+ SELECT FROM information_schema.tables
1027
+ WHERE table_name = %s
1028
+ )
1029
+ """,
1030
+ (MIGRATION_TABLE,),
1031
+ )
1032
+ row = cur.fetchone()
1033
+ table_exists = row[0] if row else False
1034
+
1035
+ applied = {}
1036
+ if table_exists:
1037
+ cur.execute(f"SELECT name, checksum, applied_at FROM {MIGRATION_TABLE} ORDER BY applied_at")
1038
+ applied = {row[0]: (row[1], row[2]) for row in cur.fetchall()}
1039
+
1040
+ migrations = get_migrations()
1041
+
1042
+ print(f"{'Status':<10} {'Migration':<50} {'Applied At'}")
1043
+ print("-" * 80)
1044
+
1045
+ for name, content in migrations:
1046
+ if name in applied:
1047
+ checksum, applied_at = applied[name]
1048
+ expected = compute_checksum(content)
1049
+ status = "applied" if checksum == expected else "MISMATCH"
1050
+ print(f"{status:<10} {name:<50} {applied_at}")
1051
+ else:
1052
+ print(f"{'pending':<10} {name:<50} -")
1053
+
1054
+ except psycopg.Error as e:
1055
+ print(f"Database error: {e}")
1056
+ sys.exit(1)
1057
+
1058
+
1059
+ def main() -> None:
1060
+ """Main CLI entry point."""
1061
+ parser = argparse.ArgumentParser(
1062
+ prog="stabilize",
1063
+ description="Stabilize - Workflow Engine CLI",
1064
+ )
1065
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
1066
+
1067
+ # mg-up command
1068
+ up_parser = subparsers.add_parser("mg-up", help="Apply pending PostgreSQL migrations")
1069
+ up_parser.add_argument(
1070
+ "--db-url",
1071
+ help="Database URL (postgres://user:pass@host:port/dbname)",
1072
+ )
1073
+
1074
+ # mg-status command
1075
+ status_parser = subparsers.add_parser("mg-status", help="Show migration status")
1076
+ status_parser.add_argument(
1077
+ "--db-url",
1078
+ help="Database URL (postgres://user:pass@host:port/dbname)",
1079
+ )
1080
+
1081
+ # prompt command
1082
+ subparsers.add_parser(
1083
+ "prompt",
1084
+ help="Output comprehensive RAG context for pipeline code generation",
1085
+ )
1086
+
1087
+ # rag command (with subcommands)
1088
+ rag_parser = subparsers.add_parser(
1089
+ "rag",
1090
+ help="RAG-powered pipeline generation",
1091
+ )
1092
+ rag_subparsers = rag_parser.add_subparsers(dest="rag_command")
1093
+
1094
+ # rag init
1095
+ init_parser = rag_subparsers.add_parser(
1096
+ "init",
1097
+ help="Initialize embeddings cache from examples and documentation",
1098
+ )
1099
+ init_parser.add_argument(
1100
+ "--db-url",
1101
+ help="Database URL for caching (postgres://... or sqlite path)",
1102
+ )
1103
+ init_parser.add_argument(
1104
+ "--force",
1105
+ action="store_true",
1106
+ help="Force regeneration even if cache exists",
1107
+ )
1108
+ init_parser.add_argument(
1109
+ "--additional-context",
1110
+ action="append",
1111
+ metavar="PATH",
1112
+ help="Additional file or directory to include in training context (can be specified multiple times)",
1113
+ )
1114
+
1115
+ # rag generate
1116
+ gen_parser = rag_subparsers.add_parser(
1117
+ "generate",
1118
+ help="Generate pipeline code from natural language prompt",
1119
+ )
1120
+ gen_parser.add_argument(
1121
+ "prompt",
1122
+ help="Natural language description of the desired pipeline",
1123
+ )
1124
+ gen_parser.add_argument(
1125
+ "--db-url",
1126
+ help="Database URL for caching",
1127
+ )
1128
+ gen_parser.add_argument(
1129
+ "-x",
1130
+ "--execute",
1131
+ action="store_true",
1132
+ help="Execute the generated code after displaying it",
1133
+ )
1134
+ gen_parser.add_argument(
1135
+ "--top-k",
1136
+ type=int,
1137
+ default=10,
1138
+ help="Number of context chunks to retrieve (default: 10)",
1139
+ )
1140
+ gen_parser.add_argument(
1141
+ "--temperature",
1142
+ type=float,
1143
+ default=0.3,
1144
+ help="LLM temperature for generation (default: 0.3)",
1145
+ )
1146
+ gen_parser.add_argument(
1147
+ "--llm-model",
1148
+ default=None,
1149
+ help="LLM model for generation (default: qwen3-vl:235b)",
1150
+ )
1151
+
1152
+ # rag clear
1153
+ clear_parser = rag_subparsers.add_parser(
1154
+ "clear",
1155
+ help="Clear all cached embeddings",
1156
+ )
1157
+ clear_parser.add_argument(
1158
+ "--db-url",
1159
+ help="Database URL for caching",
1160
+ )
1161
+
1162
+ args = parser.parse_args()
1163
+
1164
+ if args.command == "mg-up":
1165
+ mg_up(args.db_url)
1166
+ elif args.command == "mg-status":
1167
+ mg_status(args.db_url)
1168
+ elif args.command == "prompt":
1169
+ prompt()
1170
+ elif args.command == "rag":
1171
+ if args.rag_command == "init":
1172
+ rag_init(args.db_url, args.force, args.additional_context)
1173
+ elif args.rag_command == "generate":
1174
+ rag_generate(
1175
+ args.prompt,
1176
+ args.db_url,
1177
+ args.execute,
1178
+ args.top_k,
1179
+ args.temperature,
1180
+ args.llm_model,
1181
+ )
1182
+ elif args.rag_command == "clear":
1183
+ rag_clear(args.db_url)
1184
+ else:
1185
+ rag_parser.print_help()
1186
+ sys.exit(1)
1187
+ else:
1188
+ parser.print_help()
1189
+ sys.exit(1)
1190
+
1191
+
1192
+ if __name__ == "__main__":
1193
+ main()