mini-swe-agent 1.17.4__py3-none-any.whl → 2.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/METADATA +36 -52
- mini_swe_agent-2.0.0a1.dist-info/RECORD +70 -0
- {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/WHEEL +1 -1
- mini_swe_agent-2.0.0a1.dist-info/entry_points.txt +5 -0
- minisweagent/__init__.py +19 -26
- minisweagent/agents/default.py +128 -113
- minisweagent/agents/interactive.py +119 -58
- minisweagent/config/README.md +3 -4
- minisweagent/config/__init__.py +36 -1
- minisweagent/config/benchmarks/swebench.yaml +156 -0
- minisweagent/config/{extra/swebench.yaml → benchmarks/swebench_backticks.yaml} +69 -64
- minisweagent/config/benchmarks/swebench_modal.yaml +47 -0
- minisweagent/config/{extra → benchmarks}/swebench_xml.yaml +73 -70
- minisweagent/config/default.yaml +24 -21
- minisweagent/config/inspector.tcss +42 -0
- minisweagent/config/mini.yaml +53 -71
- minisweagent/config/{github_issue.yaml → mini_textbased.yaml} +43 -29
- minisweagent/environments/__init__.py +1 -0
- minisweagent/environments/docker.py +67 -20
- minisweagent/environments/extra/bubblewrap.py +86 -47
- minisweagent/environments/extra/swerex_docker.py +53 -20
- minisweagent/environments/extra/swerex_modal.py +90 -0
- minisweagent/environments/local.py +62 -21
- minisweagent/environments/singularity.py +59 -18
- minisweagent/exceptions.py +22 -0
- minisweagent/models/__init__.py +6 -7
- minisweagent/models/extra/roulette.py +20 -17
- minisweagent/models/litellm_model.py +90 -44
- minisweagent/models/litellm_response_model.py +80 -0
- minisweagent/models/litellm_textbased_model.py +45 -0
- minisweagent/models/openrouter_model.py +87 -45
- minisweagent/models/openrouter_response_model.py +123 -0
- minisweagent/models/openrouter_textbased_model.py +76 -0
- minisweagent/models/portkey_model.py +84 -42
- minisweagent/models/portkey_response_model.py +163 -0
- minisweagent/models/requesty_model.py +91 -41
- minisweagent/models/test_models.py +246 -19
- minisweagent/models/utils/actions_text.py +60 -0
- minisweagent/models/utils/actions_toolcall.py +102 -0
- minisweagent/models/utils/actions_toolcall_response.py +110 -0
- minisweagent/models/utils/anthropic_utils.py +28 -0
- minisweagent/models/utils/cache_control.py +15 -2
- minisweagent/models/utils/content_string.py +74 -0
- minisweagent/models/utils/openai_multimodal.py +50 -0
- minisweagent/models/utils/retry.py +25 -0
- minisweagent/run/benchmarks/__init__.py +1 -0
- minisweagent/run/{extra → benchmarks}/swebench.py +57 -36
- minisweagent/run/benchmarks/swebench_single.py +89 -0
- minisweagent/run/{extra → benchmarks}/utils/batch_progress.py +1 -1
- minisweagent/run/hello_world.py +6 -0
- minisweagent/run/mini.py +54 -63
- minisweagent/run/utilities/__init__.py +1 -0
- minisweagent/run/{extra → utilities}/config.py +2 -0
- minisweagent/run/{inspector.py → utilities/inspector.py} +90 -11
- minisweagent/run/{mini_extra.py → utilities/mini_extra.py} +9 -5
- minisweagent/utils/serialize.py +26 -0
- mini_swe_agent-1.17.4.dist-info/RECORD +0 -61
- mini_swe_agent-1.17.4.dist-info/entry_points.txt +0 -5
- minisweagent/agents/interactive_textual.py +0 -450
- minisweagent/config/extra/swebench_roulette.yaml +0 -233
- minisweagent/config/mini.tcss +0 -86
- minisweagent/models/anthropic.py +0 -35
- minisweagent/models/litellm_response_api_model.py +0 -82
- minisweagent/models/portkey_response_api_model.py +0 -75
- minisweagent/models/utils/key_per_thread.py +0 -20
- minisweagent/models/utils/openai_utils.py +0 -41
- minisweagent/run/extra/swebench_single.py +0 -79
- minisweagent/run/github_issue.py +0 -87
- minisweagent/run/utils/__init__.py +0 -0
- minisweagent/run/utils/save.py +0 -78
- {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/licenses/LICENSE.md +0 -0
- {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/top_level.txt +0 -0
- /minisweagent/config/{extra → benchmarks}/__init__.py +0 -0
- /minisweagent/run/{extra → benchmarks}/utils/__init__.py +0 -0
|
@@ -1,72 +1,64 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
|
|
4
|
+
import time
|
|
5
5
|
from typing import Any, Literal
|
|
6
6
|
|
|
7
7
|
import requests
|
|
8
|
-
from
|
|
9
|
-
before_sleep_log,
|
|
10
|
-
retry,
|
|
11
|
-
retry_if_not_exception_type,
|
|
12
|
-
stop_after_attempt,
|
|
13
|
-
wait_exponential,
|
|
14
|
-
)
|
|
8
|
+
from pydantic import BaseModel
|
|
15
9
|
|
|
16
10
|
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
11
|
+
from minisweagent.models.utils.actions_toolcall import (
|
|
12
|
+
BASH_TOOL,
|
|
13
|
+
format_toolcall_observation_messages,
|
|
14
|
+
parse_toolcall_actions,
|
|
15
|
+
)
|
|
16
|
+
from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
|
|
17
17
|
from minisweagent.models.utils.cache_control import set_cache_control
|
|
18
|
+
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
|
|
19
|
+
from minisweagent.models.utils.retry import retry
|
|
18
20
|
|
|
19
21
|
logger = logging.getLogger("openrouter_model")
|
|
20
22
|
|
|
21
23
|
|
|
22
|
-
|
|
23
|
-
class OpenRouterModelConfig:
|
|
24
|
+
class OpenRouterModelConfig(BaseModel):
|
|
24
25
|
model_name: str
|
|
25
|
-
model_kwargs: dict[str, Any] =
|
|
26
|
+
model_kwargs: dict[str, Any] = {}
|
|
26
27
|
set_cache_control: Literal["default_end"] | None = None
|
|
27
28
|
"""Set explicit cache control markers, for example for Anthropic models"""
|
|
28
29
|
cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
|
|
29
30
|
"""Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
|
|
31
|
+
format_error_template: str = "{{ error }}"
|
|
32
|
+
"""Template used when the LM's output is not in the expected format."""
|
|
33
|
+
observation_template: str = (
|
|
34
|
+
"{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
|
35
|
+
"<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
|
36
|
+
)
|
|
37
|
+
"""Template used to render the observation after executing an action."""
|
|
38
|
+
multimodal_regex: str = ""
|
|
39
|
+
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
|
|
30
40
|
|
|
31
41
|
|
|
32
42
|
class OpenRouterAPIError(Exception):
|
|
33
43
|
"""Custom exception for OpenRouter API errors."""
|
|
34
44
|
|
|
35
|
-
pass
|
|
36
|
-
|
|
37
45
|
|
|
38
46
|
class OpenRouterAuthenticationError(Exception):
|
|
39
47
|
"""Custom exception for OpenRouter authentication errors."""
|
|
40
48
|
|
|
41
|
-
pass
|
|
42
|
-
|
|
43
49
|
|
|
44
50
|
class OpenRouterRateLimitError(Exception):
|
|
45
51
|
"""Custom exception for OpenRouter rate limit errors."""
|
|
46
52
|
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
53
|
|
|
50
54
|
class OpenRouterModel:
|
|
55
|
+
abort_exceptions: list[type[Exception]] = [OpenRouterAuthenticationError, KeyboardInterrupt]
|
|
56
|
+
|
|
51
57
|
def __init__(self, **kwargs):
|
|
52
58
|
self.config = OpenRouterModelConfig(**kwargs)
|
|
53
|
-
self.cost = 0.0
|
|
54
|
-
self.n_calls = 0
|
|
55
59
|
self._api_url = "https://openrouter.ai/api/v1/chat/completions"
|
|
56
60
|
self._api_key = os.getenv("OPENROUTER_API_KEY", "")
|
|
57
61
|
|
|
58
|
-
@retry(
|
|
59
|
-
reraise=True,
|
|
60
|
-
stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
|
|
61
|
-
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
62
|
-
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
63
|
-
retry=retry_if_not_exception_type(
|
|
64
|
-
(
|
|
65
|
-
OpenRouterAuthenticationError,
|
|
66
|
-
KeyboardInterrupt,
|
|
67
|
-
)
|
|
68
|
-
),
|
|
69
|
-
)
|
|
70
62
|
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
71
63
|
headers = {
|
|
72
64
|
"Authorization": f"Bearer {self._api_key}",
|
|
@@ -76,6 +68,7 @@ class OpenRouterModel:
|
|
|
76
68
|
payload = {
|
|
77
69
|
"model": self.config.model_name,
|
|
78
70
|
"messages": messages,
|
|
71
|
+
"tools": [BASH_TOOL],
|
|
79
72
|
"usage": {"include": True},
|
|
80
73
|
**(self.config.model_kwargs | kwargs),
|
|
81
74
|
}
|
|
@@ -95,11 +88,27 @@ class OpenRouterModel:
|
|
|
95
88
|
except requests.exceptions.RequestException as e:
|
|
96
89
|
raise OpenRouterAPIError(f"Request failed: {e}") from e
|
|
97
90
|
|
|
91
|
+
def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
|
92
|
+
prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
|
|
93
|
+
prepared = _reorder_anthropic_thinking_blocks(prepared)
|
|
94
|
+
return set_cache_control(prepared, mode=self.config.set_cache_control)
|
|
95
|
+
|
|
98
96
|
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
97
|
+
for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
|
98
|
+
with attempt:
|
|
99
|
+
response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
|
100
|
+
cost_output = self._calculate_cost(response)
|
|
101
|
+
GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
|
102
|
+
message = dict(response["choices"][0]["message"])
|
|
103
|
+
message["extra"] = {
|
|
104
|
+
"actions": self._parse_actions(response),
|
|
105
|
+
"response": response,
|
|
106
|
+
**cost_output,
|
|
107
|
+
"timestamp": time.time(),
|
|
108
|
+
}
|
|
109
|
+
return message
|
|
102
110
|
|
|
111
|
+
def _calculate_cost(self, response) -> dict[str, float]:
|
|
103
112
|
usage = response.get("usage", {})
|
|
104
113
|
cost = usage.get("cost", 0.0)
|
|
105
114
|
if cost <= 0.0 and self.config.cost_tracking != "ignore_errors":
|
|
@@ -110,17 +119,50 @@ class OpenRouterModel:
|
|
|
110
119
|
"(for example for free/local models), more information at https://klieret.short.gy/mini-local-models "
|
|
111
120
|
"for more details. Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
|
|
112
121
|
)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
122
|
+
return {"cost": cost}
|
|
123
|
+
|
|
124
|
+
def _parse_actions(self, response: dict) -> list[dict]:
|
|
125
|
+
"""Parse tool calls from the response. Raises FormatError if unknown tool."""
|
|
126
|
+
tool_calls = response["choices"][0]["message"].get("tool_calls") or []
|
|
127
|
+
tool_calls = [_DictToObj(tc) for tc in tool_calls]
|
|
128
|
+
return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
|
|
129
|
+
|
|
130
|
+
def format_message(self, **kwargs) -> dict:
|
|
131
|
+
return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
|
132
|
+
|
|
133
|
+
def format_observation_messages(
|
|
134
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
135
|
+
) -> list[dict]:
|
|
136
|
+
"""Format execution outputs into tool result messages."""
|
|
137
|
+
actions = message.get("extra", {}).get("actions", [])
|
|
138
|
+
return format_toolcall_observation_messages(
|
|
139
|
+
actions=actions,
|
|
140
|
+
outputs=outputs,
|
|
141
|
+
observation_template=self.config.observation_template,
|
|
142
|
+
template_vars=template_vars,
|
|
143
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
|
147
|
+
return self.config.model_dump()
|
|
148
|
+
|
|
149
|
+
def serialize(self) -> dict:
|
|
118
150
|
return {
|
|
119
|
-
"
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
151
|
+
"info": {
|
|
152
|
+
"config": {
|
|
153
|
+
"model": self.config.model_dump(mode="json"),
|
|
154
|
+
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
155
|
+
},
|
|
156
|
+
}
|
|
123
157
|
}
|
|
124
158
|
|
|
125
|
-
|
|
126
|
-
|
|
159
|
+
|
|
160
|
+
class _DictToObj:
|
|
161
|
+
"""Simple wrapper to convert dict to object with attribute access."""
|
|
162
|
+
|
|
163
|
+
def __init__(self, d: dict):
|
|
164
|
+
self._d = d
|
|
165
|
+
self.id = d.get("id")
|
|
166
|
+
self.function = _DictToObj(d.get("function", {})) if "function" in d else None
|
|
167
|
+
self.name = d.get("name")
|
|
168
|
+
self.arguments = d.get("arguments")
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
8
|
+
from minisweagent.models.openrouter_model import (
|
|
9
|
+
OpenRouterAPIError,
|
|
10
|
+
OpenRouterAuthenticationError,
|
|
11
|
+
OpenRouterModel,
|
|
12
|
+
OpenRouterModelConfig,
|
|
13
|
+
OpenRouterRateLimitError,
|
|
14
|
+
)
|
|
15
|
+
from minisweagent.models.utils.actions_toolcall_response import (
|
|
16
|
+
BASH_TOOL_RESPONSE_API,
|
|
17
|
+
format_toolcall_observation_messages,
|
|
18
|
+
parse_toolcall_actions_response,
|
|
19
|
+
)
|
|
20
|
+
from minisweagent.models.utils.retry import retry
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger("openrouter_response_model")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OpenRouterResponseModelConfig(OpenRouterModelConfig):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OpenRouterResponseModel(OpenRouterModel):
|
|
30
|
+
"""OpenRouter model using the Responses API with native tool calling.
|
|
31
|
+
|
|
32
|
+
Note: OpenRouter's Responses API is stateless - each request must include
|
|
33
|
+
the full conversation history. previous_response_id is not supported.
|
|
34
|
+
See: https://openrouter.ai/docs/api/reference/responses/overview
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, **kwargs):
|
|
38
|
+
super().__init__(**kwargs)
|
|
39
|
+
self.config = OpenRouterResponseModelConfig(**kwargs)
|
|
40
|
+
self._api_url = "https://openrouter.ai/api/v1/responses"
|
|
41
|
+
|
|
42
|
+
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
43
|
+
headers = {
|
|
44
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
45
|
+
"Content-Type": "application/json",
|
|
46
|
+
}
|
|
47
|
+
payload = {
|
|
48
|
+
"model": self.config.model_name,
|
|
49
|
+
"input": messages,
|
|
50
|
+
"tools": [BASH_TOOL_RESPONSE_API],
|
|
51
|
+
**(self.config.model_kwargs | kwargs),
|
|
52
|
+
}
|
|
53
|
+
try:
|
|
54
|
+
response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
|
|
55
|
+
response.raise_for_status()
|
|
56
|
+
return response.json()
|
|
57
|
+
except requests.exceptions.HTTPError as e:
|
|
58
|
+
if response.status_code == 401:
|
|
59
|
+
error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
|
|
60
|
+
raise OpenRouterAuthenticationError(error_msg) from e
|
|
61
|
+
elif response.status_code == 429:
|
|
62
|
+
raise OpenRouterRateLimitError("Rate limit exceeded") from e
|
|
63
|
+
else:
|
|
64
|
+
raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
|
|
65
|
+
except requests.exceptions.RequestException as e:
|
|
66
|
+
raise OpenRouterAPIError(f"Request failed: {e}") from e
|
|
67
|
+
|
|
68
|
+
def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
|
69
|
+
"""Prepare messages for OpenRouter's stateless Responses API.
|
|
70
|
+
|
|
71
|
+
Flattens response objects into their output items since OpenRouter
|
|
72
|
+
doesn't support previous_response_id.
|
|
73
|
+
"""
|
|
74
|
+
result = []
|
|
75
|
+
for msg in messages:
|
|
76
|
+
if msg.get("object") == "response":
|
|
77
|
+
for item in msg.get("output", []):
|
|
78
|
+
result.append({k: v for k, v in item.items() if k != "extra"})
|
|
79
|
+
else:
|
|
80
|
+
result.append({k: v for k, v in msg.items() if k != "extra"})
|
|
81
|
+
return result
|
|
82
|
+
|
|
83
|
+
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
84
|
+
for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
|
85
|
+
with attempt:
|
|
86
|
+
response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
|
87
|
+
cost_output = self._calculate_cost(response)
|
|
88
|
+
GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
|
89
|
+
message = dict(response)
|
|
90
|
+
message["extra"] = {
|
|
91
|
+
"actions": self._parse_actions(response),
|
|
92
|
+
**cost_output,
|
|
93
|
+
"timestamp": time.time(),
|
|
94
|
+
}
|
|
95
|
+
return message
|
|
96
|
+
|
|
97
|
+
def _parse_actions(self, response: dict) -> list[dict]:
|
|
98
|
+
return parse_toolcall_actions_response(
|
|
99
|
+
response.get("output", []), format_error_template=self.config.format_error_template
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def format_message(self, **kwargs) -> dict:
|
|
103
|
+
role = kwargs.get("role", "user")
|
|
104
|
+
content = kwargs.get("content", "")
|
|
105
|
+
extra = kwargs.get("extra")
|
|
106
|
+
content_items = [{"type": "input_text", "text": content}] if isinstance(content, str) else content
|
|
107
|
+
msg = {"type": "message", "role": role, "content": content_items}
|
|
108
|
+
if extra:
|
|
109
|
+
msg["extra"] = extra
|
|
110
|
+
return msg
|
|
111
|
+
|
|
112
|
+
def format_observation_messages(
|
|
113
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
114
|
+
) -> list[dict]:
|
|
115
|
+
"""Format execution outputs into tool result messages."""
|
|
116
|
+
actions = message.get("extra", {}).get("actions", [])
|
|
117
|
+
return format_toolcall_observation_messages(
|
|
118
|
+
actions=actions,
|
|
119
|
+
outputs=outputs,
|
|
120
|
+
observation_template=self.config.observation_template,
|
|
121
|
+
template_vars=template_vars,
|
|
122
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
123
|
+
)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import requests
|
|
5
|
+
|
|
6
|
+
from minisweagent.models.openrouter_model import (
|
|
7
|
+
OpenRouterAPIError,
|
|
8
|
+
OpenRouterAuthenticationError,
|
|
9
|
+
OpenRouterModel,
|
|
10
|
+
OpenRouterModelConfig,
|
|
11
|
+
OpenRouterRateLimitError,
|
|
12
|
+
)
|
|
13
|
+
from minisweagent.models.utils.actions_text import format_observation_messages, parse_regex_actions
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger("openrouter_textbased_model")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OpenRouterTextbasedModelConfig(OpenRouterModelConfig):
|
|
19
|
+
action_regex: str = r"```mswea_bash_command\s*\n(.*?)\n```"
|
|
20
|
+
"""Regex to extract the action from the LM's output."""
|
|
21
|
+
format_error_template: str = (
|
|
22
|
+
"Please always provide EXACTLY ONE action in triple backticks, found {{actions|length}} actions."
|
|
23
|
+
)
|
|
24
|
+
"""Template used when the LM's output is not in the expected format."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpenRouterTextbasedModel(OpenRouterModel):
|
|
28
|
+
def __init__(self, **kwargs):
|
|
29
|
+
super().__init__(**kwargs)
|
|
30
|
+
self.config = OpenRouterTextbasedModelConfig(**kwargs)
|
|
31
|
+
|
|
32
|
+
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
33
|
+
headers = {
|
|
34
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
35
|
+
"Content-Type": "application/json",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
payload = {
|
|
39
|
+
"model": self.config.model_name,
|
|
40
|
+
"messages": messages,
|
|
41
|
+
"usage": {"include": True},
|
|
42
|
+
**(self.config.model_kwargs | kwargs),
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
|
|
47
|
+
response.raise_for_status()
|
|
48
|
+
return response.json()
|
|
49
|
+
except requests.exceptions.HTTPError as e:
|
|
50
|
+
if response.status_code == 401:
|
|
51
|
+
error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
|
|
52
|
+
raise OpenRouterAuthenticationError(error_msg) from e
|
|
53
|
+
elif response.status_code == 429:
|
|
54
|
+
raise OpenRouterRateLimitError("Rate limit exceeded") from e
|
|
55
|
+
else:
|
|
56
|
+
raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
|
|
57
|
+
except requests.exceptions.RequestException as e:
|
|
58
|
+
raise OpenRouterAPIError(f"Request failed: {e}") from e
|
|
59
|
+
|
|
60
|
+
def _parse_actions(self, response: dict) -> list[dict]:
|
|
61
|
+
"""Parse actions from the model response. Raises FormatError if not exactly one action."""
|
|
62
|
+
content = response["choices"][0]["message"]["content"] or ""
|
|
63
|
+
return parse_regex_actions(
|
|
64
|
+
content, action_regex=self.config.action_regex, format_error_template=self.config.format_error_template
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def format_observation_messages(
|
|
68
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
69
|
+
) -> list[dict]:
|
|
70
|
+
"""Format execution outputs into observation messages."""
|
|
71
|
+
return format_observation_messages(
|
|
72
|
+
outputs,
|
|
73
|
+
observation_template=self.config.observation_template,
|
|
74
|
+
template_vars=template_vars,
|
|
75
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
76
|
+
)
|
|
@@ -1,21 +1,23 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
|
|
4
|
+
import time
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Any, Literal
|
|
7
7
|
|
|
8
8
|
import litellm
|
|
9
|
-
from
|
|
10
|
-
before_sleep_log,
|
|
11
|
-
retry,
|
|
12
|
-
retry_if_not_exception_type,
|
|
13
|
-
stop_after_attempt,
|
|
14
|
-
wait_exponential,
|
|
15
|
-
)
|
|
9
|
+
from pydantic import BaseModel
|
|
16
10
|
|
|
17
11
|
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
12
|
+
from minisweagent.models.utils.actions_toolcall import (
|
|
13
|
+
BASH_TOOL,
|
|
14
|
+
format_toolcall_observation_messages,
|
|
15
|
+
parse_toolcall_actions,
|
|
16
|
+
)
|
|
17
|
+
from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
|
|
18
18
|
from minisweagent.models.utils.cache_control import set_cache_control
|
|
19
|
+
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
|
|
20
|
+
from minisweagent.models.utils.retry import retry
|
|
19
21
|
|
|
20
22
|
logger = logging.getLogger("portkey_model")
|
|
21
23
|
|
|
@@ -27,10 +29,14 @@ except ImportError:
|
|
|
27
29
|
)
|
|
28
30
|
|
|
29
31
|
|
|
30
|
-
|
|
31
|
-
class PortkeyModelConfig:
|
|
32
|
+
class PortkeyModelConfig(BaseModel):
|
|
32
33
|
model_name: str
|
|
33
|
-
model_kwargs: dict[str, Any] =
|
|
34
|
+
model_kwargs: dict[str, Any] = {}
|
|
35
|
+
provider: str = ""
|
|
36
|
+
"""The LLM provider to use (e.g., 'openai', 'anthropic', 'google').
|
|
37
|
+
If not specified, will be auto-detected from model_name.
|
|
38
|
+
Required by Portkey when not using a virtual key.
|
|
39
|
+
"""
|
|
34
40
|
litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
|
|
35
41
|
"""We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry.
|
|
36
42
|
Note that this might change if we get better support for Portkey and change how we calculate costs.
|
|
@@ -44,17 +50,25 @@ class PortkeyModelConfig:
|
|
|
44
50
|
"""Set explicit cache control markers, for example for Anthropic models"""
|
|
45
51
|
cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
|
|
46
52
|
"""Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
|
|
53
|
+
format_error_template: str = "{{ error }}"
|
|
54
|
+
"""Template used when the LM's output is not in the expected format."""
|
|
55
|
+
observation_template: str = (
|
|
56
|
+
"{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
|
57
|
+
"<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
|
58
|
+
)
|
|
59
|
+
"""Template used to render the observation after executing an action."""
|
|
60
|
+
multimodal_regex: str = ""
|
|
61
|
+
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
|
|
47
62
|
|
|
48
63
|
|
|
49
64
|
class PortkeyModel:
|
|
65
|
+
abort_exceptions: list[type[Exception]] = [KeyboardInterrupt, TypeError, ValueError]
|
|
66
|
+
|
|
50
67
|
def __init__(self, *, config_class: type = PortkeyModelConfig, **kwargs):
|
|
51
68
|
self.config = config_class(**kwargs)
|
|
52
|
-
self.cost = 0.0
|
|
53
|
-
self.n_calls = 0
|
|
54
69
|
if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
|
|
55
70
|
litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
|
|
56
71
|
|
|
57
|
-
# Get API key from environment or raise error
|
|
58
72
|
self._api_key = os.getenv("PORTKEY_API_KEY")
|
|
59
73
|
if not self._api_key:
|
|
60
74
|
raise ValueError(
|
|
@@ -63,51 +77,79 @@ class PortkeyModel:
|
|
|
63
77
|
"`mini-extra config set PORTKEY_API_KEY YOUR_KEY`."
|
|
64
78
|
)
|
|
65
79
|
|
|
66
|
-
# Get virtual key from environment
|
|
67
80
|
virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY")
|
|
68
|
-
|
|
69
|
-
# Initialize Portkey client
|
|
70
81
|
client_kwargs = {"api_key": self._api_key}
|
|
71
82
|
if virtual_key:
|
|
72
83
|
client_kwargs["virtual_key"] = virtual_key
|
|
84
|
+
elif self.config.provider:
|
|
85
|
+
# If no virtual key but provider is specified, pass it
|
|
86
|
+
client_kwargs["provider"] = self.config.provider
|
|
73
87
|
|
|
74
88
|
self.client = Portkey(**client_kwargs)
|
|
75
89
|
|
|
76
|
-
@retry(
|
|
77
|
-
reraise=True,
|
|
78
|
-
stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
|
|
79
|
-
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
80
|
-
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
81
|
-
retry=retry_if_not_exception_type((KeyboardInterrupt, TypeError, ValueError)),
|
|
82
|
-
)
|
|
83
90
|
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
84
|
-
# return self.client.with_options(metadata={"request_id": request_id}).chat.completions.create(
|
|
85
91
|
return self.client.chat.completions.create(
|
|
86
92
|
model=self.config.model_name,
|
|
87
93
|
messages=messages,
|
|
94
|
+
tools=[BASH_TOOL],
|
|
88
95
|
**(self.config.model_kwargs | kwargs),
|
|
89
96
|
)
|
|
90
97
|
|
|
98
|
+
def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
|
99
|
+
prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
|
|
100
|
+
prepared = _reorder_anthropic_thinking_blocks(prepared)
|
|
101
|
+
return set_cache_control(prepared, mode=self.config.set_cache_control)
|
|
102
|
+
|
|
91
103
|
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
"
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
"cost": cost,
|
|
104
|
-
},
|
|
104
|
+
for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
|
105
|
+
with attempt:
|
|
106
|
+
response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
|
107
|
+
cost_output = self._calculate_cost(response)
|
|
108
|
+
GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
|
109
|
+
message = response.choices[0].message.model_dump()
|
|
110
|
+
message["extra"] = {
|
|
111
|
+
"actions": self._parse_actions(response),
|
|
112
|
+
"response": response.model_dump(),
|
|
113
|
+
**cost_output,
|
|
114
|
+
"timestamp": time.time(),
|
|
105
115
|
}
|
|
116
|
+
return message
|
|
117
|
+
|
|
118
|
+
def _parse_actions(self, response) -> list[dict]:
|
|
119
|
+
"""Parse tool calls from the response. Raises FormatError if unknown tool."""
|
|
120
|
+
tool_calls = response.choices[0].message.tool_calls or []
|
|
121
|
+
return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
|
|
122
|
+
|
|
123
|
+
def format_message(self, **kwargs) -> dict:
|
|
124
|
+
return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
|
125
|
+
|
|
126
|
+
def format_observation_messages(
|
|
127
|
+
self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
|
128
|
+
) -> list[dict]:
|
|
129
|
+
"""Format execution outputs into tool result messages."""
|
|
130
|
+
actions = message.get("extra", {}).get("actions", [])
|
|
131
|
+
return format_toolcall_observation_messages(
|
|
132
|
+
actions=actions,
|
|
133
|
+
outputs=outputs,
|
|
134
|
+
observation_template=self.config.observation_template,
|
|
135
|
+
template_vars=template_vars,
|
|
136
|
+
multimodal_regex=self.config.multimodal_regex,
|
|
137
|
+
)
|
|
106
138
|
|
|
107
|
-
def get_template_vars(self) -> dict[str, Any]:
|
|
108
|
-
return
|
|
139
|
+
def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
|
140
|
+
return self.config.model_dump()
|
|
141
|
+
|
|
142
|
+
def serialize(self) -> dict:
|
|
143
|
+
return {
|
|
144
|
+
"info": {
|
|
145
|
+
"config": {
|
|
146
|
+
"model": self.config.model_dump(mode="json"),
|
|
147
|
+
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
|
148
|
+
},
|
|
149
|
+
}
|
|
150
|
+
}
|
|
109
151
|
|
|
110
|
-
def _calculate_cost(self, response) -> float:
|
|
152
|
+
def _calculate_cost(self, response) -> dict[str, float]:
|
|
111
153
|
response_for_cost_calc = response.model_copy()
|
|
112
154
|
if self.config.litellm_model_name_override:
|
|
113
155
|
if response_for_cost_calc.model:
|
|
@@ -152,4 +194,4 @@ class PortkeyModel:
|
|
|
152
194
|
)
|
|
153
195
|
logger.critical(msg)
|
|
154
196
|
raise RuntimeError(msg) from e
|
|
155
|
-
return cost
|
|
197
|
+
return {"cost": cost}
|