mini-swe-agent 1.17.4__py3-none-any.whl → 2.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/METADATA +36 -52
- mini_swe_agent-2.0.0a1.dist-info/RECORD +70 -0
- {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/WHEEL +1 -1
- mini_swe_agent-2.0.0a1.dist-info/entry_points.txt +5 -0
- minisweagent/__init__.py +19 -26
- minisweagent/agents/default.py +128 -113
- minisweagent/agents/interactive.py +119 -58
- minisweagent/config/README.md +3 -4
- minisweagent/config/__init__.py +36 -1
- minisweagent/config/benchmarks/swebench.yaml +156 -0
- minisweagent/config/{extra/swebench.yaml → benchmarks/swebench_backticks.yaml} +69 -64
- minisweagent/config/benchmarks/swebench_modal.yaml +47 -0
- minisweagent/config/{extra → benchmarks}/swebench_xml.yaml +73 -70
- minisweagent/config/default.yaml +24 -21
- minisweagent/config/inspector.tcss +42 -0
- minisweagent/config/mini.yaml +53 -71
- minisweagent/config/{github_issue.yaml → mini_textbased.yaml} +43 -29
- minisweagent/environments/__init__.py +1 -0
- minisweagent/environments/docker.py +67 -20
- minisweagent/environments/extra/bubblewrap.py +86 -47
- minisweagent/environments/extra/swerex_docker.py +53 -20
- minisweagent/environments/extra/swerex_modal.py +90 -0
- minisweagent/environments/local.py +62 -21
- minisweagent/environments/singularity.py +59 -18
- minisweagent/exceptions.py +22 -0
- minisweagent/models/__init__.py +6 -7
- minisweagent/models/extra/roulette.py +20 -17
- minisweagent/models/litellm_model.py +90 -44
- minisweagent/models/litellm_response_model.py +80 -0
- minisweagent/models/litellm_textbased_model.py +45 -0
- minisweagent/models/openrouter_model.py +87 -45
- minisweagent/models/openrouter_response_model.py +123 -0
- minisweagent/models/openrouter_textbased_model.py +76 -0
- minisweagent/models/portkey_model.py +84 -42
- minisweagent/models/portkey_response_model.py +163 -0
- minisweagent/models/requesty_model.py +91 -41
- minisweagent/models/test_models.py +246 -19
- minisweagent/models/utils/actions_text.py +60 -0
- minisweagent/models/utils/actions_toolcall.py +102 -0
- minisweagent/models/utils/actions_toolcall_response.py +110 -0
- minisweagent/models/utils/anthropic_utils.py +28 -0
- minisweagent/models/utils/cache_control.py +15 -2
- minisweagent/models/utils/content_string.py +74 -0
- minisweagent/models/utils/openai_multimodal.py +50 -0
- minisweagent/models/utils/retry.py +25 -0
- minisweagent/run/benchmarks/__init__.py +1 -0
- minisweagent/run/{extra → benchmarks}/swebench.py +57 -36
- minisweagent/run/benchmarks/swebench_single.py +89 -0
- minisweagent/run/{extra → benchmarks}/utils/batch_progress.py +1 -1
- minisweagent/run/hello_world.py +6 -0
- minisweagent/run/mini.py +54 -63
- minisweagent/run/utilities/__init__.py +1 -0
- minisweagent/run/{extra → utilities}/config.py +2 -0
- minisweagent/run/{inspector.py → utilities/inspector.py} +90 -11
- minisweagent/run/{mini_extra.py → utilities/mini_extra.py} +9 -5
- minisweagent/utils/serialize.py +26 -0
- mini_swe_agent-1.17.4.dist-info/RECORD +0 -61
- mini_swe_agent-1.17.4.dist-info/entry_points.txt +0 -5
- minisweagent/agents/interactive_textual.py +0 -450
- minisweagent/config/extra/swebench_roulette.yaml +0 -233
- minisweagent/config/mini.tcss +0 -86
- minisweagent/models/anthropic.py +0 -35
- minisweagent/models/litellm_response_api_model.py +0 -82
- minisweagent/models/portkey_response_api_model.py +0 -75
- minisweagent/models/utils/key_per_thread.py +0 -20
- minisweagent/models/utils/openai_utils.py +0 -41
- minisweagent/run/extra/swebench_single.py +0 -79
- minisweagent/run/github_issue.py +0 -87
- minisweagent/run/utils/__init__.py +0 -0
- minisweagent/run/utils/save.py +0 -78
- {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/licenses/LICENSE.md +0 -0
- {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/top_level.txt +0 -0
- /minisweagent/config/{extra → benchmarks}/__init__.py +0 -0
- /minisweagent/run/{extra → benchmarks}/utils/__init__.py +0 -0
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import platform
|
|
3
3
|
import subprocess
|
|
4
|
-
from dataclasses import asdict, dataclass, field
|
|
5
4
|
from typing import Any
|
|
6
5
|
|
|
6
|
+
from pydantic import BaseModel
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
8
|
+
from minisweagent.exceptions import Submitted
|
|
9
|
+
from minisweagent.utils.serialize import recursive_merge
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LocalEnvironmentConfig(BaseModel):
|
|
10
13
|
cwd: str = ""
|
|
11
|
-
env: dict[str, str] =
|
|
14
|
+
env: dict[str, str] = {}
|
|
12
15
|
timeout: int = 30
|
|
13
16
|
|
|
14
17
|
|
|
@@ -17,22 +20,60 @@ class LocalEnvironment:
|
|
|
17
20
|
"""This class executes bash commands directly on the local machine."""
|
|
18
21
|
self.config = config_class(**kwargs)
|
|
19
22
|
|
|
20
|
-
def execute(self,
|
|
23
|
+
def execute(self, action: dict, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
|
|
21
24
|
"""Execute a command in the local environment and return the result as a dict."""
|
|
25
|
+
command = action.get("command", "")
|
|
22
26
|
cwd = cwd or self.config.cwd or os.getcwd()
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
27
|
+
try:
|
|
28
|
+
result = subprocess.run(
|
|
29
|
+
command,
|
|
30
|
+
shell=True,
|
|
31
|
+
text=True,
|
|
32
|
+
cwd=cwd,
|
|
33
|
+
env=os.environ | self.config.env,
|
|
34
|
+
timeout=timeout or self.config.timeout,
|
|
35
|
+
encoding="utf-8",
|
|
36
|
+
errors="replace",
|
|
37
|
+
stdout=subprocess.PIPE,
|
|
38
|
+
stderr=subprocess.STDOUT,
|
|
39
|
+
)
|
|
40
|
+
output = {"output": result.stdout, "returncode": result.returncode, "exception_info": ""}
|
|
41
|
+
except Exception as e:
|
|
42
|
+
raw_output = getattr(e, "output", None)
|
|
43
|
+
raw_output = (
|
|
44
|
+
raw_output.decode("utf-8", errors="replace") if isinstance(raw_output, bytes) else (raw_output or "")
|
|
45
|
+
)
|
|
46
|
+
output = {
|
|
47
|
+
"output": raw_output,
|
|
48
|
+
"returncode": -1,
|
|
49
|
+
"exception_info": f"An error occurred while executing the command: {e}",
|
|
50
|
+
"extra": {"exception_type": type(e).__name__, "exception": str(e)},
|
|
51
|
+
}
|
|
52
|
+
self._check_finished(output)
|
|
53
|
+
return output
|
|
54
|
+
|
|
55
|
+
def _check_finished(self, output: dict):
|
|
56
|
+
"""Raises Submitted if the output indicates task completion."""
|
|
57
|
+
lines = output.get("output", "").lstrip().splitlines(keepends=True)
|
|
58
|
+
if lines and lines[0].strip() == "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT" and output["returncode"] == 0:
|
|
59
|
+
submission = "".join(lines[1:])
|
|
60
|
+
raise Submitted(
|
|
61
|
+
{
|
|
62
|
+
"role": "exit",
|
|
63
|
+
"content": submission,
|
|
64
|
+
"extra": {"exit_status": "Submitted", "submission": submission},
|
|
65
|
+
}
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
|
69
|
+
return recursive_merge(self.config.model_dump(), platform.uname()._asdict(), os.environ, kwargs)
|
|
70
|
+
|
|
71
|
+
def serialize(self) -> dict:
|
|
72
|
+
return {
|
|
73
|
+
"info": {
|
|
74
|
+
"config": {
|
|
75
|
+
"environment": self.config.model_dump(mode="json"),
|
|
76
|
+
"environment_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
}
|
|
@@ -6,18 +6,21 @@ import shutil
|
|
|
6
6
|
import subprocess
|
|
7
7
|
import tempfile
|
|
8
8
|
import uuid
|
|
9
|
-
from dataclasses import asdict, dataclass, field
|
|
10
9
|
from pathlib import Path
|
|
11
10
|
from typing import Any
|
|
12
11
|
|
|
12
|
+
from pydantic import BaseModel
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
from minisweagent.exceptions import Submitted
|
|
15
|
+
from minisweagent.utils.serialize import recursive_merge
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SingularityEnvironmentConfig(BaseModel):
|
|
16
19
|
image: str
|
|
17
20
|
cwd: str = "/"
|
|
18
|
-
env: dict[str, str] =
|
|
21
|
+
env: dict[str, str] = {}
|
|
19
22
|
"""Environment variables to set in the container."""
|
|
20
|
-
forward_env: list[str] =
|
|
23
|
+
forward_env: list[str] = []
|
|
21
24
|
"""Environment variables to forward to the container."""
|
|
22
25
|
timeout: int = 30
|
|
23
26
|
"""Timeout for executing commands in the container."""
|
|
@@ -57,11 +60,22 @@ class SingularityEnvironment:
|
|
|
57
60
|
raise
|
|
58
61
|
return sandbox_dir
|
|
59
62
|
|
|
60
|
-
def get_template_vars(self) -> dict[str, Any]:
|
|
61
|
-
return
|
|
63
|
+
def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
|
64
|
+
return recursive_merge(self.config.model_dump(), kwargs)
|
|
62
65
|
|
|
63
|
-
def
|
|
66
|
+
def serialize(self) -> dict:
|
|
67
|
+
return {
|
|
68
|
+
"info": {
|
|
69
|
+
"config": {
|
|
70
|
+
"environment": self.config.model_dump(mode="json"),
|
|
71
|
+
"environment_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
def execute(self, action: dict, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
|
|
64
77
|
"""Execute a command in a Singularity container and return the result as a dict."""
|
|
78
|
+
command = action.get("command", "")
|
|
65
79
|
cmd = [self.config.executable, "exec"]
|
|
66
80
|
|
|
67
81
|
# Do not inherit directories and env vars from host
|
|
@@ -78,16 +92,43 @@ class SingularityEnvironment:
|
|
|
78
92
|
cmd.extend(["--env", f"{key}={value}"])
|
|
79
93
|
|
|
80
94
|
cmd.extend(["--writable", str(self.sandbox_dir), "bash", "-c", command])
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
95
|
+
try:
|
|
96
|
+
result = subprocess.run(
|
|
97
|
+
cmd,
|
|
98
|
+
text=True,
|
|
99
|
+
timeout=timeout or self.config.timeout,
|
|
100
|
+
encoding="utf-8",
|
|
101
|
+
errors="replace",
|
|
102
|
+
stdout=subprocess.PIPE,
|
|
103
|
+
stderr=subprocess.STDOUT,
|
|
104
|
+
)
|
|
105
|
+
output = {"output": result.stdout, "returncode": result.returncode, "exception_info": ""}
|
|
106
|
+
except Exception as e:
|
|
107
|
+
raw_output = getattr(e, "output", None)
|
|
108
|
+
raw_output = (
|
|
109
|
+
raw_output.decode("utf-8", errors="replace") if isinstance(raw_output, bytes) else (raw_output or "")
|
|
110
|
+
)
|
|
111
|
+
output = {
|
|
112
|
+
"output": raw_output,
|
|
113
|
+
"returncode": -1,
|
|
114
|
+
"exception_info": f"An error occurred while executing the command: {e}",
|
|
115
|
+
"extra": {"exception_type": type(e).__name__, "exception": str(e)},
|
|
116
|
+
}
|
|
117
|
+
self._check_finished(output)
|
|
118
|
+
return output
|
|
119
|
+
|
|
120
|
+
def _check_finished(self, output: dict):
|
|
121
|
+
"""Raises Submitted if the output indicates task completion."""
|
|
122
|
+
lines = output.get("output", "").lstrip().splitlines(keepends=True)
|
|
123
|
+
if lines and lines[0].strip() == "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT" and output["returncode"] == 0:
|
|
124
|
+
submission = "".join(lines[1:])
|
|
125
|
+
raise Submitted(
|
|
126
|
+
{
|
|
127
|
+
"role": "exit",
|
|
128
|
+
"content": submission,
|
|
129
|
+
"extra": {"exit_status": "Submitted", "submission": submission},
|
|
130
|
+
}
|
|
131
|
+
)
|
|
91
132
|
|
|
92
133
|
def cleanup(self):
|
|
93
134
|
shutil.rmtree(self.sandbox_dir, ignore_errors=True)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
class InterruptAgentFlow(Exception):
|
|
2
|
+
"""Raised to interrupt the agent flow and add messages."""
|
|
3
|
+
|
|
4
|
+
def __init__(self, *messages: dict):
|
|
5
|
+
self.messages = messages
|
|
6
|
+
super().__init__()
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Submitted(InterruptAgentFlow):
|
|
10
|
+
"""Raised when the agent has completed its task."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LimitsExceeded(InterruptAgentFlow):
|
|
14
|
+
"""Raised when the agent has exceeded its cost or step limit."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UserInterruption(InterruptAgentFlow):
|
|
18
|
+
"""Raised when the user interrupts the agent."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FormatError(InterruptAgentFlow):
|
|
22
|
+
"""Raised when the LM's output is not in the expected format."""
|
minisweagent/models/__init__.py
CHANGED
|
@@ -28,7 +28,7 @@ class GlobalModelStats:
|
|
|
28
28
|
self._cost += cost
|
|
29
29
|
self._n_calls += 1
|
|
30
30
|
if 0 < self.cost_limit < self._cost or 0 < self.call_limit < self._n_calls + 1:
|
|
31
|
-
raise RuntimeError(f"Global cost/call limit exceeded: ${self._cost:.4f} / {self._n_calls
|
|
31
|
+
raise RuntimeError(f"Global cost/call limit exceeded: ${self._cost:.4f} / {self._n_calls}")
|
|
32
32
|
|
|
33
33
|
@property
|
|
34
34
|
def cost(self) -> float:
|
|
@@ -52,9 +52,6 @@ def get_model(input_model_name: str | None = None, config: dict | None = None) -
|
|
|
52
52
|
|
|
53
53
|
model_class = get_model_class(resolved_model_name, config.pop("model_class", ""))
|
|
54
54
|
|
|
55
|
-
if (from_env := os.getenv("MSWEA_MODEL_API_KEY")) and not str(type(model_class)).endswith("DeterministicModel"):
|
|
56
|
-
config.setdefault("model_kwargs", {})["api_key"] = from_env
|
|
57
|
-
|
|
58
55
|
if (
|
|
59
56
|
any(s in resolved_model_name.lower() for s in ["anthropic", "sonnet", "opus", "claude"])
|
|
60
57
|
and "set_cache_control" not in config
|
|
@@ -79,12 +76,14 @@ def get_model_name(input_model_name: str | None = None, config: dict | None = No
|
|
|
79
76
|
|
|
80
77
|
|
|
81
78
|
_MODEL_CLASS_MAPPING = {
|
|
82
|
-
"anthropic": "minisweagent.models.anthropic.AnthropicModel",
|
|
83
79
|
"litellm": "minisweagent.models.litellm_model.LitellmModel",
|
|
84
|
-
"
|
|
80
|
+
"litellm_textbased": "minisweagent.models.litellm_textbased_model.LitellmTextbasedModel",
|
|
81
|
+
"litellm_response": "minisweagent.models.litellm_response_model.LitellmResponseModel",
|
|
85
82
|
"openrouter": "minisweagent.models.openrouter_model.OpenRouterModel",
|
|
83
|
+
"openrouter_textbased": "minisweagent.models.openrouter_textbased_model.OpenRouterTextbasedModel",
|
|
84
|
+
"openrouter_response": "minisweagent.models.openrouter_response_model.OpenRouterResponseModel",
|
|
86
85
|
"portkey": "minisweagent.models.portkey_model.PortkeyModel",
|
|
87
|
-
"portkey_response": "minisweagent.models.
|
|
86
|
+
"portkey_response": "minisweagent.models.portkey_response_model.PortkeyResponseAPIModel",
|
|
88
87
|
"requesty": "minisweagent.models.requesty_model.RequestyModel",
|
|
89
88
|
"deterministic": "minisweagent.models.test_models.DeterministicModel",
|
|
90
89
|
}
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import random
|
|
2
|
-
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
3
4
|
|
|
4
5
|
from minisweagent import Model
|
|
5
6
|
from minisweagent.models import get_model
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
class RouletteModelConfig:
|
|
9
|
+
class RouletteModelConfig(BaseModel):
|
|
10
10
|
model_kwargs: list[dict]
|
|
11
11
|
"""The models to choose from"""
|
|
12
12
|
model_name: str = "roulette"
|
|
@@ -17,30 +17,33 @@ class RouletteModel:
|
|
|
17
17
|
"""This "meta"-model randomly selects one of the models at every call"""
|
|
18
18
|
self.config = config_class(**kwargs)
|
|
19
19
|
self.models = [get_model(config=config) for config in self.config.model_kwargs]
|
|
20
|
+
self._n_calls = 0
|
|
20
21
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
return sum(model.cost for model in self.models)
|
|
24
|
-
|
|
25
|
-
@property
|
|
26
|
-
def n_calls(self) -> int:
|
|
27
|
-
return sum(model.n_calls for model in self.models)
|
|
28
|
-
|
|
29
|
-
def get_template_vars(self) -> dict:
|
|
30
|
-
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
|
|
22
|
+
def get_template_vars(self, **kwargs) -> dict:
|
|
23
|
+
return self.config.model_dump()
|
|
31
24
|
|
|
32
25
|
def select_model(self) -> Model:
|
|
33
26
|
return random.choice(self.models)
|
|
34
27
|
|
|
35
28
|
def query(self, *args, **kwargs) -> dict:
|
|
36
29
|
model = self.select_model()
|
|
30
|
+
self._n_calls += 1
|
|
37
31
|
response = model.query(*args, **kwargs)
|
|
38
32
|
response["model_name"] = model.config.model_name
|
|
39
33
|
return response
|
|
40
34
|
|
|
35
|
+
def serialize(self) -> dict:
|
|
36
|
+
return {
|
|
37
|
+
"info": {
|
|
38
|
+
"config": {
|
|
39
|
+
"model": self.config.model_dump(mode="json"),
|
|
40
|
+
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
41
|
+
},
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
|
|
41
45
|
|
|
42
|
-
|
|
43
|
-
class InterleavingModelConfig:
|
|
46
|
+
class InterleavingModelConfig(BaseModel):
|
|
44
47
|
model_kwargs: list[dict]
|
|
45
48
|
sequence: list[int] | None = None
|
|
46
49
|
"""If set to 0, 0, 1, we will return the first model 2 times, then the second model 1 time,
|
|
@@ -55,7 +58,7 @@ class InterleavingModel(RouletteModel):
|
|
|
55
58
|
|
|
56
59
|
def select_model(self) -> Model:
|
|
57
60
|
if self.config.sequence is None:
|
|
58
|
-
i_model = self.
|
|
61
|
+
i_model = self._n_calls % len(self.models)
|
|
59
62
|
else:
|
|
60
|
-
i_model = self.config.sequence[self.
|
|
63
|
+
i_model = self.config.sequence[self._n_calls % len(self.config.sequence)]
|
|
61
64
|
return self.models[i_model]
|
|
@@ -1,75 +1,98 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import time
|
|
4
5
|
from collections.abc import Callable
|
|
5
|
-
from dataclasses import asdict, dataclass, field
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Any, Literal
|
|
8
8
|
|
|
9
9
|
import litellm
|
|
10
|
-
from
|
|
11
|
-
before_sleep_log,
|
|
12
|
-
retry,
|
|
13
|
-
retry_if_not_exception_type,
|
|
14
|
-
stop_after_attempt,
|
|
15
|
-
wait_exponential,
|
|
16
|
-
)
|
|
10
|
+
from pydantic import BaseModel
|
|
17
11
|
|
|
18
12
|
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
13
|
+
from minisweagent.models.utils.actions_toolcall import (
|
|
14
|
+
BASH_TOOL,
|
|
15
|
+
format_toolcall_observation_messages,
|
|
16
|
+
parse_toolcall_actions,
|
|
17
|
+
)
|
|
18
|
+
from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
|
|
19
19
|
from minisweagent.models.utils.cache_control import set_cache_control
|
|
20
|
+
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
|
|
21
|
+
from minisweagent.models.utils.retry import retry
|
|
20
22
|
|
|
21
23
|
logger = logging.getLogger("litellm_model")
|
|
22
24
|
|
|
23
25
|
|
|
24
|
-
|
|
25
|
-
class LitellmModelConfig:
|
|
26
|
+
class LitellmModelConfig(BaseModel):
|
|
26
27
|
model_name: str
|
|
27
|
-
|
|
28
|
+
"""Model name. Highly recommended to include the provider in the model name, e.g., `anthropic/claude-sonnet-4-5-20250929`."""
|
|
29
|
+
model_kwargs: dict[str, Any] = {}
|
|
30
|
+
"""Additional arguments passed to the API."""
|
|
28
31
|
litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
|
|
32
|
+
"""Model registry for cost tracking and model metadata. See the local model guide (https://mini-swe-agent.com/latest/models/local_models/) for more details."""
|
|
29
33
|
set_cache_control: Literal["default_end"] | None = None
|
|
30
34
|
"""Set explicit cache control markers, for example for Anthropic models"""
|
|
31
35
|
cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
|
|
32
36
|
"""Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
|
|
37
|
+
format_error_template: str = "{{ error }}"
|
|
38
|
+
"""Template used when the LM's output is not in the expected format."""
|
|
39
|
+
observation_template: str = (
|
|
40
|
+
"{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
|
41
|
+
"<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
|
42
|
+
)
|
|
43
|
+
"""Template used to render the observation after executing an action."""
|
|
44
|
+
multimodal_regex: str = ""
|
|
45
|
+
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
|
|
33
46
|
|
|
34
47
|
|
|
35
48
|
class LitellmModel:
|
|
49
|
+
abort_exceptions: list[type[Exception]] = [
|
|
50
|
+
litellm.exceptions.UnsupportedParamsError,
|
|
51
|
+
litellm.exceptions.NotFoundError,
|
|
52
|
+
litellm.exceptions.PermissionDeniedError,
|
|
53
|
+
litellm.exceptions.ContextWindowExceededError,
|
|
54
|
+
litellm.exceptions.AuthenticationError,
|
|
55
|
+
KeyboardInterrupt,
|
|
56
|
+
]
|
|
57
|
+
|
|
36
58
|
def __init__(self, *, config_class: Callable = LitellmModelConfig, **kwargs):
|
|
37
59
|
self.config = config_class(**kwargs)
|
|
38
|
-
self.cost = 0.0
|
|
39
|
-
self.n_calls = 0
|
|
40
60
|
if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
|
|
41
61
|
litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
|
|
42
62
|
|
|
43
|
-
@retry(
|
|
44
|
-
reraise=True,
|
|
45
|
-
stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
|
|
46
|
-
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
47
|
-
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
48
|
-
retry=retry_if_not_exception_type(
|
|
49
|
-
(
|
|
50
|
-
litellm.exceptions.UnsupportedParamsError,
|
|
51
|
-
litellm.exceptions.NotFoundError,
|
|
52
|
-
litellm.exceptions.PermissionDeniedError,
|
|
53
|
-
litellm.exceptions.ContextWindowExceededError,
|
|
54
|
-
litellm.exceptions.APIError,
|
|
55
|
-
litellm.exceptions.AuthenticationError,
|
|
56
|
-
KeyboardInterrupt,
|
|
57
|
-
)
|
|
58
|
-
),
|
|
59
|
-
)
|
|
60
63
|
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
61
64
|
try:
|
|
62
65
|
return litellm.completion(
|
|
63
|
-
model=self.config.model_name,
|
|
66
|
+
model=self.config.model_name,
|
|
67
|
+
messages=messages,
|
|
68
|
+
tools=[BASH_TOOL],
|
|
69
|
+
**(self.config.model_kwargs | kwargs),
|
|
64
70
|
)
|
|
65
71
|
except litellm.exceptions.AuthenticationError as e:
|
|
66
72
|
e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
|
|
67
73
|
raise e
|
|
68
74
|
|
|
75
|
+
def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
|
76
|
+
prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
|
|
77
|
+
prepared = _reorder_anthropic_thinking_blocks(prepared)
|
|
78
|
+
return set_cache_control(prepared, mode=self.config.set_cache_control)
|
|
79
|
+
|
|
69
80
|
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
81
|
+
for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
|
82
|
+
with attempt:
|
|
83
|
+
response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
|
84
|
+
cost_output = self._calculate_cost(response)
|
|
85
|
+
GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
|
86
|
+
message = response.choices[0].message.model_dump()
|
|
87
|
+
message["extra"] = {
|
|
88
|
+
"actions": self._parse_actions(response),
|
|
89
|
+
"response": response.model_dump(),
|
|
90
|
+
**cost_output,
|
|
91
|
+
"timestamp": time.time(),
|
|
92
|
+
}
|
|
93
|
+
return message
|
|
94
|
+
|
|
95
|
+
def _calculate_cost(self, response) -> dict[str, float]:
|
|
73
96
|
try:
|
|
74
97
|
cost = litellm.cost_calculator.completion_cost(response, model=self.config.model_name)
|
|
75
98
|
if cost <= 0.0:
|
|
@@ -87,15 +110,38 @@ class LitellmModel:
|
|
|
87
110
|
)
|
|
88
111
|
logger.critical(msg)
|
|
89
112
|
raise RuntimeError(msg) from e
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
113
|
+
return {"cost": cost}
|
|
114
|
+
|
|
115
|
+
def _parse_actions(self, response) -> list[dict]:
|
|
116
|
+
"""Parse tool calls from the response. Raises FormatError if unknown tool."""
|
|
117
|
+
tool_calls = response.choices[0].message.tool_calls or []
|
|
118
|
+
return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
|
|
119
|
+
|
|
120
|
+
def format_message(self, **kwargs) -> dict:
|
|
121
|
+
return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
|
122
|
+
|
|
123
|
+
def format_observation_messages(
|
|
124
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
125
|
+
) -> list[dict]:
|
|
126
|
+
"""Format execution outputs into tool result messages."""
|
|
127
|
+
actions = message.get("extra", {}).get("actions", [])
|
|
128
|
+
return format_toolcall_observation_messages(
|
|
129
|
+
actions=actions,
|
|
130
|
+
outputs=outputs,
|
|
131
|
+
observation_template=self.config.observation_template,
|
|
132
|
+
template_vars=template_vars,
|
|
133
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
|
137
|
+
return self.config.model_dump()
|
|
138
|
+
|
|
139
|
+
def serialize(self) -> dict:
|
|
93
140
|
return {
|
|
94
|
-
"
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
141
|
+
"info": {
|
|
142
|
+
"config": {
|
|
143
|
+
"model": self.config.model_dump(mode="json"),
|
|
144
|
+
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
145
|
+
},
|
|
146
|
+
}
|
|
98
147
|
}
|
|
99
|
-
|
|
100
|
-
def get_template_vars(self) -> dict[str, Any]:
|
|
101
|
-
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
import litellm
|
|
6
|
+
|
|
7
|
+
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
8
|
+
from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
|
|
9
|
+
from minisweagent.models.utils.actions_toolcall_response import (
|
|
10
|
+
BASH_TOOL_RESPONSE_API,
|
|
11
|
+
format_toolcall_observation_messages,
|
|
12
|
+
parse_toolcall_actions_response,
|
|
13
|
+
)
|
|
14
|
+
from minisweagent.models.utils.retry import retry
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger("litellm_response_model")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LitellmResponseModelConfig(LitellmModelConfig):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LitellmResponseModel(LitellmModel):
|
|
24
|
+
def __init__(self, *, config_class: Callable = LitellmResponseModelConfig, **kwargs):
|
|
25
|
+
super().__init__(config_class=config_class, **kwargs)
|
|
26
|
+
|
|
27
|
+
def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
|
28
|
+
"""Flatten response objects into their output items for stateless API calls."""
|
|
29
|
+
result = []
|
|
30
|
+
for msg in messages:
|
|
31
|
+
if msg.get("object") == "response":
|
|
32
|
+
for item in msg.get("output", []):
|
|
33
|
+
result.append({k: v for k, v in item.items() if k != "extra"})
|
|
34
|
+
else:
|
|
35
|
+
result.append({k: v for k, v in msg.items() if k != "extra"})
|
|
36
|
+
return result
|
|
37
|
+
|
|
38
|
+
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
39
|
+
try:
|
|
40
|
+
return litellm.responses(
|
|
41
|
+
model=self.config.model_name,
|
|
42
|
+
input=messages,
|
|
43
|
+
tools=[BASH_TOOL_RESPONSE_API],
|
|
44
|
+
**(self.config.model_kwargs | kwargs),
|
|
45
|
+
)
|
|
46
|
+
except litellm.exceptions.AuthenticationError as e:
|
|
47
|
+
e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
|
|
48
|
+
raise e
|
|
49
|
+
|
|
50
|
+
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
51
|
+
for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
|
52
|
+
with attempt:
|
|
53
|
+
response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
|
54
|
+
cost_output = self._calculate_cost(response)
|
|
55
|
+
GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
|
56
|
+
message = response.model_dump() if hasattr(response, "model_dump") else dict(response)
|
|
57
|
+
message["extra"] = {
|
|
58
|
+
"actions": self._parse_actions(response),
|
|
59
|
+
**cost_output,
|
|
60
|
+
"timestamp": time.time(),
|
|
61
|
+
}
|
|
62
|
+
return message
|
|
63
|
+
|
|
64
|
+
def _parse_actions(self, response) -> list[dict]:
|
|
65
|
+
return parse_toolcall_actions_response(
|
|
66
|
+
getattr(response, "output", []), format_error_template=self.config.format_error_template
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def format_observation_messages(
|
|
70
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
71
|
+
) -> list[dict]:
|
|
72
|
+
"""Format execution outputs into tool result messages."""
|
|
73
|
+
actions = message.get("extra", {}).get("actions", [])
|
|
74
|
+
return format_toolcall_observation_messages(
|
|
75
|
+
actions=actions,
|
|
76
|
+
outputs=outputs,
|
|
77
|
+
observation_template=self.config.observation_template,
|
|
78
|
+
template_vars=template_vars,
|
|
79
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
80
|
+
)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import litellm
|
|
2
|
+
|
|
3
|
+
from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
|
|
4
|
+
from minisweagent.models.utils.actions_text import format_observation_messages, parse_regex_actions
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LitellmTextbasedModelConfig(LitellmModelConfig):
|
|
8
|
+
action_regex: str = r"```mswea_bash_command\s*\n(.*?)\n```"
|
|
9
|
+
"""Regex to extract the action from the LM's output."""
|
|
10
|
+
format_error_template: str = (
|
|
11
|
+
"Please always provide EXACTLY ONE action in triple backticks, found {{actions|length}} actions."
|
|
12
|
+
)
|
|
13
|
+
"""Template used when the LM's output is not in the expected format."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LitellmTextbasedModel(LitellmModel):
|
|
17
|
+
def __init__(self, **kwargs):
|
|
18
|
+
super().__init__(config_class=LitellmTextbasedModelConfig, **kwargs)
|
|
19
|
+
|
|
20
|
+
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
21
|
+
try:
|
|
22
|
+
return litellm.completion(
|
|
23
|
+
model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
|
|
24
|
+
)
|
|
25
|
+
except litellm.exceptions.AuthenticationError as e:
|
|
26
|
+
e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
|
|
27
|
+
raise e
|
|
28
|
+
|
|
29
|
+
def _parse_actions(self, response: dict) -> list[dict]:
|
|
30
|
+
"""Parse actions from the model response. Raises FormatError if not exactly one action."""
|
|
31
|
+
content = response.choices[0].message.content or ""
|
|
32
|
+
return parse_regex_actions(
|
|
33
|
+
content, action_regex=self.config.action_regex, format_error_template=self.config.format_error_template
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def format_observation_messages(
|
|
37
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
38
|
+
) -> list[dict]:
|
|
39
|
+
"""Format execution outputs into observation messages."""
|
|
40
|
+
return format_observation_messages(
|
|
41
|
+
outputs,
|
|
42
|
+
observation_template=self.config.observation_template,
|
|
43
|
+
template_vars=template_vars,
|
|
44
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
45
|
+
)
|