mrmd-ai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mrmd_ai/__init__.py +3 -0
- mrmd_ai/juice.py +416 -0
- mrmd_ai/metrics/__init__.py +1 -0
- mrmd_ai/modules/__init__.py +74 -0
- mrmd_ai/modules/code.py +141 -0
- mrmd_ai/modules/correct.py +53 -0
- mrmd_ai/modules/document.py +41 -0
- mrmd_ai/modules/finish.py +95 -0
- mrmd_ai/modules/fix.py +52 -0
- mrmd_ai/modules/notebook.py +15 -0
- mrmd_ai/modules/text.py +69 -0
- mrmd_ai/optimizers/__init__.py +1 -0
- mrmd_ai/server.py +429 -0
- mrmd_ai/signatures/__init__.py +27 -0
- mrmd_ai/signatures/code.py +279 -0
- mrmd_ai/signatures/correct.py +72 -0
- mrmd_ai/signatures/document.py +57 -0
- mrmd_ai/signatures/finish.py +134 -0
- mrmd_ai/signatures/fix.py +72 -0
- mrmd_ai/signatures/notebook.py +37 -0
- mrmd_ai/signatures/text.py +134 -0
- mrmd_ai/utils/__init__.py +1 -0
- mrmd_ai-0.1.0.dist-info/METADATA +45 -0
- mrmd_ai-0.1.0.dist-info/RECORD +26 -0
- mrmd_ai-0.1.0.dist-info/WHEEL +4 -0
- mrmd_ai-0.1.0.dist-info/entry_points.txt +2 -0
mrmd_ai/__init__.py
ADDED
mrmd_ai/juice.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Juice Level System for MRMD AI Programs.
|
|
3
|
+
|
|
4
|
+
Juice levels control the quality/cost tradeoff of AI responses:
|
|
5
|
+
- Level 0: Kimi K2 on Groq (fast, cheap, default)
|
|
6
|
+
- Level 1: Claude Sonnet 4.5 (better quality)
|
|
7
|
+
- Level 2: Gemini 3 Pro with thinking (deep reasoning)
|
|
8
|
+
- Level 3: Claude Opus 4.5 with high thinking (maximum single-model quality)
|
|
9
|
+
- Level 4: Multi-model merger (Grok 4 + Sonnet 4.5 + Gemini 3 + Opus 4.5, synthesized by Gemini 3)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from enum import IntEnum
|
|
13
|
+
from typing import Any, Callable
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
import dspy
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class JuiceLevel(IntEnum):
|
|
19
|
+
"""Progressive quality levels for AI responses."""
|
|
20
|
+
|
|
21
|
+
# Fast & cheap - Kimi K2 on Groq
|
|
22
|
+
QUICK = 0
|
|
23
|
+
|
|
24
|
+
# Better quality - Sonnet 4.5
|
|
25
|
+
BALANCED = 1
|
|
26
|
+
|
|
27
|
+
# Deep reasoning - Gemini 3 with thinking
|
|
28
|
+
DEEP = 2
|
|
29
|
+
|
|
30
|
+
# Maximum single-model - Opus 4.5 with high thinking
|
|
31
|
+
MAXIMUM = 3
|
|
32
|
+
|
|
33
|
+
# Multi-model merger - all models synthesized
|
|
34
|
+
ULTIMATE = 4
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ModelConfig:
|
|
39
|
+
"""Configuration for a model at a specific juice level."""
|
|
40
|
+
model: str
|
|
41
|
+
temperature: float = 0.7
|
|
42
|
+
max_tokens: int = 4096
|
|
43
|
+
reasoning_effort: str | None = None
|
|
44
|
+
thinking: dict | None = None
|
|
45
|
+
extra_kwargs: dict = field(default_factory=dict)
|
|
46
|
+
|
|
47
|
+
def to_lm_kwargs(self) -> dict:
|
|
48
|
+
"""Convert to dspy.LM kwargs."""
|
|
49
|
+
kwargs = {
|
|
50
|
+
"model": self.model,
|
|
51
|
+
"temperature": self.temperature,
|
|
52
|
+
"max_tokens": self.max_tokens,
|
|
53
|
+
**self.extra_kwargs,
|
|
54
|
+
}
|
|
55
|
+
if self.reasoning_effort:
|
|
56
|
+
kwargs["reasoning_effort"] = self.reasoning_effort
|
|
57
|
+
if self.thinking:
|
|
58
|
+
kwargs["thinking"] = self.thinking
|
|
59
|
+
return kwargs
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# Model configurations for each juice level
|
|
63
|
+
JUICE_MODELS: dict[JuiceLevel, ModelConfig] = {
|
|
64
|
+
JuiceLevel.QUICK: ModelConfig(
|
|
65
|
+
model="groq/moonshotai/kimi-k2-instruct-0905",
|
|
66
|
+
temperature=0.7,
|
|
67
|
+
max_tokens=4096,
|
|
68
|
+
),
|
|
69
|
+
JuiceLevel.BALANCED: ModelConfig(
|
|
70
|
+
model="anthropic/claude-sonnet-4-5",
|
|
71
|
+
temperature=0.7,
|
|
72
|
+
max_tokens=4096,
|
|
73
|
+
),
|
|
74
|
+
JuiceLevel.DEEP: ModelConfig(
|
|
75
|
+
model="gemini/gemini-3-pro-preview",
|
|
76
|
+
temperature=1.0,
|
|
77
|
+
max_tokens=16000,
|
|
78
|
+
reasoning_effort="high",
|
|
79
|
+
),
|
|
80
|
+
JuiceLevel.MAXIMUM: ModelConfig(
|
|
81
|
+
model="anthropic/claude-opus-4-5",
|
|
82
|
+
temperature=1.0,
|
|
83
|
+
max_tokens=16000,
|
|
84
|
+
reasoning_effort="high",
|
|
85
|
+
),
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
# For ULTIMATE level, we use all 4 models with highest thinking
|
|
89
|
+
# Grok 4, GPT-5.1, Gemini 3, Opus 4.5
|
|
90
|
+
# NOTE: Anthropic requires temperature=1 when using extended thinking
|
|
91
|
+
ULTIMATE_MODELS: list[ModelConfig] = [
|
|
92
|
+
ModelConfig(
|
|
93
|
+
model="openrouter/x-ai/grok-4",
|
|
94
|
+
temperature=0.7,
|
|
95
|
+
max_tokens=8192,
|
|
96
|
+
),
|
|
97
|
+
ModelConfig(
|
|
98
|
+
model="openai/gpt-5.1",
|
|
99
|
+
temperature=1.0,
|
|
100
|
+
max_tokens=16000,
|
|
101
|
+
reasoning_effort="high",
|
|
102
|
+
),
|
|
103
|
+
ModelConfig(
|
|
104
|
+
model="gemini/gemini-3-pro-preview",
|
|
105
|
+
temperature=1.0,
|
|
106
|
+
max_tokens=16000,
|
|
107
|
+
reasoning_effort="high",
|
|
108
|
+
),
|
|
109
|
+
ModelConfig(
|
|
110
|
+
model="anthropic/claude-opus-4-5",
|
|
111
|
+
temperature=1.0, # Must be 1 for extended thinking
|
|
112
|
+
max_tokens=16000,
|
|
113
|
+
reasoning_effort="high",
|
|
114
|
+
),
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
# Synthesizer model for ULTIMATE level (Gemini 3 synthesizes all responses)
|
|
118
|
+
SYNTHESIZER_MODEL = ModelConfig(
|
|
119
|
+
model="gemini/gemini-3-pro-preview",
|
|
120
|
+
temperature=0.7,
|
|
121
|
+
max_tokens=32000,
|
|
122
|
+
reasoning_effort="high",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_lm(juice: JuiceLevel | int = JuiceLevel.QUICK) -> dspy.LM:
|
|
127
|
+
"""Get a dspy.LM configured for the specified juice level.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
juice: Juice level (0-3). Level 4 (ULTIMATE) requires special handling.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Configured dspy.LM instance.
|
|
134
|
+
"""
|
|
135
|
+
if isinstance(juice, int):
|
|
136
|
+
juice = JuiceLevel(juice)
|
|
137
|
+
|
|
138
|
+
if juice == JuiceLevel.ULTIMATE:
|
|
139
|
+
raise ValueError("ULTIMATE juice level requires multi-model merger. Use JuicedProgram instead.")
|
|
140
|
+
|
|
141
|
+
config = JUICE_MODELS[juice]
|
|
142
|
+
return dspy.LM(**config.to_lm_kwargs())
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class SynthesizeResponses(dspy.Signature):
|
|
146
|
+
"""Synthesize multiple AI model responses into an optimal final answer.
|
|
147
|
+
|
|
148
|
+
You are given the original input and responses from multiple AI models.
|
|
149
|
+
Analyze all responses, identify the best insights from each, resolve
|
|
150
|
+
any contradictions, and produce the ultimate synthesized response.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
original_input: str = dspy.InputField(desc="The original input/question")
|
|
154
|
+
model_responses: str = dspy.InputField(desc="Responses from multiple AI models, labeled by model name")
|
|
155
|
+
synthesized_response: str = dspy.OutputField(desc="The optimal synthesized response combining the best from all models")
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class JuicedProgram:
|
|
159
|
+
"""Wrapper that runs any DSPy program with configurable juice levels.
|
|
160
|
+
|
|
161
|
+
For levels 0-3, uses a single model with increasing capability.
|
|
162
|
+
For level 4 (ULTIMATE), runs all models in parallel and synthesizes.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def __init__(
|
|
166
|
+
self,
|
|
167
|
+
program: dspy.Module,
|
|
168
|
+
juice: JuiceLevel | int = JuiceLevel.QUICK,
|
|
169
|
+
progress_callback: Callable[[str, dict], None] | None = None
|
|
170
|
+
):
|
|
171
|
+
"""Initialize a juiced program.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
program: The DSPy program/module to wrap.
|
|
175
|
+
juice: Juice level (0-4).
|
|
176
|
+
progress_callback: Optional callback for progress events.
|
|
177
|
+
Called with (event_type, data) where event_type is:
|
|
178
|
+
- "status": General status update
|
|
179
|
+
- "model_start": A model is starting (ultimate mode)
|
|
180
|
+
- "model_complete": A model finished (ultimate mode)
|
|
181
|
+
"""
|
|
182
|
+
self.program = program
|
|
183
|
+
self.juice = JuiceLevel(juice) if isinstance(juice, int) else juice
|
|
184
|
+
self.progress_callback = progress_callback
|
|
185
|
+
|
|
186
|
+
def _emit(self, event_type: str, data: dict):
|
|
187
|
+
"""Emit a progress event if callback is set."""
|
|
188
|
+
if self.progress_callback:
|
|
189
|
+
self.progress_callback(event_type, data)
|
|
190
|
+
|
|
191
|
+
def __call__(self, **kwargs) -> Any:
|
|
192
|
+
"""Run the program with the configured juice level."""
|
|
193
|
+
if self.juice == JuiceLevel.ULTIMATE:
|
|
194
|
+
return self._run_ultimate(**kwargs)
|
|
195
|
+
else:
|
|
196
|
+
return self._run_single(**kwargs)
|
|
197
|
+
|
|
198
|
+
def _run_single(self, **kwargs) -> Any:
|
|
199
|
+
"""Run with a single model at the specified juice level."""
|
|
200
|
+
config = JUICE_MODELS[self.juice]
|
|
201
|
+
model_name = config.model.split("/")[-1]
|
|
202
|
+
|
|
203
|
+
self._emit("status", {
|
|
204
|
+
"step": "calling_model",
|
|
205
|
+
"model": model_name,
|
|
206
|
+
"model_full": config.model
|
|
207
|
+
})
|
|
208
|
+
|
|
209
|
+
lm = get_lm(self.juice)
|
|
210
|
+
with dspy.context(lm=lm):
|
|
211
|
+
result = self.program(**kwargs)
|
|
212
|
+
|
|
213
|
+
self._emit("status", {
|
|
214
|
+
"step": "model_complete",
|
|
215
|
+
"model": model_name
|
|
216
|
+
})
|
|
217
|
+
|
|
218
|
+
return result
|
|
219
|
+
|
|
220
|
+
def _run_ultimate(self, **kwargs) -> Any:
|
|
221
|
+
"""Run with all models in PARALLEL and merge results."""
|
|
222
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
223
|
+
import threading
|
|
224
|
+
|
|
225
|
+
# Track which models are running
|
|
226
|
+
model_names = [cfg.model.split("/")[-1] for cfg in ULTIMATE_MODELS]
|
|
227
|
+
models_status = {name: "pending" for name in model_names}
|
|
228
|
+
status_lock = threading.Lock()
|
|
229
|
+
|
|
230
|
+
self._emit("status", {
|
|
231
|
+
"step": "starting_multi_model",
|
|
232
|
+
"models": model_names,
|
|
233
|
+
"total": len(model_names)
|
|
234
|
+
})
|
|
235
|
+
|
|
236
|
+
def run_model(config):
|
|
237
|
+
"""Run a single model - called in parallel."""
|
|
238
|
+
lm = dspy.LM(**config.to_lm_kwargs())
|
|
239
|
+
model_name = config.model.split("/")[-1]
|
|
240
|
+
|
|
241
|
+
# Emit model start
|
|
242
|
+
with status_lock:
|
|
243
|
+
models_status[model_name] = "running"
|
|
244
|
+
self._emit("model_start", {
|
|
245
|
+
"model": model_name,
|
|
246
|
+
"models_status": dict(models_status)
|
|
247
|
+
})
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
with dspy.context(lm=lm):
|
|
251
|
+
result = self.program(**kwargs)
|
|
252
|
+
|
|
253
|
+
# Emit model complete
|
|
254
|
+
with status_lock:
|
|
255
|
+
models_status[model_name] = "complete"
|
|
256
|
+
self._emit("model_complete", {
|
|
257
|
+
"model": model_name,
|
|
258
|
+
"success": True,
|
|
259
|
+
"models_status": dict(models_status)
|
|
260
|
+
})
|
|
261
|
+
|
|
262
|
+
return {"model": model_name, "result": result, "error": None}
|
|
263
|
+
except Exception as e:
|
|
264
|
+
# Emit model error
|
|
265
|
+
with status_lock:
|
|
266
|
+
models_status[model_name] = "error"
|
|
267
|
+
self._emit("model_complete", {
|
|
268
|
+
"model": model_name,
|
|
269
|
+
"success": False,
|
|
270
|
+
"error": str(e),
|
|
271
|
+
"models_status": dict(models_status)
|
|
272
|
+
})
|
|
273
|
+
return {"model": model_name, "result": None, "error": str(e)}
|
|
274
|
+
|
|
275
|
+
# Run all 4 models in parallel
|
|
276
|
+
model_results = []
|
|
277
|
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
|
278
|
+
futures = [executor.submit(run_model, config) for config in ULTIMATE_MODELS]
|
|
279
|
+
for future in as_completed(futures):
|
|
280
|
+
model_results.append(future.result())
|
|
281
|
+
|
|
282
|
+
# Emit synthesizing status
|
|
283
|
+
self._emit("status", {
|
|
284
|
+
"step": "synthesizing",
|
|
285
|
+
"models_completed": len([r for r in model_results if r["result"] is not None])
|
|
286
|
+
})
|
|
287
|
+
|
|
288
|
+
# Merge results - combine outputs from all successful models
|
|
289
|
+
return self._merge_results(model_results)
|
|
290
|
+
|
|
291
|
+
def _merge_results(self, model_results: list) -> Any:
|
|
292
|
+
"""Merge results from multiple models into a single response.
|
|
293
|
+
|
|
294
|
+
For list fields (like synonyms), combines unique values from all models.
|
|
295
|
+
For string fields, uses the first successful result.
|
|
296
|
+
Also includes individual model responses for transparency.
|
|
297
|
+
"""
|
|
298
|
+
# Get successful results
|
|
299
|
+
successful = [r for r in model_results if r["result"] is not None]
|
|
300
|
+
if not successful:
|
|
301
|
+
# All failed - return error
|
|
302
|
+
errors = [r["error"] for r in model_results if r["error"]]
|
|
303
|
+
raise RuntimeError(f"All models failed: {errors}")
|
|
304
|
+
|
|
305
|
+
# Use first successful result as base
|
|
306
|
+
base_result = successful[0]["result"]
|
|
307
|
+
|
|
308
|
+
# Get the _store dict from the result (DSPy stores outputs there)
|
|
309
|
+
if hasattr(base_result, "_store"):
|
|
310
|
+
merged = dict(base_result._store)
|
|
311
|
+
else:
|
|
312
|
+
merged = {}
|
|
313
|
+
|
|
314
|
+
# Collect individual responses for display
|
|
315
|
+
individual_responses = []
|
|
316
|
+
for r in model_results:
|
|
317
|
+
model_name = r["model"]
|
|
318
|
+
if r["result"] is not None and hasattr(r["result"], "_store"):
|
|
319
|
+
# Extract the main output field (usually 'response', 'completion', etc.)
|
|
320
|
+
store = r["result"]._store
|
|
321
|
+
# Get the first string output field
|
|
322
|
+
output_text = None
|
|
323
|
+
for key, value in store.items():
|
|
324
|
+
if isinstance(value, str) and len(value) > 10:
|
|
325
|
+
output_text = value
|
|
326
|
+
break
|
|
327
|
+
individual_responses.append({
|
|
328
|
+
"model": model_name,
|
|
329
|
+
"response": output_text or str(store),
|
|
330
|
+
"error": None
|
|
331
|
+
})
|
|
332
|
+
elif r["error"]:
|
|
333
|
+
individual_responses.append({
|
|
334
|
+
"model": model_name,
|
|
335
|
+
"response": None,
|
|
336
|
+
"error": r["error"]
|
|
337
|
+
})
|
|
338
|
+
|
|
339
|
+
# Merge fields from other models
|
|
340
|
+
for r in successful[1:]:
|
|
341
|
+
result = r["result"]
|
|
342
|
+
if hasattr(result, "_store"):
|
|
343
|
+
store = result._store
|
|
344
|
+
for key, value in store.items():
|
|
345
|
+
if key in merged:
|
|
346
|
+
# Merge lists by combining unique values
|
|
347
|
+
if isinstance(value, list) and isinstance(merged[key], list):
|
|
348
|
+
# Combine and dedupe while preserving order
|
|
349
|
+
seen = set(merged[key])
|
|
350
|
+
for item in value:
|
|
351
|
+
if item not in seen:
|
|
352
|
+
merged[key].append(item)
|
|
353
|
+
seen.add(item)
|
|
354
|
+
# For strings, keep the first (base) value
|
|
355
|
+
else:
|
|
356
|
+
merged[key] = value
|
|
357
|
+
|
|
358
|
+
# Return a simple object with the merged data + individual responses
|
|
359
|
+
class MergedResult:
|
|
360
|
+
pass
|
|
361
|
+
|
|
362
|
+
result = MergedResult()
|
|
363
|
+
for key, value in merged.items():
|
|
364
|
+
setattr(result, key, value)
|
|
365
|
+
result._store = merged # For extract_result in server.py
|
|
366
|
+
result._individual_responses = individual_responses # For UI display
|
|
367
|
+
|
|
368
|
+
return result
|
|
369
|
+
|
|
370
|
+
def _format_input(self, kwargs: dict) -> str:
|
|
371
|
+
"""Format input kwargs as a readable string."""
|
|
372
|
+
parts = []
|
|
373
|
+
for key, value in kwargs.items():
|
|
374
|
+
parts.append(f"{key}: {value}")
|
|
375
|
+
return "\n".join(parts)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def juiced(juice: JuiceLevel | int = JuiceLevel.QUICK):
|
|
379
|
+
"""Decorator to run a DSPy program with a specific juice level.
|
|
380
|
+
|
|
381
|
+
Usage:
|
|
382
|
+
@juiced(JuiceLevel.DEEP)
|
|
383
|
+
def my_program():
|
|
384
|
+
return dspy.ChainOfThought(MySignature)
|
|
385
|
+
"""
|
|
386
|
+
def decorator(func: Callable) -> Callable:
|
|
387
|
+
def wrapper(*args, **kwargs):
|
|
388
|
+
program = func(*args, **kwargs)
|
|
389
|
+
return JuicedProgram(program, juice)
|
|
390
|
+
return wrapper
|
|
391
|
+
return decorator
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def run_with_juice(program: dspy.Module, juice: JuiceLevel | int, **kwargs) -> Any:
|
|
395
|
+
"""Convenience function to run a program with a specific juice level.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
program: The DSPy program to run.
|
|
399
|
+
juice: Juice level (0-4).
|
|
400
|
+
**kwargs: Arguments to pass to the program.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
The program result.
|
|
404
|
+
"""
|
|
405
|
+
juiced_program = JuicedProgram(program, juice)
|
|
406
|
+
return juiced_program(**kwargs)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
# Juice level descriptions for CLI/UI
|
|
410
|
+
JUICE_DESCRIPTIONS = {
|
|
411
|
+
JuiceLevel.QUICK: "Quick (Kimi K2) - Fast & cheap",
|
|
412
|
+
JuiceLevel.BALANCED: "Balanced (Sonnet 4.5) - Good quality",
|
|
413
|
+
JuiceLevel.DEEP: "Deep (Gemini 3 thinking) - Thorough reasoning",
|
|
414
|
+
JuiceLevel.MAXIMUM: "Maximum (Opus 4.5 thinking) - Best single model",
|
|
415
|
+
JuiceLevel.ULTIMATE: "Ultimate (Multi-model merger) - All models synthesized",
|
|
416
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Metrics for evaluating and optimizing MRMD AI programs."""
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""DSPy module implementations for MRMD AI programs."""
|
|
2
|
+
|
|
3
|
+
from .finish import (
|
|
4
|
+
FinishSentencePredict,
|
|
5
|
+
FinishParagraphPredict,
|
|
6
|
+
FinishCodeLinePredict,
|
|
7
|
+
FinishCodeSectionPredict,
|
|
8
|
+
)
|
|
9
|
+
from .fix import (
|
|
10
|
+
FixGrammarPredict,
|
|
11
|
+
FixTranscriptionPredict,
|
|
12
|
+
)
|
|
13
|
+
from .correct import (
|
|
14
|
+
CorrectAndFinishLinePredict,
|
|
15
|
+
CorrectAndFinishSectionPredict,
|
|
16
|
+
)
|
|
17
|
+
from .code import (
|
|
18
|
+
DocumentCodePredict,
|
|
19
|
+
CompleteCodePredict,
|
|
20
|
+
AddTypeHintsPredict,
|
|
21
|
+
ImproveNamesPredict,
|
|
22
|
+
ExplainCodePredict,
|
|
23
|
+
RefactorCodePredict,
|
|
24
|
+
FormatCodePredict,
|
|
25
|
+
ProgramCodePredict,
|
|
26
|
+
)
|
|
27
|
+
from .text import (
|
|
28
|
+
GetSynonymsPredict,
|
|
29
|
+
GetPhraseSynonymsPredict,
|
|
30
|
+
ReformatMarkdownPredict,
|
|
31
|
+
IdentifyReplacementPredict,
|
|
32
|
+
)
|
|
33
|
+
from .document import (
|
|
34
|
+
DocumentResponsePredict,
|
|
35
|
+
DocumentSummaryPredict,
|
|
36
|
+
DocumentAnalysisPredict,
|
|
37
|
+
)
|
|
38
|
+
from .notebook import (
|
|
39
|
+
NotebookNamePredict,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
__all__ = [
|
|
43
|
+
# Finish programs
|
|
44
|
+
"FinishSentencePredict",
|
|
45
|
+
"FinishParagraphPredict",
|
|
46
|
+
"FinishCodeLinePredict",
|
|
47
|
+
"FinishCodeSectionPredict",
|
|
48
|
+
# Fix programs
|
|
49
|
+
"FixGrammarPredict",
|
|
50
|
+
"FixTranscriptionPredict",
|
|
51
|
+
# Correct & Finish programs
|
|
52
|
+
"CorrectAndFinishLinePredict",
|
|
53
|
+
"CorrectAndFinishSectionPredict",
|
|
54
|
+
# Code transformation programs
|
|
55
|
+
"DocumentCodePredict",
|
|
56
|
+
"CompleteCodePredict",
|
|
57
|
+
"AddTypeHintsPredict",
|
|
58
|
+
"ImproveNamesPredict",
|
|
59
|
+
"ExplainCodePredict",
|
|
60
|
+
"RefactorCodePredict",
|
|
61
|
+
"FormatCodePredict",
|
|
62
|
+
"ProgramCodePredict",
|
|
63
|
+
# Text transformation programs
|
|
64
|
+
"GetSynonymsPredict",
|
|
65
|
+
"GetPhraseSynonymsPredict",
|
|
66
|
+
"ReformatMarkdownPredict",
|
|
67
|
+
"IdentifyReplacementPredict",
|
|
68
|
+
# Document-level programs
|
|
69
|
+
"DocumentResponsePredict",
|
|
70
|
+
"DocumentSummaryPredict",
|
|
71
|
+
"DocumentAnalysisPredict",
|
|
72
|
+
# Notebook programs
|
|
73
|
+
"NotebookNamePredict",
|
|
74
|
+
]
|
mrmd_ai/modules/code.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Code transformation modules."""
|
|
2
|
+
|
|
3
|
+
import dspy
|
|
4
|
+
from ..signatures.code import (
|
|
5
|
+
DocumentCodeSignature,
|
|
6
|
+
CompleteCodeSignature,
|
|
7
|
+
AddTypeHintsSignature,
|
|
8
|
+
ImproveNamesSignature,
|
|
9
|
+
ExplainCodeSignature,
|
|
10
|
+
RefactorCodeSignature,
|
|
11
|
+
FormatCodeSignature,
|
|
12
|
+
ProgramCodeSignature,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DocumentCodePredict(dspy.Module):
|
|
17
|
+
"""Add documentation to code."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.predict = dspy.Predict(DocumentCodeSignature)
|
|
22
|
+
|
|
23
|
+
def forward(self, code: str, language: str, local_context: str, document_context: str = None):
|
|
24
|
+
return self.predict(
|
|
25
|
+
code=code,
|
|
26
|
+
language=language,
|
|
27
|
+
local_context=local_context,
|
|
28
|
+
document_context=document_context,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CompleteCodePredict(dspy.Module):
|
|
33
|
+
"""Complete incomplete code."""
|
|
34
|
+
|
|
35
|
+
def __init__(self):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.predict = dspy.Predict(CompleteCodeSignature)
|
|
38
|
+
|
|
39
|
+
def forward(self, code: str, language: str, local_context: str, document_context: str = None):
|
|
40
|
+
return self.predict(
|
|
41
|
+
code=code,
|
|
42
|
+
language=language,
|
|
43
|
+
local_context=local_context,
|
|
44
|
+
document_context=document_context,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class AddTypeHintsPredict(dspy.Module):
|
|
49
|
+
"""Add type hints to code."""
|
|
50
|
+
|
|
51
|
+
def __init__(self):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.predict = dspy.Predict(AddTypeHintsSignature)
|
|
54
|
+
|
|
55
|
+
def forward(self, code: str, language: str, local_context: str, document_context: str = None):
|
|
56
|
+
return self.predict(
|
|
57
|
+
code=code,
|
|
58
|
+
language=language,
|
|
59
|
+
local_context=local_context,
|
|
60
|
+
document_context=document_context,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ImproveNamesPredict(dspy.Module):
|
|
65
|
+
"""Improve variable and function names."""
|
|
66
|
+
|
|
67
|
+
def __init__(self):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.predict = dspy.Predict(ImproveNamesSignature)
|
|
70
|
+
|
|
71
|
+
def forward(self, code: str, language: str, local_context: str, document_context: str = None):
|
|
72
|
+
return self.predict(
|
|
73
|
+
code=code,
|
|
74
|
+
language=language,
|
|
75
|
+
local_context=local_context,
|
|
76
|
+
document_context=document_context,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ExplainCodePredict(dspy.Module):
|
|
81
|
+
"""Add explanatory comments to code."""
|
|
82
|
+
|
|
83
|
+
def __init__(self):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.predict = dspy.Predict(ExplainCodeSignature)
|
|
86
|
+
|
|
87
|
+
def forward(self, code: str, language: str, local_context: str, document_context: str = None):
|
|
88
|
+
return self.predict(
|
|
89
|
+
code=code,
|
|
90
|
+
language=language,
|
|
91
|
+
local_context=local_context,
|
|
92
|
+
document_context=document_context,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class RefactorCodePredict(dspy.Module):
|
|
97
|
+
"""Refactor and simplify code."""
|
|
98
|
+
|
|
99
|
+
def __init__(self):
|
|
100
|
+
super().__init__()
|
|
101
|
+
self.predict = dspy.Predict(RefactorCodeSignature)
|
|
102
|
+
|
|
103
|
+
def forward(self, code: str, language: str, local_context: str, document_context: str = None):
|
|
104
|
+
return self.predict(
|
|
105
|
+
code=code,
|
|
106
|
+
language=language,
|
|
107
|
+
local_context=local_context,
|
|
108
|
+
document_context=document_context,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class FormatCodePredict(dspy.Module):
|
|
113
|
+
"""Format and prettify code."""
|
|
114
|
+
|
|
115
|
+
def __init__(self):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.predict = dspy.Predict(FormatCodeSignature)
|
|
118
|
+
|
|
119
|
+
def forward(self, code: str, language: str, local_context: str, document_context: str = None):
|
|
120
|
+
return self.predict(
|
|
121
|
+
code=code,
|
|
122
|
+
language=language,
|
|
123
|
+
local_context=local_context,
|
|
124
|
+
document_context=document_context,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class ProgramCodePredict(dspy.Module):
|
|
129
|
+
"""Transform pseudo-code into real code."""
|
|
130
|
+
|
|
131
|
+
def __init__(self):
|
|
132
|
+
super().__init__()
|
|
133
|
+
self.predict = dspy.Predict(ProgramCodeSignature)
|
|
134
|
+
|
|
135
|
+
def forward(self, pseudo_code: str, language: str, local_context: str, document_context: str = None):
|
|
136
|
+
return self.predict(
|
|
137
|
+
pseudo_code=pseudo_code,
|
|
138
|
+
language=language,
|
|
139
|
+
local_context=local_context,
|
|
140
|
+
document_context=document_context,
|
|
141
|
+
)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Correct-and-finish modules - fix errors then complete."""
|
|
2
|
+
|
|
3
|
+
import dspy
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from mrmd_ai.signatures.correct import (
|
|
7
|
+
CorrectAndFinishLineSignature,
|
|
8
|
+
CorrectAndFinishSectionSignature,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CorrectAndFinishLinePredict(dspy.Module):
|
|
13
|
+
"""Correct errors in the current line and complete it."""
|
|
14
|
+
|
|
15
|
+
def __init__(self):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.predictor = dspy.Predict(CorrectAndFinishLineSignature)
|
|
18
|
+
|
|
19
|
+
def forward(
|
|
20
|
+
self,
|
|
21
|
+
text_to_fix: str,
|
|
22
|
+
content_type: str,
|
|
23
|
+
local_context: str,
|
|
24
|
+
document_context: Optional[str] = None,
|
|
25
|
+
) -> dspy.Prediction:
|
|
26
|
+
return self.predictor(
|
|
27
|
+
document_context=document_context,
|
|
28
|
+
local_context=local_context,
|
|
29
|
+
text_to_fix=text_to_fix,
|
|
30
|
+
content_type=content_type,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CorrectAndFinishSectionPredict(dspy.Module):
|
|
35
|
+
"""Correct errors in the current section and complete it."""
|
|
36
|
+
|
|
37
|
+
def __init__(self):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.predictor = dspy.Predict(CorrectAndFinishSectionSignature)
|
|
40
|
+
|
|
41
|
+
def forward(
|
|
42
|
+
self,
|
|
43
|
+
text_to_fix: str,
|
|
44
|
+
content_type: str,
|
|
45
|
+
local_context: str,
|
|
46
|
+
document_context: Optional[str] = None,
|
|
47
|
+
) -> dspy.Prediction:
|
|
48
|
+
return self.predictor(
|
|
49
|
+
document_context=document_context,
|
|
50
|
+
local_context=local_context,
|
|
51
|
+
text_to_fix=text_to_fix,
|
|
52
|
+
content_type=content_type,
|
|
53
|
+
)
|