mini-swe-agent 1.17.5__py3-none-any.whl → 2.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mini_swe_agent-1.17.5.dist-info → mini_swe_agent-2.0.0a1.dist-info}/METADATA +36 -52
- mini_swe_agent-2.0.0a1.dist-info/RECORD +70 -0
- mini_swe_agent-2.0.0a1.dist-info/entry_points.txt +5 -0
- minisweagent/__init__.py +19 -26
- minisweagent/agents/default.py +128 -113
- minisweagent/agents/interactive.py +119 -58
- minisweagent/config/README.md +3 -4
- minisweagent/config/__init__.py +36 -1
- minisweagent/config/benchmarks/swebench.yaml +156 -0
- minisweagent/config/{extra/swebench.yaml → benchmarks/swebench_backticks.yaml} +69 -64
- minisweagent/config/benchmarks/swebench_modal.yaml +47 -0
- minisweagent/config/{extra → benchmarks}/swebench_xml.yaml +73 -70
- minisweagent/config/default.yaml +24 -21
- minisweagent/config/inspector.tcss +42 -0
- minisweagent/config/mini.yaml +53 -71
- minisweagent/config/{github_issue.yaml → mini_textbased.yaml} +43 -29
- minisweagent/environments/__init__.py +1 -0
- minisweagent/environments/docker.py +67 -20
- minisweagent/environments/extra/bubblewrap.py +86 -47
- minisweagent/environments/extra/swerex_docker.py +53 -20
- minisweagent/environments/extra/swerex_modal.py +90 -0
- minisweagent/environments/local.py +62 -21
- minisweagent/environments/singularity.py +59 -18
- minisweagent/exceptions.py +22 -0
- minisweagent/models/__init__.py +6 -7
- minisweagent/models/extra/roulette.py +20 -17
- minisweagent/models/litellm_model.py +90 -44
- minisweagent/models/litellm_response_model.py +80 -0
- minisweagent/models/litellm_textbased_model.py +45 -0
- minisweagent/models/openrouter_model.py +87 -45
- minisweagent/models/openrouter_response_model.py +123 -0
- minisweagent/models/openrouter_textbased_model.py +76 -0
- minisweagent/models/portkey_model.py +84 -42
- minisweagent/models/portkey_response_model.py +163 -0
- minisweagent/models/requesty_model.py +91 -41
- minisweagent/models/test_models.py +246 -19
- minisweagent/models/utils/actions_text.py +60 -0
- minisweagent/models/utils/actions_toolcall.py +102 -0
- minisweagent/models/utils/actions_toolcall_response.py +110 -0
- minisweagent/models/utils/anthropic_utils.py +28 -0
- minisweagent/models/utils/cache_control.py +15 -2
- minisweagent/models/utils/content_string.py +74 -0
- minisweagent/models/utils/openai_multimodal.py +50 -0
- minisweagent/models/utils/retry.py +25 -0
- minisweagent/run/benchmarks/__init__.py +1 -0
- minisweagent/run/{extra → benchmarks}/swebench.py +56 -35
- minisweagent/run/{extra → benchmarks}/swebench_single.py +36 -26
- minisweagent/run/{extra → benchmarks}/utils/batch_progress.py +1 -1
- minisweagent/run/hello_world.py +6 -0
- minisweagent/run/mini.py +54 -63
- minisweagent/run/utilities/__init__.py +1 -0
- minisweagent/run/{extra → utilities}/config.py +2 -0
- minisweagent/run/{inspector.py → utilities/inspector.py} +90 -11
- minisweagent/run/{mini_extra.py → utilities/mini_extra.py} +9 -5
- minisweagent/utils/serialize.py +26 -0
- mini_swe_agent-1.17.5.dist-info/RECORD +0 -61
- mini_swe_agent-1.17.5.dist-info/entry_points.txt +0 -5
- minisweagent/agents/interactive_textual.py +0 -450
- minisweagent/config/extra/swebench_roulette.yaml +0 -233
- minisweagent/config/mini.tcss +0 -86
- minisweagent/models/anthropic.py +0 -35
- minisweagent/models/litellm_response_api_model.py +0 -82
- minisweagent/models/portkey_response_api_model.py +0 -75
- minisweagent/models/utils/key_per_thread.py +0 -20
- minisweagent/models/utils/openai_utils.py +0 -41
- minisweagent/run/github_issue.py +0 -87
- minisweagent/run/utils/__init__.py +0 -0
- minisweagent/run/utils/save.py +0 -78
- {mini_swe_agent-1.17.5.dist-info → mini_swe_agent-2.0.0a1.dist-info}/WHEEL +0 -0
- {mini_swe_agent-1.17.5.dist-info → mini_swe_agent-2.0.0a1.dist-info}/licenses/LICENSE.md +0 -0
- {mini_swe_agent-1.17.5.dist-info → mini_swe_agent-2.0.0a1.dist-info}/top_level.txt +0 -0
- /minisweagent/config/{extra → benchmarks}/__init__.py +0 -0
- /minisweagent/run/{extra → benchmarks}/utils/__init__.py +0 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
import litellm
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
12
|
+
from minisweagent.models.utils.actions_toolcall_response import (
|
|
13
|
+
BASH_TOOL_RESPONSE_API,
|
|
14
|
+
format_toolcall_observation_messages,
|
|
15
|
+
parse_toolcall_actions_response,
|
|
16
|
+
)
|
|
17
|
+
from minisweagent.models.utils.retry import retry
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("portkey_response_model")
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from portkey_ai import Portkey
|
|
23
|
+
except ImportError:
|
|
24
|
+
raise ImportError(
|
|
25
|
+
"The portkey-ai package is required to use PortkeyResponseAPIModel. Please install it with: pip install portkey-ai"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PortkeyResponseAPIModelConfig(BaseModel):
|
|
30
|
+
model_name: str
|
|
31
|
+
model_kwargs: dict[str, Any] = {}
|
|
32
|
+
litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
|
|
33
|
+
litellm_model_name_override: str = ""
|
|
34
|
+
cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
|
|
35
|
+
format_error_template: str = "{{ error }}"
|
|
36
|
+
observation_template: str = (
|
|
37
|
+
"{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
|
38
|
+
"<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
|
39
|
+
)
|
|
40
|
+
multimodal_regex: str = ""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class PortkeyResponseAPIModel:
|
|
44
|
+
"""Portkey model using the Responses API with native tool calling.
|
|
45
|
+
|
|
46
|
+
Note: This implementation is stateless - each request must include
|
|
47
|
+
the full conversation history. previous_response_id is not used.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
abort_exceptions: list[type[Exception]] = [KeyboardInterrupt, TypeError, ValueError]
|
|
51
|
+
|
|
52
|
+
def __init__(self, **kwargs):
|
|
53
|
+
self.config = PortkeyResponseAPIModelConfig(**kwargs)
|
|
54
|
+
if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
|
|
55
|
+
litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
|
|
56
|
+
|
|
57
|
+
self._api_key = os.getenv("PORTKEY_API_KEY")
|
|
58
|
+
if not self._api_key:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Portkey API key is required. Set it via the "
|
|
61
|
+
"PORTKEY_API_KEY environment variable. You can permanently set it with "
|
|
62
|
+
"`mini-extra config set PORTKEY_API_KEY YOUR_KEY`."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY")
|
|
66
|
+
client_kwargs = {"api_key": self._api_key}
|
|
67
|
+
if virtual_key:
|
|
68
|
+
client_kwargs["virtual_key"] = virtual_key
|
|
69
|
+
|
|
70
|
+
self.client = Portkey(**client_kwargs)
|
|
71
|
+
|
|
72
|
+
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
73
|
+
return self.client.responses.create(
|
|
74
|
+
model=self.config.model_name,
|
|
75
|
+
input=messages,
|
|
76
|
+
tools=[BASH_TOOL_RESPONSE_API],
|
|
77
|
+
**(self.config.model_kwargs | kwargs),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
|
81
|
+
"""Prepare messages for Portkey's stateless Responses API.
|
|
82
|
+
|
|
83
|
+
Flattens response objects into their output items.
|
|
84
|
+
"""
|
|
85
|
+
result = []
|
|
86
|
+
for msg in messages:
|
|
87
|
+
if msg.get("object") == "response":
|
|
88
|
+
for item in msg.get("output", []):
|
|
89
|
+
result.append({k: v for k, v in item.items() if k != "extra"})
|
|
90
|
+
else:
|
|
91
|
+
result.append({k: v for k, v in msg.items() if k != "extra"})
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
95
|
+
for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
|
96
|
+
with attempt:
|
|
97
|
+
response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
|
98
|
+
cost_output = self._calculate_cost(response)
|
|
99
|
+
GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
|
100
|
+
message = response.model_dump() if hasattr(response, "model_dump") else dict(response)
|
|
101
|
+
message["extra"] = {
|
|
102
|
+
"actions": self._parse_actions(response),
|
|
103
|
+
**cost_output,
|
|
104
|
+
"timestamp": time.time(),
|
|
105
|
+
}
|
|
106
|
+
return message
|
|
107
|
+
|
|
108
|
+
def _parse_actions(self, response) -> list[dict]:
|
|
109
|
+
"""Parse tool calls from the response API response."""
|
|
110
|
+
output = response.output if hasattr(response, "output") else response.get("output", [])
|
|
111
|
+
return parse_toolcall_actions_response(output, format_error_template=self.config.format_error_template)
|
|
112
|
+
|
|
113
|
+
def _calculate_cost(self, response) -> dict[str, float]:
|
|
114
|
+
try:
|
|
115
|
+
cost = litellm.cost_calculator.completion_cost(
|
|
116
|
+
response, model=self.config.litellm_model_name_override or self.config.model_name
|
|
117
|
+
)
|
|
118
|
+
assert cost > 0.0, f"Cost is not positive: {cost}"
|
|
119
|
+
except Exception as e:
|
|
120
|
+
if self.config.cost_tracking != "ignore_errors":
|
|
121
|
+
raise RuntimeError(
|
|
122
|
+
f"Error calculating cost for model {self.config.model_name}: {e}. "
|
|
123
|
+
"You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
|
|
124
|
+
"globally with export MSWEA_COST_TRACKING='ignore_errors' to ignore this error. "
|
|
125
|
+
) from e
|
|
126
|
+
cost = 0.0
|
|
127
|
+
return {"cost": cost}
|
|
128
|
+
|
|
129
|
+
def format_message(self, **kwargs) -> dict:
|
|
130
|
+
role = kwargs.get("role", "user")
|
|
131
|
+
content = kwargs.get("content", "")
|
|
132
|
+
extra = kwargs.get("extra")
|
|
133
|
+
content_items = [{"type": "input_text", "text": content}] if isinstance(content, str) else content
|
|
134
|
+
msg = {"type": "message", "role": role, "content": content_items}
|
|
135
|
+
if extra:
|
|
136
|
+
msg["extra"] = extra
|
|
137
|
+
return msg
|
|
138
|
+
|
|
139
|
+
def format_observation_messages(
|
|
140
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
141
|
+
) -> list[dict]:
|
|
142
|
+
"""Format execution outputs into tool result messages."""
|
|
143
|
+
actions = message.get("extra", {}).get("actions", [])
|
|
144
|
+
return format_toolcall_observation_messages(
|
|
145
|
+
actions=actions,
|
|
146
|
+
outputs=outputs,
|
|
147
|
+
observation_template=self.config.observation_template,
|
|
148
|
+
template_vars=template_vars,
|
|
149
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def get_template_vars(self, **kwargs) -> dict:
|
|
153
|
+
return self.config.model_dump()
|
|
154
|
+
|
|
155
|
+
def serialize(self) -> dict:
|
|
156
|
+
return {
|
|
157
|
+
"info": {
|
|
158
|
+
"config": {
|
|
159
|
+
"model": self.config.model_dump(mode="json"),
|
|
160
|
+
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
161
|
+
},
|
|
162
|
+
}
|
|
163
|
+
}
|
|
@@ -1,27 +1,40 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
|
|
5
|
-
from typing import Any
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any, Literal
|
|
6
6
|
|
|
7
7
|
import requests
|
|
8
|
-
from
|
|
9
|
-
before_sleep_log,
|
|
10
|
-
retry,
|
|
11
|
-
retry_if_not_exception_type,
|
|
12
|
-
stop_after_attempt,
|
|
13
|
-
wait_exponential,
|
|
14
|
-
)
|
|
8
|
+
from pydantic import BaseModel
|
|
15
9
|
|
|
16
10
|
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
11
|
+
from minisweagent.models.utils.actions_toolcall import (
|
|
12
|
+
BASH_TOOL,
|
|
13
|
+
format_toolcall_observation_messages,
|
|
14
|
+
parse_toolcall_actions,
|
|
15
|
+
)
|
|
16
|
+
from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
|
|
17
|
+
from minisweagent.models.utils.cache_control import set_cache_control
|
|
18
|
+
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
|
|
19
|
+
from minisweagent.models.utils.retry import retry
|
|
17
20
|
|
|
18
21
|
logger = logging.getLogger("requesty_model")
|
|
19
22
|
|
|
20
23
|
|
|
21
|
-
|
|
22
|
-
class RequestyModelConfig:
|
|
24
|
+
class RequestyModelConfig(BaseModel):
|
|
23
25
|
model_name: str
|
|
24
|
-
model_kwargs: dict[str, Any] =
|
|
26
|
+
model_kwargs: dict[str, Any] = {}
|
|
27
|
+
set_cache_control: Literal["default_end"] | None = None
|
|
28
|
+
"""Set explicit cache control markers, for example for Anthropic models"""
|
|
29
|
+
format_error_template: str = "{{ error }}"
|
|
30
|
+
"""Template used when the LM's output is not in the expected format."""
|
|
31
|
+
observation_template: str = (
|
|
32
|
+
"{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
|
33
|
+
"<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
|
34
|
+
)
|
|
35
|
+
"""Template used to render the observation after executing an action."""
|
|
36
|
+
multimodal_regex: str = ""
|
|
37
|
+
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
|
|
25
38
|
|
|
26
39
|
|
|
27
40
|
class RequestyAPIError(Exception):
|
|
@@ -43,25 +56,13 @@ class RequestyRateLimitError(Exception):
|
|
|
43
56
|
|
|
44
57
|
|
|
45
58
|
class RequestyModel:
|
|
59
|
+
abort_exceptions: list[type[Exception]] = [RequestyAuthenticationError, KeyboardInterrupt]
|
|
60
|
+
|
|
46
61
|
def __init__(self, **kwargs):
|
|
47
62
|
self.config = RequestyModelConfig(**kwargs)
|
|
48
|
-
self.cost = 0.0
|
|
49
|
-
self.n_calls = 0
|
|
50
63
|
self._api_url = "https://router.requesty.ai/v1/chat/completions"
|
|
51
64
|
self._api_key = os.getenv("REQUESTY_API_KEY", "")
|
|
52
65
|
|
|
53
|
-
@retry(
|
|
54
|
-
reraise=True,
|
|
55
|
-
stop=stop_after_attempt(10),
|
|
56
|
-
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
57
|
-
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
58
|
-
retry=retry_if_not_exception_type(
|
|
59
|
-
(
|
|
60
|
-
RequestyAuthenticationError,
|
|
61
|
-
KeyboardInterrupt,
|
|
62
|
-
)
|
|
63
|
-
),
|
|
64
|
-
)
|
|
65
66
|
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
66
67
|
headers = {
|
|
67
68
|
"Authorization": f"Bearer {self._api_key}",
|
|
@@ -73,6 +74,7 @@ class RequestyModel:
|
|
|
73
74
|
payload = {
|
|
74
75
|
"model": self.config.model_name,
|
|
75
76
|
"messages": messages,
|
|
77
|
+
"tools": [BASH_TOOL],
|
|
76
78
|
**(self.config.model_kwargs | kwargs),
|
|
77
79
|
}
|
|
78
80
|
|
|
@@ -91,30 +93,78 @@ class RequestyModel:
|
|
|
91
93
|
except requests.exceptions.RequestException as e:
|
|
92
94
|
raise RequestyAPIError(f"Request failed: {e}") from e
|
|
93
95
|
|
|
96
|
+
def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
|
97
|
+
prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
|
|
98
|
+
prepared = _reorder_anthropic_thinking_blocks(prepared)
|
|
99
|
+
return set_cache_control(prepared, mode=self.config.set_cache_control)
|
|
100
|
+
|
|
94
101
|
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
95
|
-
|
|
102
|
+
for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
|
103
|
+
with attempt:
|
|
104
|
+
response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
|
105
|
+
cost_output = self._calculate_cost(response)
|
|
106
|
+
GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
|
107
|
+
message = dict(response["choices"][0]["message"])
|
|
108
|
+
message["extra"] = {
|
|
109
|
+
"actions": self._parse_actions(response),
|
|
110
|
+
"response": response,
|
|
111
|
+
**cost_output,
|
|
112
|
+
"timestamp": time.time(),
|
|
113
|
+
}
|
|
114
|
+
return message
|
|
96
115
|
|
|
97
|
-
|
|
116
|
+
def _calculate_cost(self, response) -> dict[str, float]:
|
|
98
117
|
usage = response.get("usage", {})
|
|
99
118
|
cost = usage.get("cost", 0.0)
|
|
100
|
-
|
|
101
|
-
# If cost is not available, raise an error
|
|
102
119
|
if cost == 0.0:
|
|
103
120
|
raise RequestyAPIError(
|
|
104
121
|
f"No cost information available from Requesty API for model {self.config.model_name}. "
|
|
105
122
|
"Cost tracking is required but not provided by the API response."
|
|
106
123
|
)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
124
|
+
return {"cost": cost}
|
|
125
|
+
|
|
126
|
+
def _parse_actions(self, response: dict) -> list[dict]:
|
|
127
|
+
"""Parse tool calls from the response. Raises FormatError if unknown tool."""
|
|
128
|
+
tool_calls = response["choices"][0]["message"].get("tool_calls") or []
|
|
129
|
+
tool_calls = [_DictToObj(tc) for tc in tool_calls]
|
|
130
|
+
return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
|
|
131
|
+
|
|
132
|
+
def format_message(self, **kwargs) -> dict:
|
|
133
|
+
return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
|
134
|
+
|
|
135
|
+
def format_observation_messages(
|
|
136
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
137
|
+
) -> list[dict]:
|
|
138
|
+
"""Format execution outputs into tool result messages."""
|
|
139
|
+
actions = message.get("extra", {}).get("actions", [])
|
|
140
|
+
return format_toolcall_observation_messages(
|
|
141
|
+
actions=actions,
|
|
142
|
+
outputs=outputs,
|
|
143
|
+
observation_template=self.config.observation_template,
|
|
144
|
+
template_vars=template_vars,
|
|
145
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
|
149
|
+
return self.config.model_dump()
|
|
150
|
+
|
|
151
|
+
def serialize(self) -> dict:
|
|
112
152
|
return {
|
|
113
|
-
"
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
153
|
+
"info": {
|
|
154
|
+
"config": {
|
|
155
|
+
"model": self.config.model_dump(mode="json"),
|
|
156
|
+
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
157
|
+
},
|
|
158
|
+
}
|
|
117
159
|
}
|
|
118
160
|
|
|
119
|
-
|
|
120
|
-
|
|
161
|
+
|
|
162
|
+
class _DictToObj:
|
|
163
|
+
"""Simple wrapper to convert dict to object with attribute access."""
|
|
164
|
+
|
|
165
|
+
def __init__(self, d: dict):
|
|
166
|
+
self._d = d
|
|
167
|
+
self.id = d.get("id")
|
|
168
|
+
self.function = _DictToObj(d.get("function", {})) if "function" in d else None
|
|
169
|
+
self.name = d.get("name")
|
|
170
|
+
self.arguments = d.get("arguments")
|
|
@@ -1,42 +1,269 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import time
|
|
3
|
-
from dataclasses import asdict, dataclass
|
|
4
3
|
from typing import Any
|
|
5
4
|
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
6
7
|
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
8
|
+
from minisweagent.models.utils.actions_text import format_observation_messages
|
|
9
|
+
from minisweagent.models.utils.actions_toolcall import format_toolcall_observation_messages
|
|
10
|
+
from minisweagent.models.utils.actions_toolcall_response import (
|
|
11
|
+
format_toolcall_observation_messages as format_response_api_observation_messages,
|
|
12
|
+
)
|
|
13
|
+
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def make_output(content: str, actions: list[dict], cost: float = 1.0) -> dict:
|
|
17
|
+
"""Helper to create an output dict for DeterministicModel.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
content: The response content string
|
|
21
|
+
actions: List of action dicts, e.g., [{"command": "echo hello"}]
|
|
22
|
+
cost: Cost to report for this output (default 1.0)
|
|
23
|
+
"""
|
|
24
|
+
return {
|
|
25
|
+
"role": "assistant",
|
|
26
|
+
"content": content,
|
|
27
|
+
"extra": {"actions": actions, "cost": cost, "timestamp": time.time()},
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def make_toolcall_output(content: str | None, tool_calls: list[dict], actions: list[dict]) -> dict:
|
|
32
|
+
"""Helper to create a toolcall output dict for DeterministicToolcallModel.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
content: Optional text content (can be None for tool-only responses)
|
|
36
|
+
tool_calls: List of tool call dicts in OpenAI format
|
|
37
|
+
actions: List of parsed action dicts, e.g., [{"command": "echo hello", "tool_call_id": "call_123"}]
|
|
38
|
+
"""
|
|
39
|
+
return {
|
|
40
|
+
"role": "assistant",
|
|
41
|
+
"content": content,
|
|
42
|
+
"tool_calls": tool_calls,
|
|
43
|
+
"extra": {"actions": actions, "cost": 1.0, "timestamp": time.time()},
|
|
44
|
+
}
|
|
7
45
|
|
|
8
46
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
47
|
+
def make_response_api_output(content: str | None, actions: list[dict]) -> dict:
|
|
48
|
+
"""Helper to create an output dict for DeterministicResponseAPIToolcallModel.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
content: Optional text content (can be None for tool-only responses)
|
|
52
|
+
actions: List of action dicts with 'command' and 'tool_call_id' keys
|
|
53
|
+
"""
|
|
54
|
+
output_items = []
|
|
55
|
+
if content:
|
|
56
|
+
output_items.append(
|
|
57
|
+
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]}
|
|
58
|
+
)
|
|
59
|
+
for action in actions:
|
|
60
|
+
output_items.append(
|
|
61
|
+
{
|
|
62
|
+
"type": "function_call",
|
|
63
|
+
"call_id": action["tool_call_id"],
|
|
64
|
+
"name": "bash",
|
|
65
|
+
"arguments": f'{{"command": "{action["command"]}"}}',
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
return {
|
|
69
|
+
"object": "response",
|
|
70
|
+
"output": output_items,
|
|
71
|
+
"extra": {"actions": actions, "cost": 1.0, "timestamp": time.time()},
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _process_test_actions(actions: list[dict]) -> bool:
|
|
76
|
+
"""Process special test actions. Returns True if the query should be retried."""
|
|
77
|
+
for action in actions:
|
|
78
|
+
if "raise" in action:
|
|
79
|
+
raise action["raise"]
|
|
80
|
+
cmd = action.get("command", "")
|
|
81
|
+
if cmd.startswith("/sleep "):
|
|
82
|
+
time.sleep(float(cmd.split("/sleep ")[1]))
|
|
83
|
+
return True
|
|
84
|
+
if cmd.startswith("/warning"):
|
|
85
|
+
logging.warning(cmd.split("/warning")[1])
|
|
86
|
+
return True
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class DeterministicModelConfig(BaseModel):
|
|
91
|
+
outputs: list[dict]
|
|
92
|
+
"""List of exact output messages to return in sequence. Each dict should have 'role', 'content', and 'extra' (with 'actions')."""
|
|
12
93
|
model_name: str = "deterministic"
|
|
13
94
|
cost_per_call: float = 1.0
|
|
95
|
+
observation_template: str = (
|
|
96
|
+
"{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
|
97
|
+
"<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
|
98
|
+
)
|
|
99
|
+
"""Template used to render the observation after executing an action."""
|
|
100
|
+
multimodal_regex: str = ""
|
|
101
|
+
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
|
|
14
102
|
|
|
15
103
|
|
|
16
104
|
class DeterministicModel:
|
|
17
105
|
def __init__(self, **kwargs):
|
|
18
|
-
"""
|
|
19
|
-
Initialize with a list of outputs to return in sequence.
|
|
20
|
-
"""
|
|
106
|
+
"""Initialize with a list of output messages to return in sequence."""
|
|
21
107
|
self.config = DeterministicModelConfig(**kwargs)
|
|
22
108
|
self.current_index = -1
|
|
23
|
-
self.cost = 0.0
|
|
24
|
-
self.n_calls = 0
|
|
25
109
|
|
|
26
110
|
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
27
111
|
self.current_index += 1
|
|
28
112
|
output = self.config.outputs[self.current_index]
|
|
29
|
-
if "
|
|
30
|
-
print("SLEEPING")
|
|
31
|
-
time.sleep(float(output.split("/sleep")[1]))
|
|
113
|
+
if _process_test_actions(output.get("extra", {}).get("actions", [])):
|
|
32
114
|
return self.query(messages, **kwargs)
|
|
33
|
-
|
|
34
|
-
|
|
115
|
+
GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
|
|
116
|
+
return output
|
|
117
|
+
|
|
118
|
+
def format_message(self, **kwargs) -> dict:
|
|
119
|
+
return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
|
120
|
+
|
|
121
|
+
def format_observation_messages(
|
|
122
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
123
|
+
) -> list[dict]:
|
|
124
|
+
"""Format execution outputs into observation messages."""
|
|
125
|
+
return format_observation_messages(
|
|
126
|
+
outputs,
|
|
127
|
+
observation_template=self.config.observation_template,
|
|
128
|
+
template_vars=template_vars,
|
|
129
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
|
133
|
+
return self.config.model_dump()
|
|
134
|
+
|
|
135
|
+
def serialize(self) -> dict:
|
|
136
|
+
return {
|
|
137
|
+
"info": {
|
|
138
|
+
"config": {
|
|
139
|
+
"model": self.config.model_dump(mode="json"),
|
|
140
|
+
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
141
|
+
},
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class DeterministicToolcallModelConfig(BaseModel):
|
|
147
|
+
outputs: list[dict]
|
|
148
|
+
"""List of exact output messages with tool_calls to return in sequence."""
|
|
149
|
+
model_name: str = "deterministic_toolcall"
|
|
150
|
+
cost_per_call: float = 1.0
|
|
151
|
+
observation_template: str = (
|
|
152
|
+
"{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
|
153
|
+
"<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
|
154
|
+
)
|
|
155
|
+
"""Template used to render the observation after executing an action."""
|
|
156
|
+
multimodal_regex: str = ""
|
|
157
|
+
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class DeterministicToolcallModel:
|
|
161
|
+
def __init__(self, **kwargs):
|
|
162
|
+
"""Initialize with a list of toolcall output messages to return in sequence."""
|
|
163
|
+
self.config = DeterministicToolcallModelConfig(**kwargs)
|
|
164
|
+
self.current_index = -1
|
|
165
|
+
|
|
166
|
+
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
167
|
+
self.current_index += 1
|
|
168
|
+
output = self.config.outputs[self.current_index]
|
|
169
|
+
if _process_test_actions(output.get("extra", {}).get("actions", [])):
|
|
170
|
+
return self.query(messages, **kwargs)
|
|
171
|
+
GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
|
|
172
|
+
return output
|
|
173
|
+
|
|
174
|
+
def format_message(self, **kwargs) -> dict:
|
|
175
|
+
return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
|
176
|
+
|
|
177
|
+
def format_observation_messages(
|
|
178
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
179
|
+
) -> list[dict]:
|
|
180
|
+
"""Format execution outputs into tool result messages."""
|
|
181
|
+
actions = message.get("extra", {}).get("actions", [])
|
|
182
|
+
return format_toolcall_observation_messages(
|
|
183
|
+
actions=actions,
|
|
184
|
+
outputs=outputs,
|
|
185
|
+
observation_template=self.config.observation_template,
|
|
186
|
+
template_vars=template_vars,
|
|
187
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
|
191
|
+
return self.config.model_dump()
|
|
192
|
+
|
|
193
|
+
def serialize(self) -> dict:
|
|
194
|
+
return {
|
|
195
|
+
"info": {
|
|
196
|
+
"config": {
|
|
197
|
+
"model": self.config.model_dump(mode="json"),
|
|
198
|
+
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
199
|
+
},
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class DeterministicResponseAPIToolcallModelConfig(BaseModel):
|
|
205
|
+
outputs: list[dict]
|
|
206
|
+
"""List of exact Response API output messages to return in sequence."""
|
|
207
|
+
model_name: str = "deterministic_response_api_toolcall"
|
|
208
|
+
cost_per_call: float = 1.0
|
|
209
|
+
observation_template: str = (
|
|
210
|
+
"{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
|
211
|
+
"<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
|
212
|
+
)
|
|
213
|
+
"""Template used to render the observation after executing an action."""
|
|
214
|
+
multimodal_regex: str = ""
|
|
215
|
+
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class DeterministicResponseAPIToolcallModel:
|
|
219
|
+
"""Deterministic test model using OpenAI Responses API format."""
|
|
220
|
+
|
|
221
|
+
def __init__(self, **kwargs):
|
|
222
|
+
"""Initialize with a list of Response API output messages to return in sequence."""
|
|
223
|
+
self.config = DeterministicResponseAPIToolcallModelConfig(**kwargs)
|
|
224
|
+
self.current_index = -1
|
|
225
|
+
|
|
226
|
+
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
227
|
+
self.current_index += 1
|
|
228
|
+
output = self.config.outputs[self.current_index]
|
|
229
|
+
if _process_test_actions(output.get("extra", {}).get("actions", [])):
|
|
35
230
|
return self.query(messages, **kwargs)
|
|
36
|
-
self.n_calls += 1
|
|
37
|
-
self.cost += self.config.cost_per_call
|
|
38
231
|
GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
|
|
39
|
-
return
|
|
232
|
+
return output
|
|
233
|
+
|
|
234
|
+
def format_message(self, **kwargs) -> dict:
|
|
235
|
+
"""Format message in Responses API format."""
|
|
236
|
+
role = kwargs.get("role", "user")
|
|
237
|
+
content = kwargs.get("content", "")
|
|
238
|
+
extra = kwargs.get("extra")
|
|
239
|
+
content_items = [{"type": "input_text", "text": content}] if isinstance(content, str) else content
|
|
240
|
+
msg: dict = {"type": "message", "role": role, "content": content_items}
|
|
241
|
+
if extra:
|
|
242
|
+
msg["extra"] = extra
|
|
243
|
+
return msg
|
|
244
|
+
|
|
245
|
+
def format_observation_messages(
|
|
246
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
247
|
+
) -> list[dict]:
|
|
248
|
+
"""Format execution outputs into function_call_output messages."""
|
|
249
|
+
actions = message.get("extra", {}).get("actions", [])
|
|
250
|
+
return format_response_api_observation_messages(
|
|
251
|
+
actions=actions,
|
|
252
|
+
outputs=outputs,
|
|
253
|
+
observation_template=self.config.observation_template,
|
|
254
|
+
template_vars=template_vars,
|
|
255
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
|
259
|
+
return self.config.model_dump()
|
|
40
260
|
|
|
41
|
-
def
|
|
42
|
-
return
|
|
261
|
+
def serialize(self) -> dict:
|
|
262
|
+
return {
|
|
263
|
+
"info": {
|
|
264
|
+
"config": {
|
|
265
|
+
"model": self.config.model_dump(mode="json"),
|
|
266
|
+
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
267
|
+
},
|
|
268
|
+
}
|
|
269
|
+
}
|