mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from train_anymaths import init_dataset
|
|
2
|
+
|
|
3
|
+
from mantisdk.algorithm.gepa.lib.adapters.anymaths_adapter.anymaths_adapter import AnyMathsStructuredOutput
|
|
4
|
+
|
|
5
|
+
if __name__ == "__main__":
|
|
6
|
+
import argparse
|
|
7
|
+
import ast
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import litellm
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
parser = argparse.ArgumentParser()
|
|
14
|
+
parser.add_argument("--anymaths_dset_name", type=str, default="openai/gsm8k")
|
|
15
|
+
parser.add_argument("--model", type=str, default="ollama/qwen3:4b", help="The model to evaluate.")
|
|
16
|
+
parser.add_argument("--use_api_url", action="store_true", help="Whether to use the API URL.")
|
|
17
|
+
parser.add_argument("--api_url", type=str, default="http://localhost:11434", help="The API URL to use.")
|
|
18
|
+
parser.add_argument("--batch_size", type=int, default=8, help="The batch size for evaluation.")
|
|
19
|
+
parser.add_argument(
|
|
20
|
+
"--max_litellm_workers", type=int, default=1, help="The maximum number of LiteLLM workers to use."
|
|
21
|
+
)
|
|
22
|
+
parser.add_argument(
|
|
23
|
+
"--which_prompt",
|
|
24
|
+
type=str,
|
|
25
|
+
default="seed",
|
|
26
|
+
choices=["seed", "optimized"],
|
|
27
|
+
help="The prompt to use for evaluation.",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
args = parser.parse_args()
|
|
31
|
+
|
|
32
|
+
dataset = args.anymaths_dset_name
|
|
33
|
+
|
|
34
|
+
use_api_url = args.use_api_url
|
|
35
|
+
if not use_api_url:
|
|
36
|
+
api_url = ""
|
|
37
|
+
else:
|
|
38
|
+
api_url = args.api_url
|
|
39
|
+
|
|
40
|
+
model = args.model
|
|
41
|
+
max_litellm_workers = args.max_litellm_workers
|
|
42
|
+
|
|
43
|
+
_, _, testset = init_dataset(dataset)
|
|
44
|
+
|
|
45
|
+
if args.which_prompt == "seed":
|
|
46
|
+
INSTRUCTION_PROMPT_PATH = Path(__file__).parent / "prompt-templates/instruction_prompt.txt"
|
|
47
|
+
else:
|
|
48
|
+
INSTRUCTION_PROMPT_PATH = Path(__file__).parent / "prompt-templates/optimal_prompt.txt"
|
|
49
|
+
|
|
50
|
+
instruction = INSTRUCTION_PROMPT_PATH.read_text()
|
|
51
|
+
|
|
52
|
+
batched_testset = []
|
|
53
|
+
batch_size = args.batch_size
|
|
54
|
+
|
|
55
|
+
for i in range(0, len(testset), batch_size):
|
|
56
|
+
batched_testset.append(testset[i : i + batch_size])
|
|
57
|
+
|
|
58
|
+
total_score = 0.0
|
|
59
|
+
|
|
60
|
+
print("-" * 100)
|
|
61
|
+
print(f"Evaluating model: {model}")
|
|
62
|
+
print(f"Using API URL: {api_url if api_url else 'No API URL'}")
|
|
63
|
+
print(f"Batch size: {batch_size}")
|
|
64
|
+
print(f"Max LiteLLM workers: {max_litellm_workers}")
|
|
65
|
+
print(f"Using prompt: {args.which_prompt}")
|
|
66
|
+
print("-" * 100)
|
|
67
|
+
|
|
68
|
+
with tqdm(total=len(testset), desc="Evaluating") as pbar:
|
|
69
|
+
for batch in batched_testset:
|
|
70
|
+
litellm_requests = []
|
|
71
|
+
|
|
72
|
+
for item in batch:
|
|
73
|
+
user_content = f"{item['input']}"
|
|
74
|
+
messages = [{"role": "system", "content": instruction}, {"role": "user", "content": user_content}]
|
|
75
|
+
|
|
76
|
+
litellm_requests.append(messages)
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
responses = litellm.batch_completion(
|
|
80
|
+
model=model,
|
|
81
|
+
messages=litellm_requests,
|
|
82
|
+
api_base=api_url,
|
|
83
|
+
max_workers=max_litellm_workers,
|
|
84
|
+
format=AnyMathsStructuredOutput.model_json_schema(),
|
|
85
|
+
response_format={
|
|
86
|
+
"type": "json_object",
|
|
87
|
+
"response_schema": AnyMathsStructuredOutput.model_json_schema(),
|
|
88
|
+
"enforce_validation": True,
|
|
89
|
+
},
|
|
90
|
+
)
|
|
91
|
+
except litellm.exceptions.JSONSchemaValidationError as e:
|
|
92
|
+
raise e
|
|
93
|
+
|
|
94
|
+
for response, item in zip(responses, batch, strict=False):
|
|
95
|
+
correct_output_format = True
|
|
96
|
+
try:
|
|
97
|
+
assistant_response = ast.literal_eval(response.choices[0].message.content.strip())
|
|
98
|
+
assistant_final_answer = assistant_response["final_answer"]
|
|
99
|
+
ground_truth = item["answer"]
|
|
100
|
+
score = 1.0 if ground_truth in assistant_final_answer else 0.0
|
|
101
|
+
total_score += score
|
|
102
|
+
except Exception:
|
|
103
|
+
correct_output_format = False
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
pbar.update(len(batch))
|
|
107
|
+
pbar.set_postfix({"Score": f"{total_score} / {len(testset):.4f}"})
|
|
108
|
+
|
|
109
|
+
print("-" * 100)
|
|
110
|
+
print(f"Final score >> {total_score} / {len(testset):.4f}")
|
|
111
|
+
print("-" * 100)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
You are an AI assistant that solves mathematical word problems. You will be given a question and you need to provide a step-by-step solution to the problem. Finally, you will provide the answer to the question. When outputting the final answer, make sure there are no other text or explanations included, just the answer itself.
|
|
2
|
+
|
|
3
|
+
The expected output must be a JSON object with the following format:
|
|
4
|
+
{
|
|
5
|
+
"final_answer": <the final answer to the question>,
|
|
6
|
+
"solution_pad": <the step-by-step solution to the problem>
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
Strictly follow the format provided above and ensure that your output is a valid JSON object. Any deviation from this format will result in an error.
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
You are an AI assistant that solves mathematical word problems. You will be given a question and you need to provide a step-by-step solution to the problem. Finally, you will provide the answer to the question.
|
|
2
|
+
|
|
3
|
+
When outputting the final answer, make sure there are no other text or explanations included, just the answer itself.
|
|
4
|
+
|
|
5
|
+
The following fields are what you need to include in your response:
|
|
6
|
+
- final_answer: The final answer to the question.
|
|
7
|
+
- solution_pad: The step-by-step solution to the problem.
|
|
8
|
+
|
|
9
|
+
Here are specific guidelines for generating your response:
|
|
10
|
+
|
|
11
|
+
1. **Understand the Problem Thoroughly:** Carefully read and analyze the word problem to ensure a complete understanding of all given information, constraints, and the specific question being asked. Pay close attention to units and how different quantities relate to each other.
|
|
12
|
+
|
|
13
|
+
2. **Formulate the Step-by-Step Solution (solution_pad):**
|
|
14
|
+
* Develop a clear, logical, and sequential step-by-step solution. Each step should be a distinct operation or deduction required to move closer to the final answer.
|
|
15
|
+
* Clearly state what is being calculated or determined in each step.
|
|
16
|
+
* Perform all necessary calculations with high precision and accuracy. Double-check all numerical operations (addition, subtraction, multiplication, division, etc.) to prevent errors.
|
|
17
|
+
* If the problem involves converting between different forms of a quantity (e.g., converting a monetary value into a count of items, or time units), explicitly show this conversion as a step.
|
|
18
|
+
* **Domain-Specific Interpretation Example:** If Barry has "$10.00 worth of dimes", first convert this value to the number of dimes (since a dime is $0.10, Barry has $10.00 / $0.10 = 100 dimes). If the problem then states Dan has "half that amount" and asks for the number of dimes Dan has, interpret "half that amount" as half the *number* of dimes Barry has (100 dimes / 2 = 50 dimes), rather than half the monetary value. Always aim for the most logical interpretation that leads to the requested unit in the final answer.
|
|
19
|
+
* The `solution_pad` field must *only* contain the clean, direct step-by-step solution. Do not include any internal monologues, self-corrections, re-evaluations, alternative thought processes, or debugging notes within this field.
|
|
20
|
+
|
|
21
|
+
3. **Calculate and Output the Final Answer:**
|
|
22
|
+
* Based on your thoroughly computed step-by-step solution, determine the exact numerical answer to the question.
|
|
23
|
+
* The `final_answer` field must contain *only* the numerical value. Do not include any currency symbols (e.g., "$"), units (e.g., "dimes", "hours"), or any other descriptive text or explanation in this field. For example, if the answer is 4625 dollars, output `4625`. If the answer is 52 dimes, output `52`.
|
|
24
|
+
* Ensure the final answer numerically matches the result of your `solution_pad` calculations.'
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
def init_dataset(anymaths_dset_name: str = "openai/gsm8k"):
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
from datasets import load_dataset
|
|
5
|
+
|
|
6
|
+
train_split = []
|
|
7
|
+
test_split = []
|
|
8
|
+
match anymaths_dset_name:
|
|
9
|
+
case "openai/gsm8k":
|
|
10
|
+
train_load_dataset = load_dataset(anymaths_dset_name, "main", split="train")
|
|
11
|
+
for item in train_load_dataset:
|
|
12
|
+
answer = item["answer"].split("####")[-1].strip()
|
|
13
|
+
solution = item["answer"].split("####")[0].strip()
|
|
14
|
+
question = item["question"]
|
|
15
|
+
|
|
16
|
+
train_split.append({"input": question, "additional_context": {"solution": solution}, "answer": answer})
|
|
17
|
+
|
|
18
|
+
random.Random(0).shuffle(train_split)
|
|
19
|
+
|
|
20
|
+
test_load_dataset = load_dataset(anymaths_dset_name, "main", split="test")
|
|
21
|
+
for item in test_load_dataset:
|
|
22
|
+
answer = item["answer"].split("####")[-1].strip()
|
|
23
|
+
solution = item["answer"].split("####")[0].strip()
|
|
24
|
+
question = item["question"]
|
|
25
|
+
|
|
26
|
+
test_split.append({"input": question, "answer": answer})
|
|
27
|
+
|
|
28
|
+
case "MathArena/aime_2025":
|
|
29
|
+
train_load_dataset = load_dataset("AI-MO/aimo-validation-aime", "default", split="train")
|
|
30
|
+
for item in train_load_dataset:
|
|
31
|
+
question = item["problem"]
|
|
32
|
+
solution = item["solution"]
|
|
33
|
+
answer = item["answer"]
|
|
34
|
+
|
|
35
|
+
train_split.append({"input": question, "additional_context": {"solution": solution}, "answer": answer})
|
|
36
|
+
|
|
37
|
+
random.Random(0).shuffle(train_split)
|
|
38
|
+
|
|
39
|
+
test_load_dataset = load_dataset("MathArena/aime_2025", "default", split="train")
|
|
40
|
+
for item in test_load_dataset:
|
|
41
|
+
question = item["problem"]
|
|
42
|
+
answer = item["answer"]
|
|
43
|
+
|
|
44
|
+
test_split.append({"input": question, "answer": answer})
|
|
45
|
+
case _:
|
|
46
|
+
raise ValueError(f"Unknown dataset name: {anymaths_dset_name}")
|
|
47
|
+
|
|
48
|
+
trainset = train_split[: len(train_split) // 2]
|
|
49
|
+
valset = train_split[len(train_split) // 2 :]
|
|
50
|
+
testset = test_split
|
|
51
|
+
|
|
52
|
+
return trainset, valset, testset
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if __name__ == "__main__":
|
|
56
|
+
import argparse
|
|
57
|
+
from functools import partial
|
|
58
|
+
from pathlib import Path
|
|
59
|
+
|
|
60
|
+
import litellm
|
|
61
|
+
|
|
62
|
+
from mantisdk.algorithm.gepa.lib import optimize
|
|
63
|
+
from mantisdk.algorithm.gepa.lib.adapters.anymaths_adapter import AnyMathsAdapter
|
|
64
|
+
|
|
65
|
+
parser = argparse.ArgumentParser()
|
|
66
|
+
parser.add_argument("--anymaths_dset_name", type=str, default="openai/gsm8k")
|
|
67
|
+
parser.add_argument("--train_size", type=int, default=1, help="The size of the training set to use.")
|
|
68
|
+
parser.add_argument("--val_size", type=int, default=1, help="The size of the validation set to use.")
|
|
69
|
+
parser.add_argument("--test_size", type=int, default=1, help="The size of the test set to use.")
|
|
70
|
+
parser.add_argument("--base_lm", type=str, default="ollama/qwen3:4b")
|
|
71
|
+
parser.add_argument("--use_api_base", action="store_true", help="Use API base URL")
|
|
72
|
+
parser.add_argument("--api_base_url", type=str, default="http://localhost:11434")
|
|
73
|
+
parser.add_argument(
|
|
74
|
+
"--reflection_lm", type=str, default="ollama/qwen3:8b", help="The name of the reflection LM to use."
|
|
75
|
+
)
|
|
76
|
+
parser.add_argument("--use_api_reflection", action="store_true", help="Use API reflection URL")
|
|
77
|
+
parser.add_argument(
|
|
78
|
+
"--api_reflection_url",
|
|
79
|
+
type=str,
|
|
80
|
+
default="http://localhost:11434",
|
|
81
|
+
help="The API base URL for the reflection LM.",
|
|
82
|
+
)
|
|
83
|
+
parser.add_argument(
|
|
84
|
+
"--reflection_minibatch_size", type=int, default=8, help="The size of the minibatch for the reflection LM."
|
|
85
|
+
)
|
|
86
|
+
parser.add_argument("--max_litellm_workers", type=int, default=10)
|
|
87
|
+
parser.add_argument("--budget", type=int, default=500, help="The budget for the optimization process.")
|
|
88
|
+
parser.add_argument(
|
|
89
|
+
"--seed", type=int, default=0, help="The seed for the random number generator for reproducibility."
|
|
90
|
+
)
|
|
91
|
+
args = parser.parse_args()
|
|
92
|
+
|
|
93
|
+
INSTRUCTION_PROMPT_PATH = Path(__file__).parent / "prompt-templates/instruction_prompt.txt"
|
|
94
|
+
|
|
95
|
+
seed_instruction = INSTRUCTION_PROMPT_PATH.read_text()
|
|
96
|
+
|
|
97
|
+
trainset, valset, testset = init_dataset(args.anymaths_dset_name)
|
|
98
|
+
|
|
99
|
+
train_size = args.train_size
|
|
100
|
+
val_size = args.val_size
|
|
101
|
+
test_size = args.test_size
|
|
102
|
+
|
|
103
|
+
for size in map(int, [train_size, val_size, test_size]):
|
|
104
|
+
if size <= 0:
|
|
105
|
+
raise ValueError("Train, val, and test sizes must be positive integers.")
|
|
106
|
+
|
|
107
|
+
trainset = trainset[:train_size]
|
|
108
|
+
valset = valset[:val_size]
|
|
109
|
+
testset = testset[:test_size]
|
|
110
|
+
|
|
111
|
+
print("-" * 100)
|
|
112
|
+
print(f"Using dataset: {args.anymaths_dset_name}")
|
|
113
|
+
print(f"Training set size: {len(trainset)}")
|
|
114
|
+
print(f"Validation set size: {len(valset)}")
|
|
115
|
+
print(f"Test set size: {len(testset)}")
|
|
116
|
+
print("-" * 100)
|
|
117
|
+
|
|
118
|
+
base_lm = args.base_lm
|
|
119
|
+
|
|
120
|
+
reflection_lm_name = args.reflection_lm
|
|
121
|
+
|
|
122
|
+
_reflection = {"model": reflection_lm_name}
|
|
123
|
+
|
|
124
|
+
use_api_base = args.use_api_base
|
|
125
|
+
use_api_reflection = args.use_api_reflection
|
|
126
|
+
|
|
127
|
+
if use_api_base:
|
|
128
|
+
api_base = args.api_base_url
|
|
129
|
+
else:
|
|
130
|
+
api_base = None
|
|
131
|
+
|
|
132
|
+
if use_api_reflection:
|
|
133
|
+
api_reflection = args.api_reflection_url
|
|
134
|
+
_reflection["base_url"] = api_reflection
|
|
135
|
+
else:
|
|
136
|
+
api_reflection = None
|
|
137
|
+
|
|
138
|
+
_reflection_completion = partial(litellm.completion, **_reflection)
|
|
139
|
+
|
|
140
|
+
def reflection_lm(prompt: str):
|
|
141
|
+
"""Call the reflection language model with the given prompt and return its content string."""
|
|
142
|
+
response = _reflection_completion(messages=[{"role": "user", "content": prompt}])
|
|
143
|
+
return response.choices[0].message.content
|
|
144
|
+
|
|
145
|
+
max_litellm_workers = args.max_litellm_workers
|
|
146
|
+
budget = args.budget
|
|
147
|
+
reflection_minibatch_size = args.reflection_minibatch_size
|
|
148
|
+
seed = args.seed
|
|
149
|
+
|
|
150
|
+
print(f"Using base LM: {base_lm}")
|
|
151
|
+
print(f"Using reflection LM: {reflection_lm_name}")
|
|
152
|
+
print(f"Using API base URL: {api_base}")
|
|
153
|
+
print(f"Using API reflection URL: {api_reflection}")
|
|
154
|
+
print(f"Reflection minibatch size: {reflection_minibatch_size}")
|
|
155
|
+
print(f"Max LiteLLM workers: {max_litellm_workers}")
|
|
156
|
+
print(f"Budget: {budget}")
|
|
157
|
+
print(f"Seed: {seed}")
|
|
158
|
+
print("-" * 100)
|
|
159
|
+
|
|
160
|
+
optimized_results = optimize(
|
|
161
|
+
seed_candidate={"instruction_prompt": seed_instruction},
|
|
162
|
+
trainset=trainset,
|
|
163
|
+
valset=valset,
|
|
164
|
+
adapter=AnyMathsAdapter(model=base_lm, api_base=api_base, max_litellm_workers=max_litellm_workers),
|
|
165
|
+
reflection_lm=reflection_lm,
|
|
166
|
+
reflection_minibatch_size=reflection_minibatch_size,
|
|
167
|
+
perfect_score=1,
|
|
168
|
+
skip_perfect_score=False,
|
|
169
|
+
use_wandb=False,
|
|
170
|
+
max_metric_calls=budget,
|
|
171
|
+
seed=seed,
|
|
172
|
+
display_progress_bar=True,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
print("-" * 100)
|
|
176
|
+
print(f"Best prompt >>> {optimized_results.best_candidate}")
|
|
177
|
+
print("-" * 100)
|