mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# RAG Adapter Dependencies
|
|
2
|
+
# Install these dependencies based on which vector store you want to use
|
|
3
|
+
|
|
4
|
+
# Common dependency for all vector stores
|
|
5
|
+
litellm>=1.64.0
|
|
6
|
+
|
|
7
|
+
# ChromaDB vector store
|
|
8
|
+
chromadb>=0.4.0
|
|
9
|
+
|
|
10
|
+
# Weaviate vector store
|
|
11
|
+
weaviate-client>=4.0.0
|
|
12
|
+
|
|
13
|
+
# Qdrant vector store
|
|
14
|
+
qdrant-client>=1.15.0
|
|
15
|
+
|
|
16
|
+
# Milvus vector store
|
|
17
|
+
pymilvus>=2.6.0
|
|
18
|
+
|
|
19
|
+
# LanceDB vector store
|
|
20
|
+
lancedb>=0.22.0
|
|
21
|
+
pyarrow>=10.0.0
|
|
22
|
+
|
|
23
|
+
# Installation examples:
|
|
24
|
+
# For ChromaDB: pip install litellm>=1.64.0 chromadb>=0.4.0
|
|
25
|
+
# For Weaviate: pip install litellm>=1.64.0 weaviate-client>=4.0.0
|
|
26
|
+
# For Qdrant: pip install litellm>=1.64.0 qdrant-client>=1.15.0
|
|
27
|
+
# For Milvus: pip install litellm>=1.64.0 pymilvus>=2.6.0
|
|
28
|
+
# For LanceDB: pip install litellm>=1.64.0 lancedb>=0.22.0 pyarrow>=10.0.0
|
|
29
|
+
# For all: pip install -r requirements-rag.txt
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
|
|
2
|
+
You are an AI assistant tasked with solving command-line tasks in a Linux environment. You will be given a task instruction and the output from previously executed commands. Your goal is to solve the task by providing batches of shell commands.
|
|
3
|
+
|
|
4
|
+
For each response:
|
|
5
|
+
1. Analyze the current state based on any terminal output provided
|
|
6
|
+
2. Determine the next set of commands needed to make progress
|
|
7
|
+
3. Decide if you need to see the output of these commands before proceeding
|
|
8
|
+
|
|
9
|
+
Don't include markdown formatting.
|
|
10
|
+
|
|
11
|
+
Note that you operate directly on the terminal from inside a tmux session. Use tmux keystrokes like `C-x` or `Escape` to interactively navigate the terminal. If you would like to execute a command that you have written you will need to append a newline character to the end of your command.
|
|
12
|
+
|
|
13
|
+
For example, if you write "ls -la" you will need to append a newline character to the end of your command like this: `ls -la
|
|
14
|
+
`.
|
|
15
|
+
|
|
16
|
+
One thing to be very careful about is handling interactive sessions like less, vim, or git diff. In these cases, you should not wait for the output of the command. Instead, you should send the keystrokes to the terminal as if you were typing them.
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import litellm
|
|
6
|
+
from terminal_bench.agents.terminus_1 import AgentResult, Chat, FailureMode, Terminus
|
|
7
|
+
from terminal_bench.dataset.dataset import Dataset
|
|
8
|
+
from terminal_bench.terminal.tmux_session import TmuxSession
|
|
9
|
+
|
|
10
|
+
from mantisdk.algorithm.gepa.lib import optimize
|
|
11
|
+
from mantisdk.algorithm.gepa.lib.adapters.terminal_bench_adapter.terminal_bench_adapter import (
|
|
12
|
+
TerminalBenchTask,
|
|
13
|
+
TerminusAdapter,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
INSTRUCTION_PROMPT_PATH = Path(__file__).parent / "prompt-templates/instruction_prompt.txt"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TerminusWrapper(Terminus):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model_name: str,
|
|
23
|
+
max_episodes: int = 50,
|
|
24
|
+
api_base: str | None = None,
|
|
25
|
+
**kwargs,
|
|
26
|
+
):
|
|
27
|
+
self.PROMPT_TEMPLATE_PATH = Path(__file__).parent / "prompt-templates/terminus.txt"
|
|
28
|
+
self.instruction_prompt = INSTRUCTION_PROMPT_PATH.read_text()
|
|
29
|
+
super().__init__(model_name, max_episodes, api_base, **kwargs)
|
|
30
|
+
|
|
31
|
+
def perform_task(
|
|
32
|
+
self,
|
|
33
|
+
instruction: str,
|
|
34
|
+
session: TmuxSession,
|
|
35
|
+
logging_dir: Path | None = None,
|
|
36
|
+
):
|
|
37
|
+
chat = Chat(self._llm)
|
|
38
|
+
|
|
39
|
+
initial_prompt = self.instruction_prompt + self._prompt_template.format(
|
|
40
|
+
response_schema=self._response_schema,
|
|
41
|
+
instruction=instruction,
|
|
42
|
+
history="",
|
|
43
|
+
terminal_state=session.capture_pane(),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
self._run_agent_loop(initial_prompt, session, chat, logging_dir)
|
|
47
|
+
|
|
48
|
+
return AgentResult(
|
|
49
|
+
total_input_tokens=chat.total_input_tokens,
|
|
50
|
+
total_output_tokens=chat.total_output_tokens,
|
|
51
|
+
failure_mode=FailureMode.NONE,
|
|
52
|
+
timestamped_markers=self._timestamped_markers,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
if __name__ == "__main__":
|
|
57
|
+
parser = argparse.ArgumentParser()
|
|
58
|
+
parser.add_argument("--model_name", type=str, default="gpt-4o-mini")
|
|
59
|
+
parser.add_argument("--n_concurrent", type=int, default=6)
|
|
60
|
+
args = parser.parse_args()
|
|
61
|
+
|
|
62
|
+
initial_prompt_from_terminus = """
|
|
63
|
+
You are an AI assistant tasked with solving command-line tasks in a Linux environment. You will be given a task instruction and the output from previously executed commands. Your goal is to solve the task by providing batches of shell commands.
|
|
64
|
+
|
|
65
|
+
For each response:
|
|
66
|
+
1. Analyze the current state based on any terminal output provided
|
|
67
|
+
2. Determine the next set of commands needed to make progress
|
|
68
|
+
3. Decide if you need to see the output of these commands before proceeding
|
|
69
|
+
|
|
70
|
+
Don't include markdown formatting.
|
|
71
|
+
|
|
72
|
+
Note that you operate directly on the terminal from inside a tmux session. Use tmux keystrokes like `C-x` or `Escape` to interactively navigate the terminal. If you would like to execute a command that you have written you will need to append a newline character to the end of your command.
|
|
73
|
+
|
|
74
|
+
For example, if you write "ls -la" you will need to append a newline character to the end of your command like this: `ls -la\n`.
|
|
75
|
+
|
|
76
|
+
One thing to be very careful about is handling interactive sessions like less, vim, or git diff. In these cases, you should not wait for the output of the command. Instead, you should send the keystrokes to the terminal as if you were typing them.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
terminal_bench_dataset = Dataset(name="terminal-bench-core", version="head")
|
|
80
|
+
terminal_bench_dataset.sort_by_duration()
|
|
81
|
+
|
|
82
|
+
terminal_bench_tasks = terminal_bench_dataset._tasks[::-1]
|
|
83
|
+
|
|
84
|
+
trainset = [
|
|
85
|
+
TerminalBenchTask(task_id=task.name, model_name=args.model_name) for task in terminal_bench_tasks[30:50]
|
|
86
|
+
]
|
|
87
|
+
valset = [TerminalBenchTask(task_id=task.name, model_name=args.model_name) for task in terminal_bench_tasks[:30]]
|
|
88
|
+
|
|
89
|
+
testset = [
|
|
90
|
+
TerminalBenchTask(task_id=task.name, model_name=args.model_name)
|
|
91
|
+
for task in terminal_bench_tasks[50:]
|
|
92
|
+
if task.name != "chem-rf"
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
reflection_lm_name = "openai/gpt-5"
|
|
96
|
+
reflection_lm = (
|
|
97
|
+
lambda prompt: litellm.completion(
|
|
98
|
+
model=reflection_lm_name,
|
|
99
|
+
messages=[{"role": "user", "content": prompt}],
|
|
100
|
+
reasoning_effort="high",
|
|
101
|
+
)
|
|
102
|
+
.choices[0]
|
|
103
|
+
.message.content
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
adapter = TerminusAdapter(n_concurrent=args.n_concurrent, instruction_prompt_path=INSTRUCTION_PROMPT_PATH)
|
|
107
|
+
testset_results_no_prompt = adapter.evaluate(testset, {"instruction_prompt": ""}, capture_traces=True)
|
|
108
|
+
testset_results_before_opt = adapter.evaluate(
|
|
109
|
+
testset,
|
|
110
|
+
{"instruction_prompt": initial_prompt_from_terminus},
|
|
111
|
+
capture_traces=True,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
with open("gepa_terminus/testset_results_no_prompt.json", "w") as f:
|
|
115
|
+
json.dump(
|
|
116
|
+
{
|
|
117
|
+
"score": sum(trajectory["success"] for trajectory in testset_results_no_prompt.trajectories),
|
|
118
|
+
"trajectories": testset_results_no_prompt.trajectories,
|
|
119
|
+
},
|
|
120
|
+
f,
|
|
121
|
+
indent=4,
|
|
122
|
+
)
|
|
123
|
+
with open("gepa_terminus/testset_results_before_opt.json", "w") as f:
|
|
124
|
+
json.dump(
|
|
125
|
+
{
|
|
126
|
+
"score": sum(trajectory["success"] for trajectory in testset_results_before_opt.trajectories),
|
|
127
|
+
"trajectories": testset_results_before_opt.trajectories,
|
|
128
|
+
},
|
|
129
|
+
f,
|
|
130
|
+
indent=4,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
optimized_results = optimize(
|
|
134
|
+
seed_candidate={"instruction_prompt": initial_prompt_from_terminus},
|
|
135
|
+
trainset=trainset,
|
|
136
|
+
valset=valset,
|
|
137
|
+
adapter=adapter,
|
|
138
|
+
reflection_lm=reflection_lm,
|
|
139
|
+
use_wandb=True,
|
|
140
|
+
max_metric_calls=400,
|
|
141
|
+
reflection_minibatch_size=3,
|
|
142
|
+
perfect_score=1,
|
|
143
|
+
skip_perfect_score=False,
|
|
144
|
+
run_dir="gepa_terminus",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
testset_results_after_opt = adapter.evaluate(
|
|
148
|
+
testset,
|
|
149
|
+
{"instruction_prompt": optimized_results.best_candidate["instruction_prompt"]},
|
|
150
|
+
capture_traces=True,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
with open("gepa_terminus/optimized_results.json", "w") as f:
|
|
154
|
+
json.dump(
|
|
155
|
+
{
|
|
156
|
+
"score": sum(trajectory["success"] for trajectory in testset_results_after_opt.trajectories),
|
|
157
|
+
"trajectories": testset_results_after_opt.trajectories,
|
|
158
|
+
},
|
|
159
|
+
f,
|
|
160
|
+
indent=4,
|
|
161
|
+
)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import random
|
|
6
|
+
from typing import Any, Mapping
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def json_default(x):
|
|
10
|
+
"""Default JSON encoder for objects that are not serializable by default."""
|
|
11
|
+
try:
|
|
12
|
+
return {**x}
|
|
13
|
+
except Exception:
|
|
14
|
+
return repr(x)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def idxmax(lst: list[float]) -> int:
|
|
18
|
+
"""Return the index of the maximum value in a list."""
|
|
19
|
+
max_val = max(lst)
|
|
20
|
+
return lst.index(max_val)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def is_dominated(y, programs, program_at_pareto_front_valset):
|
|
24
|
+
y_fronts = [front for front in program_at_pareto_front_valset.values() if y in front]
|
|
25
|
+
for front in y_fronts:
|
|
26
|
+
found_dominator_in_front = False
|
|
27
|
+
for other_prog in front:
|
|
28
|
+
if other_prog in programs:
|
|
29
|
+
found_dominator_in_front = True
|
|
30
|
+
break
|
|
31
|
+
if not found_dominator_in_front:
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
return True
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def remove_dominated_programs(program_at_pareto_front_valset, scores=None):
|
|
38
|
+
freq = {}
|
|
39
|
+
for front in program_at_pareto_front_valset.values():
|
|
40
|
+
for p in front:
|
|
41
|
+
freq[p] = freq.get(p, 0) + 1
|
|
42
|
+
|
|
43
|
+
dominated = set()
|
|
44
|
+
programs = list(freq.keys())
|
|
45
|
+
|
|
46
|
+
if scores is None:
|
|
47
|
+
scores = dict.fromkeys(programs, 1)
|
|
48
|
+
|
|
49
|
+
programs = sorted(programs, key=lambda x: scores[x], reverse=False)
|
|
50
|
+
|
|
51
|
+
found_to_remove = True
|
|
52
|
+
while found_to_remove:
|
|
53
|
+
found_to_remove = False
|
|
54
|
+
for y in programs:
|
|
55
|
+
if y in dominated:
|
|
56
|
+
continue
|
|
57
|
+
if is_dominated(y, set(programs).difference({y}).difference(dominated), program_at_pareto_front_valset):
|
|
58
|
+
dominated.add(y)
|
|
59
|
+
found_to_remove = True
|
|
60
|
+
break
|
|
61
|
+
|
|
62
|
+
dominators = [p for p in programs if p not in dominated]
|
|
63
|
+
for front in program_at_pareto_front_valset.values():
|
|
64
|
+
if not front:
|
|
65
|
+
continue
|
|
66
|
+
assert any(p in front for p in dominators)
|
|
67
|
+
|
|
68
|
+
new_program_at_pareto_front_valset = {
|
|
69
|
+
val_id: {prog_idx for prog_idx in front if prog_idx in dominators}
|
|
70
|
+
for val_id, front in program_at_pareto_front_valset.items()
|
|
71
|
+
}
|
|
72
|
+
for val_id, front_new in new_program_at_pareto_front_valset.items():
|
|
73
|
+
assert front_new.issubset(program_at_pareto_front_valset[val_id])
|
|
74
|
+
|
|
75
|
+
return new_program_at_pareto_front_valset
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def find_dominator_programs(pareto_front_programs, train_val_weighted_agg_scores_for_all_programs):
|
|
79
|
+
train_val_pareto_front_programs = pareto_front_programs
|
|
80
|
+
new_program_at_pareto_front_valset = remove_dominated_programs(
|
|
81
|
+
train_val_pareto_front_programs, scores=train_val_weighted_agg_scores_for_all_programs
|
|
82
|
+
)
|
|
83
|
+
uniq_progs = []
|
|
84
|
+
for front in new_program_at_pareto_front_valset.values():
|
|
85
|
+
uniq_progs.extend(front)
|
|
86
|
+
uniq_progs = set(uniq_progs)
|
|
87
|
+
return list(uniq_progs)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def select_program_candidate_from_pareto_front(
|
|
91
|
+
pareto_front_programs: Mapping[Any, set[int]],
|
|
92
|
+
train_val_weighted_agg_scores_for_all_programs: list[float],
|
|
93
|
+
rng: random.Random,
|
|
94
|
+
) -> int:
|
|
95
|
+
train_val_pareto_front_programs = pareto_front_programs
|
|
96
|
+
new_program_at_pareto_front_valset = remove_dominated_programs(
|
|
97
|
+
train_val_pareto_front_programs, scores=train_val_weighted_agg_scores_for_all_programs
|
|
98
|
+
)
|
|
99
|
+
program_frequency_in_validation_pareto_front = {}
|
|
100
|
+
for testcase_pareto_front in new_program_at_pareto_front_valset.values():
|
|
101
|
+
for prog_idx in testcase_pareto_front:
|
|
102
|
+
if prog_idx not in program_frequency_in_validation_pareto_front:
|
|
103
|
+
program_frequency_in_validation_pareto_front[prog_idx] = 0
|
|
104
|
+
program_frequency_in_validation_pareto_front[prog_idx] += 1
|
|
105
|
+
|
|
106
|
+
sampling_list = [
|
|
107
|
+
prog_idx for prog_idx, freq in program_frequency_in_validation_pareto_front.items() for _ in range(freq)
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
# TODO: Determine if we need this fallback
|
|
111
|
+
# if not sampling_list:
|
|
112
|
+
# # No Pareto programs survived; fall back to the globally highest-scoring program.
|
|
113
|
+
# return idxmax(train_val_weighted_agg_scores_for_all_programs)
|
|
114
|
+
assert len(sampling_list) > 0
|
|
115
|
+
|
|
116
|
+
curr_prog_id = rng.choice(sampling_list)
|
|
117
|
+
return curr_prog_id
|
|
File without changes
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ExperimentTracker:
|
|
8
|
+
"""
|
|
9
|
+
Unified experiment tracking that supports both wandb and mlflow.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __enter__(self):
|
|
13
|
+
"""Context manager entry."""
|
|
14
|
+
self.initialize()
|
|
15
|
+
self.start_run()
|
|
16
|
+
return self
|
|
17
|
+
|
|
18
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
19
|
+
"""Context manager exit - always end the run."""
|
|
20
|
+
self.end_run()
|
|
21
|
+
return False # Don't suppress exceptions
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
use_wandb: bool = False,
|
|
26
|
+
wandb_api_key: str | None = None,
|
|
27
|
+
wandb_init_kwargs: dict[str, Any] | None = None,
|
|
28
|
+
use_mlflow: bool = False,
|
|
29
|
+
mlflow_tracking_uri: str | None = None,
|
|
30
|
+
mlflow_experiment_name: str | None = None,
|
|
31
|
+
):
|
|
32
|
+
self.use_wandb = use_wandb
|
|
33
|
+
self.use_mlflow = use_mlflow
|
|
34
|
+
|
|
35
|
+
self.wandb_api_key = wandb_api_key
|
|
36
|
+
self.wandb_init_kwargs = wandb_init_kwargs or {}
|
|
37
|
+
self.mlflow_tracking_uri = mlflow_tracking_uri
|
|
38
|
+
self.mlflow_experiment_name = mlflow_experiment_name
|
|
39
|
+
|
|
40
|
+
self._created_mlflow_run = False
|
|
41
|
+
|
|
42
|
+
def initialize(self):
|
|
43
|
+
"""Initialize the logging backends."""
|
|
44
|
+
if self.use_wandb:
|
|
45
|
+
self._initialize_wandb()
|
|
46
|
+
if self.use_mlflow:
|
|
47
|
+
self._initialize_mlflow()
|
|
48
|
+
|
|
49
|
+
def _initialize_wandb(self):
|
|
50
|
+
"""Initialize wandb."""
|
|
51
|
+
try:
|
|
52
|
+
import wandb # type: ignore
|
|
53
|
+
|
|
54
|
+
if self.wandb_api_key:
|
|
55
|
+
wandb.login(key=self.wandb_api_key, verify=True)
|
|
56
|
+
else:
|
|
57
|
+
wandb.login()
|
|
58
|
+
except ImportError:
|
|
59
|
+
raise ImportError("wandb is not installed. Please install it or set backend='mlflow' or 'none'.")
|
|
60
|
+
except Exception as e:
|
|
61
|
+
raise RuntimeError(f"Error logging into wandb: {e}")
|
|
62
|
+
|
|
63
|
+
def _initialize_mlflow(self):
|
|
64
|
+
"""Initialize mlflow."""
|
|
65
|
+
try:
|
|
66
|
+
import mlflow # type: ignore
|
|
67
|
+
|
|
68
|
+
if self.mlflow_tracking_uri:
|
|
69
|
+
mlflow.set_tracking_uri(self.mlflow_tracking_uri)
|
|
70
|
+
if self.mlflow_experiment_name:
|
|
71
|
+
mlflow.set_experiment(self.mlflow_experiment_name)
|
|
72
|
+
except ImportError:
|
|
73
|
+
raise ImportError("mlflow is not installed. Please install it or set backend='wandb' or 'none'.")
|
|
74
|
+
except Exception as e:
|
|
75
|
+
raise RuntimeError(f"Error setting up mlflow: {e}")
|
|
76
|
+
|
|
77
|
+
def start_run(self):
|
|
78
|
+
"""Start a new run."""
|
|
79
|
+
if self.use_wandb:
|
|
80
|
+
import wandb # type: ignore
|
|
81
|
+
|
|
82
|
+
wandb.init(**self.wandb_init_kwargs)
|
|
83
|
+
if self.use_mlflow:
|
|
84
|
+
import mlflow # type: ignore
|
|
85
|
+
|
|
86
|
+
# Only start a new run if there's no active run
|
|
87
|
+
if mlflow.active_run() is None:
|
|
88
|
+
mlflow.start_run()
|
|
89
|
+
self._created_mlflow_run = True
|
|
90
|
+
else:
|
|
91
|
+
self._created_mlflow_run = False
|
|
92
|
+
|
|
93
|
+
def log_metrics(self, metrics: dict[str, Any], step: int | None = None):
|
|
94
|
+
"""Log metrics to the active backends."""
|
|
95
|
+
if self.use_wandb:
|
|
96
|
+
try:
|
|
97
|
+
import wandb # type: ignore
|
|
98
|
+
|
|
99
|
+
wandb.log(metrics, step=step)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
print(f"Warning: Failed to log to wandb: {e}")
|
|
102
|
+
|
|
103
|
+
if self.use_mlflow:
|
|
104
|
+
try:
|
|
105
|
+
import mlflow # type: ignore
|
|
106
|
+
|
|
107
|
+
mlflow.log_metrics(metrics, step=step)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
print(f"Warning: Failed to log to mlflow: {e}")
|
|
110
|
+
|
|
111
|
+
def end_run(self):
|
|
112
|
+
"""End the current run."""
|
|
113
|
+
if self.use_wandb:
|
|
114
|
+
try:
|
|
115
|
+
import wandb # type: ignore
|
|
116
|
+
|
|
117
|
+
if wandb.run is not None:
|
|
118
|
+
wandb.finish()
|
|
119
|
+
except Exception as e:
|
|
120
|
+
print(f"Warning: Failed to end wandb run: {e}")
|
|
121
|
+
|
|
122
|
+
if self.use_mlflow:
|
|
123
|
+
try:
|
|
124
|
+
import mlflow # type: ignore
|
|
125
|
+
|
|
126
|
+
if self._created_mlflow_run and mlflow.active_run() is not None:
|
|
127
|
+
mlflow.end_run()
|
|
128
|
+
self._created_mlflow_run = False
|
|
129
|
+
except Exception as e:
|
|
130
|
+
print(f"Warning: Failed to end mlflow run: {e}")
|
|
131
|
+
|
|
132
|
+
def is_active(self) -> bool:
|
|
133
|
+
"""Check if any backend has an active run."""
|
|
134
|
+
if self.use_wandb:
|
|
135
|
+
try:
|
|
136
|
+
import wandb # type: ignore
|
|
137
|
+
|
|
138
|
+
if wandb.run is not None:
|
|
139
|
+
return True
|
|
140
|
+
except Exception:
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
if self.use_mlflow:
|
|
144
|
+
try:
|
|
145
|
+
import mlflow # type: ignore
|
|
146
|
+
|
|
147
|
+
if mlflow.active_run() is not None:
|
|
148
|
+
return True
|
|
149
|
+
except Exception:
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
return False
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def create_experiment_tracker(
|
|
156
|
+
use_wandb: bool = False,
|
|
157
|
+
wandb_api_key: str | None = None,
|
|
158
|
+
wandb_init_kwargs: dict[str, Any] | None = None,
|
|
159
|
+
use_mlflow: bool = False,
|
|
160
|
+
mlflow_tracking_uri: str | None = None,
|
|
161
|
+
mlflow_experiment_name: str | None = None,
|
|
162
|
+
) -> ExperimentTracker:
|
|
163
|
+
"""
|
|
164
|
+
Create an experiment tracker based on the specified backends.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
use_wandb: Whether to use wandb
|
|
168
|
+
use_mlflow: Whether to use mlflow
|
|
169
|
+
wandb_api_key: API key for wandb
|
|
170
|
+
wandb_init_kwargs: Additional kwargs for wandb.init()
|
|
171
|
+
mlflow_tracking_uri: Tracking URI for mlflow
|
|
172
|
+
mlflow_experiment_name: Experiment name for mlflow
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
ExperimentTracker instance
|
|
176
|
+
|
|
177
|
+
Note:
|
|
178
|
+
Both wandb and mlflow can be used simultaneously if desired.
|
|
179
|
+
"""
|
|
180
|
+
return ExperimentTracker(
|
|
181
|
+
use_wandb=use_wandb,
|
|
182
|
+
wandb_api_key=wandb_api_key,
|
|
183
|
+
wandb_init_kwargs=wandb_init_kwargs,
|
|
184
|
+
use_mlflow=use_mlflow,
|
|
185
|
+
mlflow_tracking_uri=mlflow_tracking_uri,
|
|
186
|
+
mlflow_experiment_name=mlflow_experiment_name,
|
|
187
|
+
)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
|
|
2
|
+
# https://github.com/gepa-ai/gepa
|
|
3
|
+
|
|
4
|
+
import sys
|
|
5
|
+
from typing import Protocol
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LoggerProtocol(Protocol):
|
|
9
|
+
def log(self, message: str): ...
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class StdOutLogger(LoggerProtocol):
|
|
13
|
+
def log(self, message: str):
|
|
14
|
+
print(message)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Tee:
|
|
18
|
+
def __init__(self, *files):
|
|
19
|
+
self.files = files
|
|
20
|
+
|
|
21
|
+
def write(self, obj):
|
|
22
|
+
for f in self.files:
|
|
23
|
+
f.write(obj)
|
|
24
|
+
|
|
25
|
+
def flush(self):
|
|
26
|
+
for f in self.files:
|
|
27
|
+
if hasattr(f, "flush"):
|
|
28
|
+
f.flush()
|
|
29
|
+
|
|
30
|
+
def isatty(self):
|
|
31
|
+
# True if any of the files is a terminal
|
|
32
|
+
return any(hasattr(f, "isatty") and f.isatty() for f in self.files)
|
|
33
|
+
|
|
34
|
+
def close(self):
|
|
35
|
+
for f in self.files:
|
|
36
|
+
if hasattr(f, "close"):
|
|
37
|
+
f.close()
|
|
38
|
+
|
|
39
|
+
def fileno(self):
|
|
40
|
+
for f in self.files:
|
|
41
|
+
if hasattr(f, "fileno"):
|
|
42
|
+
return f.fileno()
|
|
43
|
+
raise OSError("No underlying file object with fileno")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Logger(LoggerProtocol):
|
|
47
|
+
def __init__(self, filename, mode="a"):
|
|
48
|
+
self.file_handle = open(filename, mode)
|
|
49
|
+
self.file_handle_stderr = open(filename.replace("run_log.", "run_log_stderr."), mode)
|
|
50
|
+
self.modified_sys = False
|
|
51
|
+
|
|
52
|
+
def __enter__(self):
|
|
53
|
+
self.original_stdout = sys.stdout
|
|
54
|
+
self.original_stderr = sys.stderr
|
|
55
|
+
sys.stdout = Tee(sys.stdout, self.file_handle)
|
|
56
|
+
sys.stderr = Tee(sys.stderr, self.file_handle_stderr)
|
|
57
|
+
self.modified_sys = True
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
61
|
+
sys.stdout = self.original_stdout
|
|
62
|
+
sys.stderr = self.original_stderr
|
|
63
|
+
self.file_handle.close()
|
|
64
|
+
self.file_handle_stderr.close()
|
|
65
|
+
self.modified_sys = False
|
|
66
|
+
|
|
67
|
+
def log(self, *args, **kwargs):
|
|
68
|
+
if self.modified_sys:
|
|
69
|
+
print(*args, **kwargs)
|
|
70
|
+
else:
|
|
71
|
+
# Emulate print(*args, **kwargs) behavior but write to the file
|
|
72
|
+
print(*args, **kwargs)
|
|
73
|
+
print(*args, file=self.file_handle_stderr, **kwargs)
|
|
74
|
+
self.file_handle.flush()
|
|
75
|
+
self.file_handle_stderr.flush()
|