mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,29 @@
1
+ # RAG Adapter Dependencies
2
+ # Install these dependencies based on which vector store you want to use
3
+
4
+ # Common dependency for all vector stores
5
+ litellm>=1.64.0
6
+
7
+ # ChromaDB vector store
8
+ chromadb>=0.4.0
9
+
10
+ # Weaviate vector store
11
+ weaviate-client>=4.0.0
12
+
13
+ # Qdrant vector store
14
+ qdrant-client>=1.15.0
15
+
16
+ # Milvus vector store
17
+ pymilvus>=2.6.0
18
+
19
+ # LanceDB vector store
20
+ lancedb>=0.22.0
21
+ pyarrow>=10.0.0
22
+
23
+ # Installation examples:
24
+ # For ChromaDB: pip install litellm>=1.64.0 chromadb>=0.4.0
25
+ # For Weaviate: pip install litellm>=1.64.0 weaviate-client>=4.0.0
26
+ # For Qdrant: pip install litellm>=1.64.0 qdrant-client>=1.15.0
27
+ # For Milvus: pip install litellm>=1.64.0 pymilvus>=2.6.0
28
+ # For LanceDB: pip install litellm>=1.64.0 lancedb>=0.22.0 pyarrow>=10.0.0
29
+ # For all: pip install -r requirements-rag.txt
@@ -0,0 +1,16 @@
1
+
2
+ You are an AI assistant tasked with solving command-line tasks in a Linux environment. You will be given a task instruction and the output from previously executed commands. Your goal is to solve the task by providing batches of shell commands.
3
+
4
+ For each response:
5
+ 1. Analyze the current state based on any terminal output provided
6
+ 2. Determine the next set of commands needed to make progress
7
+ 3. Decide if you need to see the output of these commands before proceeding
8
+
9
+ Don't include markdown formatting.
10
+
11
+ Note that you operate directly on the terminal from inside a tmux session. Use tmux keystrokes like `C-x` or `Escape` to interactively navigate the terminal. If you would like to execute a command that you have written you will need to append a newline character to the end of your command.
12
+
13
+ For example, if you write "ls -la" you will need to append a newline character to the end of your command like this: `ls -la
14
+ `.
15
+
16
+ One thing to be very careful about is handling interactive sessions like less, vim, or git diff. In these cases, you should not wait for the output of the command. Instead, you should send the keystrokes to the terminal as if you were typing them.
@@ -0,0 +1,9 @@
1
+ Instruction:
2
+ {instruction}
3
+
4
+ Your response must be a JSON object that matches this schema:
5
+
6
+ {response_schema}
7
+
8
+ The current terminal state is:
9
+ {terminal_state}
@@ -0,0 +1,161 @@
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import litellm
6
+ from terminal_bench.agents.terminus_1 import AgentResult, Chat, FailureMode, Terminus
7
+ from terminal_bench.dataset.dataset import Dataset
8
+ from terminal_bench.terminal.tmux_session import TmuxSession
9
+
10
+ from mantisdk.algorithm.gepa.lib import optimize
11
+ from mantisdk.algorithm.gepa.lib.adapters.terminal_bench_adapter.terminal_bench_adapter import (
12
+ TerminalBenchTask,
13
+ TerminusAdapter,
14
+ )
15
+
16
+ INSTRUCTION_PROMPT_PATH = Path(__file__).parent / "prompt-templates/instruction_prompt.txt"
17
+
18
+
19
+ class TerminusWrapper(Terminus):
20
+ def __init__(
21
+ self,
22
+ model_name: str,
23
+ max_episodes: int = 50,
24
+ api_base: str | None = None,
25
+ **kwargs,
26
+ ):
27
+ self.PROMPT_TEMPLATE_PATH = Path(__file__).parent / "prompt-templates/terminus.txt"
28
+ self.instruction_prompt = INSTRUCTION_PROMPT_PATH.read_text()
29
+ super().__init__(model_name, max_episodes, api_base, **kwargs)
30
+
31
+ def perform_task(
32
+ self,
33
+ instruction: str,
34
+ session: TmuxSession,
35
+ logging_dir: Path | None = None,
36
+ ):
37
+ chat = Chat(self._llm)
38
+
39
+ initial_prompt = self.instruction_prompt + self._prompt_template.format(
40
+ response_schema=self._response_schema,
41
+ instruction=instruction,
42
+ history="",
43
+ terminal_state=session.capture_pane(),
44
+ )
45
+
46
+ self._run_agent_loop(initial_prompt, session, chat, logging_dir)
47
+
48
+ return AgentResult(
49
+ total_input_tokens=chat.total_input_tokens,
50
+ total_output_tokens=chat.total_output_tokens,
51
+ failure_mode=FailureMode.NONE,
52
+ timestamped_markers=self._timestamped_markers,
53
+ )
54
+
55
+
56
+ if __name__ == "__main__":
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("--model_name", type=str, default="gpt-4o-mini")
59
+ parser.add_argument("--n_concurrent", type=int, default=6)
60
+ args = parser.parse_args()
61
+
62
+ initial_prompt_from_terminus = """
63
+ You are an AI assistant tasked with solving command-line tasks in a Linux environment. You will be given a task instruction and the output from previously executed commands. Your goal is to solve the task by providing batches of shell commands.
64
+
65
+ For each response:
66
+ 1. Analyze the current state based on any terminal output provided
67
+ 2. Determine the next set of commands needed to make progress
68
+ 3. Decide if you need to see the output of these commands before proceeding
69
+
70
+ Don't include markdown formatting.
71
+
72
+ Note that you operate directly on the terminal from inside a tmux session. Use tmux keystrokes like `C-x` or `Escape` to interactively navigate the terminal. If you would like to execute a command that you have written you will need to append a newline character to the end of your command.
73
+
74
+ For example, if you write "ls -la" you will need to append a newline character to the end of your command like this: `ls -la\n`.
75
+
76
+ One thing to be very careful about is handling interactive sessions like less, vim, or git diff. In these cases, you should not wait for the output of the command. Instead, you should send the keystrokes to the terminal as if you were typing them.
77
+ """
78
+
79
+ terminal_bench_dataset = Dataset(name="terminal-bench-core", version="head")
80
+ terminal_bench_dataset.sort_by_duration()
81
+
82
+ terminal_bench_tasks = terminal_bench_dataset._tasks[::-1]
83
+
84
+ trainset = [
85
+ TerminalBenchTask(task_id=task.name, model_name=args.model_name) for task in terminal_bench_tasks[30:50]
86
+ ]
87
+ valset = [TerminalBenchTask(task_id=task.name, model_name=args.model_name) for task in terminal_bench_tasks[:30]]
88
+
89
+ testset = [
90
+ TerminalBenchTask(task_id=task.name, model_name=args.model_name)
91
+ for task in terminal_bench_tasks[50:]
92
+ if task.name != "chem-rf"
93
+ ]
94
+
95
+ reflection_lm_name = "openai/gpt-5"
96
+ reflection_lm = (
97
+ lambda prompt: litellm.completion(
98
+ model=reflection_lm_name,
99
+ messages=[{"role": "user", "content": prompt}],
100
+ reasoning_effort="high",
101
+ )
102
+ .choices[0]
103
+ .message.content
104
+ )
105
+
106
+ adapter = TerminusAdapter(n_concurrent=args.n_concurrent, instruction_prompt_path=INSTRUCTION_PROMPT_PATH)
107
+ testset_results_no_prompt = adapter.evaluate(testset, {"instruction_prompt": ""}, capture_traces=True)
108
+ testset_results_before_opt = adapter.evaluate(
109
+ testset,
110
+ {"instruction_prompt": initial_prompt_from_terminus},
111
+ capture_traces=True,
112
+ )
113
+
114
+ with open("gepa_terminus/testset_results_no_prompt.json", "w") as f:
115
+ json.dump(
116
+ {
117
+ "score": sum(trajectory["success"] for trajectory in testset_results_no_prompt.trajectories),
118
+ "trajectories": testset_results_no_prompt.trajectories,
119
+ },
120
+ f,
121
+ indent=4,
122
+ )
123
+ with open("gepa_terminus/testset_results_before_opt.json", "w") as f:
124
+ json.dump(
125
+ {
126
+ "score": sum(trajectory["success"] for trajectory in testset_results_before_opt.trajectories),
127
+ "trajectories": testset_results_before_opt.trajectories,
128
+ },
129
+ f,
130
+ indent=4,
131
+ )
132
+
133
+ optimized_results = optimize(
134
+ seed_candidate={"instruction_prompt": initial_prompt_from_terminus},
135
+ trainset=trainset,
136
+ valset=valset,
137
+ adapter=adapter,
138
+ reflection_lm=reflection_lm,
139
+ use_wandb=True,
140
+ max_metric_calls=400,
141
+ reflection_minibatch_size=3,
142
+ perfect_score=1,
143
+ skip_perfect_score=False,
144
+ run_dir="gepa_terminus",
145
+ )
146
+
147
+ testset_results_after_opt = adapter.evaluate(
148
+ testset,
149
+ {"instruction_prompt": optimized_results.best_candidate["instruction_prompt"]},
150
+ capture_traces=True,
151
+ )
152
+
153
+ with open("gepa_terminus/optimized_results.json", "w") as f:
154
+ json.dump(
155
+ {
156
+ "score": sum(trajectory["success"] for trajectory in testset_results_after_opt.trajectories),
157
+ "trajectories": testset_results_after_opt.trajectories,
158
+ },
159
+ f,
160
+ indent=4,
161
+ )
@@ -0,0 +1,117 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+
5
+ import random
6
+ from typing import Any, Mapping
7
+
8
+
9
+ def json_default(x):
10
+ """Default JSON encoder for objects that are not serializable by default."""
11
+ try:
12
+ return {**x}
13
+ except Exception:
14
+ return repr(x)
15
+
16
+
17
+ def idxmax(lst: list[float]) -> int:
18
+ """Return the index of the maximum value in a list."""
19
+ max_val = max(lst)
20
+ return lst.index(max_val)
21
+
22
+
23
+ def is_dominated(y, programs, program_at_pareto_front_valset):
24
+ y_fronts = [front for front in program_at_pareto_front_valset.values() if y in front]
25
+ for front in y_fronts:
26
+ found_dominator_in_front = False
27
+ for other_prog in front:
28
+ if other_prog in programs:
29
+ found_dominator_in_front = True
30
+ break
31
+ if not found_dominator_in_front:
32
+ return False
33
+
34
+ return True
35
+
36
+
37
+ def remove_dominated_programs(program_at_pareto_front_valset, scores=None):
38
+ freq = {}
39
+ for front in program_at_pareto_front_valset.values():
40
+ for p in front:
41
+ freq[p] = freq.get(p, 0) + 1
42
+
43
+ dominated = set()
44
+ programs = list(freq.keys())
45
+
46
+ if scores is None:
47
+ scores = dict.fromkeys(programs, 1)
48
+
49
+ programs = sorted(programs, key=lambda x: scores[x], reverse=False)
50
+
51
+ found_to_remove = True
52
+ while found_to_remove:
53
+ found_to_remove = False
54
+ for y in programs:
55
+ if y in dominated:
56
+ continue
57
+ if is_dominated(y, set(programs).difference({y}).difference(dominated), program_at_pareto_front_valset):
58
+ dominated.add(y)
59
+ found_to_remove = True
60
+ break
61
+
62
+ dominators = [p for p in programs if p not in dominated]
63
+ for front in program_at_pareto_front_valset.values():
64
+ if not front:
65
+ continue
66
+ assert any(p in front for p in dominators)
67
+
68
+ new_program_at_pareto_front_valset = {
69
+ val_id: {prog_idx for prog_idx in front if prog_idx in dominators}
70
+ for val_id, front in program_at_pareto_front_valset.items()
71
+ }
72
+ for val_id, front_new in new_program_at_pareto_front_valset.items():
73
+ assert front_new.issubset(program_at_pareto_front_valset[val_id])
74
+
75
+ return new_program_at_pareto_front_valset
76
+
77
+
78
+ def find_dominator_programs(pareto_front_programs, train_val_weighted_agg_scores_for_all_programs):
79
+ train_val_pareto_front_programs = pareto_front_programs
80
+ new_program_at_pareto_front_valset = remove_dominated_programs(
81
+ train_val_pareto_front_programs, scores=train_val_weighted_agg_scores_for_all_programs
82
+ )
83
+ uniq_progs = []
84
+ for front in new_program_at_pareto_front_valset.values():
85
+ uniq_progs.extend(front)
86
+ uniq_progs = set(uniq_progs)
87
+ return list(uniq_progs)
88
+
89
+
90
+ def select_program_candidate_from_pareto_front(
91
+ pareto_front_programs: Mapping[Any, set[int]],
92
+ train_val_weighted_agg_scores_for_all_programs: list[float],
93
+ rng: random.Random,
94
+ ) -> int:
95
+ train_val_pareto_front_programs = pareto_front_programs
96
+ new_program_at_pareto_front_valset = remove_dominated_programs(
97
+ train_val_pareto_front_programs, scores=train_val_weighted_agg_scores_for_all_programs
98
+ )
99
+ program_frequency_in_validation_pareto_front = {}
100
+ for testcase_pareto_front in new_program_at_pareto_front_valset.values():
101
+ for prog_idx in testcase_pareto_front:
102
+ if prog_idx not in program_frequency_in_validation_pareto_front:
103
+ program_frequency_in_validation_pareto_front[prog_idx] = 0
104
+ program_frequency_in_validation_pareto_front[prog_idx] += 1
105
+
106
+ sampling_list = [
107
+ prog_idx for prog_idx, freq in program_frequency_in_validation_pareto_front.items() for _ in range(freq)
108
+ ]
109
+
110
+ # TODO: Determine if we need this fallback
111
+ # if not sampling_list:
112
+ # # No Pareto programs survived; fall back to the globally highest-scoring program.
113
+ # return idxmax(train_val_weighted_agg_scores_for_all_programs)
114
+ assert len(sampling_list) > 0
115
+
116
+ curr_prog_id = rng.choice(sampling_list)
117
+ return curr_prog_id
File without changes
@@ -0,0 +1,187 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ from typing import Any
5
+
6
+
7
+ class ExperimentTracker:
8
+ """
9
+ Unified experiment tracking that supports both wandb and mlflow.
10
+ """
11
+
12
+ def __enter__(self):
13
+ """Context manager entry."""
14
+ self.initialize()
15
+ self.start_run()
16
+ return self
17
+
18
+ def __exit__(self, exc_type, exc_val, exc_tb):
19
+ """Context manager exit - always end the run."""
20
+ self.end_run()
21
+ return False # Don't suppress exceptions
22
+
23
+ def __init__(
24
+ self,
25
+ use_wandb: bool = False,
26
+ wandb_api_key: str | None = None,
27
+ wandb_init_kwargs: dict[str, Any] | None = None,
28
+ use_mlflow: bool = False,
29
+ mlflow_tracking_uri: str | None = None,
30
+ mlflow_experiment_name: str | None = None,
31
+ ):
32
+ self.use_wandb = use_wandb
33
+ self.use_mlflow = use_mlflow
34
+
35
+ self.wandb_api_key = wandb_api_key
36
+ self.wandb_init_kwargs = wandb_init_kwargs or {}
37
+ self.mlflow_tracking_uri = mlflow_tracking_uri
38
+ self.mlflow_experiment_name = mlflow_experiment_name
39
+
40
+ self._created_mlflow_run = False
41
+
42
+ def initialize(self):
43
+ """Initialize the logging backends."""
44
+ if self.use_wandb:
45
+ self._initialize_wandb()
46
+ if self.use_mlflow:
47
+ self._initialize_mlflow()
48
+
49
+ def _initialize_wandb(self):
50
+ """Initialize wandb."""
51
+ try:
52
+ import wandb # type: ignore
53
+
54
+ if self.wandb_api_key:
55
+ wandb.login(key=self.wandb_api_key, verify=True)
56
+ else:
57
+ wandb.login()
58
+ except ImportError:
59
+ raise ImportError("wandb is not installed. Please install it or set backend='mlflow' or 'none'.")
60
+ except Exception as e:
61
+ raise RuntimeError(f"Error logging into wandb: {e}")
62
+
63
+ def _initialize_mlflow(self):
64
+ """Initialize mlflow."""
65
+ try:
66
+ import mlflow # type: ignore
67
+
68
+ if self.mlflow_tracking_uri:
69
+ mlflow.set_tracking_uri(self.mlflow_tracking_uri)
70
+ if self.mlflow_experiment_name:
71
+ mlflow.set_experiment(self.mlflow_experiment_name)
72
+ except ImportError:
73
+ raise ImportError("mlflow is not installed. Please install it or set backend='wandb' or 'none'.")
74
+ except Exception as e:
75
+ raise RuntimeError(f"Error setting up mlflow: {e}")
76
+
77
+ def start_run(self):
78
+ """Start a new run."""
79
+ if self.use_wandb:
80
+ import wandb # type: ignore
81
+
82
+ wandb.init(**self.wandb_init_kwargs)
83
+ if self.use_mlflow:
84
+ import mlflow # type: ignore
85
+
86
+ # Only start a new run if there's no active run
87
+ if mlflow.active_run() is None:
88
+ mlflow.start_run()
89
+ self._created_mlflow_run = True
90
+ else:
91
+ self._created_mlflow_run = False
92
+
93
+ def log_metrics(self, metrics: dict[str, Any], step: int | None = None):
94
+ """Log metrics to the active backends."""
95
+ if self.use_wandb:
96
+ try:
97
+ import wandb # type: ignore
98
+
99
+ wandb.log(metrics, step=step)
100
+ except Exception as e:
101
+ print(f"Warning: Failed to log to wandb: {e}")
102
+
103
+ if self.use_mlflow:
104
+ try:
105
+ import mlflow # type: ignore
106
+
107
+ mlflow.log_metrics(metrics, step=step)
108
+ except Exception as e:
109
+ print(f"Warning: Failed to log to mlflow: {e}")
110
+
111
+ def end_run(self):
112
+ """End the current run."""
113
+ if self.use_wandb:
114
+ try:
115
+ import wandb # type: ignore
116
+
117
+ if wandb.run is not None:
118
+ wandb.finish()
119
+ except Exception as e:
120
+ print(f"Warning: Failed to end wandb run: {e}")
121
+
122
+ if self.use_mlflow:
123
+ try:
124
+ import mlflow # type: ignore
125
+
126
+ if self._created_mlflow_run and mlflow.active_run() is not None:
127
+ mlflow.end_run()
128
+ self._created_mlflow_run = False
129
+ except Exception as e:
130
+ print(f"Warning: Failed to end mlflow run: {e}")
131
+
132
+ def is_active(self) -> bool:
133
+ """Check if any backend has an active run."""
134
+ if self.use_wandb:
135
+ try:
136
+ import wandb # type: ignore
137
+
138
+ if wandb.run is not None:
139
+ return True
140
+ except Exception:
141
+ pass
142
+
143
+ if self.use_mlflow:
144
+ try:
145
+ import mlflow # type: ignore
146
+
147
+ if mlflow.active_run() is not None:
148
+ return True
149
+ except Exception:
150
+ pass
151
+
152
+ return False
153
+
154
+
155
+ def create_experiment_tracker(
156
+ use_wandb: bool = False,
157
+ wandb_api_key: str | None = None,
158
+ wandb_init_kwargs: dict[str, Any] | None = None,
159
+ use_mlflow: bool = False,
160
+ mlflow_tracking_uri: str | None = None,
161
+ mlflow_experiment_name: str | None = None,
162
+ ) -> ExperimentTracker:
163
+ """
164
+ Create an experiment tracker based on the specified backends.
165
+
166
+ Args:
167
+ use_wandb: Whether to use wandb
168
+ use_mlflow: Whether to use mlflow
169
+ wandb_api_key: API key for wandb
170
+ wandb_init_kwargs: Additional kwargs for wandb.init()
171
+ mlflow_tracking_uri: Tracking URI for mlflow
172
+ mlflow_experiment_name: Experiment name for mlflow
173
+
174
+ Returns:
175
+ ExperimentTracker instance
176
+
177
+ Note:
178
+ Both wandb and mlflow can be used simultaneously if desired.
179
+ """
180
+ return ExperimentTracker(
181
+ use_wandb=use_wandb,
182
+ wandb_api_key=wandb_api_key,
183
+ wandb_init_kwargs=wandb_init_kwargs,
184
+ use_mlflow=use_mlflow,
185
+ mlflow_tracking_uri=mlflow_tracking_uri,
186
+ mlflow_experiment_name=mlflow_experiment_name,
187
+ )
@@ -0,0 +1,75 @@
1
+ # Copyright (c) 2025 Lakshya A Agrawal and the GEPA contributors
2
+ # https://github.com/gepa-ai/gepa
3
+
4
+ import sys
5
+ from typing import Protocol
6
+
7
+
8
+ class LoggerProtocol(Protocol):
9
+ def log(self, message: str): ...
10
+
11
+
12
+ class StdOutLogger(LoggerProtocol):
13
+ def log(self, message: str):
14
+ print(message)
15
+
16
+
17
+ class Tee:
18
+ def __init__(self, *files):
19
+ self.files = files
20
+
21
+ def write(self, obj):
22
+ for f in self.files:
23
+ f.write(obj)
24
+
25
+ def flush(self):
26
+ for f in self.files:
27
+ if hasattr(f, "flush"):
28
+ f.flush()
29
+
30
+ def isatty(self):
31
+ # True if any of the files is a terminal
32
+ return any(hasattr(f, "isatty") and f.isatty() for f in self.files)
33
+
34
+ def close(self):
35
+ for f in self.files:
36
+ if hasattr(f, "close"):
37
+ f.close()
38
+
39
+ def fileno(self):
40
+ for f in self.files:
41
+ if hasattr(f, "fileno"):
42
+ return f.fileno()
43
+ raise OSError("No underlying file object with fileno")
44
+
45
+
46
+ class Logger(LoggerProtocol):
47
+ def __init__(self, filename, mode="a"):
48
+ self.file_handle = open(filename, mode)
49
+ self.file_handle_stderr = open(filename.replace("run_log.", "run_log_stderr."), mode)
50
+ self.modified_sys = False
51
+
52
+ def __enter__(self):
53
+ self.original_stdout = sys.stdout
54
+ self.original_stderr = sys.stderr
55
+ sys.stdout = Tee(sys.stdout, self.file_handle)
56
+ sys.stderr = Tee(sys.stderr, self.file_handle_stderr)
57
+ self.modified_sys = True
58
+ return self
59
+
60
+ def __exit__(self, exc_type, exc_value, traceback):
61
+ sys.stdout = self.original_stdout
62
+ sys.stderr = self.original_stderr
63
+ self.file_handle.close()
64
+ self.file_handle_stderr.close()
65
+ self.modified_sys = False
66
+
67
+ def log(self, *args, **kwargs):
68
+ if self.modified_sys:
69
+ print(*args, **kwargs)
70
+ else:
71
+ # Emulate print(*args, **kwargs) behavior but write to the file
72
+ print(*args, **kwargs)
73
+ print(*args, file=self.file_handle_stderr, **kwargs)
74
+ self.file_handle.flush()
75
+ self.file_handle_stderr.flush()