gac 3.6.0__py3-none-any.whl → 3.10.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gac/__init__.py +4 -6
- gac/__version__.py +1 -1
- gac/ai_utils.py +59 -43
- gac/auth_cli.py +181 -36
- gac/cli.py +26 -9
- gac/commit_executor.py +59 -0
- gac/config.py +81 -2
- gac/config_cli.py +19 -7
- gac/constants/__init__.py +34 -0
- gac/constants/commit.py +63 -0
- gac/constants/defaults.py +40 -0
- gac/constants/file_patterns.py +110 -0
- gac/constants/languages.py +119 -0
- gac/diff_cli.py +0 -22
- gac/errors.py +8 -2
- gac/git.py +6 -6
- gac/git_state_validator.py +193 -0
- gac/grouped_commit_workflow.py +458 -0
- gac/init_cli.py +2 -1
- gac/interactive_mode.py +179 -0
- gac/language_cli.py +0 -1
- gac/main.py +231 -926
- gac/model_cli.py +67 -11
- gac/model_identifier.py +70 -0
- gac/oauth/__init__.py +26 -0
- gac/oauth/claude_code.py +89 -22
- gac/oauth/qwen_oauth.py +327 -0
- gac/oauth/token_store.py +81 -0
- gac/oauth_retry.py +161 -0
- gac/postprocess.py +155 -0
- gac/prompt.py +21 -479
- gac/prompt_builder.py +88 -0
- gac/providers/README.md +437 -0
- gac/providers/__init__.py +70 -78
- gac/providers/anthropic.py +12 -46
- gac/providers/azure_openai.py +48 -88
- gac/providers/base.py +329 -0
- gac/providers/cerebras.py +10 -33
- gac/providers/chutes.py +16 -62
- gac/providers/claude_code.py +64 -87
- gac/providers/custom_anthropic.py +51 -81
- gac/providers/custom_openai.py +29 -83
- gac/providers/deepseek.py +10 -33
- gac/providers/error_handler.py +139 -0
- gac/providers/fireworks.py +10 -33
- gac/providers/gemini.py +66 -63
- gac/providers/groq.py +10 -58
- gac/providers/kimi_coding.py +19 -55
- gac/providers/lmstudio.py +64 -43
- gac/providers/minimax.py +10 -33
- gac/providers/mistral.py +10 -33
- gac/providers/moonshot.py +10 -33
- gac/providers/ollama.py +56 -33
- gac/providers/openai.py +30 -36
- gac/providers/openrouter.py +15 -52
- gac/providers/protocol.py +71 -0
- gac/providers/qwen.py +64 -0
- gac/providers/registry.py +58 -0
- gac/providers/replicate.py +140 -82
- gac/providers/streamlake.py +26 -46
- gac/providers/synthetic.py +35 -37
- gac/providers/together.py +10 -33
- gac/providers/zai.py +29 -57
- gac/py.typed +0 -0
- gac/security.py +1 -1
- gac/templates/__init__.py +1 -0
- gac/templates/question_generation.txt +60 -0
- gac/templates/system_prompt.txt +224 -0
- gac/templates/user_prompt.txt +28 -0
- gac/utils.py +36 -6
- gac/workflow_context.py +162 -0
- gac/workflow_utils.py +3 -8
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/METADATA +6 -4
- gac-3.10.10.dist-info/RECORD +79 -0
- gac/constants.py +0 -321
- gac-3.6.0.dist-info/RECORD +0 -53
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/WHEEL +0 -0
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/entry_points.txt +0 -0
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/licenses/LICENSE +0 -0
gac/providers/azure_openai.py
CHANGED
|
@@ -4,94 +4,54 @@ This provider provides native support for Azure OpenAI Service with proper
|
|
|
4
4
|
endpoint construction and API version handling.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
import json
|
|
8
|
-
import logging
|
|
9
7
|
import os
|
|
10
|
-
|
|
11
|
-
import httpx
|
|
8
|
+
from typing import Any
|
|
12
9
|
|
|
13
10
|
from gac.errors import AIError
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
AZURE_OPENAI_API_KEY
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
try:
|
|
62
|
-
response = httpx.post(url, headers=headers, json=data, timeout=120)
|
|
63
|
-
response.raise_for_status()
|
|
64
|
-
response_data = response.json()
|
|
65
|
-
|
|
66
|
-
try:
|
|
67
|
-
content = response_data["choices"][0]["message"]["content"]
|
|
68
|
-
except (KeyError, IndexError, TypeError) as e:
|
|
69
|
-
logger.error(f"Unexpected response format from Azure OpenAI API. Response: {json.dumps(response_data)}")
|
|
70
|
-
raise AIError.model_error(
|
|
71
|
-
f"Azure OpenAI API returned unexpected format. Expected response with "
|
|
72
|
-
f"'choices[0].message.content', but got: {type(e).__name__}. Check logs for full response structure."
|
|
73
|
-
) from e
|
|
74
|
-
|
|
75
|
-
if content is None:
|
|
76
|
-
raise AIError.model_error("Azure OpenAI API returned null content")
|
|
77
|
-
if content == "":
|
|
78
|
-
raise AIError.model_error("Azure OpenAI API returned empty content")
|
|
79
|
-
return content
|
|
80
|
-
except httpx.ConnectError as e:
|
|
81
|
-
raise AIError.connection_error(f"Azure OpenAI API connection failed: {str(e)}") from e
|
|
82
|
-
except httpx.HTTPStatusError as e:
|
|
83
|
-
status_code = e.response.status_code
|
|
84
|
-
error_text = e.response.text
|
|
85
|
-
|
|
86
|
-
if status_code == 401:
|
|
87
|
-
raise AIError.authentication_error(f"Azure OpenAI API authentication failed: {error_text}") from e
|
|
88
|
-
elif status_code == 429:
|
|
89
|
-
raise AIError.rate_limit_error(f"Azure OpenAI API rate limit exceeded: {error_text}") from e
|
|
90
|
-
else:
|
|
91
|
-
raise AIError.model_error(f"Azure OpenAI API error: {status_code} - {error_text}") from e
|
|
92
|
-
except httpx.TimeoutException as e:
|
|
93
|
-
raise AIError.timeout_error(f"Azure OpenAI API request timed out: {str(e)}") from e
|
|
94
|
-
except AIError:
|
|
95
|
-
raise
|
|
96
|
-
except Exception as e:
|
|
97
|
-
raise AIError.model_error(f"Error calling Azure OpenAI API: {str(e)}") from e
|
|
11
|
+
from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AzureOpenAIProvider(OpenAICompatibleProvider):
|
|
15
|
+
"""Azure OpenAI-compatible provider with custom URL construction and headers."""
|
|
16
|
+
|
|
17
|
+
config = ProviderConfig(
|
|
18
|
+
name="Azure OpenAI",
|
|
19
|
+
api_key_env="AZURE_OPENAI_API_KEY",
|
|
20
|
+
base_url="", # Will be set in __init__
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
def __init__(self, config: ProviderConfig):
|
|
24
|
+
"""Initialize with Azure-specific endpoint and API version."""
|
|
25
|
+
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
|
26
|
+
if not endpoint:
|
|
27
|
+
raise AIError.model_error("AZURE_OPENAI_ENDPOINT environment variable not set")
|
|
28
|
+
|
|
29
|
+
api_version = os.getenv("AZURE_OPENAI_API_VERSION")
|
|
30
|
+
if not api_version:
|
|
31
|
+
raise AIError.model_error("AZURE_OPENAI_API_VERSION environment variable not set")
|
|
32
|
+
|
|
33
|
+
self.api_version = api_version
|
|
34
|
+
self.endpoint = endpoint.rstrip("/")
|
|
35
|
+
config.base_url = "" # Will be set dynamically in _get_api_url
|
|
36
|
+
super().__init__(config)
|
|
37
|
+
|
|
38
|
+
def _get_api_url(self, model: str | None = None) -> str:
|
|
39
|
+
"""Build Azure-specific URL with deployment name and API version."""
|
|
40
|
+
if model is None:
|
|
41
|
+
return super()._get_api_url(model)
|
|
42
|
+
return f"{self.endpoint}/openai/deployments/{model}/chat/completions?api-version={self.api_version}"
|
|
43
|
+
|
|
44
|
+
def _build_headers(self) -> dict[str, str]:
|
|
45
|
+
"""Build headers with api-key instead of Bearer token."""
|
|
46
|
+
headers = super()._build_headers()
|
|
47
|
+
# Replace Bearer token with api-key
|
|
48
|
+
if "Authorization" in headers:
|
|
49
|
+
del headers["Authorization"]
|
|
50
|
+
headers["api-key"] = self.api_key
|
|
51
|
+
return headers
|
|
52
|
+
|
|
53
|
+
def _build_request_body(
|
|
54
|
+
self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
|
|
55
|
+
) -> dict[str, Any]:
|
|
56
|
+
"""Build request body for Azure OpenAI."""
|
|
57
|
+
return {"messages": messages, "temperature": temperature, "max_tokens": max_tokens, **kwargs}
|
gac/providers/base.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""Base configured provider class to eliminate code duplication."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from gac.constants import ProviderDefaults
|
|
12
|
+
from gac.errors import AIError
|
|
13
|
+
from gac.providers.protocol import ProviderProtocol
|
|
14
|
+
from gac.utils import get_ssl_verify
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ProviderConfig:
|
|
21
|
+
"""Configuration for AI providers."""
|
|
22
|
+
|
|
23
|
+
name: str
|
|
24
|
+
api_key_env: str
|
|
25
|
+
base_url: str
|
|
26
|
+
timeout: int = ProviderDefaults.HTTP_TIMEOUT
|
|
27
|
+
headers: dict[str, str] | None = None
|
|
28
|
+
|
|
29
|
+
def __post_init__(self) -> None:
|
|
30
|
+
"""Initialize default headers if not provided."""
|
|
31
|
+
if self.headers is None:
|
|
32
|
+
self.headers = {"Content-Type": "application/json"}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BaseConfiguredProvider(ABC, ProviderProtocol):
|
|
36
|
+
"""Base class for configured AI providers.
|
|
37
|
+
|
|
38
|
+
This class eliminates code duplication by providing:
|
|
39
|
+
- Standardized HTTP handling with httpx
|
|
40
|
+
- Common error handling patterns
|
|
41
|
+
- Flexible configuration via ProviderConfig
|
|
42
|
+
- Template methods for customization
|
|
43
|
+
|
|
44
|
+
Implements ProviderProtocol for type safety.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, config: ProviderConfig):
|
|
48
|
+
self.config = config
|
|
49
|
+
self._api_key: str | None = None # Lazy load
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def api_key(self) -> str:
|
|
53
|
+
"""Lazy-load API key when needed."""
|
|
54
|
+
if self.config.api_key_env:
|
|
55
|
+
# Always check environment for fresh value to support test isolation
|
|
56
|
+
return self._get_api_key()
|
|
57
|
+
return ""
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def name(self) -> str:
|
|
61
|
+
"""Get the provider name."""
|
|
62
|
+
return self.config.name
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def api_key_env(self) -> str:
|
|
66
|
+
"""Get the environment variable name for the API key."""
|
|
67
|
+
return self.config.api_key_env
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def base_url(self) -> str:
|
|
71
|
+
"""Get the base URL for the API."""
|
|
72
|
+
return self.config.base_url
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def timeout(self) -> int:
|
|
76
|
+
"""Get the timeout in seconds."""
|
|
77
|
+
return self.config.timeout
|
|
78
|
+
|
|
79
|
+
def _get_api_key(self) -> str:
|
|
80
|
+
"""Get API key from environment variables."""
|
|
81
|
+
api_key = os.getenv(self.config.api_key_env)
|
|
82
|
+
if not api_key:
|
|
83
|
+
raise AIError.authentication_error(f"{self.config.api_key_env} not found in environment variables")
|
|
84
|
+
return api_key
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def _build_request_body(
|
|
88
|
+
self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
|
|
89
|
+
) -> dict[str, Any]:
|
|
90
|
+
"""Build the request body for the API call.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
messages: List of message dictionaries
|
|
94
|
+
temperature: Temperature parameter
|
|
95
|
+
max_tokens: Maximum tokens in response
|
|
96
|
+
**kwargs: Additional provider-specific parameters
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Request body dictionary
|
|
100
|
+
"""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
@abstractmethod
|
|
104
|
+
def _parse_response(self, response: dict[str, Any]) -> str:
|
|
105
|
+
"""Parse the API response and extract content.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
response: Response dictionary from API
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Generated text content
|
|
112
|
+
"""
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
def _build_headers(self) -> dict[str, str]:
|
|
116
|
+
"""Build headers for the API request.
|
|
117
|
+
|
|
118
|
+
Can be overridden by subclasses to add provider-specific headers.
|
|
119
|
+
"""
|
|
120
|
+
headers = self.config.headers.copy() if self.config.headers else {}
|
|
121
|
+
return headers
|
|
122
|
+
|
|
123
|
+
def _get_api_url(self, model: str | None = None) -> str:
|
|
124
|
+
"""Get the API URL for the request.
|
|
125
|
+
|
|
126
|
+
Can be overridden by subclasses for dynamic URLs.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
model: Model name (for providers that need model-specific URLs)
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
API URL string
|
|
133
|
+
"""
|
|
134
|
+
return self.config.base_url
|
|
135
|
+
|
|
136
|
+
def _make_http_request(self, url: str, body: dict[str, Any], headers: dict[str, str]) -> dict[str, Any]:
|
|
137
|
+
"""Make the HTTP request.
|
|
138
|
+
|
|
139
|
+
Error handling is delegated to the @handle_provider_errors decorator
|
|
140
|
+
which wraps the provider's API function. This avoids duplicate exception
|
|
141
|
+
handling and ensures consistent error classification across all providers.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
url: API URL
|
|
145
|
+
body: Request body
|
|
146
|
+
headers: Request headers
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Response JSON dictionary
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
httpx.HTTPStatusError: For HTTP errors (handled by decorator)
|
|
153
|
+
httpx.TimeoutException: For timeout errors (handled by decorator)
|
|
154
|
+
httpx.RequestError: For network errors (handled by decorator)
|
|
155
|
+
"""
|
|
156
|
+
response = httpx.post(url, json=body, headers=headers, timeout=self.config.timeout, verify=get_ssl_verify())
|
|
157
|
+
response.raise_for_status()
|
|
158
|
+
return response.json()
|
|
159
|
+
|
|
160
|
+
def generate(
|
|
161
|
+
self,
|
|
162
|
+
model: str,
|
|
163
|
+
messages: list[dict[str, Any]],
|
|
164
|
+
temperature: float = 0.7,
|
|
165
|
+
max_tokens: int = 1024,
|
|
166
|
+
**kwargs: Any,
|
|
167
|
+
) -> str:
|
|
168
|
+
"""Generate text using the AI provider.
|
|
169
|
+
|
|
170
|
+
Error handling is delegated to the @handle_provider_errors decorator
|
|
171
|
+
which wraps the provider's API function. This ensures consistent error
|
|
172
|
+
classification across all providers.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
model: Model name to use
|
|
176
|
+
messages: List of message dictionaries
|
|
177
|
+
temperature: Temperature parameter (0.0-2.0)
|
|
178
|
+
max_tokens: Maximum tokens in response
|
|
179
|
+
**kwargs: Additional provider-specific parameters
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Generated text content
|
|
183
|
+
|
|
184
|
+
Raises:
|
|
185
|
+
AIError: For any API-related errors (via decorator)
|
|
186
|
+
"""
|
|
187
|
+
logger.debug(f"Generating with {self.config.name} provider (model={model})")
|
|
188
|
+
|
|
189
|
+
# Build request components
|
|
190
|
+
url = self._get_api_url(model)
|
|
191
|
+
headers = self._build_headers()
|
|
192
|
+
body = self._build_request_body(messages, temperature, max_tokens, model, **kwargs)
|
|
193
|
+
|
|
194
|
+
# Add model to body if not already present
|
|
195
|
+
if "model" not in body:
|
|
196
|
+
body["model"] = model
|
|
197
|
+
|
|
198
|
+
# Make HTTP request
|
|
199
|
+
response_data = self._make_http_request(url, body, headers)
|
|
200
|
+
|
|
201
|
+
# Parse response
|
|
202
|
+
return self._parse_response(response_data)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class OpenAICompatibleProvider(BaseConfiguredProvider):
|
|
206
|
+
"""Base class for OpenAI-compatible providers.
|
|
207
|
+
|
|
208
|
+
Handles standard OpenAI API format with minimal customization needed.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def _build_request_body(
|
|
212
|
+
self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
|
|
213
|
+
) -> dict[str, Any]:
|
|
214
|
+
"""Build OpenAI-style request body.
|
|
215
|
+
|
|
216
|
+
Note: Subclasses should override this if they need max_completion_tokens
|
|
217
|
+
instead of max_tokens (like OpenAI provider does).
|
|
218
|
+
"""
|
|
219
|
+
return {"messages": messages, "temperature": temperature, "max_tokens": max_tokens, **kwargs}
|
|
220
|
+
|
|
221
|
+
def _build_headers(self) -> dict[str, str]:
|
|
222
|
+
"""Build headers with OpenAI-style authorization."""
|
|
223
|
+
headers = super()._build_headers()
|
|
224
|
+
if self.api_key:
|
|
225
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
226
|
+
return headers
|
|
227
|
+
|
|
228
|
+
def _parse_response(self, response: dict[str, Any]) -> str:
|
|
229
|
+
"""Parse OpenAI-style response."""
|
|
230
|
+
choices = response.get("choices")
|
|
231
|
+
if not choices or not isinstance(choices, list):
|
|
232
|
+
raise AIError.model_error("Invalid response: missing choices")
|
|
233
|
+
content = choices[0].get("message", {}).get("content")
|
|
234
|
+
if content is None:
|
|
235
|
+
raise AIError.model_error("Invalid response: null content")
|
|
236
|
+
if content == "":
|
|
237
|
+
raise AIError.model_error("Invalid response: empty content")
|
|
238
|
+
return content
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class AnthropicCompatibleProvider(BaseConfiguredProvider):
|
|
242
|
+
"""Base class for Anthropic-compatible providers."""
|
|
243
|
+
|
|
244
|
+
def _build_headers(self) -> dict[str, str]:
|
|
245
|
+
"""Build headers with Anthropic-style authorization."""
|
|
246
|
+
headers = super()._build_headers()
|
|
247
|
+
api_key = self._get_api_key()
|
|
248
|
+
headers["x-api-key"] = api_key
|
|
249
|
+
headers["anthropic-version"] = "2023-06-01"
|
|
250
|
+
return headers
|
|
251
|
+
|
|
252
|
+
def _build_request_body(
|
|
253
|
+
self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
|
|
254
|
+
) -> dict[str, Any]:
|
|
255
|
+
"""Build Anthropic-style request body."""
|
|
256
|
+
# Convert messages to Anthropic format
|
|
257
|
+
anthropic_messages = []
|
|
258
|
+
system_message = ""
|
|
259
|
+
|
|
260
|
+
for msg in messages:
|
|
261
|
+
if msg["role"] == "system":
|
|
262
|
+
system_message = msg["content"]
|
|
263
|
+
else:
|
|
264
|
+
anthropic_messages.append({"role": msg["role"], "content": msg["content"]})
|
|
265
|
+
|
|
266
|
+
body = {"messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens, **kwargs}
|
|
267
|
+
|
|
268
|
+
if system_message:
|
|
269
|
+
body["system"] = system_message
|
|
270
|
+
|
|
271
|
+
return body
|
|
272
|
+
|
|
273
|
+
def _parse_response(self, response: dict[str, Any]) -> str:
|
|
274
|
+
"""Parse Anthropic-style response."""
|
|
275
|
+
content = response.get("content")
|
|
276
|
+
if not content or not isinstance(content, list):
|
|
277
|
+
raise AIError.model_error("Invalid response: missing content")
|
|
278
|
+
|
|
279
|
+
text_content = content[0].get("text")
|
|
280
|
+
if text_content is None:
|
|
281
|
+
raise AIError.model_error("Invalid response: null content")
|
|
282
|
+
if text_content == "":
|
|
283
|
+
raise AIError.model_error("Invalid response: empty content")
|
|
284
|
+
return text_content
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class GenericHTTPProvider(BaseConfiguredProvider):
|
|
288
|
+
"""Base class for completely custom providers."""
|
|
289
|
+
|
|
290
|
+
def _build_request_body(
|
|
291
|
+
self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
|
|
292
|
+
) -> dict[str, Any]:
|
|
293
|
+
"""Default implementation - override this in subclasses."""
|
|
294
|
+
return {"messages": messages, "temperature": temperature, "max_tokens": max_tokens, **kwargs}
|
|
295
|
+
|
|
296
|
+
def _parse_response(self, response: dict[str, Any]) -> str:
|
|
297
|
+
"""Default implementation - override this in subclasses."""
|
|
298
|
+
# Try OpenAI-style first
|
|
299
|
+
choices = response.get("choices")
|
|
300
|
+
if choices and isinstance(choices, list):
|
|
301
|
+
content = choices[0].get("message", {}).get("content")
|
|
302
|
+
if content:
|
|
303
|
+
return content
|
|
304
|
+
|
|
305
|
+
# Try Anthropic-style
|
|
306
|
+
content = response.get("content")
|
|
307
|
+
if content and isinstance(content, list):
|
|
308
|
+
return content[0].get("text", "")
|
|
309
|
+
|
|
310
|
+
# Try Ollama-style
|
|
311
|
+
message = response.get("message", {})
|
|
312
|
+
if "content" in message:
|
|
313
|
+
return message["content"]
|
|
314
|
+
|
|
315
|
+
# Fallback - try to find any string content
|
|
316
|
+
for value in response.values():
|
|
317
|
+
if isinstance(value, str) and len(value) > 10: # Assume longer strings are content
|
|
318
|
+
return value
|
|
319
|
+
|
|
320
|
+
raise AIError.model_error("Could not extract content from response")
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
__all__ = [
|
|
324
|
+
"AnthropicCompatibleProvider",
|
|
325
|
+
"BaseConfiguredProvider",
|
|
326
|
+
"GenericHTTPProvider",
|
|
327
|
+
"OpenAICompatibleProvider",
|
|
328
|
+
"ProviderConfig",
|
|
329
|
+
]
|
gac/providers/cerebras.py
CHANGED
|
@@ -1,38 +1,15 @@
|
|
|
1
1
|
"""Cerebras AI provider implementation."""
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
|
|
4
4
|
|
|
5
|
-
import httpx
|
|
6
5
|
|
|
7
|
-
|
|
6
|
+
class CerebrasProvider(OpenAICompatibleProvider):
|
|
7
|
+
config = ProviderConfig(
|
|
8
|
+
name="Cerebras",
|
|
9
|
+
api_key_env="CEREBRAS_API_KEY",
|
|
10
|
+
base_url="https://api.cerebras.ai/v1",
|
|
11
|
+
)
|
|
8
12
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
api_key = os.getenv("CEREBRAS_API_KEY")
|
|
13
|
-
if not api_key:
|
|
14
|
-
raise AIError.authentication_error("CEREBRAS_API_KEY not found in environment variables")
|
|
15
|
-
|
|
16
|
-
url = "https://api.cerebras.ai/v1/chat/completions"
|
|
17
|
-
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
18
|
-
|
|
19
|
-
data = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens}
|
|
20
|
-
|
|
21
|
-
try:
|
|
22
|
-
response = httpx.post(url, headers=headers, json=data, timeout=120)
|
|
23
|
-
response.raise_for_status()
|
|
24
|
-
response_data = response.json()
|
|
25
|
-
content = response_data["choices"][0]["message"]["content"]
|
|
26
|
-
if content is None:
|
|
27
|
-
raise AIError.model_error("Cerebras API returned null content")
|
|
28
|
-
if content == "":
|
|
29
|
-
raise AIError.model_error("Cerebras API returned empty content")
|
|
30
|
-
return content
|
|
31
|
-
except httpx.HTTPStatusError as e:
|
|
32
|
-
if e.response.status_code == 429:
|
|
33
|
-
raise AIError.rate_limit_error(f"Cerebras API rate limit exceeded: {e.response.text}") from e
|
|
34
|
-
raise AIError.model_error(f"Cerebras API error: {e.response.status_code} - {e.response.text}") from e
|
|
35
|
-
except httpx.TimeoutException as e:
|
|
36
|
-
raise AIError.timeout_error(f"Cerebras API request timed out: {str(e)}") from e
|
|
37
|
-
except Exception as e:
|
|
38
|
-
raise AIError.model_error(f"Error calling Cerebras API: {str(e)}") from e
|
|
13
|
+
def _get_api_url(self, model: str | None = None) -> str:
|
|
14
|
+
"""Get Cerebras API URL with /chat/completions endpoint."""
|
|
15
|
+
return f"{self.config.base_url}/chat/completions"
|
gac/providers/chutes.py
CHANGED
|
@@ -2,70 +2,24 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
|
|
6
6
|
|
|
7
|
-
from gac.errors import AIError
|
|
8
7
|
|
|
8
|
+
class ChutesProvider(OpenAICompatibleProvider):
|
|
9
|
+
"""Chutes.ai OpenAI-compatible provider with custom base URL."""
|
|
9
10
|
|
|
10
|
-
|
|
11
|
-
|
|
11
|
+
config = ProviderConfig(
|
|
12
|
+
name="Chutes",
|
|
13
|
+
api_key_env="CHUTES_API_KEY",
|
|
14
|
+
base_url="", # Will be set in __init__
|
|
15
|
+
)
|
|
12
16
|
|
|
13
|
-
|
|
17
|
+
def __init__(self, config: ProviderConfig):
|
|
18
|
+
"""Initialize with base URL from environment or default."""
|
|
19
|
+
base_url = os.getenv("CHUTES_BASE_URL", "https://llm.chutes.ai")
|
|
20
|
+
config.base_url = f"{base_url.rstrip('/')}/v1"
|
|
21
|
+
super().__init__(config)
|
|
14
22
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
temperature: Controls randomness (0.0-1.0)
|
|
19
|
-
max_tokens: Maximum tokens in the response
|
|
20
|
-
|
|
21
|
-
Returns:
|
|
22
|
-
The generated commit message
|
|
23
|
-
|
|
24
|
-
Raises:
|
|
25
|
-
AIError: If authentication fails, API errors occur, or response is invalid
|
|
26
|
-
"""
|
|
27
|
-
api_key = os.getenv("CHUTES_API_KEY")
|
|
28
|
-
if not api_key:
|
|
29
|
-
raise AIError.authentication_error("CHUTES_API_KEY environment variable not set")
|
|
30
|
-
|
|
31
|
-
base_url = os.getenv("CHUTES_BASE_URL", "https://llm.chutes.ai")
|
|
32
|
-
url = f"{base_url}/v1/chat/completions"
|
|
33
|
-
|
|
34
|
-
headers = {
|
|
35
|
-
"Content-Type": "application/json",
|
|
36
|
-
"Authorization": f"Bearer {api_key}",
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
data = {
|
|
40
|
-
"model": model,
|
|
41
|
-
"messages": messages,
|
|
42
|
-
"temperature": temperature,
|
|
43
|
-
"max_tokens": max_tokens,
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
response = httpx.post(url, headers=headers, json=data, timeout=120)
|
|
48
|
-
response.raise_for_status()
|
|
49
|
-
response_data = response.json()
|
|
50
|
-
content = response_data["choices"][0]["message"]["content"]
|
|
51
|
-
if content is None:
|
|
52
|
-
raise AIError.model_error("Chutes.ai API returned null content")
|
|
53
|
-
if content == "":
|
|
54
|
-
raise AIError.model_error("Chutes.ai API returned empty content")
|
|
55
|
-
return content
|
|
56
|
-
except httpx.HTTPStatusError as e:
|
|
57
|
-
status_code = e.response.status_code
|
|
58
|
-
error_text = e.response.text
|
|
59
|
-
|
|
60
|
-
if status_code == 429:
|
|
61
|
-
raise AIError.rate_limit_error(f"Chutes.ai API rate limit exceeded: {error_text}") from e
|
|
62
|
-
elif status_code in (502, 503):
|
|
63
|
-
raise AIError.connection_error(f"Chutes.ai API service unavailable: {status_code} - {error_text}") from e
|
|
64
|
-
else:
|
|
65
|
-
raise AIError.model_error(f"Chutes.ai API error: {status_code} - {error_text}") from e
|
|
66
|
-
except httpx.ConnectError as e:
|
|
67
|
-
raise AIError.connection_error(f"Chutes.ai API connection error: {str(e)}") from e
|
|
68
|
-
except httpx.TimeoutException as e:
|
|
69
|
-
raise AIError.timeout_error(f"Chutes.ai API request timed out: {str(e)}") from e
|
|
70
|
-
except Exception as e:
|
|
71
|
-
raise AIError.model_error(f"Error calling Chutes.ai API: {str(e)}") from e
|
|
23
|
+
def _get_api_url(self, model: str | None = None) -> str:
|
|
24
|
+
"""Get Chutes API URL with /chat/completions endpoint."""
|
|
25
|
+
return f"{self.config.base_url}/chat/completions"
|