mrmd-ai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mrmd_ai/server.py ADDED
@@ -0,0 +1,429 @@
1
+ """
2
+ MRMD AI Server - Custom server with Juice Level support.
3
+
4
+ Replaces dspy-cli with a FastAPI server that supports juice levels
5
+ for progressive quality/cost tradeoff.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import asyncio
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Any
13
+ from contextlib import asynccontextmanager
14
+
15
+ import dspy
16
+ from fastapi import FastAPI, Request, HTTPException
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from fastapi.responses import StreamingResponse
19
+ from pydantic import BaseModel
20
+ import uvicorn
21
+ import json
22
+
23
+ # Thread pool for running blocking DSPy calls
24
+ _executor = ThreadPoolExecutor(max_workers=10)
25
+
26
+ from .juice import JuiceLevel, JuicedProgram, get_lm, JUICE_MODELS
27
+ from .modules import (
28
+ # Finish
29
+ FinishSentencePredict,
30
+ FinishParagraphPredict,
31
+ FinishCodeLinePredict,
32
+ FinishCodeSectionPredict,
33
+ # Fix
34
+ FixGrammarPredict,
35
+ FixTranscriptionPredict,
36
+ # Correct
37
+ CorrectAndFinishLinePredict,
38
+ CorrectAndFinishSectionPredict,
39
+ # Code
40
+ DocumentCodePredict,
41
+ CompleteCodePredict,
42
+ AddTypeHintsPredict,
43
+ ImproveNamesPredict,
44
+ ExplainCodePredict,
45
+ RefactorCodePredict,
46
+ FormatCodePredict,
47
+ ProgramCodePredict,
48
+ # Text
49
+ GetSynonymsPredict,
50
+ GetPhraseSynonymsPredict,
51
+ ReformatMarkdownPredict,
52
+ IdentifyReplacementPredict,
53
+ # Document
54
+ DocumentResponsePredict,
55
+ DocumentSummaryPredict,
56
+ DocumentAnalysisPredict,
57
+ # Notebook
58
+ NotebookNamePredict,
59
+ )
60
+
61
+
62
+ # Program registry - maps program names to their classes
63
+ PROGRAMS = {
64
+ # Finish
65
+ "FinishSentencePredict": FinishSentencePredict,
66
+ "FinishParagraphPredict": FinishParagraphPredict,
67
+ "FinishCodeLinePredict": FinishCodeLinePredict,
68
+ "FinishCodeSectionPredict": FinishCodeSectionPredict,
69
+ # Fix
70
+ "FixGrammarPredict": FixGrammarPredict,
71
+ "FixTranscriptionPredict": FixTranscriptionPredict,
72
+ # Correct
73
+ "CorrectAndFinishLinePredict": CorrectAndFinishLinePredict,
74
+ "CorrectAndFinishSectionPredict": CorrectAndFinishSectionPredict,
75
+ # Code
76
+ "DocumentCodePredict": DocumentCodePredict,
77
+ "CompleteCodePredict": CompleteCodePredict,
78
+ "AddTypeHintsPredict": AddTypeHintsPredict,
79
+ "ImproveNamesPredict": ImproveNamesPredict,
80
+ "ExplainCodePredict": ExplainCodePredict,
81
+ "RefactorCodePredict": RefactorCodePredict,
82
+ "FormatCodePredict": FormatCodePredict,
83
+ "ProgramCodePredict": ProgramCodePredict,
84
+ # Text
85
+ "GetSynonymsPredict": GetSynonymsPredict,
86
+ "GetPhraseSynonymsPredict": GetPhraseSynonymsPredict,
87
+ "ReformatMarkdownPredict": ReformatMarkdownPredict,
88
+ "IdentifyReplacementPredict": IdentifyReplacementPredict,
89
+ # Document
90
+ "DocumentResponsePredict": DocumentResponsePredict,
91
+ "DocumentSummaryPredict": DocumentSummaryPredict,
92
+ "DocumentAnalysisPredict": DocumentAnalysisPredict,
93
+ # Notebook
94
+ "NotebookNamePredict": NotebookNamePredict,
95
+ }
96
+
97
+ # Cached program instances per juice level
98
+ _program_cache: dict[tuple[str, int], JuicedProgram] = {}
99
+
100
+
101
+ def get_program(name: str, juice: int = 0) -> JuicedProgram:
102
+ """Get a JuicedProgram instance for the given program and juice level."""
103
+ cache_key = (name, juice)
104
+ if cache_key not in _program_cache:
105
+ if name not in PROGRAMS:
106
+ raise ValueError(f"Unknown program: {name}")
107
+ program_class = PROGRAMS[name]
108
+ program = program_class()
109
+ _program_cache[cache_key] = JuicedProgram(program, juice=juice)
110
+ return _program_cache[cache_key]
111
+
112
+
113
+ @asynccontextmanager
114
+ async def lifespan(app: FastAPI):
115
+ """Application lifespan - configure default LM on startup."""
116
+ # Configure default LM (juice level 0)
117
+ default_lm = get_lm(JuiceLevel.QUICK)
118
+ dspy.configure(lm=default_lm)
119
+ print(f"[AI Server] Configured default LM: {JUICE_MODELS[JuiceLevel.QUICK].model}")
120
+ yield
121
+ print("[AI Server] Shutting down...")
122
+
123
+
124
+ # Create FastAPI app
125
+ app = FastAPI(
126
+ title="MRMD AI Server",
127
+ description="AI server with Juice Level support for progressive quality/cost tradeoff",
128
+ version="1.0.0",
129
+ lifespan=lifespan,
130
+ )
131
+
132
+ # Add CORS middleware
133
+ app.add_middleware(
134
+ CORSMiddleware,
135
+ allow_origins=["*"],
136
+ allow_credentials=True,
137
+ allow_methods=["*"],
138
+ allow_headers=["*"],
139
+ )
140
+
141
+
142
+ @app.get("/programs")
143
+ async def list_programs():
144
+ """List available programs."""
145
+ programs = []
146
+ for name, cls in PROGRAMS.items():
147
+ programs.append({
148
+ "name": name,
149
+ "endpoint": f"/{name}",
150
+ })
151
+ return {"programs": programs}
152
+
153
+
154
+ @app.get("/juice")
155
+ async def get_juice_levels():
156
+ """Get available juice levels."""
157
+ from .juice import JUICE_DESCRIPTIONS
158
+ return {
159
+ "levels": [
160
+ {"level": level.value, "description": desc}
161
+ for level, desc in JUICE_DESCRIPTIONS.items()
162
+ ]
163
+ }
164
+
165
+
166
+ def extract_result(prediction: Any) -> dict:
167
+ """Extract result from a DSPy prediction object."""
168
+ result = {}
169
+
170
+ if hasattr(prediction, "synthesized_response"):
171
+ # Ultimate level returns synthesized response
172
+ result["synthesized_response"] = prediction.synthesized_response
173
+
174
+ # DSPy Prediction objects store outputs in _store
175
+ if hasattr(prediction, "_store") and prediction._store:
176
+ result.update(dict(prediction._store))
177
+ else:
178
+ # Fallback: try direct attribute access
179
+ output_fields = [
180
+ "synonyms", "original", "alternatives",
181
+ "completion", "fixed_text", "corrected_completion",
182
+ "documented_code", "typed_code", "improved_code",
183
+ "explained_code", "refactored_code", "formatted_code",
184
+ "reformatted_text", "text_to_replace", "replacement",
185
+ "response", "summary", "analysis", # Document-level fields
186
+ "code", # ProgramCodePredict output
187
+ ]
188
+
189
+ for field in output_fields:
190
+ if hasattr(prediction, field):
191
+ val = getattr(prediction, field)
192
+ if val is not None:
193
+ result[field] = val
194
+
195
+ # Include individual model responses for Ultimate level (juice=4)
196
+ if hasattr(prediction, "_individual_responses"):
197
+ result["_individual_responses"] = prediction._individual_responses
198
+
199
+ return result
200
+
201
+
202
+ @app.post("/{program_name}")
203
+ async def run_program(program_name: str, request: Request):
204
+ """Run a program with the given parameters."""
205
+ # Get juice level from header
206
+ juice_header = request.headers.get("X-Juice-Level", "0")
207
+ try:
208
+ juice_level = int(juice_header)
209
+ juice_level = max(0, min(4, juice_level)) # Clamp to 0-4
210
+ except ValueError:
211
+ juice_level = 0
212
+
213
+ # Get request body
214
+ try:
215
+ params = await request.json()
216
+ except Exception:
217
+ raise HTTPException(status_code=400, detail="Invalid JSON body")
218
+
219
+ # Get program
220
+ try:
221
+ juiced_program = get_program(program_name, juice_level)
222
+ except ValueError as e:
223
+ raise HTTPException(status_code=404, detail=str(e))
224
+
225
+ # Log the call and get model info
226
+ from .juice import JUICE_DESCRIPTIONS, JUICE_MODELS, ULTIMATE_MODELS, JuiceLevel
227
+ juice_desc = JUICE_DESCRIPTIONS.get(JuiceLevel(juice_level), f"Level {juice_level}")
228
+ print(f"[AI] {program_name} @ {juice_desc}", flush=True)
229
+
230
+ # Get the model name for this juice level
231
+ if juice_level == JuiceLevel.ULTIMATE:
232
+ model_name = "multi-model" # Ultimate uses multiple models
233
+ else:
234
+ model_config = JUICE_MODELS.get(JuiceLevel(juice_level))
235
+ model_name = model_config.model if model_config else "unknown"
236
+
237
+ # Run program in thread pool (DSPy calls are blocking)
238
+ def run_sync():
239
+ return juiced_program(**params)
240
+
241
+ try:
242
+ loop = asyncio.get_event_loop()
243
+ result = await loop.run_in_executor(_executor, run_sync)
244
+ response = extract_result(result)
245
+ # Add model metadata to response
246
+ response["_model"] = model_name
247
+ response["_juice_level"] = juice_level
248
+ return response
249
+ except Exception as e:
250
+ import traceback
251
+ traceback.print_exc()
252
+ raise HTTPException(status_code=500, detail=str(e))
253
+
254
+
255
+ def sse_event(event: str, data: dict) -> str:
256
+ """Format a Server-Sent Event."""
257
+ return f"event: {event}\ndata: {json.dumps(data)}\n\n"
258
+
259
+
260
+ @app.post("/{program_name}/stream")
261
+ async def run_program_stream(program_name: str, request: Request):
262
+ """Run a program with SSE streaming for status updates.
263
+
264
+ Emits events:
265
+ - status: Progress updates (step, model, etc.)
266
+ - model_complete: When a model finishes (for ultimate mode)
267
+ - result: Final result
268
+ - error: If an error occurs
269
+ """
270
+ # Get juice level from header
271
+ juice_header = request.headers.get("X-Juice-Level", "0")
272
+ try:
273
+ juice_level = int(juice_header)
274
+ juice_level = max(0, min(4, juice_level)) # Clamp to 0-4
275
+ except ValueError:
276
+ juice_level = 0
277
+
278
+ # Get request body
279
+ try:
280
+ params = await request.json()
281
+ except Exception:
282
+ raise HTTPException(status_code=400, detail="Invalid JSON body")
283
+
284
+ # Check program exists
285
+ if program_name not in PROGRAMS:
286
+ raise HTTPException(status_code=404, detail=f"Unknown program: {program_name}")
287
+
288
+ # Get model info
289
+ from .juice import JUICE_DESCRIPTIONS, JUICE_MODELS, ULTIMATE_MODELS, JuiceLevel, JuicedProgram
290
+ juice_desc = JUICE_DESCRIPTIONS.get(JuiceLevel(juice_level), f"Level {juice_level}")
291
+ print(f"[AI Stream] {program_name} @ {juice_desc}", flush=True)
292
+
293
+ # Get model name(s) for display
294
+ if juice_level == JuiceLevel.ULTIMATE:
295
+ model_name = "multi-model"
296
+ model_names = [cfg.model.split("/")[-1] for cfg in ULTIMATE_MODELS]
297
+ else:
298
+ model_config = JUICE_MODELS.get(JuiceLevel(juice_level))
299
+ model_name = model_config.model if model_config else "unknown"
300
+ model_names = [model_name.split("/")[-1] if "/" in model_name else model_name]
301
+
302
+ async def event_generator():
303
+ """Generate SSE events for the AI call."""
304
+ import queue
305
+ import threading
306
+
307
+ # Queue for progress events from the sync execution
308
+ progress_queue = queue.Queue()
309
+ result_holder = {"result": None, "error": None}
310
+
311
+ def progress_callback(event_type: str, data: dict):
312
+ """Called by JuicedProgram to report progress."""
313
+ progress_queue.put((event_type, data))
314
+
315
+ def run_with_progress():
316
+ """Run the program in a thread, emitting progress events."""
317
+ try:
318
+ # Create program with progress callback
319
+ program_class = PROGRAMS[program_name]
320
+ program = program_class()
321
+ juiced = JuicedProgram(program, juice=juice_level, progress_callback=progress_callback)
322
+
323
+ # Emit starting event
324
+ progress_callback("status", {
325
+ "step": "starting",
326
+ "model": model_name,
327
+ "juice_level": juice_level,
328
+ "juice_name": juice_desc
329
+ })
330
+
331
+ # Run the program
332
+ result = juiced(**params)
333
+ result_holder["result"] = result
334
+
335
+ # Signal completion
336
+ progress_queue.put(("done", None))
337
+
338
+ except Exception as e:
339
+ import traceback
340
+ traceback.print_exc()
341
+ result_holder["error"] = str(e)
342
+ progress_queue.put(("error", {"message": str(e)}))
343
+
344
+ # Start execution in thread
345
+ thread = threading.Thread(target=run_with_progress)
346
+ thread.start()
347
+
348
+ # Stream progress events
349
+ while True:
350
+ try:
351
+ # Wait for events with timeout to allow checking if thread is done
352
+ event_type, data = progress_queue.get(timeout=0.1)
353
+
354
+ if event_type == "done":
355
+ # Send final result
356
+ if result_holder["result"] is not None:
357
+ response = extract_result(result_holder["result"])
358
+ response["_model"] = model_name
359
+ response["_juice_level"] = juice_level
360
+ yield sse_event("result", response)
361
+ break
362
+
363
+ elif event_type == "error":
364
+ yield sse_event("error", data)
365
+ break
366
+
367
+ else:
368
+ # Status or model_complete event
369
+ yield sse_event(event_type, data)
370
+
371
+ except queue.Empty:
372
+ # Check if thread is still running
373
+ if not thread.is_alive():
374
+ # Thread finished but didn't put done event - check for error
375
+ if result_holder["error"]:
376
+ yield sse_event("error", {"message": result_holder["error"]})
377
+ break
378
+ continue
379
+
380
+ thread.join(timeout=1.0)
381
+
382
+ return StreamingResponse(
383
+ event_generator(),
384
+ media_type="text/event-stream",
385
+ headers={
386
+ "Cache-Control": "no-cache",
387
+ "Connection": "keep-alive",
388
+ "X-Accel-Buffering": "no", # Disable nginx buffering
389
+ }
390
+ )
391
+
392
+
393
+ def main():
394
+ """Run the AI server."""
395
+ import argparse
396
+ parser = argparse.ArgumentParser(description="MRMD AI Server")
397
+ parser.add_argument("--host", default="127.0.0.1", help="Host to bind to")
398
+ parser.add_argument("--port", type=int, default=51790, help="Port to listen on")
399
+ parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
400
+ args = parser.parse_args()
401
+
402
+ print(f"""
403
+ ╔══════════════════════════════════════════╗
404
+ ║ MRMD AI Server ║
405
+ ║ with Juice Level Support ║
406
+ ╚══════════════════════════════════════════╝
407
+
408
+ URL: http://{args.host}:{args.port}/
409
+
410
+ Juice Levels:
411
+ 0 = ⚡ Quick (Grok 4.1)
412
+ 1 = ⚖️ Balanced (Sonnet 4.5)
413
+ 2 = 🧠 Deep (Gemini 3 thinking)
414
+ 3 = 🚀 Maximum (Opus 4.5 thinking)
415
+ 4 = 🔥 Ultimate (Multi-model merger)
416
+
417
+ Press Ctrl+C to stop
418
+ """)
419
+
420
+ uvicorn.run(
421
+ "mrmd_ai.server:app",
422
+ host=args.host,
423
+ port=args.port,
424
+ reload=args.reload,
425
+ )
426
+
427
+
428
+ if __name__ == "__main__":
429
+ main()
@@ -0,0 +1,27 @@
1
+ """DSPy signature definitions for MRMD AI programs."""
2
+
3
+ from .finish import (
4
+ FinishSentenceSignature,
5
+ FinishParagraphSignature,
6
+ FinishCodeLineSignature,
7
+ FinishCodeSectionSignature,
8
+ )
9
+ from .fix import (
10
+ FixGrammarSignature,
11
+ FixTranscriptionSignature,
12
+ )
13
+ from .correct import (
14
+ CorrectAndFinishLineSignature,
15
+ CorrectAndFinishSectionSignature,
16
+ )
17
+
18
+ __all__ = [
19
+ "FinishSentenceSignature",
20
+ "FinishParagraphSignature",
21
+ "FinishCodeLineSignature",
22
+ "FinishCodeSectionSignature",
23
+ "FixGrammarSignature",
24
+ "FixTranscriptionSignature",
25
+ "CorrectAndFinishLineSignature",
26
+ "CorrectAndFinishSectionSignature",
27
+ ]