mini-swe-agent 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. mini_swe_agent-1.16.0.dist-info/METADATA +314 -0
  2. mini_swe_agent-1.16.0.dist-info/RECORD +62 -0
  3. mini_swe_agent-1.16.0.dist-info/WHEEL +5 -0
  4. mini_swe_agent-1.16.0.dist-info/entry_points.txt +5 -0
  5. mini_swe_agent-1.16.0.dist-info/licenses/LICENSE.md +21 -0
  6. mini_swe_agent-1.16.0.dist-info/top_level.txt +1 -0
  7. minisweagent/__init__.py +83 -0
  8. minisweagent/__main__.py +7 -0
  9. minisweagent/agents/__init__.py +1 -0
  10. minisweagent/agents/default.py +131 -0
  11. minisweagent/agents/interactive.py +153 -0
  12. minisweagent/agents/interactive_textual.py +450 -0
  13. minisweagent/config/README.md +10 -0
  14. minisweagent/config/__init__.py +27 -0
  15. minisweagent/config/default.yaml +157 -0
  16. minisweagent/config/extra/__init__.py +1 -0
  17. minisweagent/config/extra/swebench.yaml +230 -0
  18. minisweagent/config/extra/swebench_roulette.yaml +233 -0
  19. minisweagent/config/extra/swebench_xml.yaml +215 -0
  20. minisweagent/config/github_issue.yaml +146 -0
  21. minisweagent/config/mini.tcss +86 -0
  22. minisweagent/config/mini.yaml +158 -0
  23. minisweagent/config/mini_no_temp.yaml +158 -0
  24. minisweagent/environments/__init__.py +31 -0
  25. minisweagent/environments/docker.py +114 -0
  26. minisweagent/environments/extra/__init__.py +0 -0
  27. minisweagent/environments/extra/bubblewrap.py +112 -0
  28. minisweagent/environments/extra/swerex_docker.py +47 -0
  29. minisweagent/environments/local.py +38 -0
  30. minisweagent/environments/singularity.py +97 -0
  31. minisweagent/models/__init__.py +114 -0
  32. minisweagent/models/anthropic.py +35 -0
  33. minisweagent/models/extra/__init__.py +0 -0
  34. minisweagent/models/extra/roulette.py +61 -0
  35. minisweagent/models/litellm_model.py +100 -0
  36. minisweagent/models/litellm_response_api_model.py +80 -0
  37. minisweagent/models/openrouter_model.py +125 -0
  38. minisweagent/models/portkey_model.py +154 -0
  39. minisweagent/models/portkey_response_api_model.py +74 -0
  40. minisweagent/models/requesty_model.py +119 -0
  41. minisweagent/models/test_models.py +42 -0
  42. minisweagent/models/utils/__init__.py +0 -0
  43. minisweagent/models/utils/cache_control.py +54 -0
  44. minisweagent/models/utils/key_per_thread.py +20 -0
  45. minisweagent/models/utils/openai_utils.py +41 -0
  46. minisweagent/py.typed +0 -0
  47. minisweagent/run/__init__.py +1 -0
  48. minisweagent/run/extra/__init__.py +0 -0
  49. minisweagent/run/extra/config.py +114 -0
  50. minisweagent/run/extra/swebench.py +266 -0
  51. minisweagent/run/extra/swebench_single.py +79 -0
  52. minisweagent/run/extra/utils/__init__.py +0 -0
  53. minisweagent/run/extra/utils/batch_progress.py +178 -0
  54. minisweagent/run/github_issue.py +87 -0
  55. minisweagent/run/hello_world.py +36 -0
  56. minisweagent/run/inspector.py +212 -0
  57. minisweagent/run/mini.py +108 -0
  58. minisweagent/run/mini_extra.py +44 -0
  59. minisweagent/run/utils/__init__.py +0 -0
  60. minisweagent/run/utils/save.py +78 -0
  61. minisweagent/utils/__init__.py +0 -0
  62. minisweagent/utils/log.py +36 -0
@@ -0,0 +1,114 @@
1
+ """This file provides convenience functions for selecting models.
2
+ You can ignore this file completely if you explicitly set your model in your run script.
3
+ """
4
+
5
+ import copy
6
+ import importlib
7
+ import os
8
+ import threading
9
+
10
+ from minisweagent import Model
11
+
12
+
13
+ class GlobalModelStats:
14
+ """Global model statistics tracker with optional limits."""
15
+
16
+ def __init__(self):
17
+ self._cost = 0.0
18
+ self._n_calls = 0
19
+ self._lock = threading.Lock()
20
+ self.cost_limit = float(os.getenv("MSWEA_GLOBAL_COST_LIMIT", "0"))
21
+ self.call_limit = int(os.getenv("MSWEA_GLOBAL_CALL_LIMIT", "0"))
22
+ if (self.cost_limit > 0 or self.call_limit > 0) and not os.getenv("MSWEA_SILENT_STARTUP"):
23
+ print(f"Global cost/call limit: ${self.cost_limit:.4f} / {self.call_limit}")
24
+
25
+ def add(self, cost: float) -> None:
26
+ """Add a model call with its cost, checking limits."""
27
+ with self._lock:
28
+ self._cost += cost
29
+ self._n_calls += 1
30
+ if 0 < self.cost_limit < self._cost or 0 < self.call_limit < self._n_calls + 1:
31
+ raise RuntimeError(f"Global cost/call limit exceeded: ${self._cost:.4f} / {self._n_calls + 1}")
32
+
33
+ @property
34
+ def cost(self) -> float:
35
+ return self._cost
36
+
37
+ @property
38
+ def n_calls(self) -> int:
39
+ return self._n_calls
40
+
41
+
42
+ GLOBAL_MODEL_STATS = GlobalModelStats()
43
+
44
+
45
+ def get_model(input_model_name: str | None = None, config: dict | None = None) -> Model:
46
+ """Get an initialized model object from any kind of user input or settings."""
47
+ resolved_model_name = get_model_name(input_model_name, config)
48
+ if config is None:
49
+ config = {}
50
+ config = copy.deepcopy(config)
51
+ config["model_name"] = resolved_model_name
52
+
53
+ model_class = get_model_class(resolved_model_name, config.pop("model_class", ""))
54
+
55
+ if (from_env := os.getenv("MSWEA_MODEL_API_KEY")) and not str(type(model_class)).endswith("DeterministicModel"):
56
+ config.setdefault("model_kwargs", {})["api_key"] = from_env
57
+
58
+ if (
59
+ any(s in resolved_model_name.lower() for s in ["anthropic", "sonnet", "opus", "claude"])
60
+ and "set_cache_control" not in config
61
+ ):
62
+ # Select cache control for Anthropic models by default
63
+ config["set_cache_control"] = "default_end"
64
+
65
+ return model_class(**config)
66
+
67
+
68
+ def get_model_name(input_model_name: str | None = None, config: dict | None = None) -> str:
69
+ """Get a model name from any kind of user input or settings."""
70
+ if config is None:
71
+ config = {}
72
+ if input_model_name:
73
+ return input_model_name
74
+ if from_config := config.get("model_name"):
75
+ return from_config
76
+ if from_env := os.getenv("MSWEA_MODEL_NAME"):
77
+ return from_env
78
+ raise ValueError("No default model set. Please run `mini-extra config setup` to set one.")
79
+
80
+
81
+ _MODEL_CLASS_MAPPING = {
82
+ "anthropic": "minisweagent.models.anthropic.AnthropicModel",
83
+ "litellm": "minisweagent.models.litellm_model.LitellmModel",
84
+ "litellm_response": "minisweagent.models.litellm_response_api_model.LitellmResponseAPIModel",
85
+ "openrouter": "minisweagent.models.openrouter_model.OpenRouterModel",
86
+ "portkey": "minisweagent.models.portkey_model.PortkeyModel",
87
+ "portkey_response": "minisweagent.models.portkey_response_api_model.PortkeyResponseAPIModel",
88
+ "requesty": "minisweagent.models.requesty_model.RequestyModel",
89
+ "deterministic": "minisweagent.models.test_models.DeterministicModel",
90
+ }
91
+
92
+
93
+ def get_model_class(model_name: str, model_class: str = "") -> type:
94
+ """Select the best model class.
95
+
96
+ If a model_class is provided (as shortcut name, or as full import path,
97
+ e.g., "anthropic" or "minisweagent.models.anthropic.AnthropicModel"),
98
+ it takes precedence over the `model_name`.
99
+ Otherwise, the model_name is used to select the best model class.
100
+ """
101
+ if model_class:
102
+ full_path = _MODEL_CLASS_MAPPING.get(model_class, model_class)
103
+ try:
104
+ module_name, class_name = full_path.rsplit(".", 1)
105
+ module = importlib.import_module(module_name)
106
+ return getattr(module, class_name)
107
+ except (ValueError, ImportError, AttributeError):
108
+ msg = f"Unknown model class: {model_class} (resolved to {full_path}, available: {_MODEL_CLASS_MAPPING})"
109
+ raise ValueError(msg)
110
+
111
+ # Default to LitellmModel
112
+ from minisweagent.models.litellm_model import LitellmModel
113
+
114
+ return LitellmModel
@@ -0,0 +1,35 @@
1
+ import os
2
+ import warnings
3
+ from typing import Literal
4
+
5
+ from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
6
+ from minisweagent.models.utils.cache_control import set_cache_control
7
+ from minisweagent.models.utils.key_per_thread import get_key_per_thread
8
+
9
+
10
+ class AnthropicModelConfig(LitellmModelConfig):
11
+ set_cache_control: Literal["default_end"] | None = "default_end"
12
+ """Set explicit cache control markers, for example for Anthropic models"""
13
+
14
+
15
+ class AnthropicModel(LitellmModel):
16
+ """This class is now only a thin wrapper around the LitellmModel class.
17
+ It is largely kept for backwards compatibility.
18
+ It will not be selected by `get_model` and `get_model_class` unless explicitly specified.
19
+ """
20
+
21
+ def __init__(self, *, config_class: type = AnthropicModelConfig, **kwargs):
22
+ super().__init__(config_class=config_class, **kwargs)
23
+
24
+ def query(self, messages: list[dict], **kwargs) -> dict:
25
+ api_key = None
26
+ # Legacy only
27
+ if rotating_keys := os.getenv("ANTHROPIC_API_KEYS"):
28
+ warnings.warn(
29
+ "ANTHROPIC_API_KEYS is deprecated and will be removed in the future. "
30
+ "Simply use the ANTHROPIC_API_KEY environment variable instead. "
31
+ "Key rotation is no longer required."
32
+ )
33
+ api_key = get_key_per_thread(rotating_keys.split("::"))
34
+ messages = set_cache_control(messages, mode="default_end")
35
+ return super().query(messages, api_key=api_key, **kwargs)
File without changes
@@ -0,0 +1,61 @@
1
+ import random
2
+ from dataclasses import asdict, dataclass
3
+
4
+ from minisweagent import Model
5
+ from minisweagent.models import get_model
6
+
7
+
8
+ @dataclass
9
+ class RouletteModelConfig:
10
+ model_kwargs: list[dict]
11
+ """The models to choose from"""
12
+ model_name: str = "roulette"
13
+
14
+
15
+ class RouletteModel:
16
+ def __init__(self, *, config_class: type = RouletteModelConfig, **kwargs):
17
+ """This "meta"-model randomly selects one of the models at every call"""
18
+ self.config = config_class(**kwargs)
19
+ self.models = [get_model(config=config) for config in self.config.model_kwargs]
20
+
21
+ @property
22
+ def cost(self) -> float:
23
+ return sum(model.cost for model in self.models)
24
+
25
+ @property
26
+ def n_calls(self) -> int:
27
+ return sum(model.n_calls for model in self.models)
28
+
29
+ def get_template_vars(self) -> dict:
30
+ return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
31
+
32
+ def select_model(self) -> Model:
33
+ return random.choice(self.models)
34
+
35
+ def query(self, *args, **kwargs) -> dict:
36
+ model = self.select_model()
37
+ response = model.query(*args, **kwargs)
38
+ response["model_name"] = model.config.model_name
39
+ return response
40
+
41
+
42
+ @dataclass
43
+ class InterleavingModelConfig:
44
+ model_kwargs: list[dict]
45
+ sequence: list[int] | None = None
46
+ """If set to 0, 0, 1, we will return the first model 2 times, then the second model 1 time,
47
+ then the first model again, etc."""
48
+ model_name: str = "interleaving"
49
+
50
+
51
+ class InterleavingModel(RouletteModel):
52
+ def __init__(self, *, config_class: type = InterleavingModelConfig, **kwargs):
53
+ """This "meta"-model alternates between the models in the sequence for every call"""
54
+ super().__init__(config_class=config_class, **kwargs)
55
+
56
+ def select_model(self) -> Model:
57
+ if self.config.sequence is None:
58
+ i_model = self.n_calls % len(self.models)
59
+ else:
60
+ i_model = self.config.sequence[self.n_calls % len(self.config.sequence)]
61
+ return self.models[i_model]
@@ -0,0 +1,100 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from collections.abc import Callable
5
+ from dataclasses import asdict, dataclass, field
6
+ from pathlib import Path
7
+ from typing import Any, Literal
8
+
9
+ import litellm
10
+ from tenacity import (
11
+ before_sleep_log,
12
+ retry,
13
+ retry_if_not_exception_type,
14
+ stop_after_attempt,
15
+ wait_exponential,
16
+ )
17
+
18
+ from minisweagent.models import GLOBAL_MODEL_STATS
19
+ from minisweagent.models.utils.cache_control import set_cache_control
20
+
21
+ logger = logging.getLogger("litellm_model")
22
+
23
+
24
+ @dataclass
25
+ class LitellmModelConfig:
26
+ model_name: str
27
+ model_kwargs: dict[str, Any] = field(default_factory=dict)
28
+ litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
29
+ set_cache_control: Literal["default_end"] | None = None
30
+ """Set explicit cache control markers, for example for Anthropic models"""
31
+ cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
32
+ """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
33
+
34
+
35
+ class LitellmModel:
36
+ def __init__(self, *, config_class: Callable = LitellmModelConfig, **kwargs):
37
+ self.config = config_class(**kwargs)
38
+ self.cost = 0.0
39
+ self.n_calls = 0
40
+ if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
41
+ litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
42
+
43
+ @retry(
44
+ stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
45
+ wait=wait_exponential(multiplier=1, min=4, max=60),
46
+ before_sleep=before_sleep_log(logger, logging.WARNING),
47
+ retry=retry_if_not_exception_type(
48
+ (
49
+ litellm.exceptions.UnsupportedParamsError,
50
+ litellm.exceptions.NotFoundError,
51
+ litellm.exceptions.PermissionDeniedError,
52
+ litellm.exceptions.ContextWindowExceededError,
53
+ litellm.exceptions.APIError,
54
+ litellm.exceptions.AuthenticationError,
55
+ KeyboardInterrupt,
56
+ )
57
+ ),
58
+ )
59
+ def _query(self, messages: list[dict[str, str]], **kwargs):
60
+ try:
61
+ return litellm.completion(
62
+ model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
63
+ )
64
+ except litellm.exceptions.AuthenticationError as e:
65
+ e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
66
+ raise e
67
+
68
+ def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
69
+ if self.config.set_cache_control:
70
+ messages = set_cache_control(messages, mode=self.config.set_cache_control)
71
+ response = self._query(messages, **kwargs)
72
+ try:
73
+ cost = litellm.cost_calculator.completion_cost(response)
74
+ if cost <= 0.0:
75
+ raise ValueError(f"Cost must be > 0.0, got {cost}")
76
+ except Exception as e:
77
+ cost = 0.0
78
+ if self.config.cost_tracking != "ignore_errors":
79
+ msg = (
80
+ f"Error calculating cost for model {self.config.model_name}: {e}, perhaps it's not registered? "
81
+ "You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
82
+ "globally with export MSWEA_COST_TRACKING='ignore_errors'. "
83
+ "Alternatively check the 'Cost tracking' section in the documentation at "
84
+ "https://klieret.short.gy/mini-local-models. "
85
+ " Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
86
+ )
87
+ logger.critical(msg)
88
+ raise RuntimeError(msg) from e
89
+ self.n_calls += 1
90
+ self.cost += cost
91
+ GLOBAL_MODEL_STATS.add(cost)
92
+ return {
93
+ "content": response.choices[0].message.content or "", # type: ignore
94
+ "extra": {
95
+ "response": response.model_dump(),
96
+ },
97
+ }
98
+
99
+ def get_template_vars(self) -> dict[str, Any]:
100
+ return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
@@ -0,0 +1,80 @@
1
+ import logging
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass
4
+
5
+ import litellm
6
+ from tenacity import (
7
+ before_sleep_log,
8
+ retry,
9
+ retry_if_not_exception_type,
10
+ stop_after_attempt,
11
+ wait_exponential,
12
+ )
13
+
14
+ from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
15
+ from minisweagent.models.utils.openai_utils import coerce_responses_text
16
+
17
+ logger = logging.getLogger("litellm_response_api_model")
18
+
19
+
20
+ @dataclass
21
+ class LitellmResponseAPIModelConfig(LitellmModelConfig):
22
+ pass
23
+
24
+
25
+ class LitellmResponseAPIModel(LitellmModel):
26
+ def __init__(self, *, config_class: Callable = LitellmResponseAPIModelConfig, **kwargs):
27
+ super().__init__(config_class=config_class, **kwargs)
28
+ self._previous_response_id: str | None = None
29
+
30
+ @retry(
31
+ stop=stop_after_attempt(10),
32
+ wait=wait_exponential(multiplier=1, min=4, max=60),
33
+ before_sleep=before_sleep_log(logger, logging.WARNING),
34
+ retry=retry_if_not_exception_type(
35
+ (
36
+ litellm.exceptions.UnsupportedParamsError,
37
+ litellm.exceptions.NotFoundError,
38
+ litellm.exceptions.PermissionDeniedError,
39
+ litellm.exceptions.ContextWindowExceededError,
40
+ litellm.exceptions.APIError,
41
+ litellm.exceptions.AuthenticationError,
42
+ KeyboardInterrupt,
43
+ )
44
+ ),
45
+ )
46
+ def _query(self, messages: list[dict[str, str]], **kwargs):
47
+ try:
48
+ resp = litellm.responses(
49
+ model=self.config.model_name,
50
+ input=messages if self._previous_response_id is None else messages[-1:],
51
+ previous_response_id=self._previous_response_id,
52
+ **(self.config.model_kwargs | kwargs),
53
+ )
54
+ self._previous_response_id = getattr(resp, "id", None)
55
+ return resp
56
+ except litellm.exceptions.AuthenticationError as e:
57
+ e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
58
+ raise e
59
+
60
+ def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
61
+ response = self._query(messages, **kwargs)
62
+ print(response)
63
+ text = coerce_responses_text(response)
64
+ try:
65
+ cost = litellm.cost_calculator.completion_cost(response)
66
+ except Exception as e:
67
+ logger.critical(
68
+ f"Error calculating cost for model {self.config.model_name}: {e}. "
69
+ "Please check the 'Updating the model registry' section in the documentation. "
70
+ "http://bit.ly/4p31bi4 Still stuck? Please open a github issue for help!"
71
+ )
72
+ raise
73
+ self.n_calls += 1
74
+ self.cost += cost
75
+ from minisweagent.models import GLOBAL_MODEL_STATS
76
+
77
+ GLOBAL_MODEL_STATS.add(cost)
78
+ return {
79
+ "content": text,
80
+ }
@@ -0,0 +1,125 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from dataclasses import asdict, dataclass, field
5
+ from typing import Any, Literal
6
+
7
+ import requests
8
+ from tenacity import (
9
+ before_sleep_log,
10
+ retry,
11
+ retry_if_not_exception_type,
12
+ stop_after_attempt,
13
+ wait_exponential,
14
+ )
15
+
16
+ from minisweagent.models import GLOBAL_MODEL_STATS
17
+ from minisweagent.models.utils.cache_control import set_cache_control
18
+
19
+ logger = logging.getLogger("openrouter_model")
20
+
21
+
22
+ @dataclass
23
+ class OpenRouterModelConfig:
24
+ model_name: str
25
+ model_kwargs: dict[str, Any] = field(default_factory=dict)
26
+ set_cache_control: Literal["default_end"] | None = None
27
+ """Set explicit cache control markers, for example for Anthropic models"""
28
+ cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
29
+ """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
30
+
31
+
32
+ class OpenRouterAPIError(Exception):
33
+ """Custom exception for OpenRouter API errors."""
34
+
35
+ pass
36
+
37
+
38
+ class OpenRouterAuthenticationError(Exception):
39
+ """Custom exception for OpenRouter authentication errors."""
40
+
41
+ pass
42
+
43
+
44
+ class OpenRouterRateLimitError(Exception):
45
+ """Custom exception for OpenRouter rate limit errors."""
46
+
47
+ pass
48
+
49
+
50
+ class OpenRouterModel:
51
+ def __init__(self, **kwargs):
52
+ self.config = OpenRouterModelConfig(**kwargs)
53
+ self.cost = 0.0
54
+ self.n_calls = 0
55
+ self._api_url = "https://openrouter.ai/api/v1/chat/completions"
56
+ self._api_key = os.getenv("OPENROUTER_API_KEY", "")
57
+
58
+ @retry(
59
+ stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
60
+ wait=wait_exponential(multiplier=1, min=4, max=60),
61
+ before_sleep=before_sleep_log(logger, logging.WARNING),
62
+ retry=retry_if_not_exception_type(
63
+ (
64
+ OpenRouterAuthenticationError,
65
+ KeyboardInterrupt,
66
+ )
67
+ ),
68
+ )
69
+ def _query(self, messages: list[dict[str, str]], **kwargs):
70
+ headers = {
71
+ "Authorization": f"Bearer {self._api_key}",
72
+ "Content-Type": "application/json",
73
+ }
74
+
75
+ payload = {
76
+ "model": self.config.model_name,
77
+ "messages": messages,
78
+ "usage": {"include": True},
79
+ **(self.config.model_kwargs | kwargs),
80
+ }
81
+
82
+ try:
83
+ response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
84
+ response.raise_for_status()
85
+ return response.json()
86
+ except requests.exceptions.HTTPError as e:
87
+ if response.status_code == 401:
88
+ error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
89
+ raise OpenRouterAuthenticationError(error_msg) from e
90
+ elif response.status_code == 429:
91
+ raise OpenRouterRateLimitError("Rate limit exceeded") from e
92
+ else:
93
+ raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
94
+ except requests.exceptions.RequestException as e:
95
+ raise OpenRouterAPIError(f"Request failed: {e}") from e
96
+
97
+ def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
98
+ if self.config.set_cache_control:
99
+ messages = set_cache_control(messages, mode=self.config.set_cache_control)
100
+ response = self._query(messages, **kwargs)
101
+
102
+ usage = response.get("usage", {})
103
+ cost = usage.get("cost", 0.0)
104
+ if cost <= 0.0 and self.config.cost_tracking != "ignore_errors":
105
+ raise RuntimeError(
106
+ f"No valid cost information available from OpenRouter API for model {self.config.model_name}: "
107
+ f"Usage {usage}, cost {cost}. Cost must be > 0.0. Set cost_tracking: 'ignore_errors' in your config file or "
108
+ "export MSWEA_COST_TRACKING='ignore_errors' to ignore cost tracking errors "
109
+ "(for example for free/local models), more information at https://klieret.short.gy/mini-local-models "
110
+ "for more details. Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
111
+ )
112
+
113
+ self.n_calls += 1
114
+ self.cost += cost
115
+ GLOBAL_MODEL_STATS.add(cost)
116
+
117
+ return {
118
+ "content": response["choices"][0]["message"]["content"] or "",
119
+ "extra": {
120
+ "response": response, # already is json
121
+ },
122
+ }
123
+
124
+ def get_template_vars(self) -> dict[str, Any]:
125
+ return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
@@ -0,0 +1,154 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from dataclasses import asdict, dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Literal
7
+
8
+ import litellm
9
+ from tenacity import (
10
+ before_sleep_log,
11
+ retry,
12
+ retry_if_not_exception_type,
13
+ stop_after_attempt,
14
+ wait_exponential,
15
+ )
16
+
17
+ from minisweagent.models import GLOBAL_MODEL_STATS
18
+ from minisweagent.models.utils.cache_control import set_cache_control
19
+
20
+ logger = logging.getLogger("portkey_model")
21
+
22
+ try:
23
+ from portkey_ai import Portkey
24
+ except ImportError:
25
+ raise ImportError(
26
+ "The portkey-ai package is required to use PortkeyModel. Please install it with: pip install portkey-ai"
27
+ )
28
+
29
+
30
+ @dataclass
31
+ class PortkeyModelConfig:
32
+ model_name: str
33
+ model_kwargs: dict[str, Any] = field(default_factory=dict)
34
+ litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
35
+ """We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry.
36
+ Note that this might change if we get better support for Portkey and change how we calculate costs.
37
+ """
38
+ litellm_model_name_override: str = ""
39
+ """We currently use litellm to calculate costs. Here you can override the model name to use for litellm in case it
40
+ doesn't match the Portkey model name.
41
+ Note that this might change if we get better support for Portkey and change how we calculate costs.
42
+ """
43
+ set_cache_control: Literal["default_end"] | None = None
44
+ """Set explicit cache control markers, for example for Anthropic models"""
45
+ cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
46
+ """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
47
+
48
+
49
+ class PortkeyModel:
50
+ def __init__(self, *, config_class: type = PortkeyModelConfig, **kwargs):
51
+ self.config = config_class(**kwargs)
52
+ self.cost = 0.0
53
+ self.n_calls = 0
54
+ if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
55
+ litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
56
+
57
+ # Get API key from environment or raise error
58
+ self._api_key = os.getenv("PORTKEY_API_KEY")
59
+ if not self._api_key:
60
+ raise ValueError(
61
+ "Portkey API key is required. Set it via the "
62
+ "PORTKEY_API_KEY environment variable. You can permanently set it with "
63
+ "`mini-extra config set PORTKEY_API_KEY YOUR_KEY`."
64
+ )
65
+
66
+ # Get virtual key from environment
67
+ virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY")
68
+
69
+ # Initialize Portkey client
70
+ client_kwargs = {"api_key": self._api_key}
71
+ if virtual_key:
72
+ client_kwargs["virtual_key"] = virtual_key
73
+
74
+ self.client = Portkey(**client_kwargs)
75
+
76
+ @retry(
77
+ stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
78
+ wait=wait_exponential(multiplier=1, min=4, max=60),
79
+ before_sleep=before_sleep_log(logger, logging.WARNING),
80
+ retry=retry_if_not_exception_type((KeyboardInterrupt, TypeError, ValueError)),
81
+ )
82
+ def _query(self, messages: list[dict[str, str]], **kwargs):
83
+ # return self.client.with_options(metadata={"request_id": request_id}).chat.completions.create(
84
+ return self.client.chat.completions.create(
85
+ model=self.config.model_name,
86
+ messages=messages,
87
+ **(self.config.model_kwargs | kwargs),
88
+ )
89
+
90
+ def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
91
+ if self.config.set_cache_control:
92
+ messages = set_cache_control(messages, mode=self.config.set_cache_control)
93
+ response = self._query(messages, **kwargs)
94
+ cost = self._calculate_cost(response)
95
+ self.n_calls += 1
96
+ self.cost += cost
97
+ GLOBAL_MODEL_STATS.add(cost)
98
+ return {
99
+ "content": response.choices[0].message.content or "",
100
+ "extra": {
101
+ "response": response.model_dump(),
102
+ "cost": cost,
103
+ },
104
+ }
105
+
106
+ def get_template_vars(self) -> dict[str, Any]:
107
+ return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
108
+
109
+ def _calculate_cost(self, response) -> float:
110
+ response_for_cost_calc = response.model_copy()
111
+ if self.config.litellm_model_name_override:
112
+ if response_for_cost_calc.model:
113
+ response_for_cost_calc.model = self.config.litellm_model_name_override
114
+ prompt_tokens = response_for_cost_calc.usage.prompt_tokens
115
+ if prompt_tokens is None:
116
+ logger.warning(
117
+ f"Prompt tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
118
+ )
119
+ prompt_tokens = 0
120
+ total_tokens = response_for_cost_calc.usage.total_tokens
121
+ completion_tokens = response_for_cost_calc.usage.completion_tokens
122
+ if completion_tokens is None:
123
+ logger.warning(
124
+ f"Completion tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
125
+ )
126
+ completion_tokens = 0
127
+ if total_tokens - prompt_tokens - completion_tokens != 0:
128
+ # This is most likely related to how portkey treats cached tokens: It doesn't count them towards the prompt tokens (?)
129
+ logger.warning(
130
+ f"WARNING: Total tokens - prompt tokens - completion tokens != 0: {response_for_cost_calc.model_dump()}."
131
+ " This is probably a portkey bug or incompatibility with litellm cost tracking. "
132
+ "Setting prompt tokens based on total tokens and completion tokens. You might want to double check your costs. "
133
+ f"Full response: {response_for_cost_calc.model_dump()}"
134
+ )
135
+ response_for_cost_calc.usage.prompt_tokens = total_tokens - completion_tokens
136
+ try:
137
+ cost = litellm.cost_calculator.completion_cost(
138
+ response_for_cost_calc, model=self.config.litellm_model_name_override or None
139
+ )
140
+ assert cost >= 0.0, f"Cost is negative: {cost}"
141
+ except Exception as e:
142
+ cost = 0.0
143
+ if self.config.cost_tracking != "ignore_errors":
144
+ msg = (
145
+ f"Error calculating cost for model {self.config.model_name} based on {response_for_cost_calc.model_dump()}: {e}. "
146
+ "You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
147
+ "globally with export MSWEA_COST_TRACKING='ignore_errors' to ignore this error. "
148
+ "Alternatively check the 'Cost tracking' section in the documentation at "
149
+ "https://klieret.short.gy/mini-local-models. "
150
+ "Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
151
+ )
152
+ logger.critical(msg)
153
+ raise RuntimeError(msg) from e
154
+ return cost