gac 3.6.0__py3-none-any.whl → 3.10.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gac/__init__.py +4 -6
- gac/__version__.py +1 -1
- gac/ai_utils.py +59 -43
- gac/auth_cli.py +181 -36
- gac/cli.py +26 -9
- gac/commit_executor.py +59 -0
- gac/config.py +81 -2
- gac/config_cli.py +19 -7
- gac/constants/__init__.py +34 -0
- gac/constants/commit.py +63 -0
- gac/constants/defaults.py +40 -0
- gac/constants/file_patterns.py +110 -0
- gac/constants/languages.py +119 -0
- gac/diff_cli.py +0 -22
- gac/errors.py +8 -2
- gac/git.py +6 -6
- gac/git_state_validator.py +193 -0
- gac/grouped_commit_workflow.py +458 -0
- gac/init_cli.py +2 -1
- gac/interactive_mode.py +179 -0
- gac/language_cli.py +0 -1
- gac/main.py +231 -926
- gac/model_cli.py +67 -11
- gac/model_identifier.py +70 -0
- gac/oauth/__init__.py +26 -0
- gac/oauth/claude_code.py +89 -22
- gac/oauth/qwen_oauth.py +327 -0
- gac/oauth/token_store.py +81 -0
- gac/oauth_retry.py +161 -0
- gac/postprocess.py +155 -0
- gac/prompt.py +21 -479
- gac/prompt_builder.py +88 -0
- gac/providers/README.md +437 -0
- gac/providers/__init__.py +70 -78
- gac/providers/anthropic.py +12 -46
- gac/providers/azure_openai.py +48 -88
- gac/providers/base.py +329 -0
- gac/providers/cerebras.py +10 -33
- gac/providers/chutes.py +16 -62
- gac/providers/claude_code.py +64 -87
- gac/providers/custom_anthropic.py +51 -81
- gac/providers/custom_openai.py +29 -83
- gac/providers/deepseek.py +10 -33
- gac/providers/error_handler.py +139 -0
- gac/providers/fireworks.py +10 -33
- gac/providers/gemini.py +66 -63
- gac/providers/groq.py +10 -58
- gac/providers/kimi_coding.py +19 -55
- gac/providers/lmstudio.py +64 -43
- gac/providers/minimax.py +10 -33
- gac/providers/mistral.py +10 -33
- gac/providers/moonshot.py +10 -33
- gac/providers/ollama.py +56 -33
- gac/providers/openai.py +30 -36
- gac/providers/openrouter.py +15 -52
- gac/providers/protocol.py +71 -0
- gac/providers/qwen.py +64 -0
- gac/providers/registry.py +58 -0
- gac/providers/replicate.py +140 -82
- gac/providers/streamlake.py +26 -46
- gac/providers/synthetic.py +35 -37
- gac/providers/together.py +10 -33
- gac/providers/zai.py +29 -57
- gac/py.typed +0 -0
- gac/security.py +1 -1
- gac/templates/__init__.py +1 -0
- gac/templates/question_generation.txt +60 -0
- gac/templates/system_prompt.txt +224 -0
- gac/templates/user_prompt.txt +28 -0
- gac/utils.py +36 -6
- gac/workflow_context.py +162 -0
- gac/workflow_utils.py +3 -8
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/METADATA +6 -4
- gac-3.10.10.dist-info/RECORD +79 -0
- gac/constants.py +0 -321
- gac-3.6.0.dist-info/RECORD +0 -53
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/WHEEL +0 -0
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/entry_points.txt +0 -0
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/licenses/LICENSE +0 -0
gac/providers/ollama.py
CHANGED
|
@@ -1,50 +1,73 @@
|
|
|
1
|
-
"""Ollama
|
|
1
|
+
"""Ollama API provider for gac."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
|
-
import
|
|
6
|
+
from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
|
|
6
7
|
|
|
7
|
-
from gac.errors import AIError
|
|
8
8
|
|
|
9
|
+
class OllamaProvider(OpenAICompatibleProvider):
|
|
10
|
+
"""Ollama provider for local LLM models with optional authentication."""
|
|
9
11
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
12
|
+
config = ProviderConfig(
|
|
13
|
+
name="Ollama",
|
|
14
|
+
api_key_env="OLLAMA_API_KEY",
|
|
15
|
+
base_url="http://localhost:11434",
|
|
16
|
+
)
|
|
14
17
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
def __init__(self, config: ProviderConfig):
|
|
19
|
+
"""Initialize with configurable URL from environment."""
|
|
20
|
+
super().__init__(config)
|
|
21
|
+
# Allow URL override via environment variable
|
|
22
|
+
api_url = os.getenv("OLLAMA_API_URL", "http://localhost:11434")
|
|
23
|
+
self.config.base_url = api_url.rstrip("/")
|
|
20
24
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
+
def _build_headers(self) -> dict[str, str]:
|
|
26
|
+
"""Build headers with optional API key."""
|
|
27
|
+
headers = super()._build_headers()
|
|
28
|
+
api_key = os.getenv("OLLAMA_API_KEY")
|
|
29
|
+
if api_key:
|
|
30
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
31
|
+
return headers
|
|
32
|
+
|
|
33
|
+
def _build_request_body(
|
|
34
|
+
self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
|
|
35
|
+
) -> dict[str, Any]:
|
|
36
|
+
"""Build Ollama request body with stream disabled."""
|
|
37
|
+
return {
|
|
38
|
+
"messages": messages,
|
|
39
|
+
"temperature": temperature,
|
|
40
|
+
"stream": False,
|
|
41
|
+
**kwargs,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
def _get_api_url(self, model: str | None = None) -> str:
|
|
45
|
+
"""Get API URL with /api/chat endpoint."""
|
|
46
|
+
return f"{self.config.base_url}/api/chat"
|
|
47
|
+
|
|
48
|
+
def _get_api_key(self) -> str:
|
|
49
|
+
"""Get optional API key for Ollama."""
|
|
50
|
+
api_key = os.getenv(self.config.api_key_env)
|
|
51
|
+
if not api_key:
|
|
52
|
+
return "" # Optional API key
|
|
53
|
+
return api_key
|
|
54
|
+
|
|
55
|
+
def _parse_response(self, response: dict[str, Any]) -> str:
|
|
56
|
+
"""Parse Ollama response with flexible format support."""
|
|
57
|
+
from gac.errors import AIError
|
|
25
58
|
|
|
26
|
-
content = None
|
|
27
59
|
# Handle different response formats from Ollama
|
|
28
|
-
if "message" in
|
|
29
|
-
content =
|
|
30
|
-
elif "response" in
|
|
31
|
-
content =
|
|
60
|
+
if "message" in response and "content" in response["message"]:
|
|
61
|
+
content = response["message"]["content"]
|
|
62
|
+
elif "response" in response:
|
|
63
|
+
content = response["response"]
|
|
32
64
|
else:
|
|
33
|
-
# Fallback:
|
|
34
|
-
content = str(
|
|
65
|
+
# Fallback: try to serialize response
|
|
66
|
+
content = str(response) if response else ""
|
|
35
67
|
|
|
36
68
|
if content is None:
|
|
37
69
|
raise AIError.model_error("Ollama API returned null content")
|
|
38
70
|
if content == "":
|
|
39
71
|
raise AIError.model_error("Ollama API returned empty content")
|
|
72
|
+
|
|
40
73
|
return content
|
|
41
|
-
except httpx.ConnectError as e:
|
|
42
|
-
raise AIError.connection_error(f"Ollama connection failed. Make sure Ollama is running: {str(e)}") from e
|
|
43
|
-
except httpx.HTTPStatusError as e:
|
|
44
|
-
if e.response.status_code == 429:
|
|
45
|
-
raise AIError.rate_limit_error(f"Ollama API rate limit exceeded: {e.response.text}") from e
|
|
46
|
-
raise AIError.model_error(f"Ollama API error: {e.response.status_code} - {e.response.text}") from e
|
|
47
|
-
except httpx.TimeoutException as e:
|
|
48
|
-
raise AIError.timeout_error(f"Ollama API request timed out: {str(e)}") from e
|
|
49
|
-
except Exception as e:
|
|
50
|
-
raise AIError.model_error(f"Error calling Ollama API: {str(e)}") from e
|
gac/providers/openai.py
CHANGED
|
@@ -1,38 +1,32 @@
|
|
|
1
1
|
"""OpenAI API provider for gac."""
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
raise AIError.rate_limit_error(f"OpenAI API rate limit exceeded: {e.response.text}") from e
|
|
34
|
-
raise AIError.model_error(f"OpenAI API error: {e.response.status_code} - {e.response.text}") from e
|
|
35
|
-
except httpx.TimeoutException as e:
|
|
36
|
-
raise AIError.timeout_error(f"OpenAI API request timed out: {str(e)}") from e
|
|
37
|
-
except Exception as e:
|
|
38
|
-
raise AIError.model_error(f"Error calling OpenAI API: {str(e)}") from e
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class OpenAIProvider(OpenAICompatibleProvider):
|
|
9
|
+
"""OpenAI API provider with model-specific adjustments."""
|
|
10
|
+
|
|
11
|
+
config = ProviderConfig(name="OpenAI", api_key_env="OPENAI_API_KEY", base_url="https://api.openai.com/v1")
|
|
12
|
+
|
|
13
|
+
def _get_api_url(self, model: str | None = None) -> str:
|
|
14
|
+
"""Get OpenAI API URL with /chat/completions endpoint."""
|
|
15
|
+
return f"{self.config.base_url}/chat/completions"
|
|
16
|
+
|
|
17
|
+
def _build_request_body(
|
|
18
|
+
self, messages: list[dict[str, Any]], temperature: float, max_tokens: int, model: str, **kwargs: Any
|
|
19
|
+
) -> dict[str, Any]:
|
|
20
|
+
"""Build OpenAI-specific request body."""
|
|
21
|
+
data = super()._build_request_body(messages, temperature, max_tokens, model, **kwargs)
|
|
22
|
+
|
|
23
|
+
# OpenAI uses max_completion_tokens instead of max_tokens
|
|
24
|
+
data["max_completion_tokens"] = data.pop("max_tokens")
|
|
25
|
+
|
|
26
|
+
# Handle optional parameters
|
|
27
|
+
if "response_format" in kwargs:
|
|
28
|
+
data["response_format"] = kwargs["response_format"]
|
|
29
|
+
if "stop" in kwargs:
|
|
30
|
+
data["stop"] = kwargs["stop"]
|
|
31
|
+
|
|
32
|
+
return data
|
gac/providers/openrouter.py
CHANGED
|
@@ -1,58 +1,21 @@
|
|
|
1
1
|
"""OpenRouter API provider for gac."""
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
|
|
4
4
|
|
|
5
|
-
import httpx
|
|
6
5
|
|
|
7
|
-
|
|
6
|
+
class OpenRouterProvider(OpenAICompatibleProvider):
|
|
7
|
+
config = ProviderConfig(
|
|
8
|
+
name="OpenRouter",
|
|
9
|
+
api_key_env="OPENROUTER_API_KEY",
|
|
10
|
+
base_url="https://openrouter.ai/api/v1",
|
|
11
|
+
)
|
|
8
12
|
|
|
13
|
+
def _get_api_url(self, model: str | None = None) -> str:
|
|
14
|
+
"""Get OpenRouter API URL with /chat/completions endpoint."""
|
|
15
|
+
return f"{self.config.base_url}/chat/completions"
|
|
9
16
|
|
|
10
|
-
def
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
url = "https://openrouter.ai/api/v1/chat/completions"
|
|
17
|
-
headers = {
|
|
18
|
-
"Content-Type": "application/json",
|
|
19
|
-
"Authorization": f"Bearer {api_key}",
|
|
20
|
-
}
|
|
21
|
-
|
|
22
|
-
data = {
|
|
23
|
-
"model": model,
|
|
24
|
-
"messages": messages,
|
|
25
|
-
"temperature": temperature,
|
|
26
|
-
"max_tokens": max_tokens,
|
|
27
|
-
}
|
|
28
|
-
|
|
29
|
-
try:
|
|
30
|
-
response = httpx.post(url, headers=headers, json=data, timeout=120)
|
|
31
|
-
response.raise_for_status()
|
|
32
|
-
response_data = response.json()
|
|
33
|
-
content = response_data["choices"][0]["message"]["content"]
|
|
34
|
-
if content is None:
|
|
35
|
-
raise AIError.model_error("OpenRouter API returned null content")
|
|
36
|
-
if content == "":
|
|
37
|
-
raise AIError.model_error("OpenRouter API returned empty content")
|
|
38
|
-
return content
|
|
39
|
-
except httpx.HTTPStatusError as e:
|
|
40
|
-
# Handle specific HTTP status codes
|
|
41
|
-
status_code = e.response.status_code
|
|
42
|
-
error_text = e.response.text
|
|
43
|
-
|
|
44
|
-
# Rate limiting
|
|
45
|
-
if status_code == 429:
|
|
46
|
-
raise AIError.rate_limit_error(f"OpenRouter API rate limit exceeded: {error_text}") from e
|
|
47
|
-
# Service unavailable
|
|
48
|
-
elif status_code in (502, 503):
|
|
49
|
-
raise AIError.connection_error(f"OpenRouter API service unavailable: {status_code} - {error_text}") from e
|
|
50
|
-
# Other HTTP errors
|
|
51
|
-
else:
|
|
52
|
-
raise AIError.model_error(f"OpenRouter API error: {status_code} - {error_text}") from e
|
|
53
|
-
except httpx.ConnectError as e:
|
|
54
|
-
raise AIError.connection_error(f"OpenRouter API connection error: {str(e)}") from e
|
|
55
|
-
except httpx.TimeoutException as e:
|
|
56
|
-
raise AIError.timeout_error(f"OpenRouter API request timed out: {str(e)}") from e
|
|
57
|
-
except Exception as e:
|
|
58
|
-
raise AIError.model_error(f"Error calling OpenRouter API: {str(e)}") from e
|
|
17
|
+
def _build_headers(self) -> dict[str, str]:
|
|
18
|
+
"""Build headers with OpenRouter-style authorization and HTTP-Referer."""
|
|
19
|
+
headers = super()._build_headers()
|
|
20
|
+
headers["HTTP-Referer"] = "https://github.com/codeindolence/gac"
|
|
21
|
+
return headers
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Provider protocol for type-safe AI provider implementations."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@runtime_checkable
|
|
7
|
+
class ProviderProtocol(Protocol):
|
|
8
|
+
"""Protocol defining the contract for AI providers.
|
|
9
|
+
|
|
10
|
+
All providers must implement this protocol to ensure consistent
|
|
11
|
+
interface and type safety across the codebase.
|
|
12
|
+
|
|
13
|
+
This protocol supports both class-based providers (with methods)
|
|
14
|
+
and function-based providers (used in the registry).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def generate(
|
|
18
|
+
self, model: str, messages: list[dict[str, Any]], temperature: float, max_tokens: int, **kwargs: Any
|
|
19
|
+
) -> str:
|
|
20
|
+
"""Generate text using the AI model.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model: The model name to use
|
|
24
|
+
messages: List of message dictionaries in chat format
|
|
25
|
+
temperature: Temperature parameter (0.0-2.0)
|
|
26
|
+
max_tokens: Maximum tokens in response
|
|
27
|
+
**kwargs: Additional provider-specific parameters
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Generated text content
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
AIError: For any generation-related errors
|
|
34
|
+
"""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def name(self) -> str:
|
|
39
|
+
"""Get the provider name.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Provider name identifier
|
|
43
|
+
"""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def api_key_env(self) -> str:
|
|
48
|
+
"""Get the environment variable name for the API key.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Environment variable name
|
|
52
|
+
"""
|
|
53
|
+
...
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def base_url(self) -> str:
|
|
57
|
+
"""Get the base URL for the API.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Base API URL
|
|
61
|
+
"""
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def timeout(self) -> int:
|
|
66
|
+
"""Get the timeout in seconds.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Timeout in seconds
|
|
70
|
+
"""
|
|
71
|
+
...
|
gac/providers/qwen.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Qwen API provider for gac with OAuth-only support."""
|
|
2
|
+
|
|
3
|
+
from gac.errors import AIError
|
|
4
|
+
from gac.oauth import QwenOAuthProvider, TokenStore
|
|
5
|
+
from gac.providers.base import OpenAICompatibleProvider, ProviderConfig
|
|
6
|
+
|
|
7
|
+
QWEN_DEFAULT_API_URL = "https://chat.qwen.ai/api/v1"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class QwenProvider(OpenAICompatibleProvider):
|
|
11
|
+
"""Qwen provider with OAuth-only authentication."""
|
|
12
|
+
|
|
13
|
+
config = ProviderConfig(
|
|
14
|
+
name="Qwen",
|
|
15
|
+
api_key_env="",
|
|
16
|
+
base_url=QWEN_DEFAULT_API_URL,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def __init__(self, config: ProviderConfig):
|
|
20
|
+
"""Initialize with OAuth authentication."""
|
|
21
|
+
super().__init__(config)
|
|
22
|
+
self._auth_token, self._resolved_base_url = self._get_oauth_token()
|
|
23
|
+
|
|
24
|
+
def _get_api_key(self) -> str:
|
|
25
|
+
"""Return placeholder for parent class compatibility (OAuth is used instead)."""
|
|
26
|
+
return "oauth-token"
|
|
27
|
+
|
|
28
|
+
def _get_oauth_token(self) -> tuple[str, str]:
|
|
29
|
+
"""Get Qwen OAuth token from token store.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Tuple of (access_token, api_url) for authentication.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
AIError: If no OAuth token is found.
|
|
36
|
+
"""
|
|
37
|
+
oauth_provider = QwenOAuthProvider(TokenStore())
|
|
38
|
+
token = oauth_provider.get_token()
|
|
39
|
+
if token:
|
|
40
|
+
resource_url = token.get("resource_url")
|
|
41
|
+
if resource_url:
|
|
42
|
+
if not resource_url.startswith(("http://", "https://")):
|
|
43
|
+
resource_url = f"https://{resource_url}"
|
|
44
|
+
if not resource_url.endswith("/v1"):
|
|
45
|
+
resource_url = resource_url.rstrip("/") + "/v1"
|
|
46
|
+
base_url = resource_url
|
|
47
|
+
else:
|
|
48
|
+
base_url = QWEN_DEFAULT_API_URL
|
|
49
|
+
return token["access_token"], base_url
|
|
50
|
+
|
|
51
|
+
raise AIError.authentication_error("Qwen OAuth token not found. Run 'gac auth qwen login' to authenticate.")
|
|
52
|
+
|
|
53
|
+
def _build_headers(self) -> dict[str, str]:
|
|
54
|
+
"""Build headers with OAuth token."""
|
|
55
|
+
headers = super()._build_headers()
|
|
56
|
+
# Replace Bearer token with the stored auth token
|
|
57
|
+
if "Authorization" in headers:
|
|
58
|
+
del headers["Authorization"]
|
|
59
|
+
headers["Authorization"] = f"Bearer {self._auth_token}"
|
|
60
|
+
return headers
|
|
61
|
+
|
|
62
|
+
def _get_api_url(self, model: str | None = None) -> str:
|
|
63
|
+
"""Get Qwen API URL with /chat/completions endpoint."""
|
|
64
|
+
return f"{self._resolved_base_url}/chat/completions"
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Provider registry for AI providers."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from gac.providers.base import BaseConfiguredProvider
|
|
9
|
+
|
|
10
|
+
# Global registry for provider functions
|
|
11
|
+
PROVIDER_REGISTRY: dict[str, Callable[..., str]] = {}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def create_provider_func(provider_class: type["BaseConfiguredProvider"]) -> Callable[..., str]:
|
|
15
|
+
"""Create a provider function from a provider class.
|
|
16
|
+
|
|
17
|
+
This function creates a callable that:
|
|
18
|
+
1. Instantiates the provider class
|
|
19
|
+
2. Calls generate() with the provided arguments
|
|
20
|
+
3. Is wrapped with @handle_provider_errors for consistent error handling
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
provider_class: A provider class with a `config` class attribute
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
A callable function that can be used to generate text
|
|
27
|
+
"""
|
|
28
|
+
from gac.providers.error_handler import handle_provider_errors
|
|
29
|
+
|
|
30
|
+
provider_name = provider_class.config.name
|
|
31
|
+
|
|
32
|
+
@handle_provider_errors(provider_name)
|
|
33
|
+
@wraps(provider_class.generate)
|
|
34
|
+
def provider_func(model: str, messages: list[dict[str, Any]], temperature: float, max_tokens: int) -> str:
|
|
35
|
+
provider = provider_class(provider_class.config)
|
|
36
|
+
return provider.generate(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens)
|
|
37
|
+
|
|
38
|
+
# Add metadata for introspection
|
|
39
|
+
provider_func.__name__ = f"call_{provider_name.lower().replace(' ', '_').replace('.', '_')}_api"
|
|
40
|
+
provider_func.__doc__ = f"Call {provider_name} API to generate text."
|
|
41
|
+
|
|
42
|
+
return provider_func
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def register_provider(name: str, provider_class: type["BaseConfiguredProvider"]) -> None:
|
|
46
|
+
"""Register a provider class and auto-generate its function.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
name: Provider name (e.g., "openai", "anthropic")
|
|
50
|
+
provider_class: The provider class to register
|
|
51
|
+
"""
|
|
52
|
+
PROVIDER_REGISTRY[name] = create_provider_func(provider_class)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
__all__ = [
|
|
56
|
+
"PROVIDER_REGISTRY",
|
|
57
|
+
"register_provider",
|
|
58
|
+
]
|