mini-swe-agent 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mini_swe_agent-1.16.0.dist-info/METADATA +314 -0
- mini_swe_agent-1.16.0.dist-info/RECORD +62 -0
- mini_swe_agent-1.16.0.dist-info/WHEEL +5 -0
- mini_swe_agent-1.16.0.dist-info/entry_points.txt +5 -0
- mini_swe_agent-1.16.0.dist-info/licenses/LICENSE.md +21 -0
- mini_swe_agent-1.16.0.dist-info/top_level.txt +1 -0
- minisweagent/__init__.py +83 -0
- minisweagent/__main__.py +7 -0
- minisweagent/agents/__init__.py +1 -0
- minisweagent/agents/default.py +131 -0
- minisweagent/agents/interactive.py +153 -0
- minisweagent/agents/interactive_textual.py +450 -0
- minisweagent/config/README.md +10 -0
- minisweagent/config/__init__.py +27 -0
- minisweagent/config/default.yaml +157 -0
- minisweagent/config/extra/__init__.py +1 -0
- minisweagent/config/extra/swebench.yaml +230 -0
- minisweagent/config/extra/swebench_roulette.yaml +233 -0
- minisweagent/config/extra/swebench_xml.yaml +215 -0
- minisweagent/config/github_issue.yaml +146 -0
- minisweagent/config/mini.tcss +86 -0
- minisweagent/config/mini.yaml +158 -0
- minisweagent/config/mini_no_temp.yaml +158 -0
- minisweagent/environments/__init__.py +31 -0
- minisweagent/environments/docker.py +114 -0
- minisweagent/environments/extra/__init__.py +0 -0
- minisweagent/environments/extra/bubblewrap.py +112 -0
- minisweagent/environments/extra/swerex_docker.py +47 -0
- minisweagent/environments/local.py +38 -0
- minisweagent/environments/singularity.py +97 -0
- minisweagent/models/__init__.py +114 -0
- minisweagent/models/anthropic.py +35 -0
- minisweagent/models/extra/__init__.py +0 -0
- minisweagent/models/extra/roulette.py +61 -0
- minisweagent/models/litellm_model.py +100 -0
- minisweagent/models/litellm_response_api_model.py +80 -0
- minisweagent/models/openrouter_model.py +125 -0
- minisweagent/models/portkey_model.py +154 -0
- minisweagent/models/portkey_response_api_model.py +74 -0
- minisweagent/models/requesty_model.py +119 -0
- minisweagent/models/test_models.py +42 -0
- minisweagent/models/utils/__init__.py +0 -0
- minisweagent/models/utils/cache_control.py +54 -0
- minisweagent/models/utils/key_per_thread.py +20 -0
- minisweagent/models/utils/openai_utils.py +41 -0
- minisweagent/py.typed +0 -0
- minisweagent/run/__init__.py +1 -0
- minisweagent/run/extra/__init__.py +0 -0
- minisweagent/run/extra/config.py +114 -0
- minisweagent/run/extra/swebench.py +266 -0
- minisweagent/run/extra/swebench_single.py +79 -0
- minisweagent/run/extra/utils/__init__.py +0 -0
- minisweagent/run/extra/utils/batch_progress.py +178 -0
- minisweagent/run/github_issue.py +87 -0
- minisweagent/run/hello_world.py +36 -0
- minisweagent/run/inspector.py +212 -0
- minisweagent/run/mini.py +108 -0
- minisweagent/run/mini_extra.py +44 -0
- minisweagent/run/utils/__init__.py +0 -0
- minisweagent/run/utils/save.py +78 -0
- minisweagent/utils/__init__.py +0 -0
- minisweagent/utils/log.py +36 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""This file provides convenience functions for selecting models.
|
|
2
|
+
You can ignore this file completely if you explicitly set your model in your run script.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import importlib
|
|
7
|
+
import os
|
|
8
|
+
import threading
|
|
9
|
+
|
|
10
|
+
from minisweagent import Model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GlobalModelStats:
|
|
14
|
+
"""Global model statistics tracker with optional limits."""
|
|
15
|
+
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self._cost = 0.0
|
|
18
|
+
self._n_calls = 0
|
|
19
|
+
self._lock = threading.Lock()
|
|
20
|
+
self.cost_limit = float(os.getenv("MSWEA_GLOBAL_COST_LIMIT", "0"))
|
|
21
|
+
self.call_limit = int(os.getenv("MSWEA_GLOBAL_CALL_LIMIT", "0"))
|
|
22
|
+
if (self.cost_limit > 0 or self.call_limit > 0) and not os.getenv("MSWEA_SILENT_STARTUP"):
|
|
23
|
+
print(f"Global cost/call limit: ${self.cost_limit:.4f} / {self.call_limit}")
|
|
24
|
+
|
|
25
|
+
def add(self, cost: float) -> None:
|
|
26
|
+
"""Add a model call with its cost, checking limits."""
|
|
27
|
+
with self._lock:
|
|
28
|
+
self._cost += cost
|
|
29
|
+
self._n_calls += 1
|
|
30
|
+
if 0 < self.cost_limit < self._cost or 0 < self.call_limit < self._n_calls + 1:
|
|
31
|
+
raise RuntimeError(f"Global cost/call limit exceeded: ${self._cost:.4f} / {self._n_calls + 1}")
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def cost(self) -> float:
|
|
35
|
+
return self._cost
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def n_calls(self) -> int:
|
|
39
|
+
return self._n_calls
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
GLOBAL_MODEL_STATS = GlobalModelStats()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_model(input_model_name: str | None = None, config: dict | None = None) -> Model:
|
|
46
|
+
"""Get an initialized model object from any kind of user input or settings."""
|
|
47
|
+
resolved_model_name = get_model_name(input_model_name, config)
|
|
48
|
+
if config is None:
|
|
49
|
+
config = {}
|
|
50
|
+
config = copy.deepcopy(config)
|
|
51
|
+
config["model_name"] = resolved_model_name
|
|
52
|
+
|
|
53
|
+
model_class = get_model_class(resolved_model_name, config.pop("model_class", ""))
|
|
54
|
+
|
|
55
|
+
if (from_env := os.getenv("MSWEA_MODEL_API_KEY")) and not str(type(model_class)).endswith("DeterministicModel"):
|
|
56
|
+
config.setdefault("model_kwargs", {})["api_key"] = from_env
|
|
57
|
+
|
|
58
|
+
if (
|
|
59
|
+
any(s in resolved_model_name.lower() for s in ["anthropic", "sonnet", "opus", "claude"])
|
|
60
|
+
and "set_cache_control" not in config
|
|
61
|
+
):
|
|
62
|
+
# Select cache control for Anthropic models by default
|
|
63
|
+
config["set_cache_control"] = "default_end"
|
|
64
|
+
|
|
65
|
+
return model_class(**config)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_model_name(input_model_name: str | None = None, config: dict | None = None) -> str:
|
|
69
|
+
"""Get a model name from any kind of user input or settings."""
|
|
70
|
+
if config is None:
|
|
71
|
+
config = {}
|
|
72
|
+
if input_model_name:
|
|
73
|
+
return input_model_name
|
|
74
|
+
if from_config := config.get("model_name"):
|
|
75
|
+
return from_config
|
|
76
|
+
if from_env := os.getenv("MSWEA_MODEL_NAME"):
|
|
77
|
+
return from_env
|
|
78
|
+
raise ValueError("No default model set. Please run `mini-extra config setup` to set one.")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
_MODEL_CLASS_MAPPING = {
|
|
82
|
+
"anthropic": "minisweagent.models.anthropic.AnthropicModel",
|
|
83
|
+
"litellm": "minisweagent.models.litellm_model.LitellmModel",
|
|
84
|
+
"litellm_response": "minisweagent.models.litellm_response_api_model.LitellmResponseAPIModel",
|
|
85
|
+
"openrouter": "minisweagent.models.openrouter_model.OpenRouterModel",
|
|
86
|
+
"portkey": "minisweagent.models.portkey_model.PortkeyModel",
|
|
87
|
+
"portkey_response": "minisweagent.models.portkey_response_api_model.PortkeyResponseAPIModel",
|
|
88
|
+
"requesty": "minisweagent.models.requesty_model.RequestyModel",
|
|
89
|
+
"deterministic": "minisweagent.models.test_models.DeterministicModel",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_model_class(model_name: str, model_class: str = "") -> type:
|
|
94
|
+
"""Select the best model class.
|
|
95
|
+
|
|
96
|
+
If a model_class is provided (as shortcut name, or as full import path,
|
|
97
|
+
e.g., "anthropic" or "minisweagent.models.anthropic.AnthropicModel"),
|
|
98
|
+
it takes precedence over the `model_name`.
|
|
99
|
+
Otherwise, the model_name is used to select the best model class.
|
|
100
|
+
"""
|
|
101
|
+
if model_class:
|
|
102
|
+
full_path = _MODEL_CLASS_MAPPING.get(model_class, model_class)
|
|
103
|
+
try:
|
|
104
|
+
module_name, class_name = full_path.rsplit(".", 1)
|
|
105
|
+
module = importlib.import_module(module_name)
|
|
106
|
+
return getattr(module, class_name)
|
|
107
|
+
except (ValueError, ImportError, AttributeError):
|
|
108
|
+
msg = f"Unknown model class: {model_class} (resolved to {full_path}, available: {_MODEL_CLASS_MAPPING})"
|
|
109
|
+
raise ValueError(msg)
|
|
110
|
+
|
|
111
|
+
# Default to LitellmModel
|
|
112
|
+
from minisweagent.models.litellm_model import LitellmModel
|
|
113
|
+
|
|
114
|
+
return LitellmModel
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
|
|
6
|
+
from minisweagent.models.utils.cache_control import set_cache_control
|
|
7
|
+
from minisweagent.models.utils.key_per_thread import get_key_per_thread
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AnthropicModelConfig(LitellmModelConfig):
|
|
11
|
+
set_cache_control: Literal["default_end"] | None = "default_end"
|
|
12
|
+
"""Set explicit cache control markers, for example for Anthropic models"""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AnthropicModel(LitellmModel):
|
|
16
|
+
"""This class is now only a thin wrapper around the LitellmModel class.
|
|
17
|
+
It is largely kept for backwards compatibility.
|
|
18
|
+
It will not be selected by `get_model` and `get_model_class` unless explicitly specified.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, *, config_class: type = AnthropicModelConfig, **kwargs):
|
|
22
|
+
super().__init__(config_class=config_class, **kwargs)
|
|
23
|
+
|
|
24
|
+
def query(self, messages: list[dict], **kwargs) -> dict:
|
|
25
|
+
api_key = None
|
|
26
|
+
# Legacy only
|
|
27
|
+
if rotating_keys := os.getenv("ANTHROPIC_API_KEYS"):
|
|
28
|
+
warnings.warn(
|
|
29
|
+
"ANTHROPIC_API_KEYS is deprecated and will be removed in the future. "
|
|
30
|
+
"Simply use the ANTHROPIC_API_KEY environment variable instead. "
|
|
31
|
+
"Key rotation is no longer required."
|
|
32
|
+
)
|
|
33
|
+
api_key = get_key_per_thread(rotating_keys.split("::"))
|
|
34
|
+
messages = set_cache_control(messages, mode="default_end")
|
|
35
|
+
return super().query(messages, api_key=api_key, **kwargs)
|
|
File without changes
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from dataclasses import asdict, dataclass
|
|
3
|
+
|
|
4
|
+
from minisweagent import Model
|
|
5
|
+
from minisweagent.models import get_model
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class RouletteModelConfig:
|
|
10
|
+
model_kwargs: list[dict]
|
|
11
|
+
"""The models to choose from"""
|
|
12
|
+
model_name: str = "roulette"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RouletteModel:
|
|
16
|
+
def __init__(self, *, config_class: type = RouletteModelConfig, **kwargs):
|
|
17
|
+
"""This "meta"-model randomly selects one of the models at every call"""
|
|
18
|
+
self.config = config_class(**kwargs)
|
|
19
|
+
self.models = [get_model(config=config) for config in self.config.model_kwargs]
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def cost(self) -> float:
|
|
23
|
+
return sum(model.cost for model in self.models)
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def n_calls(self) -> int:
|
|
27
|
+
return sum(model.n_calls for model in self.models)
|
|
28
|
+
|
|
29
|
+
def get_template_vars(self) -> dict:
|
|
30
|
+
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
|
|
31
|
+
|
|
32
|
+
def select_model(self) -> Model:
|
|
33
|
+
return random.choice(self.models)
|
|
34
|
+
|
|
35
|
+
def query(self, *args, **kwargs) -> dict:
|
|
36
|
+
model = self.select_model()
|
|
37
|
+
response = model.query(*args, **kwargs)
|
|
38
|
+
response["model_name"] = model.config.model_name
|
|
39
|
+
return response
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class InterleavingModelConfig:
|
|
44
|
+
model_kwargs: list[dict]
|
|
45
|
+
sequence: list[int] | None = None
|
|
46
|
+
"""If set to 0, 0, 1, we will return the first model 2 times, then the second model 1 time,
|
|
47
|
+
then the first model again, etc."""
|
|
48
|
+
model_name: str = "interleaving"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class InterleavingModel(RouletteModel):
|
|
52
|
+
def __init__(self, *, config_class: type = InterleavingModelConfig, **kwargs):
|
|
53
|
+
"""This "meta"-model alternates between the models in the sequence for every call"""
|
|
54
|
+
super().__init__(config_class=config_class, **kwargs)
|
|
55
|
+
|
|
56
|
+
def select_model(self) -> Model:
|
|
57
|
+
if self.config.sequence is None:
|
|
58
|
+
i_model = self.n_calls % len(self.models)
|
|
59
|
+
else:
|
|
60
|
+
i_model = self.config.sequence[self.n_calls % len(self.config.sequence)]
|
|
61
|
+
return self.models[i_model]
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from dataclasses import asdict, dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
import litellm
|
|
10
|
+
from tenacity import (
|
|
11
|
+
before_sleep_log,
|
|
12
|
+
retry,
|
|
13
|
+
retry_if_not_exception_type,
|
|
14
|
+
stop_after_attempt,
|
|
15
|
+
wait_exponential,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
19
|
+
from minisweagent.models.utils.cache_control import set_cache_control
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger("litellm_model")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class LitellmModelConfig:
|
|
26
|
+
model_name: str
|
|
27
|
+
model_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
28
|
+
litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
|
|
29
|
+
set_cache_control: Literal["default_end"] | None = None
|
|
30
|
+
"""Set explicit cache control markers, for example for Anthropic models"""
|
|
31
|
+
cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
|
|
32
|
+
"""Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LitellmModel:
|
|
36
|
+
def __init__(self, *, config_class: Callable = LitellmModelConfig, **kwargs):
|
|
37
|
+
self.config = config_class(**kwargs)
|
|
38
|
+
self.cost = 0.0
|
|
39
|
+
self.n_calls = 0
|
|
40
|
+
if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
|
|
41
|
+
litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
|
|
42
|
+
|
|
43
|
+
@retry(
|
|
44
|
+
stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
|
|
45
|
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
46
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
47
|
+
retry=retry_if_not_exception_type(
|
|
48
|
+
(
|
|
49
|
+
litellm.exceptions.UnsupportedParamsError,
|
|
50
|
+
litellm.exceptions.NotFoundError,
|
|
51
|
+
litellm.exceptions.PermissionDeniedError,
|
|
52
|
+
litellm.exceptions.ContextWindowExceededError,
|
|
53
|
+
litellm.exceptions.APIError,
|
|
54
|
+
litellm.exceptions.AuthenticationError,
|
|
55
|
+
KeyboardInterrupt,
|
|
56
|
+
)
|
|
57
|
+
),
|
|
58
|
+
)
|
|
59
|
+
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
60
|
+
try:
|
|
61
|
+
return litellm.completion(
|
|
62
|
+
model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
|
|
63
|
+
)
|
|
64
|
+
except litellm.exceptions.AuthenticationError as e:
|
|
65
|
+
e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
|
|
66
|
+
raise e
|
|
67
|
+
|
|
68
|
+
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
69
|
+
if self.config.set_cache_control:
|
|
70
|
+
messages = set_cache_control(messages, mode=self.config.set_cache_control)
|
|
71
|
+
response = self._query(messages, **kwargs)
|
|
72
|
+
try:
|
|
73
|
+
cost = litellm.cost_calculator.completion_cost(response)
|
|
74
|
+
if cost <= 0.0:
|
|
75
|
+
raise ValueError(f"Cost must be > 0.0, got {cost}")
|
|
76
|
+
except Exception as e:
|
|
77
|
+
cost = 0.0
|
|
78
|
+
if self.config.cost_tracking != "ignore_errors":
|
|
79
|
+
msg = (
|
|
80
|
+
f"Error calculating cost for model {self.config.model_name}: {e}, perhaps it's not registered? "
|
|
81
|
+
"You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
|
|
82
|
+
"globally with export MSWEA_COST_TRACKING='ignore_errors'. "
|
|
83
|
+
"Alternatively check the 'Cost tracking' section in the documentation at "
|
|
84
|
+
"https://klieret.short.gy/mini-local-models. "
|
|
85
|
+
" Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
|
|
86
|
+
)
|
|
87
|
+
logger.critical(msg)
|
|
88
|
+
raise RuntimeError(msg) from e
|
|
89
|
+
self.n_calls += 1
|
|
90
|
+
self.cost += cost
|
|
91
|
+
GLOBAL_MODEL_STATS.add(cost)
|
|
92
|
+
return {
|
|
93
|
+
"content": response.choices[0].message.content or "", # type: ignore
|
|
94
|
+
"extra": {
|
|
95
|
+
"response": response.model_dump(),
|
|
96
|
+
},
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
def get_template_vars(self) -> dict[str, Any]:
|
|
100
|
+
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import litellm
|
|
6
|
+
from tenacity import (
|
|
7
|
+
before_sleep_log,
|
|
8
|
+
retry,
|
|
9
|
+
retry_if_not_exception_type,
|
|
10
|
+
stop_after_attempt,
|
|
11
|
+
wait_exponential,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
|
|
15
|
+
from minisweagent.models.utils.openai_utils import coerce_responses_text
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("litellm_response_api_model")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class LitellmResponseAPIModelConfig(LitellmModelConfig):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LitellmResponseAPIModel(LitellmModel):
|
|
26
|
+
def __init__(self, *, config_class: Callable = LitellmResponseAPIModelConfig, **kwargs):
|
|
27
|
+
super().__init__(config_class=config_class, **kwargs)
|
|
28
|
+
self._previous_response_id: str | None = None
|
|
29
|
+
|
|
30
|
+
@retry(
|
|
31
|
+
stop=stop_after_attempt(10),
|
|
32
|
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
33
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
34
|
+
retry=retry_if_not_exception_type(
|
|
35
|
+
(
|
|
36
|
+
litellm.exceptions.UnsupportedParamsError,
|
|
37
|
+
litellm.exceptions.NotFoundError,
|
|
38
|
+
litellm.exceptions.PermissionDeniedError,
|
|
39
|
+
litellm.exceptions.ContextWindowExceededError,
|
|
40
|
+
litellm.exceptions.APIError,
|
|
41
|
+
litellm.exceptions.AuthenticationError,
|
|
42
|
+
KeyboardInterrupt,
|
|
43
|
+
)
|
|
44
|
+
),
|
|
45
|
+
)
|
|
46
|
+
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
47
|
+
try:
|
|
48
|
+
resp = litellm.responses(
|
|
49
|
+
model=self.config.model_name,
|
|
50
|
+
input=messages if self._previous_response_id is None else messages[-1:],
|
|
51
|
+
previous_response_id=self._previous_response_id,
|
|
52
|
+
**(self.config.model_kwargs | kwargs),
|
|
53
|
+
)
|
|
54
|
+
self._previous_response_id = getattr(resp, "id", None)
|
|
55
|
+
return resp
|
|
56
|
+
except litellm.exceptions.AuthenticationError as e:
|
|
57
|
+
e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
|
|
58
|
+
raise e
|
|
59
|
+
|
|
60
|
+
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
61
|
+
response = self._query(messages, **kwargs)
|
|
62
|
+
print(response)
|
|
63
|
+
text = coerce_responses_text(response)
|
|
64
|
+
try:
|
|
65
|
+
cost = litellm.cost_calculator.completion_cost(response)
|
|
66
|
+
except Exception as e:
|
|
67
|
+
logger.critical(
|
|
68
|
+
f"Error calculating cost for model {self.config.model_name}: {e}. "
|
|
69
|
+
"Please check the 'Updating the model registry' section in the documentation. "
|
|
70
|
+
"http://bit.ly/4p31bi4 Still stuck? Please open a github issue for help!"
|
|
71
|
+
)
|
|
72
|
+
raise
|
|
73
|
+
self.n_calls += 1
|
|
74
|
+
self.cost += cost
|
|
75
|
+
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
76
|
+
|
|
77
|
+
GLOBAL_MODEL_STATS.add(cost)
|
|
78
|
+
return {
|
|
79
|
+
"content": text,
|
|
80
|
+
}
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import asdict, dataclass, field
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
from tenacity import (
|
|
9
|
+
before_sleep_log,
|
|
10
|
+
retry,
|
|
11
|
+
retry_if_not_exception_type,
|
|
12
|
+
stop_after_attempt,
|
|
13
|
+
wait_exponential,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
17
|
+
from minisweagent.models.utils.cache_control import set_cache_control
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("openrouter_model")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class OpenRouterModelConfig:
|
|
24
|
+
model_name: str
|
|
25
|
+
model_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
26
|
+
set_cache_control: Literal["default_end"] | None = None
|
|
27
|
+
"""Set explicit cache control markers, for example for Anthropic models"""
|
|
28
|
+
cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
|
|
29
|
+
"""Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OpenRouterAPIError(Exception):
|
|
33
|
+
"""Custom exception for OpenRouter API errors."""
|
|
34
|
+
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class OpenRouterAuthenticationError(Exception):
|
|
39
|
+
"""Custom exception for OpenRouter authentication errors."""
|
|
40
|
+
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class OpenRouterRateLimitError(Exception):
|
|
45
|
+
"""Custom exception for OpenRouter rate limit errors."""
|
|
46
|
+
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class OpenRouterModel:
|
|
51
|
+
def __init__(self, **kwargs):
|
|
52
|
+
self.config = OpenRouterModelConfig(**kwargs)
|
|
53
|
+
self.cost = 0.0
|
|
54
|
+
self.n_calls = 0
|
|
55
|
+
self._api_url = "https://openrouter.ai/api/v1/chat/completions"
|
|
56
|
+
self._api_key = os.getenv("OPENROUTER_API_KEY", "")
|
|
57
|
+
|
|
58
|
+
@retry(
|
|
59
|
+
stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
|
|
60
|
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
61
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
62
|
+
retry=retry_if_not_exception_type(
|
|
63
|
+
(
|
|
64
|
+
OpenRouterAuthenticationError,
|
|
65
|
+
KeyboardInterrupt,
|
|
66
|
+
)
|
|
67
|
+
),
|
|
68
|
+
)
|
|
69
|
+
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
70
|
+
headers = {
|
|
71
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
72
|
+
"Content-Type": "application/json",
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
payload = {
|
|
76
|
+
"model": self.config.model_name,
|
|
77
|
+
"messages": messages,
|
|
78
|
+
"usage": {"include": True},
|
|
79
|
+
**(self.config.model_kwargs | kwargs),
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
|
|
84
|
+
response.raise_for_status()
|
|
85
|
+
return response.json()
|
|
86
|
+
except requests.exceptions.HTTPError as e:
|
|
87
|
+
if response.status_code == 401:
|
|
88
|
+
error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
|
|
89
|
+
raise OpenRouterAuthenticationError(error_msg) from e
|
|
90
|
+
elif response.status_code == 429:
|
|
91
|
+
raise OpenRouterRateLimitError("Rate limit exceeded") from e
|
|
92
|
+
else:
|
|
93
|
+
raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
|
|
94
|
+
except requests.exceptions.RequestException as e:
|
|
95
|
+
raise OpenRouterAPIError(f"Request failed: {e}") from e
|
|
96
|
+
|
|
97
|
+
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
98
|
+
if self.config.set_cache_control:
|
|
99
|
+
messages = set_cache_control(messages, mode=self.config.set_cache_control)
|
|
100
|
+
response = self._query(messages, **kwargs)
|
|
101
|
+
|
|
102
|
+
usage = response.get("usage", {})
|
|
103
|
+
cost = usage.get("cost", 0.0)
|
|
104
|
+
if cost <= 0.0 and self.config.cost_tracking != "ignore_errors":
|
|
105
|
+
raise RuntimeError(
|
|
106
|
+
f"No valid cost information available from OpenRouter API for model {self.config.model_name}: "
|
|
107
|
+
f"Usage {usage}, cost {cost}. Cost must be > 0.0. Set cost_tracking: 'ignore_errors' in your config file or "
|
|
108
|
+
"export MSWEA_COST_TRACKING='ignore_errors' to ignore cost tracking errors "
|
|
109
|
+
"(for example for free/local models), more information at https://klieret.short.gy/mini-local-models "
|
|
110
|
+
"for more details. Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self.n_calls += 1
|
|
114
|
+
self.cost += cost
|
|
115
|
+
GLOBAL_MODEL_STATS.add(cost)
|
|
116
|
+
|
|
117
|
+
return {
|
|
118
|
+
"content": response["choices"][0]["message"]["content"] or "",
|
|
119
|
+
"extra": {
|
|
120
|
+
"response": response, # already is json
|
|
121
|
+
},
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
def get_template_vars(self) -> dict[str, Any]:
|
|
125
|
+
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import asdict, dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
import litellm
|
|
9
|
+
from tenacity import (
|
|
10
|
+
before_sleep_log,
|
|
11
|
+
retry,
|
|
12
|
+
retry_if_not_exception_type,
|
|
13
|
+
stop_after_attempt,
|
|
14
|
+
wait_exponential,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from minisweagent.models import GLOBAL_MODEL_STATS
|
|
18
|
+
from minisweagent.models.utils.cache_control import set_cache_control
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger("portkey_model")
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from portkey_ai import Portkey
|
|
24
|
+
except ImportError:
|
|
25
|
+
raise ImportError(
|
|
26
|
+
"The portkey-ai package is required to use PortkeyModel. Please install it with: pip install portkey-ai"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class PortkeyModelConfig:
|
|
32
|
+
model_name: str
|
|
33
|
+
model_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
34
|
+
litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
|
|
35
|
+
"""We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry.
|
|
36
|
+
Note that this might change if we get better support for Portkey and change how we calculate costs.
|
|
37
|
+
"""
|
|
38
|
+
litellm_model_name_override: str = ""
|
|
39
|
+
"""We currently use litellm to calculate costs. Here you can override the model name to use for litellm in case it
|
|
40
|
+
doesn't match the Portkey model name.
|
|
41
|
+
Note that this might change if we get better support for Portkey and change how we calculate costs.
|
|
42
|
+
"""
|
|
43
|
+
set_cache_control: Literal["default_end"] | None = None
|
|
44
|
+
"""Set explicit cache control markers, for example for Anthropic models"""
|
|
45
|
+
cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
|
|
46
|
+
"""Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class PortkeyModel:
|
|
50
|
+
def __init__(self, *, config_class: type = PortkeyModelConfig, **kwargs):
|
|
51
|
+
self.config = config_class(**kwargs)
|
|
52
|
+
self.cost = 0.0
|
|
53
|
+
self.n_calls = 0
|
|
54
|
+
if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
|
|
55
|
+
litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
|
|
56
|
+
|
|
57
|
+
# Get API key from environment or raise error
|
|
58
|
+
self._api_key = os.getenv("PORTKEY_API_KEY")
|
|
59
|
+
if not self._api_key:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"Portkey API key is required. Set it via the "
|
|
62
|
+
"PORTKEY_API_KEY environment variable. You can permanently set it with "
|
|
63
|
+
"`mini-extra config set PORTKEY_API_KEY YOUR_KEY`."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Get virtual key from environment
|
|
67
|
+
virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY")
|
|
68
|
+
|
|
69
|
+
# Initialize Portkey client
|
|
70
|
+
client_kwargs = {"api_key": self._api_key}
|
|
71
|
+
if virtual_key:
|
|
72
|
+
client_kwargs["virtual_key"] = virtual_key
|
|
73
|
+
|
|
74
|
+
self.client = Portkey(**client_kwargs)
|
|
75
|
+
|
|
76
|
+
@retry(
|
|
77
|
+
stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
|
|
78
|
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
79
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
80
|
+
retry=retry_if_not_exception_type((KeyboardInterrupt, TypeError, ValueError)),
|
|
81
|
+
)
|
|
82
|
+
def _query(self, messages: list[dict[str, str]], **kwargs):
|
|
83
|
+
# return self.client.with_options(metadata={"request_id": request_id}).chat.completions.create(
|
|
84
|
+
return self.client.chat.completions.create(
|
|
85
|
+
model=self.config.model_name,
|
|
86
|
+
messages=messages,
|
|
87
|
+
**(self.config.model_kwargs | kwargs),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
|
91
|
+
if self.config.set_cache_control:
|
|
92
|
+
messages = set_cache_control(messages, mode=self.config.set_cache_control)
|
|
93
|
+
response = self._query(messages, **kwargs)
|
|
94
|
+
cost = self._calculate_cost(response)
|
|
95
|
+
self.n_calls += 1
|
|
96
|
+
self.cost += cost
|
|
97
|
+
GLOBAL_MODEL_STATS.add(cost)
|
|
98
|
+
return {
|
|
99
|
+
"content": response.choices[0].message.content or "",
|
|
100
|
+
"extra": {
|
|
101
|
+
"response": response.model_dump(),
|
|
102
|
+
"cost": cost,
|
|
103
|
+
},
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
def get_template_vars(self) -> dict[str, Any]:
|
|
107
|
+
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
|
|
108
|
+
|
|
109
|
+
def _calculate_cost(self, response) -> float:
|
|
110
|
+
response_for_cost_calc = response.model_copy()
|
|
111
|
+
if self.config.litellm_model_name_override:
|
|
112
|
+
if response_for_cost_calc.model:
|
|
113
|
+
response_for_cost_calc.model = self.config.litellm_model_name_override
|
|
114
|
+
prompt_tokens = response_for_cost_calc.usage.prompt_tokens
|
|
115
|
+
if prompt_tokens is None:
|
|
116
|
+
logger.warning(
|
|
117
|
+
f"Prompt tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
|
|
118
|
+
)
|
|
119
|
+
prompt_tokens = 0
|
|
120
|
+
total_tokens = response_for_cost_calc.usage.total_tokens
|
|
121
|
+
completion_tokens = response_for_cost_calc.usage.completion_tokens
|
|
122
|
+
if completion_tokens is None:
|
|
123
|
+
logger.warning(
|
|
124
|
+
f"Completion tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
|
|
125
|
+
)
|
|
126
|
+
completion_tokens = 0
|
|
127
|
+
if total_tokens - prompt_tokens - completion_tokens != 0:
|
|
128
|
+
# This is most likely related to how portkey treats cached tokens: It doesn't count them towards the prompt tokens (?)
|
|
129
|
+
logger.warning(
|
|
130
|
+
f"WARNING: Total tokens - prompt tokens - completion tokens != 0: {response_for_cost_calc.model_dump()}."
|
|
131
|
+
" This is probably a portkey bug or incompatibility with litellm cost tracking. "
|
|
132
|
+
"Setting prompt tokens based on total tokens and completion tokens. You might want to double check your costs. "
|
|
133
|
+
f"Full response: {response_for_cost_calc.model_dump()}"
|
|
134
|
+
)
|
|
135
|
+
response_for_cost_calc.usage.prompt_tokens = total_tokens - completion_tokens
|
|
136
|
+
try:
|
|
137
|
+
cost = litellm.cost_calculator.completion_cost(
|
|
138
|
+
response_for_cost_calc, model=self.config.litellm_model_name_override or None
|
|
139
|
+
)
|
|
140
|
+
assert cost >= 0.0, f"Cost is negative: {cost}"
|
|
141
|
+
except Exception as e:
|
|
142
|
+
cost = 0.0
|
|
143
|
+
if self.config.cost_tracking != "ignore_errors":
|
|
144
|
+
msg = (
|
|
145
|
+
f"Error calculating cost for model {self.config.model_name} based on {response_for_cost_calc.model_dump()}: {e}. "
|
|
146
|
+
"You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
|
|
147
|
+
"globally with export MSWEA_COST_TRACKING='ignore_errors' to ignore this error. "
|
|
148
|
+
"Alternatively check the 'Cost tracking' section in the documentation at "
|
|
149
|
+
"https://klieret.short.gy/mini-local-models. "
|
|
150
|
+
"Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
|
|
151
|
+
)
|
|
152
|
+
logger.critical(msg)
|
|
153
|
+
raise RuntimeError(msg) from e
|
|
154
|
+
return cost
|