gac 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of gac might be problematic. Click here for more details.
- gac/__init__.py +6 -6
- gac/__version__.py +1 -1
- gac/ai.py +43 -34
- gac/ai_utils.py +113 -62
- gac/errors.py +5 -0
- gac/providers/__init__.py +17 -1
- gac/providers/anthropic.py +32 -131
- gac/providers/cerebras.py +19 -124
- gac/providers/groq.py +43 -119
- gac/providers/ollama.py +27 -127
- gac/providers/openai.py +18 -123
- gac/providers/openrouter.py +19 -98
- {gac-1.1.0.dist-info → gac-1.2.1.dist-info}/METADATA +1 -1
- gac-1.2.1.dist-info/RECORD +28 -0
- gac-1.1.0.dist-info/RECORD +0 -28
- {gac-1.1.0.dist-info → gac-1.2.1.dist-info}/WHEEL +0 -0
- {gac-1.1.0.dist-info → gac-1.2.1.dist-info}/entry_points.txt +0 -0
- {gac-1.1.0.dist-info → gac-1.2.1.dist-info}/licenses/LICENSE +0 -0
gac/providers/cerebras.py
CHANGED
|
@@ -1,134 +1,29 @@
|
|
|
1
|
-
"""Cerebras
|
|
1
|
+
"""Cerebras AI provider implementation."""
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
import os
|
|
5
|
-
import time
|
|
6
4
|
|
|
7
5
|
import httpx
|
|
8
|
-
from halo import Halo
|
|
9
6
|
|
|
10
|
-
from gac.ai_utils import _classify_error
|
|
11
|
-
from gac.constants import EnvDefaults
|
|
12
7
|
from gac.errors import AIError
|
|
13
8
|
|
|
14
|
-
logger = logging.getLogger(__name__)
|
|
15
9
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
model: str,
|
|
19
|
-
prompt: str | tuple[str, str],
|
|
20
|
-
temperature: float = EnvDefaults.TEMPERATURE,
|
|
21
|
-
max_tokens: int = EnvDefaults.MAX_OUTPUT_TOKENS,
|
|
22
|
-
max_retries: int = EnvDefaults.MAX_RETRIES,
|
|
23
|
-
quiet: bool = False,
|
|
24
|
-
) -> str:
|
|
25
|
-
"""Generate commit message using Cerebras API with retry logic.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
model: The model name (e.g., 'llama3.1-8b', 'llama3.1-70b')
|
|
29
|
-
prompt: Either a string prompt or tuple of (system_prompt, user_prompt)
|
|
30
|
-
temperature: Controls randomness (0.0-1.0)
|
|
31
|
-
max_tokens: Maximum tokens in the response
|
|
32
|
-
max_retries: Number of retry attempts if generation fails
|
|
33
|
-
quiet: If True, suppress progress indicators
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
A formatted commit message string
|
|
37
|
-
|
|
38
|
-
Raises:
|
|
39
|
-
AIError: If generation fails after max_retries attempts
|
|
40
|
-
"""
|
|
10
|
+
def call_cerebras_api(model: str, messages: list[dict], temperature: float, max_tokens: int) -> str:
|
|
11
|
+
"""Call Cerebras API directly."""
|
|
41
12
|
api_key = os.getenv("CEREBRAS_API_KEY")
|
|
42
13
|
if not api_key:
|
|
43
|
-
raise AIError.model_error("CEREBRAS_API_KEY
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
"
|
|
55
|
-
|
|
56
|
-
"
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
headers = {
|
|
61
|
-
"Content-Type": "application/json",
|
|
62
|
-
"Authorization": f"Bearer {api_key}",
|
|
63
|
-
}
|
|
64
|
-
|
|
65
|
-
return _make_request_with_retry(
|
|
66
|
-
url="https://api.cerebras.ai/v1/chat/completions",
|
|
67
|
-
headers=headers,
|
|
68
|
-
payload=payload,
|
|
69
|
-
provider_name=f"Cerebras {model}",
|
|
70
|
-
max_retries=max_retries,
|
|
71
|
-
quiet=quiet,
|
|
72
|
-
response_parser=lambda r: r["choices"][0]["message"]["content"],
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _make_request_with_retry(
|
|
77
|
-
url: str,
|
|
78
|
-
headers: dict,
|
|
79
|
-
payload: dict,
|
|
80
|
-
provider_name: str,
|
|
81
|
-
max_retries: int,
|
|
82
|
-
quiet: bool,
|
|
83
|
-
response_parser: callable,
|
|
84
|
-
) -> str:
|
|
85
|
-
"""Make HTTP request with retry logic and common error handling."""
|
|
86
|
-
if quiet:
|
|
87
|
-
spinner = None
|
|
88
|
-
else:
|
|
89
|
-
spinner = Halo(text=f"Generating commit message with {provider_name}...", spinner="dots")
|
|
90
|
-
spinner.start()
|
|
91
|
-
|
|
92
|
-
last_error = None
|
|
93
|
-
retry_count = 0
|
|
94
|
-
|
|
95
|
-
while retry_count < max_retries:
|
|
96
|
-
try:
|
|
97
|
-
logger.debug(f"Trying with {provider_name} (attempt {retry_count + 1}/{max_retries})")
|
|
98
|
-
|
|
99
|
-
with httpx.Client(timeout=30.0) as client:
|
|
100
|
-
response = client.post(url, headers=headers, json=payload)
|
|
101
|
-
response.raise_for_status()
|
|
102
|
-
|
|
103
|
-
response_data = response.json()
|
|
104
|
-
message = response_parser(response_data)
|
|
105
|
-
|
|
106
|
-
if spinner:
|
|
107
|
-
spinner.succeed(f"Generated commit message with {provider_name}")
|
|
108
|
-
|
|
109
|
-
return message
|
|
110
|
-
|
|
111
|
-
except Exception as e:
|
|
112
|
-
last_error = e
|
|
113
|
-
retry_count += 1
|
|
114
|
-
|
|
115
|
-
if retry_count == max_retries:
|
|
116
|
-
logger.warning(f"Error generating commit message: {e}. Giving up.")
|
|
117
|
-
break
|
|
118
|
-
|
|
119
|
-
wait_time = 2**retry_count
|
|
120
|
-
logger.warning(f"Error generating commit message: {e}. Retrying in {wait_time}s...")
|
|
121
|
-
if spinner:
|
|
122
|
-
for i in range(wait_time, 0, -1):
|
|
123
|
-
spinner.text = f"Retry {retry_count}/{max_retries} in {i}s..."
|
|
124
|
-
time.sleep(1)
|
|
125
|
-
else:
|
|
126
|
-
time.sleep(wait_time)
|
|
127
|
-
|
|
128
|
-
if spinner:
|
|
129
|
-
spinner.fail(f"Failed to generate commit message with {provider_name}")
|
|
130
|
-
|
|
131
|
-
error_type = _classify_error(str(last_error))
|
|
132
|
-
raise AIError(
|
|
133
|
-
f"Failed to generate commit message after {max_retries} attempts: {last_error}", error_type=error_type
|
|
134
|
-
)
|
|
14
|
+
raise AIError.model_error("CEREBRAS_API_KEY not found in environment variables")
|
|
15
|
+
|
|
16
|
+
url = "https://api.cerebras.ai/v1/chat/completions"
|
|
17
|
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
18
|
+
|
|
19
|
+
data = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens}
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
response = httpx.post(url, headers=headers, json=data, timeout=120)
|
|
23
|
+
response.raise_for_status()
|
|
24
|
+
response_data = response.json()
|
|
25
|
+
return response_data["choices"][0]["message"]["content"]
|
|
26
|
+
except httpx.HTTPStatusError as e:
|
|
27
|
+
raise AIError.model_error(f"Cerebras API error: {e.response.status_code} - {e.response.text}") from e
|
|
28
|
+
except Exception as e:
|
|
29
|
+
raise AIError.model_error(f"Error calling Cerebras API: {str(e)}") from e
|
gac/providers/groq.py
CHANGED
|
@@ -2,133 +2,57 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
|
-
import time
|
|
6
5
|
|
|
7
6
|
import httpx
|
|
8
|
-
from halo import Halo
|
|
9
7
|
|
|
10
|
-
from gac.ai_utils import _classify_error
|
|
11
|
-
from gac.constants import EnvDefaults
|
|
12
8
|
from gac.errors import AIError
|
|
13
9
|
|
|
14
10
|
logger = logging.getLogger(__name__)
|
|
15
11
|
|
|
16
12
|
|
|
17
|
-
def
|
|
18
|
-
|
|
19
|
-
prompt: str | tuple[str, str],
|
|
20
|
-
temperature: float = EnvDefaults.TEMPERATURE,
|
|
21
|
-
max_tokens: int = EnvDefaults.MAX_OUTPUT_TOKENS,
|
|
22
|
-
max_retries: int = EnvDefaults.MAX_RETRIES,
|
|
23
|
-
quiet: bool = False,
|
|
24
|
-
) -> str:
|
|
25
|
-
"""Generate commit message using Groq API with retry logic.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
model: The model name (e.g., 'llama3-8b-8192', 'llama3-70b-8192')
|
|
29
|
-
prompt: Either a string prompt or tuple of (system_prompt, user_prompt)
|
|
30
|
-
temperature: Controls randomness (0.0-1.0)
|
|
31
|
-
max_tokens: Maximum tokens in the response
|
|
32
|
-
max_retries: Number of retry attempts if generation fails
|
|
33
|
-
quiet: If True, suppress progress indicators
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
A formatted commit message string
|
|
37
|
-
|
|
38
|
-
Raises:
|
|
39
|
-
AIError: If generation fails after max_retries attempts
|
|
40
|
-
"""
|
|
13
|
+
def call_groq_api(model: str, messages: list[dict], temperature: float, max_tokens: int) -> str:
|
|
14
|
+
"""Call Groq API directly."""
|
|
41
15
|
api_key = os.getenv("GROQ_API_KEY")
|
|
42
16
|
if not api_key:
|
|
43
|
-
raise AIError.model_error("GROQ_API_KEY
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
"
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _make_request_with_retry(
|
|
77
|
-
url: str,
|
|
78
|
-
headers: dict,
|
|
79
|
-
payload: dict,
|
|
80
|
-
provider_name: str,
|
|
81
|
-
max_retries: int,
|
|
82
|
-
quiet: bool,
|
|
83
|
-
response_parser: callable,
|
|
84
|
-
) -> str:
|
|
85
|
-
"""Make HTTP request with retry logic and common error handling."""
|
|
86
|
-
if quiet:
|
|
87
|
-
spinner = None
|
|
88
|
-
else:
|
|
89
|
-
spinner = Halo(text=f"Generating commit message with {provider_name}...", spinner="dots")
|
|
90
|
-
spinner.start()
|
|
91
|
-
|
|
92
|
-
last_error = None
|
|
93
|
-
retry_count = 0
|
|
94
|
-
|
|
95
|
-
while retry_count < max_retries:
|
|
96
|
-
try:
|
|
97
|
-
logger.debug(f"Trying with {provider_name} (attempt {retry_count + 1}/{max_retries})")
|
|
98
|
-
|
|
99
|
-
with httpx.Client(timeout=30.0) as client:
|
|
100
|
-
response = client.post(url, headers=headers, json=payload)
|
|
101
|
-
response.raise_for_status()
|
|
102
|
-
|
|
103
|
-
response_data = response.json()
|
|
104
|
-
message = response_parser(response_data)
|
|
105
|
-
|
|
106
|
-
if spinner:
|
|
107
|
-
spinner.succeed(f"Generated commit message with {provider_name}")
|
|
108
|
-
|
|
109
|
-
return message
|
|
110
|
-
|
|
111
|
-
except Exception as e:
|
|
112
|
-
last_error = e
|
|
113
|
-
retry_count += 1
|
|
114
|
-
|
|
115
|
-
if retry_count == max_retries:
|
|
116
|
-
logger.warning(f"Error generating commit message: {e}. Giving up.")
|
|
117
|
-
break
|
|
118
|
-
|
|
119
|
-
wait_time = 2**retry_count
|
|
120
|
-
logger.warning(f"Error generating commit message: {e}. Retrying in {wait_time}s...")
|
|
121
|
-
if spinner:
|
|
122
|
-
for i in range(wait_time, 0, -1):
|
|
123
|
-
spinner.text = f"Retry {retry_count}/{max_retries} in {i}s..."
|
|
124
|
-
time.sleep(1)
|
|
17
|
+
raise AIError.model_error("GROQ_API_KEY not found in environment variables")
|
|
18
|
+
|
|
19
|
+
url = "https://api.groq.com/openai/v1/chat/completions"
|
|
20
|
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
21
|
+
|
|
22
|
+
data = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens}
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
response = httpx.post(url, headers=headers, json=data, timeout=120)
|
|
26
|
+
response.raise_for_status()
|
|
27
|
+
response_data = response.json()
|
|
28
|
+
|
|
29
|
+
# Debug logging to understand response structure
|
|
30
|
+
logger.debug(f"Groq API response: {response_data}")
|
|
31
|
+
|
|
32
|
+
# Handle different response formats
|
|
33
|
+
if "choices" in response_data and len(response_data["choices"]) > 0:
|
|
34
|
+
choice = response_data["choices"][0]
|
|
35
|
+
if "message" in choice and "content" in choice["message"]:
|
|
36
|
+
content = choice["message"]["content"]
|
|
37
|
+
logger.debug(f"Found content in message.content: {repr(content)}")
|
|
38
|
+
if content is None:
|
|
39
|
+
logger.warning("Groq API returned None content in message.content")
|
|
40
|
+
return ""
|
|
41
|
+
return content
|
|
42
|
+
elif "text" in choice:
|
|
43
|
+
content = choice["text"]
|
|
44
|
+
logger.debug(f"Found content in choice.text: {repr(content)}")
|
|
45
|
+
if content is None:
|
|
46
|
+
logger.warning("Groq API returned None content in choice.text")
|
|
47
|
+
return ""
|
|
48
|
+
return content
|
|
125
49
|
else:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
50
|
+
logger.warning(f"Unexpected choice structure: {choice}")
|
|
51
|
+
|
|
52
|
+
# If we can't find content in the expected places, raise an error
|
|
53
|
+
logger.error(f"Unexpected response format from Groq API: {response_data}")
|
|
54
|
+
raise AIError.model_error(f"Unexpected response format from Groq API: {response_data}")
|
|
55
|
+
except httpx.HTTPStatusError as e:
|
|
56
|
+
raise AIError.model_error(f"Groq API error: {e.response.status_code} - {e.response.text}") from e
|
|
57
|
+
except Exception as e:
|
|
58
|
+
raise AIError.model_error(f"Error calling Groq API: {str(e)}") from e
|
gac/providers/ollama.py
CHANGED
|
@@ -1,135 +1,35 @@
|
|
|
1
|
-
"""Ollama
|
|
1
|
+
"""Ollama AI provider implementation."""
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
import os
|
|
5
|
-
import time
|
|
6
4
|
|
|
7
5
|
import httpx
|
|
8
|
-
from halo import Halo
|
|
9
6
|
|
|
10
|
-
from gac.ai_utils import _classify_error
|
|
11
|
-
from gac.constants import EnvDefaults
|
|
12
7
|
from gac.errors import AIError
|
|
13
8
|
|
|
14
|
-
logger = logging.getLogger(__name__)
|
|
15
9
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
AIError:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
if isinstance(prompt, tuple):
|
|
43
|
-
system_prompt, user_prompt = prompt
|
|
44
|
-
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
|
|
45
|
-
else:
|
|
46
|
-
# Backward compatibility: treat string as user prompt
|
|
47
|
-
messages = [{"role": "user", "content": prompt}]
|
|
48
|
-
|
|
49
|
-
payload = {
|
|
50
|
-
"model": model,
|
|
51
|
-
"messages": messages,
|
|
52
|
-
"stream": False,
|
|
53
|
-
"options": {
|
|
54
|
-
"temperature": temperature,
|
|
55
|
-
"num_predict": max_tokens,
|
|
56
|
-
},
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
headers = {
|
|
60
|
-
"Content-Type": "application/json",
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
# Ollama typically runs locally on port 11434
|
|
64
|
-
ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
|
|
65
|
-
|
|
66
|
-
return _make_request_with_retry(
|
|
67
|
-
url=f"{ollama_url}/api/chat",
|
|
68
|
-
headers=headers,
|
|
69
|
-
payload=payload,
|
|
70
|
-
provider_name=f"Ollama {model}",
|
|
71
|
-
max_retries=max_retries,
|
|
72
|
-
quiet=quiet,
|
|
73
|
-
response_parser=lambda r: r["message"]["content"],
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def _make_request_with_retry(
|
|
78
|
-
url: str,
|
|
79
|
-
headers: dict,
|
|
80
|
-
payload: dict,
|
|
81
|
-
provider_name: str,
|
|
82
|
-
max_retries: int,
|
|
83
|
-
quiet: bool,
|
|
84
|
-
response_parser: callable,
|
|
85
|
-
) -> str:
|
|
86
|
-
"""Make HTTP request with retry logic and common error handling."""
|
|
87
|
-
if quiet:
|
|
88
|
-
spinner = None
|
|
89
|
-
else:
|
|
90
|
-
spinner = Halo(text=f"Generating commit message with {provider_name}...", spinner="dots")
|
|
91
|
-
spinner.start()
|
|
92
|
-
|
|
93
|
-
last_error = None
|
|
94
|
-
retry_count = 0
|
|
95
|
-
|
|
96
|
-
while retry_count < max_retries:
|
|
97
|
-
try:
|
|
98
|
-
logger.debug(f"Trying with {provider_name} (attempt {retry_count + 1}/{max_retries})")
|
|
99
|
-
|
|
100
|
-
with httpx.Client(timeout=30.0) as client:
|
|
101
|
-
response = client.post(url, headers=headers, json=payload)
|
|
102
|
-
response.raise_for_status()
|
|
103
|
-
|
|
104
|
-
response_data = response.json()
|
|
105
|
-
message = response_parser(response_data)
|
|
106
|
-
|
|
107
|
-
if spinner:
|
|
108
|
-
spinner.succeed(f"Generated commit message with {provider_name}")
|
|
109
|
-
|
|
110
|
-
return message
|
|
111
|
-
|
|
112
|
-
except Exception as e:
|
|
113
|
-
last_error = e
|
|
114
|
-
retry_count += 1
|
|
115
|
-
|
|
116
|
-
if retry_count == max_retries:
|
|
117
|
-
logger.warning(f"Error generating commit message: {e}. Giving up.")
|
|
118
|
-
break
|
|
119
|
-
|
|
120
|
-
wait_time = 2**retry_count
|
|
121
|
-
logger.warning(f"Error generating commit message: {e}. Retrying in {wait_time}s...")
|
|
122
|
-
if spinner:
|
|
123
|
-
for i in range(wait_time, 0, -1):
|
|
124
|
-
spinner.text = f"Retry {retry_count}/{max_retries} in {i}s..."
|
|
125
|
-
time.sleep(1)
|
|
126
|
-
else:
|
|
127
|
-
time.sleep(wait_time)
|
|
128
|
-
|
|
129
|
-
if spinner:
|
|
130
|
-
spinner.fail(f"Failed to generate commit message with {provider_name}")
|
|
131
|
-
|
|
132
|
-
error_type = _classify_error(str(last_error))
|
|
133
|
-
raise AIError(
|
|
134
|
-
f"Failed to generate commit message after {max_retries} attempts: {last_error}", error_type=error_type
|
|
135
|
-
)
|
|
10
|
+
def call_ollama_api(model: str, messages: list[dict], temperature: float, max_tokens: int) -> str:
|
|
11
|
+
"""Call Ollama API directly."""
|
|
12
|
+
api_url = os.getenv("OLLAMA_API_URL", "http://localhost:11434")
|
|
13
|
+
|
|
14
|
+
url = f"{api_url.rstrip('/')}/api/chat"
|
|
15
|
+
data = {"model": model, "messages": messages, "temperature": temperature, "stream": False}
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
response = httpx.post(url, json=data, timeout=120)
|
|
19
|
+
response.raise_for_status()
|
|
20
|
+
response_data = response.json()
|
|
21
|
+
|
|
22
|
+
# Handle different response formats from Ollama
|
|
23
|
+
if "message" in response_data and "content" in response_data["message"]:
|
|
24
|
+
return response_data["message"]["content"]
|
|
25
|
+
elif "response" in response_data:
|
|
26
|
+
return response_data["response"]
|
|
27
|
+
else:
|
|
28
|
+
# Fallback: return the full response as string
|
|
29
|
+
return str(response_data)
|
|
30
|
+
except httpx.ConnectError as e:
|
|
31
|
+
raise AIError.connection_error(f"Ollama connection failed. Make sure Ollama is running: {str(e)}") from e
|
|
32
|
+
except httpx.HTTPStatusError as e:
|
|
33
|
+
raise AIError.model_error(f"Ollama API error: {e.response.status_code} - {e.response.text}") from e
|
|
34
|
+
except Exception as e:
|
|
35
|
+
raise AIError.model_error(f"Error calling Ollama API: {str(e)}") from e
|
gac/providers/openai.py
CHANGED
|
@@ -1,134 +1,29 @@
|
|
|
1
1
|
"""OpenAI API provider for gac."""
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
import os
|
|
5
|
-
import time
|
|
6
4
|
|
|
7
5
|
import httpx
|
|
8
|
-
from halo import Halo
|
|
9
6
|
|
|
10
|
-
from gac.ai_utils import _classify_error
|
|
11
|
-
from gac.constants import EnvDefaults
|
|
12
7
|
from gac.errors import AIError
|
|
13
8
|
|
|
14
|
-
logger = logging.getLogger(__name__)
|
|
15
9
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
model: str,
|
|
19
|
-
prompt: str | tuple[str, str],
|
|
20
|
-
temperature: float = EnvDefaults.TEMPERATURE,
|
|
21
|
-
max_tokens: int = EnvDefaults.MAX_OUTPUT_TOKENS,
|
|
22
|
-
max_retries: int = EnvDefaults.MAX_RETRIES,
|
|
23
|
-
quiet: bool = False,
|
|
24
|
-
) -> str:
|
|
25
|
-
"""Generate commit message using OpenAI API with retry logic.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
model: The model name (e.g., 'gpt-4', 'gpt-3.5-turbo')
|
|
29
|
-
prompt: Either a string prompt or tuple of (system_prompt, user_prompt)
|
|
30
|
-
temperature: Controls randomness (0.0-1.0)
|
|
31
|
-
max_tokens: Maximum tokens in the response
|
|
32
|
-
max_retries: Number of retry attempts if generation fails
|
|
33
|
-
quiet: If True, suppress progress indicators
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
A formatted commit message string
|
|
37
|
-
|
|
38
|
-
Raises:
|
|
39
|
-
AIError: If generation fails after max_retries attempts
|
|
40
|
-
"""
|
|
10
|
+
def call_openai_api(model: str, messages: list[dict], temperature: float, max_tokens: int) -> str:
|
|
11
|
+
"""Call OpenAI API directly."""
|
|
41
12
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
42
13
|
if not api_key:
|
|
43
|
-
raise AIError.model_error("OPENAI_API_KEY
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
"
|
|
55
|
-
|
|
56
|
-
"
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
headers = {
|
|
61
|
-
"Content-Type": "application/json",
|
|
62
|
-
"Authorization": f"Bearer {api_key}",
|
|
63
|
-
}
|
|
64
|
-
|
|
65
|
-
return _make_request_with_retry(
|
|
66
|
-
url="https://api.openai.com/v1/chat/completions",
|
|
67
|
-
headers=headers,
|
|
68
|
-
payload=payload,
|
|
69
|
-
provider_name=f"OpenAI {model}",
|
|
70
|
-
max_retries=max_retries,
|
|
71
|
-
quiet=quiet,
|
|
72
|
-
response_parser=lambda r: r["choices"][0]["message"]["content"],
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _make_request_with_retry(
|
|
77
|
-
url: str,
|
|
78
|
-
headers: dict,
|
|
79
|
-
payload: dict,
|
|
80
|
-
provider_name: str,
|
|
81
|
-
max_retries: int,
|
|
82
|
-
quiet: bool,
|
|
83
|
-
response_parser: callable,
|
|
84
|
-
) -> str:
|
|
85
|
-
"""Make HTTP request with retry logic and common error handling."""
|
|
86
|
-
if quiet:
|
|
87
|
-
spinner = None
|
|
88
|
-
else:
|
|
89
|
-
spinner = Halo(text=f"Generating commit message with {provider_name}...", spinner="dots")
|
|
90
|
-
spinner.start()
|
|
91
|
-
|
|
92
|
-
last_error = None
|
|
93
|
-
retry_count = 0
|
|
94
|
-
|
|
95
|
-
while retry_count < max_retries:
|
|
96
|
-
try:
|
|
97
|
-
logger.debug(f"Trying with {provider_name} (attempt {retry_count + 1}/{max_retries})")
|
|
98
|
-
|
|
99
|
-
with httpx.Client(timeout=30.0) as client:
|
|
100
|
-
response = client.post(url, headers=headers, json=payload)
|
|
101
|
-
response.raise_for_status()
|
|
102
|
-
|
|
103
|
-
response_data = response.json()
|
|
104
|
-
message = response_parser(response_data)
|
|
105
|
-
|
|
106
|
-
if spinner:
|
|
107
|
-
spinner.succeed(f"Generated commit message with {provider_name}")
|
|
108
|
-
|
|
109
|
-
return message
|
|
110
|
-
|
|
111
|
-
except Exception as e:
|
|
112
|
-
last_error = e
|
|
113
|
-
retry_count += 1
|
|
114
|
-
|
|
115
|
-
if retry_count == max_retries:
|
|
116
|
-
logger.warning(f"Error generating commit message: {e}. Giving up.")
|
|
117
|
-
break
|
|
118
|
-
|
|
119
|
-
wait_time = 2**retry_count
|
|
120
|
-
logger.warning(f"Error generating commit message: {e}. Retrying in {wait_time}s...")
|
|
121
|
-
if spinner:
|
|
122
|
-
for i in range(wait_time, 0, -1):
|
|
123
|
-
spinner.text = f"Retry {retry_count}/{max_retries} in {i}s..."
|
|
124
|
-
time.sleep(1)
|
|
125
|
-
else:
|
|
126
|
-
time.sleep(wait_time)
|
|
127
|
-
|
|
128
|
-
if spinner:
|
|
129
|
-
spinner.fail(f"Failed to generate commit message with {provider_name}")
|
|
130
|
-
|
|
131
|
-
error_type = _classify_error(str(last_error))
|
|
132
|
-
raise AIError(
|
|
133
|
-
f"Failed to generate commit message after {max_retries} attempts: {last_error}", error_type=error_type
|
|
134
|
-
)
|
|
14
|
+
raise AIError.model_error("OPENAI_API_KEY not found in environment variables")
|
|
15
|
+
|
|
16
|
+
url = "https://api.openai.com/v1/chat/completions"
|
|
17
|
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
18
|
+
|
|
19
|
+
data = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens}
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
response = httpx.post(url, headers=headers, json=data, timeout=120)
|
|
23
|
+
response.raise_for_status()
|
|
24
|
+
response_data = response.json()
|
|
25
|
+
return response_data["choices"][0]["message"]["content"]
|
|
26
|
+
except httpx.HTTPStatusError as e:
|
|
27
|
+
raise AIError.model_error(f"OpenAI API error: {e.response.status_code} - {e.response.text}") from e
|
|
28
|
+
except Exception as e:
|
|
29
|
+
raise AIError.model_error(f"Error calling OpenAI API: {str(e)}") from e
|