mrmd-ai 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -38,6 +38,12 @@ from .document import (
38
38
  from .notebook import (
39
39
  NotebookNamePredict,
40
40
  )
41
+ from .edit import (
42
+ EditAtCursorPredict,
43
+ AddressCommentPredict,
44
+ AddressAllCommentsPredict,
45
+ AddressNearbyCommentPredict,
46
+ )
41
47
 
42
48
  __all__ = [
43
49
  # Finish programs
@@ -71,4 +77,9 @@ __all__ = [
71
77
  "DocumentAnalysisPredict",
72
78
  # Notebook programs
73
79
  "NotebookNamePredict",
80
+ # Edit programs (Ctrl-K and comments)
81
+ "EditAtCursorPredict",
82
+ "AddressCommentPredict",
83
+ "AddressAllCommentsPredict",
84
+ "AddressNearbyCommentPredict",
74
85
  ]
@@ -0,0 +1,102 @@
1
+ """Edit modules for cursor-based editing and comment processing."""
2
+
3
+ import dspy
4
+ from typing import List, Optional
5
+ from ..signatures.edit import (
6
+ Edit,
7
+ CommentInfo,
8
+ EditAtCursorSignature,
9
+ AddressCommentSignature,
10
+ AddressAllCommentsSignature,
11
+ AddressNearbyCommentSignature,
12
+ )
13
+
14
+
15
+ class EditAtCursorPredict(dspy.Module):
16
+ """Execute user instructions via precise find/replace edits."""
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.predict = dspy.Predict(EditAtCursorSignature)
21
+
22
+ def forward(
23
+ self,
24
+ text_before: str,
25
+ text_after: str,
26
+ selection: str,
27
+ full_document: str,
28
+ instruction: str,
29
+ ):
30
+ return self.predict(
31
+ text_before=text_before,
32
+ text_after=text_after,
33
+ selection=selection,
34
+ full_document=full_document,
35
+ instruction=instruction,
36
+ )
37
+
38
+
39
+ class AddressCommentPredict(dspy.Module):
40
+ """Address a single comment embedded in the document."""
41
+
42
+ def __init__(self):
43
+ super().__init__()
44
+ self.predict = dspy.Predict(AddressCommentSignature)
45
+
46
+ def forward(
47
+ self,
48
+ full_document: str,
49
+ comment_text: str,
50
+ comment_context_before: str,
51
+ comment_context_after: str,
52
+ comment_raw: str,
53
+ ):
54
+ return self.predict(
55
+ full_document=full_document,
56
+ comment_text=comment_text,
57
+ comment_context_before=comment_context_before,
58
+ comment_context_after=comment_context_after,
59
+ comment_raw=comment_raw,
60
+ )
61
+
62
+
63
+ class AddressAllCommentsPredict(dspy.Module):
64
+ """Address all comments in a document."""
65
+
66
+ def __init__(self):
67
+ super().__init__()
68
+ self.predict = dspy.Predict(AddressAllCommentsSignature)
69
+
70
+ def forward(
71
+ self,
72
+ full_document: str,
73
+ comments: List[CommentInfo],
74
+ ):
75
+ return self.predict(
76
+ full_document=full_document,
77
+ comments=comments,
78
+ )
79
+
80
+
81
+ class AddressNearbyCommentPredict(dspy.Module):
82
+ """Address the comment nearest to the cursor."""
83
+
84
+ def __init__(self):
85
+ super().__init__()
86
+ self.predict = dspy.Predict(AddressNearbyCommentSignature)
87
+
88
+ def forward(
89
+ self,
90
+ full_document: str,
91
+ cursor_context_before: str,
92
+ cursor_context_after: str,
93
+ nearby_comment: CommentInfo,
94
+ nearby_comment_raw: str,
95
+ ):
96
+ return self.predict(
97
+ full_document=full_document,
98
+ cursor_context_before=cursor_context_before,
99
+ cursor_context_after=cursor_context_after,
100
+ nearby_comment=nearby_comment,
101
+ nearby_comment_raw=nearby_comment_raw,
102
+ )
mrmd_ai/server.py CHANGED
@@ -23,7 +23,8 @@ import json
23
23
  # Thread pool for running blocking DSPy calls
24
24
  _executor = ThreadPoolExecutor(max_workers=10)
25
25
 
26
- from .juice import JuiceLevel, JuicedProgram, get_lm, JUICE_MODELS
26
+ from .juice import JuiceLevel, ReasoningLevel, JuicedProgram, get_lm, JUICE_MODELS, REASONING_DESCRIPTIONS
27
+ from .custom_programs import custom_registry, register_custom_programs
27
28
  from .modules import (
28
29
  # Finish
29
30
  FinishSentencePredict,
@@ -56,6 +57,11 @@ from .modules import (
56
57
  DocumentAnalysisPredict,
57
58
  # Notebook
58
59
  NotebookNamePredict,
60
+ # Edit (Ctrl-K and comments)
61
+ EditAtCursorPredict,
62
+ AddressCommentPredict,
63
+ AddressAllCommentsPredict,
64
+ AddressNearbyCommentPredict,
59
65
  )
60
66
 
61
67
 
@@ -92,21 +98,101 @@ PROGRAMS = {
92
98
  "DocumentAnalysisPredict": DocumentAnalysisPredict,
93
99
  # Notebook
94
100
  "NotebookNamePredict": NotebookNamePredict,
101
+ # Edit (Ctrl-K and comments)
102
+ "EditAtCursorPredict": EditAtCursorPredict,
103
+ "AddressCommentPredict": AddressCommentPredict,
104
+ "AddressAllCommentsPredict": AddressAllCommentsPredict,
105
+ "AddressNearbyCommentPredict": AddressNearbyCommentPredict,
95
106
  }
96
107
 
97
- # Cached program instances per juice level
98
- _program_cache: dict[tuple[str, int], JuicedProgram] = {}
108
+ # Cached program instances per juice level and reasoning level
109
+ # NOTE: Cache is only used when no custom API keys are provided
110
+ _program_cache: dict[tuple[str, int, int | None], JuicedProgram] = {}
99
111
 
100
112
 
101
- def get_program(name: str, juice: int = 0) -> JuicedProgram:
102
- """Get a JuicedProgram instance for the given program and juice level."""
103
- cache_key = (name, juice)
113
+ # API key header names
114
+ API_KEY_HEADERS = {
115
+ "anthropic": "X-Api-Key-Anthropic",
116
+ "openai": "X-Api-Key-Openai",
117
+ "groq": "X-Api-Key-Groq",
118
+ "gemini": "X-Api-Key-Gemini",
119
+ "openrouter": "X-Api-Key-Openrouter",
120
+ }
121
+
122
+
123
+ def extract_api_keys(request: Request) -> dict | None:
124
+ """Extract API keys from request headers.
125
+
126
+ Headers:
127
+ X-Api-Key-Anthropic: Anthropic API key
128
+ X-Api-Key-Openai: OpenAI API key
129
+ X-Api-Key-Groq: Groq API key
130
+ X-Api-Key-Gemini: Google Gemini API key
131
+ X-Api-Key-Openrouter: OpenRouter API key
132
+
133
+ Returns:
134
+ Dict of provider -> key if any keys are provided, None otherwise.
135
+ """
136
+ api_keys = {}
137
+ for provider, header in API_KEY_HEADERS.items():
138
+ key = request.headers.get(header)
139
+ if key:
140
+ api_keys[provider] = key
141
+
142
+ return api_keys if api_keys else None
143
+
144
+
145
+ def get_program(
146
+ name: str,
147
+ juice: int = 0,
148
+ reasoning: int | None = None,
149
+ api_keys: dict | None = None,
150
+ model_override: str | None = None,
151
+ ) -> JuicedProgram:
152
+ """Get a JuicedProgram instance for the given program configuration.
153
+
154
+ Args:
155
+ name: Program name (can be built-in or custom)
156
+ juice: Juice level (0-4)
157
+ reasoning: Optional reasoning level (0-5)
158
+ api_keys: Optional dict of provider -> API key
159
+ model_override: Optional model to use instead of default
160
+
161
+ Returns:
162
+ Configured JuicedProgram instance.
163
+
164
+ Note:
165
+ Programs with custom API keys or model overrides are NOT cached,
166
+ since they need fresh instances with the provided configuration.
167
+ Custom programs are never cached.
168
+ """
169
+ # Check built-in programs first
170
+ program_class = PROGRAMS.get(name)
171
+
172
+ # If not found, check custom registry
173
+ if program_class is None:
174
+ program_class = custom_registry.get(name)
175
+
176
+ if program_class is None:
177
+ raise ValueError(f"Unknown program: {name}")
178
+
179
+ # Custom programs and those with custom config are never cached
180
+ is_custom = name not in PROGRAMS
181
+ if api_keys or model_override or is_custom:
182
+ program = program_class()
183
+ return JuicedProgram(
184
+ program,
185
+ juice=juice,
186
+ reasoning=reasoning,
187
+ api_keys=api_keys,
188
+ model_override=model_override,
189
+ )
190
+
191
+ # Use cache for standard built-in requests (no custom keys)
192
+ cache_key = (name, juice, reasoning)
104
193
  if cache_key not in _program_cache:
105
- if name not in PROGRAMS:
106
- raise ValueError(f"Unknown program: {name}")
107
- program_class = PROGRAMS[name]
108
194
  program = program_class()
109
- _program_cache[cache_key] = JuicedProgram(program, juice=juice)
195
+ _program_cache[cache_key] = JuicedProgram(program, juice=juice, reasoning=reasoning)
110
196
  return _program_cache[cache_key]
111
197
 
112
198
 
@@ -153,12 +239,31 @@ async def list_programs():
153
239
 
154
240
  @app.get("/juice")
155
241
  async def get_juice_levels():
156
- """Get available juice levels."""
157
- from .juice import JUICE_DESCRIPTIONS
242
+ """Get available juice levels with their capabilities."""
243
+ from .juice import JUICE_DESCRIPTIONS, JUICE_MODELS, JuiceLevel
244
+ levels = []
245
+ for level, desc in JUICE_DESCRIPTIONS.items():
246
+ level_info = {
247
+ "level": level.value,
248
+ "description": desc,
249
+ }
250
+ # Add supports_reasoning for non-ULTIMATE levels
251
+ if level != JuiceLevel.ULTIMATE and level in JUICE_MODELS:
252
+ level_info["supports_reasoning"] = JUICE_MODELS[level].supports_reasoning
253
+ else:
254
+ # ULTIMATE level supports reasoning (all its sub-models do)
255
+ level_info["supports_reasoning"] = True
256
+ levels.append(level_info)
257
+ return {"levels": levels}
258
+
259
+
260
+ @app.get("/reasoning")
261
+ async def get_reasoning_levels():
262
+ """Get available reasoning levels."""
158
263
  return {
159
264
  "levels": [
160
265
  {"level": level.value, "description": desc}
161
- for level, desc in JUICE_DESCRIPTIONS.items()
266
+ for level, desc in REASONING_DESCRIPTIONS.items()
162
267
  ]
163
268
  }
164
269
 
@@ -184,6 +289,7 @@ def extract_result(prediction: Any) -> dict:
184
289
  "reformatted_text", "text_to_replace", "replacement",
185
290
  "response", "summary", "analysis", # Document-level fields
186
291
  "code", # ProgramCodePredict output
292
+ "edits", # EditAtCursor and AddressComment outputs
187
293
  ]
188
294
 
189
295
  for field in output_fields:
@@ -210,6 +316,22 @@ async def run_program(program_name: str, request: Request):
210
316
  except ValueError:
211
317
  juice_level = 0
212
318
 
319
+ # Get reasoning level from header (optional)
320
+ reasoning_header = request.headers.get("X-Reasoning-Level")
321
+ reasoning_level = None
322
+ if reasoning_header is not None:
323
+ try:
324
+ reasoning_level = int(reasoning_header)
325
+ reasoning_level = max(0, min(5, reasoning_level)) # Clamp to 0-5
326
+ except ValueError:
327
+ reasoning_level = None
328
+
329
+ # Extract API keys from headers (optional)
330
+ api_keys = extract_api_keys(request)
331
+
332
+ # Get model override from header (optional)
333
+ model_override = request.headers.get("X-Model-Override")
334
+
213
335
  # Get request body
214
336
  try:
215
337
  params = await request.json()
@@ -218,14 +340,23 @@ async def run_program(program_name: str, request: Request):
218
340
 
219
341
  # Get program
220
342
  try:
221
- juiced_program = get_program(program_name, juice_level)
343
+ juiced_program = get_program(
344
+ program_name,
345
+ juice_level,
346
+ reasoning_level,
347
+ api_keys=api_keys,
348
+ model_override=model_override,
349
+ )
222
350
  except ValueError as e:
223
351
  raise HTTPException(status_code=404, detail=str(e))
224
352
 
225
353
  # Log the call and get model info
226
- from .juice import JUICE_DESCRIPTIONS, JUICE_MODELS, ULTIMATE_MODELS, JuiceLevel
354
+ from .juice import JUICE_DESCRIPTIONS, JUICE_MODELS, ULTIMATE_MODELS, JuiceLevel, ReasoningLevel
227
355
  juice_desc = JUICE_DESCRIPTIONS.get(JuiceLevel(juice_level), f"Level {juice_level}")
228
- print(f"[AI] {program_name} @ {juice_desc}", flush=True)
356
+ reasoning_desc = ""
357
+ if reasoning_level is not None:
358
+ reasoning_desc = f" | {REASONING_DESCRIPTIONS.get(ReasoningLevel(reasoning_level), f'Reasoning {reasoning_level}')}"
359
+ print(f"[AI] {program_name} @ {juice_desc}{reasoning_desc}", flush=True)
229
360
 
230
361
  # Get the model name for this juice level
231
362
  if juice_level == JuiceLevel.ULTIMATE:
@@ -245,16 +376,35 @@ async def run_program(program_name: str, request: Request):
245
376
  # Add model metadata to response
246
377
  response["_model"] = model_name
247
378
  response["_juice_level"] = juice_level
248
- return response
379
+ response["_reasoning_level"] = reasoning_level
380
+ # Serialize any Pydantic models to dicts for JSON compatibility
381
+ return serialize_for_json(response)
249
382
  except Exception as e:
250
383
  import traceback
251
384
  traceback.print_exc()
252
385
  raise HTTPException(status_code=500, detail=str(e))
253
386
 
254
387
 
388
+ def serialize_for_json(obj):
389
+ """Recursively convert Pydantic models and other objects to JSON-serializable form."""
390
+ if hasattr(obj, 'model_dump'):
391
+ # Pydantic v2 model
392
+ return obj.model_dump()
393
+ elif hasattr(obj, 'dict'):
394
+ # Pydantic v1 model
395
+ return obj.dict()
396
+ elif isinstance(obj, dict):
397
+ return {k: serialize_for_json(v) for k, v in obj.items()}
398
+ elif isinstance(obj, (list, tuple)):
399
+ return [serialize_for_json(item) for item in obj]
400
+ else:
401
+ return obj
402
+
403
+
255
404
  def sse_event(event: str, data: dict) -> str:
256
405
  """Format a Server-Sent Event."""
257
- return f"event: {event}\ndata: {json.dumps(data)}\n\n"
406
+ serialized = serialize_for_json(data)
407
+ return f"event: {event}\ndata: {json.dumps(serialized)}\n\n"
258
408
 
259
409
 
260
410
  @app.post("/{program_name}/stream")
@@ -275,6 +425,22 @@ async def run_program_stream(program_name: str, request: Request):
275
425
  except ValueError:
276
426
  juice_level = 0
277
427
 
428
+ # Get reasoning level from header (optional)
429
+ reasoning_header = request.headers.get("X-Reasoning-Level")
430
+ reasoning_level = None
431
+ if reasoning_header is not None:
432
+ try:
433
+ reasoning_level = int(reasoning_header)
434
+ reasoning_level = max(0, min(5, reasoning_level)) # Clamp to 0-5
435
+ except ValueError:
436
+ reasoning_level = None
437
+
438
+ # Extract API keys from headers (optional)
439
+ api_keys = extract_api_keys(request)
440
+
441
+ # Get model override from header (optional)
442
+ model_override = request.headers.get("X-Model-Override")
443
+
278
444
  # Get request body
279
445
  try:
280
446
  params = await request.json()
@@ -286,9 +452,12 @@ async def run_program_stream(program_name: str, request: Request):
286
452
  raise HTTPException(status_code=404, detail=f"Unknown program: {program_name}")
287
453
 
288
454
  # Get model info
289
- from .juice import JUICE_DESCRIPTIONS, JUICE_MODELS, ULTIMATE_MODELS, JuiceLevel, JuicedProgram
455
+ from .juice import JUICE_DESCRIPTIONS, JUICE_MODELS, ULTIMATE_MODELS, JuiceLevel, ReasoningLevel, JuicedProgram
290
456
  juice_desc = JUICE_DESCRIPTIONS.get(JuiceLevel(juice_level), f"Level {juice_level}")
291
- print(f"[AI Stream] {program_name} @ {juice_desc}", flush=True)
457
+ reasoning_desc = ""
458
+ if reasoning_level is not None:
459
+ reasoning_desc = f" | {REASONING_DESCRIPTIONS.get(ReasoningLevel(reasoning_level), f'Reasoning {reasoning_level}')}"
460
+ print(f"[AI Stream] {program_name} @ {juice_desc}{reasoning_desc}", flush=True)
292
461
 
293
462
  # Get model name(s) for display
294
463
  if juice_level == JuiceLevel.ULTIMATE:
@@ -315,17 +484,26 @@ async def run_program_stream(program_name: str, request: Request):
315
484
  def run_with_progress():
316
485
  """Run the program in a thread, emitting progress events."""
317
486
  try:
318
- # Create program with progress callback
487
+ # Create program with progress callback and optional API keys
319
488
  program_class = PROGRAMS[program_name]
320
489
  program = program_class()
321
- juiced = JuicedProgram(program, juice=juice_level, progress_callback=progress_callback)
490
+ juiced = JuicedProgram(
491
+ program,
492
+ juice=juice_level,
493
+ reasoning=reasoning_level,
494
+ progress_callback=progress_callback,
495
+ api_keys=api_keys,
496
+ model_override=model_override,
497
+ )
322
498
 
323
499
  # Emit starting event
324
500
  progress_callback("status", {
325
501
  "step": "starting",
326
502
  "model": model_name,
327
503
  "juice_level": juice_level,
328
- "juice_name": juice_desc
504
+ "juice_name": juice_desc,
505
+ "reasoning_level": reasoning_level,
506
+ "reasoning_name": reasoning_desc.strip(" |") if reasoning_desc else None,
329
507
  })
330
508
 
331
509
  # Run the program
@@ -357,6 +535,7 @@ async def run_program_stream(program_name: str, request: Request):
357
535
  response = extract_result(result_holder["result"])
358
536
  response["_model"] = model_name
359
537
  response["_juice_level"] = juice_level
538
+ response["_reasoning_level"] = reasoning_level
360
539
  yield sse_event("result", response)
361
540
  break
362
541
 
@@ -390,6 +569,79 @@ async def run_program_stream(program_name: str, request: Request):
390
569
  )
391
570
 
392
571
 
572
+ # =============================================================================
573
+ # CUSTOM PROGRAMS API
574
+ # =============================================================================
575
+
576
+ class CustomProgramConfig(BaseModel):
577
+ """Configuration for a custom program."""
578
+ id: str # Unique program ID (e.g., "Custom_cmd-123")
579
+ name: str
580
+ instructions: str
581
+ inputType: str = "selection" # selection | cursor | fullDoc
582
+ outputType: str = "replace" # replace | insert
583
+
584
+
585
+ class RegisterCustomProgramsRequest(BaseModel):
586
+ """Request to register multiple custom programs."""
587
+ programs: list[CustomProgramConfig]
588
+
589
+
590
+ @app.post("/api/custom-programs/register")
591
+ async def register_custom_programs_endpoint(request: RegisterCustomProgramsRequest):
592
+ """Register custom programs from frontend configuration.
593
+
594
+ This endpoint is called when the app starts or when settings change.
595
+ It clears existing custom programs and registers the new ones.
596
+ """
597
+ # Clear existing custom programs
598
+ custom_registry.clear()
599
+
600
+ # Register new programs
601
+ registered = []
602
+ for prog in request.programs:
603
+ try:
604
+ custom_registry.register(
605
+ program_id=prog.id,
606
+ config={
607
+ "name": prog.name,
608
+ "instructions": prog.instructions,
609
+ "inputType": prog.inputType,
610
+ "outputType": prog.outputType,
611
+ }
612
+ )
613
+ registered.append(prog.id)
614
+ print(f"[Custom] Registered: {prog.name} ({prog.id})")
615
+ except Exception as e:
616
+ print(f"[Custom] Failed to register {prog.id}: {e}")
617
+
618
+ return {"registered": registered, "count": len(registered)}
619
+
620
+
621
+ @app.get("/api/custom-programs")
622
+ async def list_custom_programs():
623
+ """List all registered custom programs."""
624
+ programs = []
625
+ for program_id in custom_registry.list_programs():
626
+ config = custom_registry.get_config(program_id)
627
+ if config:
628
+ programs.append({
629
+ "id": program_id,
630
+ "name": config.get("name", program_id),
631
+ "inputType": config.get("inputType", "selection"),
632
+ "outputType": config.get("outputType", "replace"),
633
+ })
634
+ return {"programs": programs}
635
+
636
+
637
+ @app.delete("/api/custom-programs/{program_id}")
638
+ async def unregister_custom_program(program_id: str):
639
+ """Unregister a specific custom program."""
640
+ if custom_registry.unregister(program_id):
641
+ return {"success": True, "id": program_id}
642
+ raise HTTPException(status_code=404, detail=f"Program not found: {program_id}")
643
+
644
+
393
645
  def main():
394
646
  """Run the AI server."""
395
647
  import argparse
@@ -14,6 +14,14 @@ from .correct import (
14
14
  CorrectAndFinishLineSignature,
15
15
  CorrectAndFinishSectionSignature,
16
16
  )
17
+ from .edit import (
18
+ Edit,
19
+ CommentInfo,
20
+ EditAtCursorSignature,
21
+ AddressCommentSignature,
22
+ AddressAllCommentsSignature,
23
+ AddressNearbyCommentSignature,
24
+ )
17
25
 
18
26
  __all__ = [
19
27
  "FinishSentenceSignature",
@@ -24,4 +32,11 @@ __all__ = [
24
32
  "FixTranscriptionSignature",
25
33
  "CorrectAndFinishLineSignature",
26
34
  "CorrectAndFinishSectionSignature",
35
+ # Edit signatures
36
+ "Edit",
37
+ "CommentInfo",
38
+ "EditAtCursorSignature",
39
+ "AddressCommentSignature",
40
+ "AddressAllCommentsSignature",
41
+ "AddressNearbyCommentSignature",
27
42
  ]