mini-swe-agent 1.17.4__py3-none-any.whl → 2.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/METADATA +36 -52
  2. mini_swe_agent-2.0.0a1.dist-info/RECORD +70 -0
  3. {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/WHEEL +1 -1
  4. mini_swe_agent-2.0.0a1.dist-info/entry_points.txt +5 -0
  5. minisweagent/__init__.py +19 -26
  6. minisweagent/agents/default.py +128 -113
  7. minisweagent/agents/interactive.py +119 -58
  8. minisweagent/config/README.md +3 -4
  9. minisweagent/config/__init__.py +36 -1
  10. minisweagent/config/benchmarks/swebench.yaml +156 -0
  11. minisweagent/config/{extra/swebench.yaml → benchmarks/swebench_backticks.yaml} +69 -64
  12. minisweagent/config/benchmarks/swebench_modal.yaml +47 -0
  13. minisweagent/config/{extra → benchmarks}/swebench_xml.yaml +73 -70
  14. minisweagent/config/default.yaml +24 -21
  15. minisweagent/config/inspector.tcss +42 -0
  16. minisweagent/config/mini.yaml +53 -71
  17. minisweagent/config/{github_issue.yaml → mini_textbased.yaml} +43 -29
  18. minisweagent/environments/__init__.py +1 -0
  19. minisweagent/environments/docker.py +67 -20
  20. minisweagent/environments/extra/bubblewrap.py +86 -47
  21. minisweagent/environments/extra/swerex_docker.py +53 -20
  22. minisweagent/environments/extra/swerex_modal.py +90 -0
  23. minisweagent/environments/local.py +62 -21
  24. minisweagent/environments/singularity.py +59 -18
  25. minisweagent/exceptions.py +22 -0
  26. minisweagent/models/__init__.py +6 -7
  27. minisweagent/models/extra/roulette.py +20 -17
  28. minisweagent/models/litellm_model.py +90 -44
  29. minisweagent/models/litellm_response_model.py +80 -0
  30. minisweagent/models/litellm_textbased_model.py +45 -0
  31. minisweagent/models/openrouter_model.py +87 -45
  32. minisweagent/models/openrouter_response_model.py +123 -0
  33. minisweagent/models/openrouter_textbased_model.py +76 -0
  34. minisweagent/models/portkey_model.py +84 -42
  35. minisweagent/models/portkey_response_model.py +163 -0
  36. minisweagent/models/requesty_model.py +91 -41
  37. minisweagent/models/test_models.py +246 -19
  38. minisweagent/models/utils/actions_text.py +60 -0
  39. minisweagent/models/utils/actions_toolcall.py +102 -0
  40. minisweagent/models/utils/actions_toolcall_response.py +110 -0
  41. minisweagent/models/utils/anthropic_utils.py +28 -0
  42. minisweagent/models/utils/cache_control.py +15 -2
  43. minisweagent/models/utils/content_string.py +74 -0
  44. minisweagent/models/utils/openai_multimodal.py +50 -0
  45. minisweagent/models/utils/retry.py +25 -0
  46. minisweagent/run/benchmarks/__init__.py +1 -0
  47. minisweagent/run/{extra → benchmarks}/swebench.py +57 -36
  48. minisweagent/run/benchmarks/swebench_single.py +89 -0
  49. minisweagent/run/{extra → benchmarks}/utils/batch_progress.py +1 -1
  50. minisweagent/run/hello_world.py +6 -0
  51. minisweagent/run/mini.py +54 -63
  52. minisweagent/run/utilities/__init__.py +1 -0
  53. minisweagent/run/{extra → utilities}/config.py +2 -0
  54. minisweagent/run/{inspector.py → utilities/inspector.py} +90 -11
  55. minisweagent/run/{mini_extra.py → utilities/mini_extra.py} +9 -5
  56. minisweagent/utils/serialize.py +26 -0
  57. mini_swe_agent-1.17.4.dist-info/RECORD +0 -61
  58. mini_swe_agent-1.17.4.dist-info/entry_points.txt +0 -5
  59. minisweagent/agents/interactive_textual.py +0 -450
  60. minisweagent/config/extra/swebench_roulette.yaml +0 -233
  61. minisweagent/config/mini.tcss +0 -86
  62. minisweagent/models/anthropic.py +0 -35
  63. minisweagent/models/litellm_response_api_model.py +0 -82
  64. minisweagent/models/portkey_response_api_model.py +0 -75
  65. minisweagent/models/utils/key_per_thread.py +0 -20
  66. minisweagent/models/utils/openai_utils.py +0 -41
  67. minisweagent/run/extra/swebench_single.py +0 -79
  68. minisweagent/run/github_issue.py +0 -87
  69. minisweagent/run/utils/__init__.py +0 -0
  70. minisweagent/run/utils/save.py +0 -78
  71. {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/licenses/LICENSE.md +0 -0
  72. {mini_swe_agent-1.17.4.dist-info → mini_swe_agent-2.0.0a1.dist-info}/top_level.txt +0 -0
  73. /minisweagent/config/{extra → benchmarks}/__init__.py +0 -0
  74. /minisweagent/run/{extra → benchmarks}/utils/__init__.py +0 -0
@@ -1,72 +1,64 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
- from dataclasses import asdict, dataclass, field
4
+ import time
5
5
  from typing import Any, Literal
6
6
 
7
7
  import requests
8
- from tenacity import (
9
- before_sleep_log,
10
- retry,
11
- retry_if_not_exception_type,
12
- stop_after_attempt,
13
- wait_exponential,
14
- )
8
+ from pydantic import BaseModel
15
9
 
16
10
  from minisweagent.models import GLOBAL_MODEL_STATS
11
+ from minisweagent.models.utils.actions_toolcall import (
12
+ BASH_TOOL,
13
+ format_toolcall_observation_messages,
14
+ parse_toolcall_actions,
15
+ )
16
+ from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
17
17
  from minisweagent.models.utils.cache_control import set_cache_control
18
+ from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
19
+ from minisweagent.models.utils.retry import retry
18
20
 
19
21
  logger = logging.getLogger("openrouter_model")
20
22
 
21
23
 
22
- @dataclass
23
- class OpenRouterModelConfig:
24
+ class OpenRouterModelConfig(BaseModel):
24
25
  model_name: str
25
- model_kwargs: dict[str, Any] = field(default_factory=dict)
26
+ model_kwargs: dict[str, Any] = {}
26
27
  set_cache_control: Literal["default_end"] | None = None
27
28
  """Set explicit cache control markers, for example for Anthropic models"""
28
29
  cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
29
30
  """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
31
+ format_error_template: str = "{{ error }}"
32
+ """Template used when the LM's output is not in the expected format."""
33
+ observation_template: str = (
34
+ "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
35
+ "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
36
+ )
37
+ """Template used to render the observation after executing an action."""
38
+ multimodal_regex: str = ""
39
+ """Regex to extract multimodal content. Empty string disables multimodal processing."""
30
40
 
31
41
 
32
42
  class OpenRouterAPIError(Exception):
33
43
  """Custom exception for OpenRouter API errors."""
34
44
 
35
- pass
36
-
37
45
 
38
46
  class OpenRouterAuthenticationError(Exception):
39
47
  """Custom exception for OpenRouter authentication errors."""
40
48
 
41
- pass
42
-
43
49
 
44
50
  class OpenRouterRateLimitError(Exception):
45
51
  """Custom exception for OpenRouter rate limit errors."""
46
52
 
47
- pass
48
-
49
53
 
50
54
  class OpenRouterModel:
55
+ abort_exceptions: list[type[Exception]] = [OpenRouterAuthenticationError, KeyboardInterrupt]
56
+
51
57
  def __init__(self, **kwargs):
52
58
  self.config = OpenRouterModelConfig(**kwargs)
53
- self.cost = 0.0
54
- self.n_calls = 0
55
59
  self._api_url = "https://openrouter.ai/api/v1/chat/completions"
56
60
  self._api_key = os.getenv("OPENROUTER_API_KEY", "")
57
61
 
58
- @retry(
59
- reraise=True,
60
- stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
61
- wait=wait_exponential(multiplier=1, min=4, max=60),
62
- before_sleep=before_sleep_log(logger, logging.WARNING),
63
- retry=retry_if_not_exception_type(
64
- (
65
- OpenRouterAuthenticationError,
66
- KeyboardInterrupt,
67
- )
68
- ),
69
- )
70
62
  def _query(self, messages: list[dict[str, str]], **kwargs):
71
63
  headers = {
72
64
  "Authorization": f"Bearer {self._api_key}",
@@ -76,6 +68,7 @@ class OpenRouterModel:
76
68
  payload = {
77
69
  "model": self.config.model_name,
78
70
  "messages": messages,
71
+ "tools": [BASH_TOOL],
79
72
  "usage": {"include": True},
80
73
  **(self.config.model_kwargs | kwargs),
81
74
  }
@@ -95,11 +88,27 @@ class OpenRouterModel:
95
88
  except requests.exceptions.RequestException as e:
96
89
  raise OpenRouterAPIError(f"Request failed: {e}") from e
97
90
 
91
+ def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
92
+ prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
93
+ prepared = _reorder_anthropic_thinking_blocks(prepared)
94
+ return set_cache_control(prepared, mode=self.config.set_cache_control)
95
+
98
96
  def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
99
- if self.config.set_cache_control:
100
- messages = set_cache_control(messages, mode=self.config.set_cache_control)
101
- response = self._query([{"role": msg["role"], "content": msg["content"]} for msg in messages], **kwargs)
97
+ for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
98
+ with attempt:
99
+ response = self._query(self._prepare_messages_for_api(messages), **kwargs)
100
+ cost_output = self._calculate_cost(response)
101
+ GLOBAL_MODEL_STATS.add(cost_output["cost"])
102
+ message = dict(response["choices"][0]["message"])
103
+ message["extra"] = {
104
+ "actions": self._parse_actions(response),
105
+ "response": response,
106
+ **cost_output,
107
+ "timestamp": time.time(),
108
+ }
109
+ return message
102
110
 
111
+ def _calculate_cost(self, response) -> dict[str, float]:
103
112
  usage = response.get("usage", {})
104
113
  cost = usage.get("cost", 0.0)
105
114
  if cost <= 0.0 and self.config.cost_tracking != "ignore_errors":
@@ -110,17 +119,50 @@ class OpenRouterModel:
110
119
  "(for example for free/local models), more information at https://klieret.short.gy/mini-local-models "
111
120
  "for more details. Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
112
121
  )
113
-
114
- self.n_calls += 1
115
- self.cost += cost
116
- GLOBAL_MODEL_STATS.add(cost)
117
-
122
+ return {"cost": cost}
123
+
124
+ def _parse_actions(self, response: dict) -> list[dict]:
125
+ """Parse tool calls from the response. Raises FormatError if unknown tool."""
126
+ tool_calls = response["choices"][0]["message"].get("tool_calls") or []
127
+ tool_calls = [_DictToObj(tc) for tc in tool_calls]
128
+ return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
129
+
130
+ def format_message(self, **kwargs) -> dict:
131
+ return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
132
+
133
+ def format_observation_messages(
134
+ self, message: dict, outputs: list[dict], template_vars: dict | None = None
135
+ ) -> list[dict]:
136
+ """Format execution outputs into tool result messages."""
137
+ actions = message.get("extra", {}).get("actions", [])
138
+ return format_toolcall_observation_messages(
139
+ actions=actions,
140
+ outputs=outputs,
141
+ observation_template=self.config.observation_template,
142
+ template_vars=template_vars,
143
+ multimodal_regex=self.config.multimodal_regex,
144
+ )
145
+
146
+ def get_template_vars(self, **kwargs) -> dict[str, Any]:
147
+ return self.config.model_dump()
148
+
149
+ def serialize(self) -> dict:
118
150
  return {
119
- "content": response["choices"][0]["message"]["content"] or "",
120
- "extra": {
121
- "response": response, # already is json
122
- },
151
+ "info": {
152
+ "config": {
153
+ "model": self.config.model_dump(mode="json"),
154
+ "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
155
+ },
156
+ }
123
157
  }
124
158
 
125
- def get_template_vars(self) -> dict[str, Any]:
126
- return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
159
+
160
+ class _DictToObj:
161
+ """Simple wrapper to convert dict to object with attribute access."""
162
+
163
+ def __init__(self, d: dict):
164
+ self._d = d
165
+ self.id = d.get("id")
166
+ self.function = _DictToObj(d.get("function", {})) if "function" in d else None
167
+ self.name = d.get("name")
168
+ self.arguments = d.get("arguments")
@@ -0,0 +1,123 @@
1
+ import json
2
+ import logging
3
+ import time
4
+
5
+ import requests
6
+
7
+ from minisweagent.models import GLOBAL_MODEL_STATS
8
+ from minisweagent.models.openrouter_model import (
9
+ OpenRouterAPIError,
10
+ OpenRouterAuthenticationError,
11
+ OpenRouterModel,
12
+ OpenRouterModelConfig,
13
+ OpenRouterRateLimitError,
14
+ )
15
+ from minisweagent.models.utils.actions_toolcall_response import (
16
+ BASH_TOOL_RESPONSE_API,
17
+ format_toolcall_observation_messages,
18
+ parse_toolcall_actions_response,
19
+ )
20
+ from minisweagent.models.utils.retry import retry
21
+
22
+ logger = logging.getLogger("openrouter_response_model")
23
+
24
+
25
+ class OpenRouterResponseModelConfig(OpenRouterModelConfig):
26
+ pass
27
+
28
+
29
+ class OpenRouterResponseModel(OpenRouterModel):
30
+ """OpenRouter model using the Responses API with native tool calling.
31
+
32
+ Note: OpenRouter's Responses API is stateless - each request must include
33
+ the full conversation history. previous_response_id is not supported.
34
+ See: https://openrouter.ai/docs/api/reference/responses/overview
35
+ """
36
+
37
+ def __init__(self, **kwargs):
38
+ super().__init__(**kwargs)
39
+ self.config = OpenRouterResponseModelConfig(**kwargs)
40
+ self._api_url = "https://openrouter.ai/api/v1/responses"
41
+
42
+ def _query(self, messages: list[dict[str, str]], **kwargs):
43
+ headers = {
44
+ "Authorization": f"Bearer {self._api_key}",
45
+ "Content-Type": "application/json",
46
+ }
47
+ payload = {
48
+ "model": self.config.model_name,
49
+ "input": messages,
50
+ "tools": [BASH_TOOL_RESPONSE_API],
51
+ **(self.config.model_kwargs | kwargs),
52
+ }
53
+ try:
54
+ response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
55
+ response.raise_for_status()
56
+ return response.json()
57
+ except requests.exceptions.HTTPError as e:
58
+ if response.status_code == 401:
59
+ error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
60
+ raise OpenRouterAuthenticationError(error_msg) from e
61
+ elif response.status_code == 429:
62
+ raise OpenRouterRateLimitError("Rate limit exceeded") from e
63
+ else:
64
+ raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
65
+ except requests.exceptions.RequestException as e:
66
+ raise OpenRouterAPIError(f"Request failed: {e}") from e
67
+
68
+ def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
69
+ """Prepare messages for OpenRouter's stateless Responses API.
70
+
71
+ Flattens response objects into their output items since OpenRouter
72
+ doesn't support previous_response_id.
73
+ """
74
+ result = []
75
+ for msg in messages:
76
+ if msg.get("object") == "response":
77
+ for item in msg.get("output", []):
78
+ result.append({k: v for k, v in item.items() if k != "extra"})
79
+ else:
80
+ result.append({k: v for k, v in msg.items() if k != "extra"})
81
+ return result
82
+
83
+ def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
84
+ for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
85
+ with attempt:
86
+ response = self._query(self._prepare_messages_for_api(messages), **kwargs)
87
+ cost_output = self._calculate_cost(response)
88
+ GLOBAL_MODEL_STATS.add(cost_output["cost"])
89
+ message = dict(response)
90
+ message["extra"] = {
91
+ "actions": self._parse_actions(response),
92
+ **cost_output,
93
+ "timestamp": time.time(),
94
+ }
95
+ return message
96
+
97
+ def _parse_actions(self, response: dict) -> list[dict]:
98
+ return parse_toolcall_actions_response(
99
+ response.get("output", []), format_error_template=self.config.format_error_template
100
+ )
101
+
102
+ def format_message(self, **kwargs) -> dict:
103
+ role = kwargs.get("role", "user")
104
+ content = kwargs.get("content", "")
105
+ extra = kwargs.get("extra")
106
+ content_items = [{"type": "input_text", "text": content}] if isinstance(content, str) else content
107
+ msg = {"type": "message", "role": role, "content": content_items}
108
+ if extra:
109
+ msg["extra"] = extra
110
+ return msg
111
+
112
+ def format_observation_messages(
113
+ self, message: dict, outputs: list[dict], template_vars: dict | None = None
114
+ ) -> list[dict]:
115
+ """Format execution outputs into tool result messages."""
116
+ actions = message.get("extra", {}).get("actions", [])
117
+ return format_toolcall_observation_messages(
118
+ actions=actions,
119
+ outputs=outputs,
120
+ observation_template=self.config.observation_template,
121
+ template_vars=template_vars,
122
+ multimodal_regex=self.config.multimodal_regex,
123
+ )
@@ -0,0 +1,76 @@
1
+ import json
2
+ import logging
3
+
4
+ import requests
5
+
6
+ from minisweagent.models.openrouter_model import (
7
+ OpenRouterAPIError,
8
+ OpenRouterAuthenticationError,
9
+ OpenRouterModel,
10
+ OpenRouterModelConfig,
11
+ OpenRouterRateLimitError,
12
+ )
13
+ from minisweagent.models.utils.actions_text import format_observation_messages, parse_regex_actions
14
+
15
+ logger = logging.getLogger("openrouter_textbased_model")
16
+
17
+
18
+ class OpenRouterTextbasedModelConfig(OpenRouterModelConfig):
19
+ action_regex: str = r"```mswea_bash_command\s*\n(.*?)\n```"
20
+ """Regex to extract the action from the LM's output."""
21
+ format_error_template: str = (
22
+ "Please always provide EXACTLY ONE action in triple backticks, found {{actions|length}} actions."
23
+ )
24
+ """Template used when the LM's output is not in the expected format."""
25
+
26
+
27
+ class OpenRouterTextbasedModel(OpenRouterModel):
28
+ def __init__(self, **kwargs):
29
+ super().__init__(**kwargs)
30
+ self.config = OpenRouterTextbasedModelConfig(**kwargs)
31
+
32
+ def _query(self, messages: list[dict[str, str]], **kwargs):
33
+ headers = {
34
+ "Authorization": f"Bearer {self._api_key}",
35
+ "Content-Type": "application/json",
36
+ }
37
+
38
+ payload = {
39
+ "model": self.config.model_name,
40
+ "messages": messages,
41
+ "usage": {"include": True},
42
+ **(self.config.model_kwargs | kwargs),
43
+ }
44
+
45
+ try:
46
+ response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
47
+ response.raise_for_status()
48
+ return response.json()
49
+ except requests.exceptions.HTTPError as e:
50
+ if response.status_code == 401:
51
+ error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
52
+ raise OpenRouterAuthenticationError(error_msg) from e
53
+ elif response.status_code == 429:
54
+ raise OpenRouterRateLimitError("Rate limit exceeded") from e
55
+ else:
56
+ raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
57
+ except requests.exceptions.RequestException as e:
58
+ raise OpenRouterAPIError(f"Request failed: {e}") from e
59
+
60
+ def _parse_actions(self, response: dict) -> list[dict]:
61
+ """Parse actions from the model response. Raises FormatError if not exactly one action."""
62
+ content = response["choices"][0]["message"]["content"] or ""
63
+ return parse_regex_actions(
64
+ content, action_regex=self.config.action_regex, format_error_template=self.config.format_error_template
65
+ )
66
+
67
+ def format_observation_messages(
68
+ self, message: dict, outputs: list[dict], template_vars: dict | None = None
69
+ ) -> list[dict]:
70
+ """Format execution outputs into observation messages."""
71
+ return format_observation_messages(
72
+ outputs,
73
+ observation_template=self.config.observation_template,
74
+ template_vars=template_vars,
75
+ multimodal_regex=self.config.multimodal_regex,
76
+ )
@@ -1,21 +1,23 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
- from dataclasses import asdict, dataclass, field
4
+ import time
5
5
  from pathlib import Path
6
6
  from typing import Any, Literal
7
7
 
8
8
  import litellm
9
- from tenacity import (
10
- before_sleep_log,
11
- retry,
12
- retry_if_not_exception_type,
13
- stop_after_attempt,
14
- wait_exponential,
15
- )
9
+ from pydantic import BaseModel
16
10
 
17
11
  from minisweagent.models import GLOBAL_MODEL_STATS
12
+ from minisweagent.models.utils.actions_toolcall import (
13
+ BASH_TOOL,
14
+ format_toolcall_observation_messages,
15
+ parse_toolcall_actions,
16
+ )
17
+ from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
18
18
  from minisweagent.models.utils.cache_control import set_cache_control
19
+ from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
20
+ from minisweagent.models.utils.retry import retry
19
21
 
20
22
  logger = logging.getLogger("portkey_model")
21
23
 
@@ -27,10 +29,14 @@ except ImportError:
27
29
  )
28
30
 
29
31
 
30
- @dataclass
31
- class PortkeyModelConfig:
32
+ class PortkeyModelConfig(BaseModel):
32
33
  model_name: str
33
- model_kwargs: dict[str, Any] = field(default_factory=dict)
34
+ model_kwargs: dict[str, Any] = {}
35
+ provider: str = ""
36
+ """The LLM provider to use (e.g., 'openai', 'anthropic', 'google').
37
+ If not specified, will be auto-detected from model_name.
38
+ Required by Portkey when not using a virtual key.
39
+ """
34
40
  litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
35
41
  """We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry.
36
42
  Note that this might change if we get better support for Portkey and change how we calculate costs.
@@ -44,17 +50,25 @@ class PortkeyModelConfig:
44
50
  """Set explicit cache control markers, for example for Anthropic models"""
45
51
  cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
46
52
  """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
53
+ format_error_template: str = "{{ error }}"
54
+ """Template used when the LM's output is not in the expected format."""
55
+ observation_template: str = (
56
+ "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
57
+ "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
58
+ )
59
+ """Template used to render the observation after executing an action."""
60
+ multimodal_regex: str = ""
61
+ """Regex to extract multimodal content. Empty string disables multimodal processing."""
47
62
 
48
63
 
49
64
  class PortkeyModel:
65
+ abort_exceptions: list[type[Exception]] = [KeyboardInterrupt, TypeError, ValueError]
66
+
50
67
  def __init__(self, *, config_class: type = PortkeyModelConfig, **kwargs):
51
68
  self.config = config_class(**kwargs)
52
- self.cost = 0.0
53
- self.n_calls = 0
54
69
  if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
55
70
  litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
56
71
 
57
- # Get API key from environment or raise error
58
72
  self._api_key = os.getenv("PORTKEY_API_KEY")
59
73
  if not self._api_key:
60
74
  raise ValueError(
@@ -63,51 +77,79 @@ class PortkeyModel:
63
77
  "`mini-extra config set PORTKEY_API_KEY YOUR_KEY`."
64
78
  )
65
79
 
66
- # Get virtual key from environment
67
80
  virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY")
68
-
69
- # Initialize Portkey client
70
81
  client_kwargs = {"api_key": self._api_key}
71
82
  if virtual_key:
72
83
  client_kwargs["virtual_key"] = virtual_key
84
+ elif self.config.provider:
85
+ # If no virtual key but provider is specified, pass it
86
+ client_kwargs["provider"] = self.config.provider
73
87
 
74
88
  self.client = Portkey(**client_kwargs)
75
89
 
76
- @retry(
77
- reraise=True,
78
- stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
79
- wait=wait_exponential(multiplier=1, min=4, max=60),
80
- before_sleep=before_sleep_log(logger, logging.WARNING),
81
- retry=retry_if_not_exception_type((KeyboardInterrupt, TypeError, ValueError)),
82
- )
83
90
  def _query(self, messages: list[dict[str, str]], **kwargs):
84
- # return self.client.with_options(metadata={"request_id": request_id}).chat.completions.create(
85
91
  return self.client.chat.completions.create(
86
92
  model=self.config.model_name,
87
93
  messages=messages,
94
+ tools=[BASH_TOOL],
88
95
  **(self.config.model_kwargs | kwargs),
89
96
  )
90
97
 
98
+ def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
99
+ prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
100
+ prepared = _reorder_anthropic_thinking_blocks(prepared)
101
+ return set_cache_control(prepared, mode=self.config.set_cache_control)
102
+
91
103
  def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
92
- if self.config.set_cache_control:
93
- messages = set_cache_control(messages, mode=self.config.set_cache_control)
94
- response = self._query([{"role": msg["role"], "content": msg["content"]} for msg in messages], **kwargs)
95
- cost = self._calculate_cost(response)
96
- self.n_calls += 1
97
- self.cost += cost
98
- GLOBAL_MODEL_STATS.add(cost)
99
- return {
100
- "content": response.choices[0].message.content or "",
101
- "extra": {
102
- "response": response.model_dump(),
103
- "cost": cost,
104
- },
104
+ for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
105
+ with attempt:
106
+ response = self._query(self._prepare_messages_for_api(messages), **kwargs)
107
+ cost_output = self._calculate_cost(response)
108
+ GLOBAL_MODEL_STATS.add(cost_output["cost"])
109
+ message = response.choices[0].message.model_dump()
110
+ message["extra"] = {
111
+ "actions": self._parse_actions(response),
112
+ "response": response.model_dump(),
113
+ **cost_output,
114
+ "timestamp": time.time(),
105
115
  }
116
+ return message
117
+
118
+ def _parse_actions(self, response) -> list[dict]:
119
+ """Parse tool calls from the response. Raises FormatError if unknown tool."""
120
+ tool_calls = response.choices[0].message.tool_calls or []
121
+ return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
122
+
123
+ def format_message(self, **kwargs) -> dict:
124
+ return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
125
+
126
+ def format_observation_messages(
127
+ self, message: dict, outputs: list[dict], template_vars: dict | None = None
128
+ ) -> list[dict]:
129
+ """Format execution outputs into tool result messages."""
130
+ actions = message.get("extra", {}).get("actions", [])
131
+ return format_toolcall_observation_messages(
132
+ actions=actions,
133
+ outputs=outputs,
134
+ observation_template=self.config.observation_template,
135
+ template_vars=template_vars,
136
+ multimodal_regex=self.config.multimodal_regex,
137
+ )
106
138
 
107
- def get_template_vars(self) -> dict[str, Any]:
108
- return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
139
+ def get_template_vars(self, **kwargs) -> dict[str, Any]:
140
+ return self.config.model_dump()
141
+
142
+ def serialize(self) -> dict:
143
+ return {
144
+ "info": {
145
+ "config": {
146
+ "model": self.config.model_dump(mode="json"),
147
+ "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
148
+ },
149
+ }
150
+ }
109
151
 
110
- def _calculate_cost(self, response) -> float:
152
+ def _calculate_cost(self, response) -> dict[str, float]:
111
153
  response_for_cost_calc = response.model_copy()
112
154
  if self.config.litellm_model_name_override:
113
155
  if response_for_cost_calc.model:
@@ -152,4 +194,4 @@ class PortkeyModel:
152
194
  )
153
195
  logger.critical(msg)
154
196
  raise RuntimeError(msg) from e
155
- return cost
197
+ return {"cost": cost}