gac 3.6.0__py3-none-any.whl → 3.10.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. gac/__init__.py +4 -6
  2. gac/__version__.py +1 -1
  3. gac/ai_utils.py +59 -43
  4. gac/auth_cli.py +181 -36
  5. gac/cli.py +26 -9
  6. gac/commit_executor.py +59 -0
  7. gac/config.py +81 -2
  8. gac/config_cli.py +19 -7
  9. gac/constants/__init__.py +34 -0
  10. gac/constants/commit.py +63 -0
  11. gac/constants/defaults.py +40 -0
  12. gac/constants/file_patterns.py +110 -0
  13. gac/constants/languages.py +119 -0
  14. gac/diff_cli.py +0 -22
  15. gac/errors.py +8 -2
  16. gac/git.py +6 -6
  17. gac/git_state_validator.py +193 -0
  18. gac/grouped_commit_workflow.py +458 -0
  19. gac/init_cli.py +2 -1
  20. gac/interactive_mode.py +179 -0
  21. gac/language_cli.py +0 -1
  22. gac/main.py +231 -926
  23. gac/model_cli.py +67 -11
  24. gac/model_identifier.py +70 -0
  25. gac/oauth/__init__.py +26 -0
  26. gac/oauth/claude_code.py +89 -22
  27. gac/oauth/qwen_oauth.py +327 -0
  28. gac/oauth/token_store.py +81 -0
  29. gac/oauth_retry.py +161 -0
  30. gac/postprocess.py +155 -0
  31. gac/prompt.py +21 -479
  32. gac/prompt_builder.py +88 -0
  33. gac/providers/README.md +437 -0
  34. gac/providers/__init__.py +70 -78
  35. gac/providers/anthropic.py +12 -46
  36. gac/providers/azure_openai.py +48 -88
  37. gac/providers/base.py +329 -0
  38. gac/providers/cerebras.py +10 -33
  39. gac/providers/chutes.py +16 -62
  40. gac/providers/claude_code.py +64 -87
  41. gac/providers/custom_anthropic.py +51 -81
  42. gac/providers/custom_openai.py +29 -83
  43. gac/providers/deepseek.py +10 -33
  44. gac/providers/error_handler.py +139 -0
  45. gac/providers/fireworks.py +10 -33
  46. gac/providers/gemini.py +66 -63
  47. gac/providers/groq.py +10 -58
  48. gac/providers/kimi_coding.py +19 -55
  49. gac/providers/lmstudio.py +64 -43
  50. gac/providers/minimax.py +10 -33
  51. gac/providers/mistral.py +10 -33
  52. gac/providers/moonshot.py +10 -33
  53. gac/providers/ollama.py +56 -33
  54. gac/providers/openai.py +30 -36
  55. gac/providers/openrouter.py +15 -52
  56. gac/providers/protocol.py +71 -0
  57. gac/providers/qwen.py +64 -0
  58. gac/providers/registry.py +58 -0
  59. gac/providers/replicate.py +140 -82
  60. gac/providers/streamlake.py +26 -46
  61. gac/providers/synthetic.py +35 -37
  62. gac/providers/together.py +10 -33
  63. gac/providers/zai.py +29 -57
  64. gac/py.typed +0 -0
  65. gac/security.py +1 -1
  66. gac/templates/__init__.py +1 -0
  67. gac/templates/question_generation.txt +60 -0
  68. gac/templates/system_prompt.txt +224 -0
  69. gac/templates/user_prompt.txt +28 -0
  70. gac/utils.py +36 -6
  71. gac/workflow_context.py +162 -0
  72. gac/workflow_utils.py +3 -8
  73. {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/METADATA +6 -4
  74. gac-3.10.10.dist-info/RECORD +79 -0
  75. gac/constants.py +0 -321
  76. gac-3.6.0.dist-info/RECORD +0 -53
  77. {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/WHEEL +0 -0
  78. {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/entry_points.txt +0 -0
  79. {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/licenses/LICENSE +0 -0
@@ -4,94 +4,54 @@ This provider provides native support for Azure OpenAI Service with proper
4
4
  endpoint construction and API version handling.
5
5
  """
6
6
 
7
- import json
8
- import logging
9
7
  import os
10
-
11
- import httpx
8
+ from typing import Any
12
9
 
13
10
  from gac.errors import AIError
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- def call_azure_openai_api(model: str, messages: list[dict], temperature: float, max_tokens: int) -> str:
19
- """Call Azure OpenAI Service API.
20
-
21
- Environment variables:
22
- AZURE_OPENAI_API_KEY: Azure OpenAI API key (required)
23
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL (required)
24
- Example: https://your-resource.openai.azure.com
25
- AZURE_OPENAI_API_VERSION: Azure OpenAI API version (required)
26
- Example: 2025-01-01-preview
27
- Example: 2024-02-15-preview
28
-
29
- Args:
30
- model: The deployment name in Azure OpenAI (e.g., 'gpt-4o', 'gpt-35-turbo')
31
- messages: List of message dictionaries with 'role' and 'content' keys
32
- temperature: Controls randomness (0.0-1.0)
33
- max_tokens: Maximum tokens in the response
34
-
35
- Returns:
36
- The generated commit message
37
-
38
- Raises:
39
- AIError: If authentication fails, API errors occur, or response is invalid
40
- """
41
- api_key = os.getenv("AZURE_OPENAI_API_KEY")
42
- if not api_key:
43
- raise AIError.authentication_error("AZURE_OPENAI_API_KEY environment variable not set")
44
-
45
- endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
46
- if not endpoint:
47
- raise AIError.model_error("AZURE_OPENAI_ENDPOINT environment variable not set")
48
-
49
- api_version = os.getenv("AZURE_OPENAI_API_VERSION")
50
- if not api_version:
51
- raise AIError.model_error("AZURE_OPENAI_API_VERSION environment variable not set")
52
-
53
- # Build Azure OpenAI URL with proper structure
54
- endpoint = endpoint.rstrip("/")
55
- url = f"{endpoint}/openai/deployments/{model}/chat/completions?api-version={api_version}"
56
-
57
- headers = {"api-key": api_key, "Content-Type": "application/json"}
58
-
59
- data = {"messages": messages, "temperature": temperature, "max_tokens": max_tokens}
60
-
61
- try:
62
- response = httpx.post(url, headers=headers, json=data, timeout=120)
63
- response.raise_for_status()
64
- response_data = response.json()
65
-
66
- try:
67
- content = response_data["choices"][0]["message"]["content"]
68
- except (KeyError, IndexError, TypeError) as e:
69
- logger.error(f"Unexpected response format from Azure OpenAI API. Response: {json.dumps(response_data)}")
70
- raise AIError.model_error(
71
- f"Azure OpenAI API returned unexpected format. Expected response with "
72
- f"'choices[0].message.content', but got: {type(e).__name__}. Check logs for full response structure."
73
- ) from e
74
-
75
- if content is None:
76
- raise AIError.model_error("Azure OpenAI API returned null content")
77
- if content == "":
78
- raise AIError.model_error("Azure OpenAI API returned empty content")
79
- return content
80
- except httpx.ConnectError as e:
81
- raise AIError.connection_error(f"Azure OpenAI API connection failed: {str(e)}") from e
82
- except httpx.HTTPStatusError as e:
83
- status_code = e.response.status_code
84
- error_text = e.response.text
85
-
86
- if status_code == 401:
87
- raise AIError.authentication_error(f"Azure OpenAI API authentication failed: {error_text}") from e
88
- elif status_code == 429:
89
- raise AIError.rate_limit_error(f"Azure OpenAI API rate limit exceeded: {error_text}") from e
90
- else:
91
- raise AIError.model_error(f"Azure OpenAI API error: {status_code} - {error_text}") from e
92
- except httpx.TimeoutException as e:
93
- raise AIError.timeout_error(f"Azure OpenAI API request timed out: {str(e)}") from e
94
- except AIError:
95
- raise
96
- except Exception as e:
97
- raise AIError.model_error(f"Error calling Azure OpenAI API: {str(e)}") from e
11
+ from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
12
+
13
+
14
+ class AzureOpenAIProvider(OpenAICompatibleProvider):
15
+ """Azure OpenAI-compatible provider with custom URL construction and headers."""
16
+
17
+ config = ProviderConfig(
18
+ name="Azure OpenAI",
19
+ api_key_env="AZURE_OPENAI_API_KEY",
20
+ base_url="", # Will be set in __init__
21
+ )
22
+
23
+ def __init__(self, config: ProviderConfig):
24
+ """Initialize with Azure-specific endpoint and API version."""
25
+ endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
26
+ if not endpoint:
27
+ raise AIError.model_error("AZURE_OPENAI_ENDPOINT environment variable not set")
28
+
29
+ api_version = os.getenv("AZURE_OPENAI_API_VERSION")
30
+ if not api_version:
31
+ raise AIError.model_error("AZURE_OPENAI_API_VERSION environment variable not set")
32
+
33
+ self.api_version = api_version
34
+ self.endpoint = endpoint.rstrip("/")
35
+ config.base_url = "" # Will be set dynamically in _get_api_url
36
+ super().__init__(config)
37
+
38
+ def _get_api_url(self, model: str | None = None) -> str:
39
+ """Build Azure-specific URL with deployment name and API version."""
40
+ if model is None:
41
+ return super()._get_api_url(model)
42
+ return f"{self.endpoint}/openai/deployments/{model}/chat/completions?api-version={self.api_version}"
43
+
44
+ def _build_headers(self) -> dict[str, str]:
45
+ """Build headers with api-key instead of Bearer token."""
46
+ headers = super()._build_headers()
47
+ # Replace Bearer token with api-key
48
+ if "Authorization" in headers:
49
+ del headers["Authorization"]
50
+ headers["api-key"] = self.api_key
51
+ return headers
52
+
53
+ def _build_request_body(
54
+ self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
55
+ ) -> dict[str, Any]:
56
+ """Build request body for Azure OpenAI."""
57
+ return {"messages": messages, "temperature": temperature, "max_tokens": max_tokens, **kwargs}
gac/providers/base.py ADDED
@@ -0,0 +1,329 @@
1
+ """Base configured provider class to eliminate code duplication."""
2
+
3
+ import logging
4
+ import os
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
7
+ from typing import Any
8
+
9
+ import httpx
10
+
11
+ from gac.constants import ProviderDefaults
12
+ from gac.errors import AIError
13
+ from gac.providers.protocol import ProviderProtocol
14
+ from gac.utils import get_ssl_verify
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class ProviderConfig:
21
+ """Configuration for AI providers."""
22
+
23
+ name: str
24
+ api_key_env: str
25
+ base_url: str
26
+ timeout: int = ProviderDefaults.HTTP_TIMEOUT
27
+ headers: dict[str, str] | None = None
28
+
29
+ def __post_init__(self) -> None:
30
+ """Initialize default headers if not provided."""
31
+ if self.headers is None:
32
+ self.headers = {"Content-Type": "application/json"}
33
+
34
+
35
+ class BaseConfiguredProvider(ABC, ProviderProtocol):
36
+ """Base class for configured AI providers.
37
+
38
+ This class eliminates code duplication by providing:
39
+ - Standardized HTTP handling with httpx
40
+ - Common error handling patterns
41
+ - Flexible configuration via ProviderConfig
42
+ - Template methods for customization
43
+
44
+ Implements ProviderProtocol for type safety.
45
+ """
46
+
47
+ def __init__(self, config: ProviderConfig):
48
+ self.config = config
49
+ self._api_key: str | None = None # Lazy load
50
+
51
+ @property
52
+ def api_key(self) -> str:
53
+ """Lazy-load API key when needed."""
54
+ if self.config.api_key_env:
55
+ # Always check environment for fresh value to support test isolation
56
+ return self._get_api_key()
57
+ return ""
58
+
59
+ @property
60
+ def name(self) -> str:
61
+ """Get the provider name."""
62
+ return self.config.name
63
+
64
+ @property
65
+ def api_key_env(self) -> str:
66
+ """Get the environment variable name for the API key."""
67
+ return self.config.api_key_env
68
+
69
+ @property
70
+ def base_url(self) -> str:
71
+ """Get the base URL for the API."""
72
+ return self.config.base_url
73
+
74
+ @property
75
+ def timeout(self) -> int:
76
+ """Get the timeout in seconds."""
77
+ return self.config.timeout
78
+
79
+ def _get_api_key(self) -> str:
80
+ """Get API key from environment variables."""
81
+ api_key = os.getenv(self.config.api_key_env)
82
+ if not api_key:
83
+ raise AIError.authentication_error(f"{self.config.api_key_env} not found in environment variables")
84
+ return api_key
85
+
86
+ @abstractmethod
87
+ def _build_request_body(
88
+ self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
89
+ ) -> dict[str, Any]:
90
+ """Build the request body for the API call.
91
+
92
+ Args:
93
+ messages: List of message dictionaries
94
+ temperature: Temperature parameter
95
+ max_tokens: Maximum tokens in response
96
+ **kwargs: Additional provider-specific parameters
97
+
98
+ Returns:
99
+ Request body dictionary
100
+ """
101
+ pass
102
+
103
+ @abstractmethod
104
+ def _parse_response(self, response: dict[str, Any]) -> str:
105
+ """Parse the API response and extract content.
106
+
107
+ Args:
108
+ response: Response dictionary from API
109
+
110
+ Returns:
111
+ Generated text content
112
+ """
113
+ pass
114
+
115
+ def _build_headers(self) -> dict[str, str]:
116
+ """Build headers for the API request.
117
+
118
+ Can be overridden by subclasses to add provider-specific headers.
119
+ """
120
+ headers = self.config.headers.copy() if self.config.headers else {}
121
+ return headers
122
+
123
+ def _get_api_url(self, model: str | None = None) -> str:
124
+ """Get the API URL for the request.
125
+
126
+ Can be overridden by subclasses for dynamic URLs.
127
+
128
+ Args:
129
+ model: Model name (for providers that need model-specific URLs)
130
+
131
+ Returns:
132
+ API URL string
133
+ """
134
+ return self.config.base_url
135
+
136
+ def _make_http_request(self, url: str, body: dict[str, Any], headers: dict[str, str]) -> dict[str, Any]:
137
+ """Make the HTTP request.
138
+
139
+ Error handling is delegated to the @handle_provider_errors decorator
140
+ which wraps the provider's API function. This avoids duplicate exception
141
+ handling and ensures consistent error classification across all providers.
142
+
143
+ Args:
144
+ url: API URL
145
+ body: Request body
146
+ headers: Request headers
147
+
148
+ Returns:
149
+ Response JSON dictionary
150
+
151
+ Raises:
152
+ httpx.HTTPStatusError: For HTTP errors (handled by decorator)
153
+ httpx.TimeoutException: For timeout errors (handled by decorator)
154
+ httpx.RequestError: For network errors (handled by decorator)
155
+ """
156
+ response = httpx.post(url, json=body, headers=headers, timeout=self.config.timeout, verify=get_ssl_verify())
157
+ response.raise_for_status()
158
+ return response.json()
159
+
160
+ def generate(
161
+ self,
162
+ model: str,
163
+ messages: list[dict[str, Any]],
164
+ temperature: float = 0.7,
165
+ max_tokens: int = 1024,
166
+ **kwargs: Any,
167
+ ) -> str:
168
+ """Generate text using the AI provider.
169
+
170
+ Error handling is delegated to the @handle_provider_errors decorator
171
+ which wraps the provider's API function. This ensures consistent error
172
+ classification across all providers.
173
+
174
+ Args:
175
+ model: Model name to use
176
+ messages: List of message dictionaries
177
+ temperature: Temperature parameter (0.0-2.0)
178
+ max_tokens: Maximum tokens in response
179
+ **kwargs: Additional provider-specific parameters
180
+
181
+ Returns:
182
+ Generated text content
183
+
184
+ Raises:
185
+ AIError: For any API-related errors (via decorator)
186
+ """
187
+ logger.debug(f"Generating with {self.config.name} provider (model={model})")
188
+
189
+ # Build request components
190
+ url = self._get_api_url(model)
191
+ headers = self._build_headers()
192
+ body = self._build_request_body(messages, temperature, max_tokens, model, **kwargs)
193
+
194
+ # Add model to body if not already present
195
+ if "model" not in body:
196
+ body["model"] = model
197
+
198
+ # Make HTTP request
199
+ response_data = self._make_http_request(url, body, headers)
200
+
201
+ # Parse response
202
+ return self._parse_response(response_data)
203
+
204
+
205
+ class OpenAICompatibleProvider(BaseConfiguredProvider):
206
+ """Base class for OpenAI-compatible providers.
207
+
208
+ Handles standard OpenAI API format with minimal customization needed.
209
+ """
210
+
211
+ def _build_request_body(
212
+ self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
213
+ ) -> dict[str, Any]:
214
+ """Build OpenAI-style request body.
215
+
216
+ Note: Subclasses should override this if they need max_completion_tokens
217
+ instead of max_tokens (like OpenAI provider does).
218
+ """
219
+ return {"messages": messages, "temperature": temperature, "max_tokens": max_tokens, **kwargs}
220
+
221
+ def _build_headers(self) -> dict[str, str]:
222
+ """Build headers with OpenAI-style authorization."""
223
+ headers = super()._build_headers()
224
+ if self.api_key:
225
+ headers["Authorization"] = f"Bearer {self.api_key}"
226
+ return headers
227
+
228
+ def _parse_response(self, response: dict[str, Any]) -> str:
229
+ """Parse OpenAI-style response."""
230
+ choices = response.get("choices")
231
+ if not choices or not isinstance(choices, list):
232
+ raise AIError.model_error("Invalid response: missing choices")
233
+ content = choices[0].get("message", {}).get("content")
234
+ if content is None:
235
+ raise AIError.model_error("Invalid response: null content")
236
+ if content == "":
237
+ raise AIError.model_error("Invalid response: empty content")
238
+ return content
239
+
240
+
241
+ class AnthropicCompatibleProvider(BaseConfiguredProvider):
242
+ """Base class for Anthropic-compatible providers."""
243
+
244
+ def _build_headers(self) -> dict[str, str]:
245
+ """Build headers with Anthropic-style authorization."""
246
+ headers = super()._build_headers()
247
+ api_key = self._get_api_key()
248
+ headers["x-api-key"] = api_key
249
+ headers["anthropic-version"] = "2023-06-01"
250
+ return headers
251
+
252
+ def _build_request_body(
253
+ self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
254
+ ) -> dict[str, Any]:
255
+ """Build Anthropic-style request body."""
256
+ # Convert messages to Anthropic format
257
+ anthropic_messages = []
258
+ system_message = ""
259
+
260
+ for msg in messages:
261
+ if msg["role"] == "system":
262
+ system_message = msg["content"]
263
+ else:
264
+ anthropic_messages.append({"role": msg["role"], "content": msg["content"]})
265
+
266
+ body = {"messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens, **kwargs}
267
+
268
+ if system_message:
269
+ body["system"] = system_message
270
+
271
+ return body
272
+
273
+ def _parse_response(self, response: dict[str, Any]) -> str:
274
+ """Parse Anthropic-style response."""
275
+ content = response.get("content")
276
+ if not content or not isinstance(content, list):
277
+ raise AIError.model_error("Invalid response: missing content")
278
+
279
+ text_content = content[0].get("text")
280
+ if text_content is None:
281
+ raise AIError.model_error("Invalid response: null content")
282
+ if text_content == "":
283
+ raise AIError.model_error("Invalid response: empty content")
284
+ return text_content
285
+
286
+
287
+ class GenericHTTPProvider(BaseConfiguredProvider):
288
+ """Base class for completely custom providers."""
289
+
290
+ def _build_request_body(
291
+ self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
292
+ ) -> dict[str, Any]:
293
+ """Default implementation - override this in subclasses."""
294
+ return {"messages": messages, "temperature": temperature, "max_tokens": max_tokens, **kwargs}
295
+
296
+ def _parse_response(self, response: dict[str, Any]) -> str:
297
+ """Default implementation - override this in subclasses."""
298
+ # Try OpenAI-style first
299
+ choices = response.get("choices")
300
+ if choices and isinstance(choices, list):
301
+ content = choices[0].get("message", {}).get("content")
302
+ if content:
303
+ return content
304
+
305
+ # Try Anthropic-style
306
+ content = response.get("content")
307
+ if content and isinstance(content, list):
308
+ return content[0].get("text", "")
309
+
310
+ # Try Ollama-style
311
+ message = response.get("message", {})
312
+ if "content" in message:
313
+ return message["content"]
314
+
315
+ # Fallback - try to find any string content
316
+ for value in response.values():
317
+ if isinstance(value, str) and len(value) > 10: # Assume longer strings are content
318
+ return value
319
+
320
+ raise AIError.model_error("Could not extract content from response")
321
+
322
+
323
+ __all__ = [
324
+ "AnthropicCompatibleProvider",
325
+ "BaseConfiguredProvider",
326
+ "GenericHTTPProvider",
327
+ "OpenAICompatibleProvider",
328
+ "ProviderConfig",
329
+ ]
gac/providers/cerebras.py CHANGED
@@ -1,38 +1,15 @@
1
1
  """Cerebras AI provider implementation."""
2
2
 
3
- import os
3
+ from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
4
4
 
5
- import httpx
6
5
 
7
- from gac.errors import AIError
6
+ class CerebrasProvider(OpenAICompatibleProvider):
7
+ config = ProviderConfig(
8
+ name="Cerebras",
9
+ api_key_env="CEREBRAS_API_KEY",
10
+ base_url="https://api.cerebras.ai/v1",
11
+ )
8
12
 
9
-
10
- def call_cerebras_api(model: str, messages: list[dict], temperature: float, max_tokens: int) -> str:
11
- """Call Cerebras API directly."""
12
- api_key = os.getenv("CEREBRAS_API_KEY")
13
- if not api_key:
14
- raise AIError.authentication_error("CEREBRAS_API_KEY not found in environment variables")
15
-
16
- url = "https://api.cerebras.ai/v1/chat/completions"
17
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
18
-
19
- data = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens}
20
-
21
- try:
22
- response = httpx.post(url, headers=headers, json=data, timeout=120)
23
- response.raise_for_status()
24
- response_data = response.json()
25
- content = response_data["choices"][0]["message"]["content"]
26
- if content is None:
27
- raise AIError.model_error("Cerebras API returned null content")
28
- if content == "":
29
- raise AIError.model_error("Cerebras API returned empty content")
30
- return content
31
- except httpx.HTTPStatusError as e:
32
- if e.response.status_code == 429:
33
- raise AIError.rate_limit_error(f"Cerebras API rate limit exceeded: {e.response.text}") from e
34
- raise AIError.model_error(f"Cerebras API error: {e.response.status_code} - {e.response.text}") from e
35
- except httpx.TimeoutException as e:
36
- raise AIError.timeout_error(f"Cerebras API request timed out: {str(e)}") from e
37
- except Exception as e:
38
- raise AIError.model_error(f"Error calling Cerebras API: {str(e)}") from e
13
+ def _get_api_url(self, model: str | None = None) -> str:
14
+ """Get Cerebras API URL with /chat/completions endpoint."""
15
+ return f"{self.config.base_url}/chat/completions"
gac/providers/chutes.py CHANGED
@@ -2,70 +2,24 @@
2
2
 
3
3
  import os
4
4
 
5
- import httpx
5
+ from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
6
6
 
7
- from gac.errors import AIError
8
7
 
8
+ class ChutesProvider(OpenAICompatibleProvider):
9
+ """Chutes.ai OpenAI-compatible provider with custom base URL."""
9
10
 
10
- def call_chutes_api(model: str, messages: list[dict], temperature: float, max_tokens: int) -> str:
11
- """Call Chutes.ai API directly.
11
+ config = ProviderConfig(
12
+ name="Chutes",
13
+ api_key_env="CHUTES_API_KEY",
14
+ base_url="", # Will be set in __init__
15
+ )
12
16
 
13
- Chutes.ai provides an OpenAI-compatible API for serverless, decentralized AI compute.
17
+ def __init__(self, config: ProviderConfig):
18
+ """Initialize with base URL from environment or default."""
19
+ base_url = os.getenv("CHUTES_BASE_URL", "https://llm.chutes.ai")
20
+ config.base_url = f"{base_url.rstrip('/')}/v1"
21
+ super().__init__(config)
14
22
 
15
- Args:
16
- model: The model to use (e.g., 'deepseek-ai/DeepSeek-V3-0324')
17
- messages: List of message dictionaries with 'role' and 'content' keys
18
- temperature: Controls randomness (0.0-1.0)
19
- max_tokens: Maximum tokens in the response
20
-
21
- Returns:
22
- The generated commit message
23
-
24
- Raises:
25
- AIError: If authentication fails, API errors occur, or response is invalid
26
- """
27
- api_key = os.getenv("CHUTES_API_KEY")
28
- if not api_key:
29
- raise AIError.authentication_error("CHUTES_API_KEY environment variable not set")
30
-
31
- base_url = os.getenv("CHUTES_BASE_URL", "https://llm.chutes.ai")
32
- url = f"{base_url}/v1/chat/completions"
33
-
34
- headers = {
35
- "Content-Type": "application/json",
36
- "Authorization": f"Bearer {api_key}",
37
- }
38
-
39
- data = {
40
- "model": model,
41
- "messages": messages,
42
- "temperature": temperature,
43
- "max_tokens": max_tokens,
44
- }
45
-
46
- try:
47
- response = httpx.post(url, headers=headers, json=data, timeout=120)
48
- response.raise_for_status()
49
- response_data = response.json()
50
- content = response_data["choices"][0]["message"]["content"]
51
- if content is None:
52
- raise AIError.model_error("Chutes.ai API returned null content")
53
- if content == "":
54
- raise AIError.model_error("Chutes.ai API returned empty content")
55
- return content
56
- except httpx.HTTPStatusError as e:
57
- status_code = e.response.status_code
58
- error_text = e.response.text
59
-
60
- if status_code == 429:
61
- raise AIError.rate_limit_error(f"Chutes.ai API rate limit exceeded: {error_text}") from e
62
- elif status_code in (502, 503):
63
- raise AIError.connection_error(f"Chutes.ai API service unavailable: {status_code} - {error_text}") from e
64
- else:
65
- raise AIError.model_error(f"Chutes.ai API error: {status_code} - {error_text}") from e
66
- except httpx.ConnectError as e:
67
- raise AIError.connection_error(f"Chutes.ai API connection error: {str(e)}") from e
68
- except httpx.TimeoutException as e:
69
- raise AIError.timeout_error(f"Chutes.ai API request timed out: {str(e)}") from e
70
- except Exception as e:
71
- raise AIError.model_error(f"Error calling Chutes.ai API: {str(e)}") from e
23
+ def _get_api_url(self, model: str | None = None) -> str:
24
+ """Get Chutes API URL with /chat/completions endpoint."""
25
+ return f"{self.config.base_url}/chat/completions"