gitwit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gitwit/__init__.py +4 -0
- gitwit/ai.py +324 -0
- gitwit/cli.py +520 -0
- gitwit/config.py +169 -0
- gitwit/git.py +312 -0
- gitwit/license.py +124 -0
- gitwit/prompts.py +161 -0
- gitwit-0.1.0.dist-info/METADATA +201 -0
- gitwit-0.1.0.dist-info/RECORD +12 -0
- gitwit-0.1.0.dist-info/WHEEL +4 -0
- gitwit-0.1.0.dist-info/entry_points.txt +2 -0
- gitwit-0.1.0.dist-info/licenses/LICENSE +21 -0
gitwit/__init__.py
ADDED
gitwit/ai.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
"""AI provider implementations for GitWit."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import time
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from .config import get_api_key, get_model, get_provider
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AIError(Exception):
|
|
14
|
+
"""Exception raised for AI-related errors."""
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RateLimitError(AIError):
|
|
19
|
+
"""Exception raised when rate limited."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AuthenticationError(AIError):
|
|
24
|
+
"""Exception raised for authentication failures."""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AIProvider(ABC):
|
|
29
|
+
"""Abstract base class for AI providers."""
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
async def generate(self, system_prompt: str, user_prompt: str) -> str:
|
|
33
|
+
"""
|
|
34
|
+
Generate a response from the AI model.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
system_prompt: The system prompt.
|
|
38
|
+
user_prompt: The user prompt.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The generated text response.
|
|
42
|
+
"""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def name(self) -> str:
|
|
48
|
+
"""The provider name."""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class GroqProvider(AIProvider):
|
|
53
|
+
"""Groq API provider (OpenAI-compatible)."""
|
|
54
|
+
|
|
55
|
+
BASE_URL = "https://api.groq.com/openai/v1/chat/completions"
|
|
56
|
+
DEFAULT_MODEL = "llama-3.3-70b-versatile"
|
|
57
|
+
|
|
58
|
+
def __init__(self, api_key: str, model: str | None = None):
|
|
59
|
+
self.api_key = api_key
|
|
60
|
+
self.model = model or self.DEFAULT_MODEL
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def name(self) -> str:
|
|
64
|
+
return "Groq"
|
|
65
|
+
|
|
66
|
+
async def generate(self, system_prompt: str, user_prompt: str) -> str:
|
|
67
|
+
"""Generate a response using Groq API."""
|
|
68
|
+
headers = {
|
|
69
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
70
|
+
"Content-Type": "application/json",
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
payload = {
|
|
74
|
+
"model": self.model,
|
|
75
|
+
"messages": [
|
|
76
|
+
{"role": "system", "content": system_prompt},
|
|
77
|
+
{"role": "user", "content": user_prompt},
|
|
78
|
+
],
|
|
79
|
+
"temperature": 0.3, # Low temperature for consistent commit messages
|
|
80
|
+
"max_tokens": 256,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
84
|
+
try:
|
|
85
|
+
response = await client.post(
|
|
86
|
+
self.BASE_URL,
|
|
87
|
+
headers=headers,
|
|
88
|
+
json=payload,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if response.status_code == 401:
|
|
92
|
+
raise AuthenticationError(
|
|
93
|
+
"Invalid Groq API key. Get one at https://console.groq.com"
|
|
94
|
+
)
|
|
95
|
+
elif response.status_code == 429:
|
|
96
|
+
raise RateLimitError(
|
|
97
|
+
"Groq rate limit exceeded. Wait a moment and try again."
|
|
98
|
+
)
|
|
99
|
+
elif response.status_code != 200:
|
|
100
|
+
raise AIError(
|
|
101
|
+
f"Groq API error ({response.status_code}): {response.text}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
data = response.json()
|
|
105
|
+
return data["choices"][0]["message"]["content"].strip()
|
|
106
|
+
|
|
107
|
+
except httpx.TimeoutException:
|
|
108
|
+
raise AIError("Groq API request timed out")
|
|
109
|
+
except httpx.RequestError as e:
|
|
110
|
+
raise AIError(f"Network error: {e}")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class GeminiProvider(AIProvider):
|
|
114
|
+
"""Google Gemini API provider."""
|
|
115
|
+
|
|
116
|
+
BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
|
117
|
+
DEFAULT_MODEL = "gemini-1.5-flash"
|
|
118
|
+
|
|
119
|
+
def __init__(self, api_key: str, model: str | None = None):
|
|
120
|
+
self.api_key = api_key
|
|
121
|
+
self.model = model or self.DEFAULT_MODEL
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def name(self) -> str:
|
|
125
|
+
return "Gemini"
|
|
126
|
+
|
|
127
|
+
async def generate(self, system_prompt: str, user_prompt: str) -> str:
|
|
128
|
+
"""Generate a response using Gemini API."""
|
|
129
|
+
url = f"{self.BASE_URL}/{self.model}:generateContent"
|
|
130
|
+
|
|
131
|
+
payload = {
|
|
132
|
+
"contents": [
|
|
133
|
+
{
|
|
134
|
+
"parts": [
|
|
135
|
+
{"text": f"{system_prompt}\n\n{user_prompt}"}
|
|
136
|
+
]
|
|
137
|
+
}
|
|
138
|
+
],
|
|
139
|
+
"generationConfig": {
|
|
140
|
+
"temperature": 0.3,
|
|
141
|
+
"maxOutputTokens": 256,
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
146
|
+
try:
|
|
147
|
+
response = await client.post(
|
|
148
|
+
url,
|
|
149
|
+
params={"key": self.api_key},
|
|
150
|
+
json=payload,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if response.status_code == 401 or response.status_code == 403:
|
|
154
|
+
raise AuthenticationError(
|
|
155
|
+
"Invalid Gemini API key. Get one at https://aistudio.google.com"
|
|
156
|
+
)
|
|
157
|
+
elif response.status_code == 429:
|
|
158
|
+
raise RateLimitError(
|
|
159
|
+
"Gemini rate limit exceeded. Wait a moment and try again."
|
|
160
|
+
)
|
|
161
|
+
elif response.status_code != 200:
|
|
162
|
+
raise AIError(
|
|
163
|
+
f"Gemini API error ({response.status_code}): {response.text}"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
data = response.json()
|
|
167
|
+
return data["candidates"][0]["content"]["parts"][0]["text"].strip()
|
|
168
|
+
|
|
169
|
+
except httpx.TimeoutException:
|
|
170
|
+
raise AIError("Gemini API request timed out")
|
|
171
|
+
except httpx.RequestError as e:
|
|
172
|
+
raise AIError(f"Network error: {e}")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class OllamaProvider(AIProvider):
|
|
176
|
+
"""Ollama local API provider."""
|
|
177
|
+
|
|
178
|
+
BASE_URL = "http://localhost:11434/api/generate"
|
|
179
|
+
DEFAULT_MODEL = "llama3.2"
|
|
180
|
+
|
|
181
|
+
def __init__(self, model: str | None = None, base_url: str | None = None):
|
|
182
|
+
self.model = model or self.DEFAULT_MODEL
|
|
183
|
+
self.base_url = base_url or self.BASE_URL
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def name(self) -> str:
|
|
187
|
+
return "Ollama"
|
|
188
|
+
|
|
189
|
+
async def generate(self, system_prompt: str, user_prompt: str) -> str:
|
|
190
|
+
"""Generate a response using Ollama API."""
|
|
191
|
+
payload = {
|
|
192
|
+
"model": self.model,
|
|
193
|
+
"prompt": f"{system_prompt}\n\n{user_prompt}",
|
|
194
|
+
"stream": False,
|
|
195
|
+
"options": {
|
|
196
|
+
"temperature": 0.3,
|
|
197
|
+
"num_predict": 256,
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
202
|
+
try:
|
|
203
|
+
response = await client.post(
|
|
204
|
+
self.base_url,
|
|
205
|
+
json=payload,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if response.status_code == 404:
|
|
209
|
+
raise AIError(
|
|
210
|
+
f"Model '{self.model}' not found. "
|
|
211
|
+
f"Run 'ollama pull {self.model}' first."
|
|
212
|
+
)
|
|
213
|
+
elif response.status_code != 200:
|
|
214
|
+
raise AIError(
|
|
215
|
+
f"Ollama API error ({response.status_code}): {response.text}"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
data = response.json()
|
|
219
|
+
return data["response"].strip()
|
|
220
|
+
|
|
221
|
+
except httpx.ConnectError:
|
|
222
|
+
raise AIError(
|
|
223
|
+
"Cannot connect to Ollama. "
|
|
224
|
+
"Make sure Ollama is running: https://ollama.ai"
|
|
225
|
+
)
|
|
226
|
+
except httpx.TimeoutException:
|
|
227
|
+
raise AIError("Ollama request timed out (model may be loading)")
|
|
228
|
+
except httpx.RequestError as e:
|
|
229
|
+
raise AIError(f"Network error: {e}")
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def get_ai_provider() -> AIProvider:
|
|
233
|
+
"""
|
|
234
|
+
Get the configured AI provider instance.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
An AI provider instance based on configuration.
|
|
238
|
+
|
|
239
|
+
Raises:
|
|
240
|
+
AIError: If configuration is invalid.
|
|
241
|
+
"""
|
|
242
|
+
provider_name = get_provider()
|
|
243
|
+
api_key = get_api_key()
|
|
244
|
+
model = get_model()
|
|
245
|
+
|
|
246
|
+
if provider_name == "groq":
|
|
247
|
+
if not api_key:
|
|
248
|
+
raise AuthenticationError(
|
|
249
|
+
"Groq API key not configured. Run:\n"
|
|
250
|
+
" gitwit config set api-key YOUR_GROQ_KEY\n\n"
|
|
251
|
+
"Get a free key at: https://console.groq.com"
|
|
252
|
+
)
|
|
253
|
+
return GroqProvider(api_key, model or None)
|
|
254
|
+
|
|
255
|
+
elif provider_name == "gemini":
|
|
256
|
+
if not api_key:
|
|
257
|
+
raise AuthenticationError(
|
|
258
|
+
"Gemini API key not configured. Run:\n"
|
|
259
|
+
" gitwit config set api-key YOUR_GEMINI_KEY\n\n"
|
|
260
|
+
"Get a free key at: https://aistudio.google.com"
|
|
261
|
+
)
|
|
262
|
+
return GeminiProvider(api_key, model or None)
|
|
263
|
+
|
|
264
|
+
elif provider_name == "ollama":
|
|
265
|
+
return OllamaProvider(model or None)
|
|
266
|
+
|
|
267
|
+
else:
|
|
268
|
+
raise AIError(
|
|
269
|
+
f"Unknown provider: {provider_name}. "
|
|
270
|
+
f"Supported: groq, gemini, ollama"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
async def generate_with_retry(
|
|
275
|
+
provider: AIProvider,
|
|
276
|
+
system_prompt: str,
|
|
277
|
+
user_prompt: str,
|
|
278
|
+
max_retries: int = 3,
|
|
279
|
+
initial_delay: float = 1.0,
|
|
280
|
+
) -> str:
|
|
281
|
+
"""
|
|
282
|
+
Generate a response with automatic retry on rate limits.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
provider: The AI provider to use.
|
|
286
|
+
system_prompt: The system prompt.
|
|
287
|
+
user_prompt: The user prompt.
|
|
288
|
+
max_retries: Maximum number of retries.
|
|
289
|
+
initial_delay: Initial delay between retries (exponential backoff).
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
The generated text response.
|
|
293
|
+
"""
|
|
294
|
+
last_error = None
|
|
295
|
+
delay = initial_delay
|
|
296
|
+
|
|
297
|
+
for attempt in range(max_retries + 1):
|
|
298
|
+
try:
|
|
299
|
+
return await provider.generate(system_prompt, user_prompt)
|
|
300
|
+
except RateLimitError as e:
|
|
301
|
+
last_error = e
|
|
302
|
+
if attempt < max_retries:
|
|
303
|
+
await asyncio.sleep(delay)
|
|
304
|
+
delay *= 2 # Exponential backoff
|
|
305
|
+
continue
|
|
306
|
+
except (AuthenticationError, AIError):
|
|
307
|
+
raise
|
|
308
|
+
|
|
309
|
+
raise last_error or AIError("Failed after retries")
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def generate_sync(system_prompt: str, user_prompt: str) -> str:
|
|
313
|
+
"""
|
|
314
|
+
Synchronous wrapper for AI generation.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
system_prompt: The system prompt.
|
|
318
|
+
user_prompt: The user prompt.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
The generated text response.
|
|
322
|
+
"""
|
|
323
|
+
provider = get_ai_provider()
|
|
324
|
+
return asyncio.run(generate_with_retry(provider, system_prompt, user_prompt))
|