mini-swe-agent 1.17.5__py3-none-any.whl → 2.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {mini_swe_agent-1.17.5.dist-info → mini_swe_agent-2.0.0a1.dist-info}/METADATA +36 -52
  2. mini_swe_agent-2.0.0a1.dist-info/RECORD +70 -0
  3. mini_swe_agent-2.0.0a1.dist-info/entry_points.txt +5 -0
  4. minisweagent/__init__.py +19 -26
  5. minisweagent/agents/default.py +128 -113
  6. minisweagent/agents/interactive.py +119 -58
  7. minisweagent/config/README.md +3 -4
  8. minisweagent/config/__init__.py +36 -1
  9. minisweagent/config/benchmarks/swebench.yaml +156 -0
  10. minisweagent/config/{extra/swebench.yaml → benchmarks/swebench_backticks.yaml} +69 -64
  11. minisweagent/config/benchmarks/swebench_modal.yaml +47 -0
  12. minisweagent/config/{extra → benchmarks}/swebench_xml.yaml +73 -70
  13. minisweagent/config/default.yaml +24 -21
  14. minisweagent/config/inspector.tcss +42 -0
  15. minisweagent/config/mini.yaml +53 -71
  16. minisweagent/config/{github_issue.yaml → mini_textbased.yaml} +43 -29
  17. minisweagent/environments/__init__.py +1 -0
  18. minisweagent/environments/docker.py +67 -20
  19. minisweagent/environments/extra/bubblewrap.py +86 -47
  20. minisweagent/environments/extra/swerex_docker.py +53 -20
  21. minisweagent/environments/extra/swerex_modal.py +90 -0
  22. minisweagent/environments/local.py +62 -21
  23. minisweagent/environments/singularity.py +59 -18
  24. minisweagent/exceptions.py +22 -0
  25. minisweagent/models/__init__.py +6 -7
  26. minisweagent/models/extra/roulette.py +20 -17
  27. minisweagent/models/litellm_model.py +90 -44
  28. minisweagent/models/litellm_response_model.py +80 -0
  29. minisweagent/models/litellm_textbased_model.py +45 -0
  30. minisweagent/models/openrouter_model.py +87 -45
  31. minisweagent/models/openrouter_response_model.py +123 -0
  32. minisweagent/models/openrouter_textbased_model.py +76 -0
  33. minisweagent/models/portkey_model.py +84 -42
  34. minisweagent/models/portkey_response_model.py +163 -0
  35. minisweagent/models/requesty_model.py +91 -41
  36. minisweagent/models/test_models.py +246 -19
  37. minisweagent/models/utils/actions_text.py +60 -0
  38. minisweagent/models/utils/actions_toolcall.py +102 -0
  39. minisweagent/models/utils/actions_toolcall_response.py +110 -0
  40. minisweagent/models/utils/anthropic_utils.py +28 -0
  41. minisweagent/models/utils/cache_control.py +15 -2
  42. minisweagent/models/utils/content_string.py +74 -0
  43. minisweagent/models/utils/openai_multimodal.py +50 -0
  44. minisweagent/models/utils/retry.py +25 -0
  45. minisweagent/run/benchmarks/__init__.py +1 -0
  46. minisweagent/run/{extra → benchmarks}/swebench.py +56 -35
  47. minisweagent/run/{extra → benchmarks}/swebench_single.py +36 -26
  48. minisweagent/run/{extra → benchmarks}/utils/batch_progress.py +1 -1
  49. minisweagent/run/hello_world.py +6 -0
  50. minisweagent/run/mini.py +54 -63
  51. minisweagent/run/utilities/__init__.py +1 -0
  52. minisweagent/run/{extra → utilities}/config.py +2 -0
  53. minisweagent/run/{inspector.py → utilities/inspector.py} +90 -11
  54. minisweagent/run/{mini_extra.py → utilities/mini_extra.py} +9 -5
  55. minisweagent/utils/serialize.py +26 -0
  56. mini_swe_agent-1.17.5.dist-info/RECORD +0 -61
  57. mini_swe_agent-1.17.5.dist-info/entry_points.txt +0 -5
  58. minisweagent/agents/interactive_textual.py +0 -450
  59. minisweagent/config/extra/swebench_roulette.yaml +0 -233
  60. minisweagent/config/mini.tcss +0 -86
  61. minisweagent/models/anthropic.py +0 -35
  62. minisweagent/models/litellm_response_api_model.py +0 -82
  63. minisweagent/models/portkey_response_api_model.py +0 -75
  64. minisweagent/models/utils/key_per_thread.py +0 -20
  65. minisweagent/models/utils/openai_utils.py +0 -41
  66. minisweagent/run/github_issue.py +0 -87
  67. minisweagent/run/utils/__init__.py +0 -0
  68. minisweagent/run/utils/save.py +0 -78
  69. {mini_swe_agent-1.17.5.dist-info → mini_swe_agent-2.0.0a1.dist-info}/WHEEL +0 -0
  70. {mini_swe_agent-1.17.5.dist-info → mini_swe_agent-2.0.0a1.dist-info}/licenses/LICENSE.md +0 -0
  71. {mini_swe_agent-1.17.5.dist-info → mini_swe_agent-2.0.0a1.dist-info}/top_level.txt +0 -0
  72. /minisweagent/config/{extra → benchmarks}/__init__.py +0 -0
  73. /minisweagent/run/{extra → benchmarks}/utils/__init__.py +0 -0
@@ -1,14 +1,17 @@
1
1
  import os
2
2
  import platform
3
3
  import subprocess
4
- from dataclasses import asdict, dataclass, field
5
4
  from typing import Any
6
5
 
6
+ from pydantic import BaseModel
7
7
 
8
- @dataclass
9
- class LocalEnvironmentConfig:
8
+ from minisweagent.exceptions import Submitted
9
+ from minisweagent.utils.serialize import recursive_merge
10
+
11
+
12
+ class LocalEnvironmentConfig(BaseModel):
10
13
  cwd: str = ""
11
- env: dict[str, str] = field(default_factory=dict)
14
+ env: dict[str, str] = {}
12
15
  timeout: int = 30
13
16
 
14
17
 
@@ -17,22 +20,60 @@ class LocalEnvironment:
17
20
  """This class executes bash commands directly on the local machine."""
18
21
  self.config = config_class(**kwargs)
19
22
 
20
- def execute(self, command: str, cwd: str = "", *, timeout: int | None = None):
23
+ def execute(self, action: dict, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
21
24
  """Execute a command in the local environment and return the result as a dict."""
25
+ command = action.get("command", "")
22
26
  cwd = cwd or self.config.cwd or os.getcwd()
23
- result = subprocess.run(
24
- command,
25
- shell=True,
26
- text=True,
27
- cwd=cwd,
28
- env=os.environ | self.config.env,
29
- timeout=timeout or self.config.timeout,
30
- encoding="utf-8",
31
- errors="replace",
32
- stdout=subprocess.PIPE,
33
- stderr=subprocess.STDOUT,
34
- )
35
- return {"output": result.stdout, "returncode": result.returncode}
36
-
37
- def get_template_vars(self) -> dict[str, Any]:
38
- return asdict(self.config) | platform.uname()._asdict() | os.environ
27
+ try:
28
+ result = subprocess.run(
29
+ command,
30
+ shell=True,
31
+ text=True,
32
+ cwd=cwd,
33
+ env=os.environ | self.config.env,
34
+ timeout=timeout or self.config.timeout,
35
+ encoding="utf-8",
36
+ errors="replace",
37
+ stdout=subprocess.PIPE,
38
+ stderr=subprocess.STDOUT,
39
+ )
40
+ output = {"output": result.stdout, "returncode": result.returncode, "exception_info": ""}
41
+ except Exception as e:
42
+ raw_output = getattr(e, "output", None)
43
+ raw_output = (
44
+ raw_output.decode("utf-8", errors="replace") if isinstance(raw_output, bytes) else (raw_output or "")
45
+ )
46
+ output = {
47
+ "output": raw_output,
48
+ "returncode": -1,
49
+ "exception_info": f"An error occurred while executing the command: {e}",
50
+ "extra": {"exception_type": type(e).__name__, "exception": str(e)},
51
+ }
52
+ self._check_finished(output)
53
+ return output
54
+
55
+ def _check_finished(self, output: dict):
56
+ """Raises Submitted if the output indicates task completion."""
57
+ lines = output.get("output", "").lstrip().splitlines(keepends=True)
58
+ if lines and lines[0].strip() == "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT" and output["returncode"] == 0:
59
+ submission = "".join(lines[1:])
60
+ raise Submitted(
61
+ {
62
+ "role": "exit",
63
+ "content": submission,
64
+ "extra": {"exit_status": "Submitted", "submission": submission},
65
+ }
66
+ )
67
+
68
+ def get_template_vars(self, **kwargs) -> dict[str, Any]:
69
+ return recursive_merge(self.config.model_dump(), platform.uname()._asdict(), os.environ, kwargs)
70
+
71
+ def serialize(self) -> dict:
72
+ return {
73
+ "info": {
74
+ "config": {
75
+ "environment": self.config.model_dump(mode="json"),
76
+ "environment_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
77
+ }
78
+ }
79
+ }
@@ -6,18 +6,21 @@ import shutil
6
6
  import subprocess
7
7
  import tempfile
8
8
  import uuid
9
- from dataclasses import asdict, dataclass, field
10
9
  from pathlib import Path
11
10
  from typing import Any
12
11
 
12
+ from pydantic import BaseModel
13
13
 
14
- @dataclass
15
- class SingularityEnvironmentConfig:
14
+ from minisweagent.exceptions import Submitted
15
+ from minisweagent.utils.serialize import recursive_merge
16
+
17
+
18
+ class SingularityEnvironmentConfig(BaseModel):
16
19
  image: str
17
20
  cwd: str = "/"
18
- env: dict[str, str] = field(default_factory=dict)
21
+ env: dict[str, str] = {}
19
22
  """Environment variables to set in the container."""
20
- forward_env: list[str] = field(default_factory=list)
23
+ forward_env: list[str] = []
21
24
  """Environment variables to forward to the container."""
22
25
  timeout: int = 30
23
26
  """Timeout for executing commands in the container."""
@@ -57,11 +60,22 @@ class SingularityEnvironment:
57
60
  raise
58
61
  return sandbox_dir
59
62
 
60
- def get_template_vars(self) -> dict[str, Any]:
61
- return asdict(self.config)
63
+ def get_template_vars(self, **kwargs) -> dict[str, Any]:
64
+ return recursive_merge(self.config.model_dump(), kwargs)
62
65
 
63
- def execute(self, command: str, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
66
+ def serialize(self) -> dict:
67
+ return {
68
+ "info": {
69
+ "config": {
70
+ "environment": self.config.model_dump(mode="json"),
71
+ "environment_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
72
+ }
73
+ }
74
+ }
75
+
76
+ def execute(self, action: dict, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
64
77
  """Execute a command in a Singularity container and return the result as a dict."""
78
+ command = action.get("command", "")
65
79
  cmd = [self.config.executable, "exec"]
66
80
 
67
81
  # Do not inherit directories and env vars from host
@@ -78,16 +92,43 @@ class SingularityEnvironment:
78
92
  cmd.extend(["--env", f"{key}={value}"])
79
93
 
80
94
  cmd.extend(["--writable", str(self.sandbox_dir), "bash", "-c", command])
81
- result = subprocess.run(
82
- cmd,
83
- text=True,
84
- timeout=timeout or self.config.timeout,
85
- encoding="utf-8",
86
- errors="replace",
87
- stdout=subprocess.PIPE,
88
- stderr=subprocess.STDOUT,
89
- )
90
- return {"output": result.stdout, "returncode": result.returncode}
95
+ try:
96
+ result = subprocess.run(
97
+ cmd,
98
+ text=True,
99
+ timeout=timeout or self.config.timeout,
100
+ encoding="utf-8",
101
+ errors="replace",
102
+ stdout=subprocess.PIPE,
103
+ stderr=subprocess.STDOUT,
104
+ )
105
+ output = {"output": result.stdout, "returncode": result.returncode, "exception_info": ""}
106
+ except Exception as e:
107
+ raw_output = getattr(e, "output", None)
108
+ raw_output = (
109
+ raw_output.decode("utf-8", errors="replace") if isinstance(raw_output, bytes) else (raw_output or "")
110
+ )
111
+ output = {
112
+ "output": raw_output,
113
+ "returncode": -1,
114
+ "exception_info": f"An error occurred while executing the command: {e}",
115
+ "extra": {"exception_type": type(e).__name__, "exception": str(e)},
116
+ }
117
+ self._check_finished(output)
118
+ return output
119
+
120
+ def _check_finished(self, output: dict):
121
+ """Raises Submitted if the output indicates task completion."""
122
+ lines = output.get("output", "").lstrip().splitlines(keepends=True)
123
+ if lines and lines[0].strip() == "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT" and output["returncode"] == 0:
124
+ submission = "".join(lines[1:])
125
+ raise Submitted(
126
+ {
127
+ "role": "exit",
128
+ "content": submission,
129
+ "extra": {"exit_status": "Submitted", "submission": submission},
130
+ }
131
+ )
91
132
 
92
133
  def cleanup(self):
93
134
  shutil.rmtree(self.sandbox_dir, ignore_errors=True)
@@ -0,0 +1,22 @@
1
+ class InterruptAgentFlow(Exception):
2
+ """Raised to interrupt the agent flow and add messages."""
3
+
4
+ def __init__(self, *messages: dict):
5
+ self.messages = messages
6
+ super().__init__()
7
+
8
+
9
+ class Submitted(InterruptAgentFlow):
10
+ """Raised when the agent has completed its task."""
11
+
12
+
13
+ class LimitsExceeded(InterruptAgentFlow):
14
+ """Raised when the agent has exceeded its cost or step limit."""
15
+
16
+
17
+ class UserInterruption(InterruptAgentFlow):
18
+ """Raised when the user interrupts the agent."""
19
+
20
+
21
+ class FormatError(InterruptAgentFlow):
22
+ """Raised when the LM's output is not in the expected format."""
@@ -28,7 +28,7 @@ class GlobalModelStats:
28
28
  self._cost += cost
29
29
  self._n_calls += 1
30
30
  if 0 < self.cost_limit < self._cost or 0 < self.call_limit < self._n_calls + 1:
31
- raise RuntimeError(f"Global cost/call limit exceeded: ${self._cost:.4f} / {self._n_calls + 1}")
31
+ raise RuntimeError(f"Global cost/call limit exceeded: ${self._cost:.4f} / {self._n_calls}")
32
32
 
33
33
  @property
34
34
  def cost(self) -> float:
@@ -52,9 +52,6 @@ def get_model(input_model_name: str | None = None, config: dict | None = None) -
52
52
 
53
53
  model_class = get_model_class(resolved_model_name, config.pop("model_class", ""))
54
54
 
55
- if (from_env := os.getenv("MSWEA_MODEL_API_KEY")) and not str(type(model_class)).endswith("DeterministicModel"):
56
- config.setdefault("model_kwargs", {})["api_key"] = from_env
57
-
58
55
  if (
59
56
  any(s in resolved_model_name.lower() for s in ["anthropic", "sonnet", "opus", "claude"])
60
57
  and "set_cache_control" not in config
@@ -79,12 +76,14 @@ def get_model_name(input_model_name: str | None = None, config: dict | None = No
79
76
 
80
77
 
81
78
  _MODEL_CLASS_MAPPING = {
82
- "anthropic": "minisweagent.models.anthropic.AnthropicModel",
83
79
  "litellm": "minisweagent.models.litellm_model.LitellmModel",
84
- "litellm_response": "minisweagent.models.litellm_response_api_model.LitellmResponseAPIModel",
80
+ "litellm_textbased": "minisweagent.models.litellm_textbased_model.LitellmTextbasedModel",
81
+ "litellm_response": "minisweagent.models.litellm_response_model.LitellmResponseModel",
85
82
  "openrouter": "minisweagent.models.openrouter_model.OpenRouterModel",
83
+ "openrouter_textbased": "minisweagent.models.openrouter_textbased_model.OpenRouterTextbasedModel",
84
+ "openrouter_response": "minisweagent.models.openrouter_response_model.OpenRouterResponseModel",
86
85
  "portkey": "minisweagent.models.portkey_model.PortkeyModel",
87
- "portkey_response": "minisweagent.models.portkey_response_api_model.PortkeyResponseAPIModel",
86
+ "portkey_response": "minisweagent.models.portkey_response_model.PortkeyResponseAPIModel",
88
87
  "requesty": "minisweagent.models.requesty_model.RequestyModel",
89
88
  "deterministic": "minisweagent.models.test_models.DeterministicModel",
90
89
  }
@@ -1,12 +1,12 @@
1
1
  import random
2
- from dataclasses import asdict, dataclass
2
+
3
+ from pydantic import BaseModel
3
4
 
4
5
  from minisweagent import Model
5
6
  from minisweagent.models import get_model
6
7
 
7
8
 
8
- @dataclass
9
- class RouletteModelConfig:
9
+ class RouletteModelConfig(BaseModel):
10
10
  model_kwargs: list[dict]
11
11
  """The models to choose from"""
12
12
  model_name: str = "roulette"
@@ -17,30 +17,33 @@ class RouletteModel:
17
17
  """This "meta"-model randomly selects one of the models at every call"""
18
18
  self.config = config_class(**kwargs)
19
19
  self.models = [get_model(config=config) for config in self.config.model_kwargs]
20
+ self._n_calls = 0
20
21
 
21
- @property
22
- def cost(self) -> float:
23
- return sum(model.cost for model in self.models)
24
-
25
- @property
26
- def n_calls(self) -> int:
27
- return sum(model.n_calls for model in self.models)
28
-
29
- def get_template_vars(self) -> dict:
30
- return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
22
+ def get_template_vars(self, **kwargs) -> dict:
23
+ return self.config.model_dump()
31
24
 
32
25
  def select_model(self) -> Model:
33
26
  return random.choice(self.models)
34
27
 
35
28
  def query(self, *args, **kwargs) -> dict:
36
29
  model = self.select_model()
30
+ self._n_calls += 1
37
31
  response = model.query(*args, **kwargs)
38
32
  response["model_name"] = model.config.model_name
39
33
  return response
40
34
 
35
+ def serialize(self) -> dict:
36
+ return {
37
+ "info": {
38
+ "config": {
39
+ "model": self.config.model_dump(mode="json"),
40
+ "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
41
+ },
42
+ }
43
+ }
44
+
41
45
 
42
- @dataclass
43
- class InterleavingModelConfig:
46
+ class InterleavingModelConfig(BaseModel):
44
47
  model_kwargs: list[dict]
45
48
  sequence: list[int] | None = None
46
49
  """If set to 0, 0, 1, we will return the first model 2 times, then the second model 1 time,
@@ -55,7 +58,7 @@ class InterleavingModel(RouletteModel):
55
58
 
56
59
  def select_model(self) -> Model:
57
60
  if self.config.sequence is None:
58
- i_model = self.n_calls % len(self.models)
61
+ i_model = self._n_calls % len(self.models)
59
62
  else:
60
- i_model = self.config.sequence[self.n_calls % len(self.config.sequence)]
63
+ i_model = self.config.sequence[self._n_calls % len(self.config.sequence)]
61
64
  return self.models[i_model]
@@ -1,75 +1,98 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import time
4
5
  from collections.abc import Callable
5
- from dataclasses import asdict, dataclass, field
6
6
  from pathlib import Path
7
7
  from typing import Any, Literal
8
8
 
9
9
  import litellm
10
- from tenacity import (
11
- before_sleep_log,
12
- retry,
13
- retry_if_not_exception_type,
14
- stop_after_attempt,
15
- wait_exponential,
16
- )
10
+ from pydantic import BaseModel
17
11
 
18
12
  from minisweagent.models import GLOBAL_MODEL_STATS
13
+ from minisweagent.models.utils.actions_toolcall import (
14
+ BASH_TOOL,
15
+ format_toolcall_observation_messages,
16
+ parse_toolcall_actions,
17
+ )
18
+ from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
19
19
  from minisweagent.models.utils.cache_control import set_cache_control
20
+ from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
21
+ from minisweagent.models.utils.retry import retry
20
22
 
21
23
  logger = logging.getLogger("litellm_model")
22
24
 
23
25
 
24
- @dataclass
25
- class LitellmModelConfig:
26
+ class LitellmModelConfig(BaseModel):
26
27
  model_name: str
27
- model_kwargs: dict[str, Any] = field(default_factory=dict)
28
+ """Model name. Highly recommended to include the provider in the model name, e.g., `anthropic/claude-sonnet-4-5-20250929`."""
29
+ model_kwargs: dict[str, Any] = {}
30
+ """Additional arguments passed to the API."""
28
31
  litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
32
+ """Model registry for cost tracking and model metadata. See the local model guide (https://mini-swe-agent.com/latest/models/local_models/) for more details."""
29
33
  set_cache_control: Literal["default_end"] | None = None
30
34
  """Set explicit cache control markers, for example for Anthropic models"""
31
35
  cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
32
36
  """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
37
+ format_error_template: str = "{{ error }}"
38
+ """Template used when the LM's output is not in the expected format."""
39
+ observation_template: str = (
40
+ "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
41
+ "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
42
+ )
43
+ """Template used to render the observation after executing an action."""
44
+ multimodal_regex: str = ""
45
+ """Regex to extract multimodal content. Empty string disables multimodal processing."""
33
46
 
34
47
 
35
48
  class LitellmModel:
49
+ abort_exceptions: list[type[Exception]] = [
50
+ litellm.exceptions.UnsupportedParamsError,
51
+ litellm.exceptions.NotFoundError,
52
+ litellm.exceptions.PermissionDeniedError,
53
+ litellm.exceptions.ContextWindowExceededError,
54
+ litellm.exceptions.AuthenticationError,
55
+ KeyboardInterrupt,
56
+ ]
57
+
36
58
  def __init__(self, *, config_class: Callable = LitellmModelConfig, **kwargs):
37
59
  self.config = config_class(**kwargs)
38
- self.cost = 0.0
39
- self.n_calls = 0
40
60
  if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
41
61
  litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
42
62
 
43
- @retry(
44
- reraise=True,
45
- stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
46
- wait=wait_exponential(multiplier=1, min=4, max=60),
47
- before_sleep=before_sleep_log(logger, logging.WARNING),
48
- retry=retry_if_not_exception_type(
49
- (
50
- litellm.exceptions.UnsupportedParamsError,
51
- litellm.exceptions.NotFoundError,
52
- litellm.exceptions.PermissionDeniedError,
53
- litellm.exceptions.ContextWindowExceededError,
54
- litellm.exceptions.APIError,
55
- litellm.exceptions.AuthenticationError,
56
- KeyboardInterrupt,
57
- )
58
- ),
59
- )
60
63
  def _query(self, messages: list[dict[str, str]], **kwargs):
61
64
  try:
62
65
  return litellm.completion(
63
- model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
66
+ model=self.config.model_name,
67
+ messages=messages,
68
+ tools=[BASH_TOOL],
69
+ **(self.config.model_kwargs | kwargs),
64
70
  )
65
71
  except litellm.exceptions.AuthenticationError as e:
66
72
  e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
67
73
  raise e
68
74
 
75
+ def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
76
+ prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
77
+ prepared = _reorder_anthropic_thinking_blocks(prepared)
78
+ return set_cache_control(prepared, mode=self.config.set_cache_control)
79
+
69
80
  def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
70
- if self.config.set_cache_control:
71
- messages = set_cache_control(messages, mode=self.config.set_cache_control)
72
- response = self._query([{"role": msg["role"], "content": msg["content"]} for msg in messages], **kwargs)
81
+ for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
82
+ with attempt:
83
+ response = self._query(self._prepare_messages_for_api(messages), **kwargs)
84
+ cost_output = self._calculate_cost(response)
85
+ GLOBAL_MODEL_STATS.add(cost_output["cost"])
86
+ message = response.choices[0].message.model_dump()
87
+ message["extra"] = {
88
+ "actions": self._parse_actions(response),
89
+ "response": response.model_dump(),
90
+ **cost_output,
91
+ "timestamp": time.time(),
92
+ }
93
+ return message
94
+
95
+ def _calculate_cost(self, response) -> dict[str, float]:
73
96
  try:
74
97
  cost = litellm.cost_calculator.completion_cost(response, model=self.config.model_name)
75
98
  if cost <= 0.0:
@@ -87,15 +110,38 @@ class LitellmModel:
87
110
  )
88
111
  logger.critical(msg)
89
112
  raise RuntimeError(msg) from e
90
- self.n_calls += 1
91
- self.cost += cost
92
- GLOBAL_MODEL_STATS.add(cost)
113
+ return {"cost": cost}
114
+
115
+ def _parse_actions(self, response) -> list[dict]:
116
+ """Parse tool calls from the response. Raises FormatError if unknown tool."""
117
+ tool_calls = response.choices[0].message.tool_calls or []
118
+ return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
119
+
120
+ def format_message(self, **kwargs) -> dict:
121
+ return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
122
+
123
+ def format_observation_messages(
124
+ self, message: dict, outputs: list[dict], template_vars: dict | None = None
125
+ ) -> list[dict]:
126
+ """Format execution outputs into tool result messages."""
127
+ actions = message.get("extra", {}).get("actions", [])
128
+ return format_toolcall_observation_messages(
129
+ actions=actions,
130
+ outputs=outputs,
131
+ observation_template=self.config.observation_template,
132
+ template_vars=template_vars,
133
+ multimodal_regex=self.config.multimodal_regex,
134
+ )
135
+
136
+ def get_template_vars(self, **kwargs) -> dict[str, Any]:
137
+ return self.config.model_dump()
138
+
139
+ def serialize(self) -> dict:
93
140
  return {
94
- "content": response.choices[0].message.content or "", # type: ignore
95
- "extra": {
96
- "response": response.model_dump(),
97
- },
141
+ "info": {
142
+ "config": {
143
+ "model": self.config.model_dump(mode="json"),
144
+ "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
145
+ },
146
+ }
98
147
  }
99
-
100
- def get_template_vars(self) -> dict[str, Any]:
101
- return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
@@ -0,0 +1,80 @@
1
+ import logging
2
+ import time
3
+ from collections.abc import Callable
4
+
5
+ import litellm
6
+
7
+ from minisweagent.models import GLOBAL_MODEL_STATS
8
+ from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
9
+ from minisweagent.models.utils.actions_toolcall_response import (
10
+ BASH_TOOL_RESPONSE_API,
11
+ format_toolcall_observation_messages,
12
+ parse_toolcall_actions_response,
13
+ )
14
+ from minisweagent.models.utils.retry import retry
15
+
16
+ logger = logging.getLogger("litellm_response_model")
17
+
18
+
19
+ class LitellmResponseModelConfig(LitellmModelConfig):
20
+ pass
21
+
22
+
23
+ class LitellmResponseModel(LitellmModel):
24
+ def __init__(self, *, config_class: Callable = LitellmResponseModelConfig, **kwargs):
25
+ super().__init__(config_class=config_class, **kwargs)
26
+
27
+ def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
28
+ """Flatten response objects into their output items for stateless API calls."""
29
+ result = []
30
+ for msg in messages:
31
+ if msg.get("object") == "response":
32
+ for item in msg.get("output", []):
33
+ result.append({k: v for k, v in item.items() if k != "extra"})
34
+ else:
35
+ result.append({k: v for k, v in msg.items() if k != "extra"})
36
+ return result
37
+
38
+ def _query(self, messages: list[dict[str, str]], **kwargs):
39
+ try:
40
+ return litellm.responses(
41
+ model=self.config.model_name,
42
+ input=messages,
43
+ tools=[BASH_TOOL_RESPONSE_API],
44
+ **(self.config.model_kwargs | kwargs),
45
+ )
46
+ except litellm.exceptions.AuthenticationError as e:
47
+ e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
48
+ raise e
49
+
50
+ def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
51
+ for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
52
+ with attempt:
53
+ response = self._query(self._prepare_messages_for_api(messages), **kwargs)
54
+ cost_output = self._calculate_cost(response)
55
+ GLOBAL_MODEL_STATS.add(cost_output["cost"])
56
+ message = response.model_dump() if hasattr(response, "model_dump") else dict(response)
57
+ message["extra"] = {
58
+ "actions": self._parse_actions(response),
59
+ **cost_output,
60
+ "timestamp": time.time(),
61
+ }
62
+ return message
63
+
64
+ def _parse_actions(self, response) -> list[dict]:
65
+ return parse_toolcall_actions_response(
66
+ getattr(response, "output", []), format_error_template=self.config.format_error_template
67
+ )
68
+
69
+ def format_observation_messages(
70
+ self, message: dict, outputs: list[dict], template_vars: dict | None = None
71
+ ) -> list[dict]:
72
+ """Format execution outputs into tool result messages."""
73
+ actions = message.get("extra", {}).get("actions", [])
74
+ return format_toolcall_observation_messages(
75
+ actions=actions,
76
+ outputs=outputs,
77
+ observation_template=self.config.observation_template,
78
+ template_vars=template_vars,
79
+ multimodal_regex=self.config.multimodal_regex,
80
+ )
@@ -0,0 +1,45 @@
1
+ import litellm
2
+
3
+ from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
4
+ from minisweagent.models.utils.actions_text import format_observation_messages, parse_regex_actions
5
+
6
+
7
+ class LitellmTextbasedModelConfig(LitellmModelConfig):
8
+ action_regex: str = r"```mswea_bash_command\s*\n(.*?)\n```"
9
+ """Regex to extract the action from the LM's output."""
10
+ format_error_template: str = (
11
+ "Please always provide EXACTLY ONE action in triple backticks, found {{actions|length}} actions."
12
+ )
13
+ """Template used when the LM's output is not in the expected format."""
14
+
15
+
16
+ class LitellmTextbasedModel(LitellmModel):
17
+ def __init__(self, **kwargs):
18
+ super().__init__(config_class=LitellmTextbasedModelConfig, **kwargs)
19
+
20
+ def _query(self, messages: list[dict[str, str]], **kwargs):
21
+ try:
22
+ return litellm.completion(
23
+ model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
24
+ )
25
+ except litellm.exceptions.AuthenticationError as e:
26
+ e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
27
+ raise e
28
+
29
+ def _parse_actions(self, response: dict) -> list[dict]:
30
+ """Parse actions from the model response. Raises FormatError if not exactly one action."""
31
+ content = response.choices[0].message.content or ""
32
+ return parse_regex_actions(
33
+ content, action_regex=self.config.action_regex, format_error_template=self.config.format_error_template
34
+ )
35
+
36
+ def format_observation_messages(
37
+ self, message: dict, outputs: list[dict], template_vars: dict | None = None
38
+ ) -> list[dict]:
39
+ """Format execution outputs into observation messages."""
40
+ return format_observation_messages(
41
+ outputs,
42
+ observation_template=self.config.observation_template,
43
+ template_vars=template_vars,
44
+ multimodal_regex=self.config.multimodal_regex,
45
+ )