gitwit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gitwit/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ """GitWit - AI-powered git commit messages using free APIs."""
2
+
3
+ __version__ = "0.1.0"
4
+ __author__ = "GitWit Contributors"
gitwit/ai.py ADDED
@@ -0,0 +1,324 @@
1
+ """AI provider implementations for GitWit."""
2
+
3
+ import asyncio
4
+ import time
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any
7
+
8
+ import httpx
9
+
10
+ from .config import get_api_key, get_model, get_provider
11
+
12
+
13
+ class AIError(Exception):
14
+ """Exception raised for AI-related errors."""
15
+ pass
16
+
17
+
18
+ class RateLimitError(AIError):
19
+ """Exception raised when rate limited."""
20
+ pass
21
+
22
+
23
+ class AuthenticationError(AIError):
24
+ """Exception raised for authentication failures."""
25
+ pass
26
+
27
+
28
+ class AIProvider(ABC):
29
+ """Abstract base class for AI providers."""
30
+
31
+ @abstractmethod
32
+ async def generate(self, system_prompt: str, user_prompt: str) -> str:
33
+ """
34
+ Generate a response from the AI model.
35
+
36
+ Args:
37
+ system_prompt: The system prompt.
38
+ user_prompt: The user prompt.
39
+
40
+ Returns:
41
+ The generated text response.
42
+ """
43
+ pass
44
+
45
+ @property
46
+ @abstractmethod
47
+ def name(self) -> str:
48
+ """The provider name."""
49
+ pass
50
+
51
+
52
+ class GroqProvider(AIProvider):
53
+ """Groq API provider (OpenAI-compatible)."""
54
+
55
+ BASE_URL = "https://api.groq.com/openai/v1/chat/completions"
56
+ DEFAULT_MODEL = "llama-3.3-70b-versatile"
57
+
58
+ def __init__(self, api_key: str, model: str | None = None):
59
+ self.api_key = api_key
60
+ self.model = model or self.DEFAULT_MODEL
61
+
62
+ @property
63
+ def name(self) -> str:
64
+ return "Groq"
65
+
66
+ async def generate(self, system_prompt: str, user_prompt: str) -> str:
67
+ """Generate a response using Groq API."""
68
+ headers = {
69
+ "Authorization": f"Bearer {self.api_key}",
70
+ "Content-Type": "application/json",
71
+ }
72
+
73
+ payload = {
74
+ "model": self.model,
75
+ "messages": [
76
+ {"role": "system", "content": system_prompt},
77
+ {"role": "user", "content": user_prompt},
78
+ ],
79
+ "temperature": 0.3, # Low temperature for consistent commit messages
80
+ "max_tokens": 256,
81
+ }
82
+
83
+ async with httpx.AsyncClient(timeout=30.0) as client:
84
+ try:
85
+ response = await client.post(
86
+ self.BASE_URL,
87
+ headers=headers,
88
+ json=payload,
89
+ )
90
+
91
+ if response.status_code == 401:
92
+ raise AuthenticationError(
93
+ "Invalid Groq API key. Get one at https://console.groq.com"
94
+ )
95
+ elif response.status_code == 429:
96
+ raise RateLimitError(
97
+ "Groq rate limit exceeded. Wait a moment and try again."
98
+ )
99
+ elif response.status_code != 200:
100
+ raise AIError(
101
+ f"Groq API error ({response.status_code}): {response.text}"
102
+ )
103
+
104
+ data = response.json()
105
+ return data["choices"][0]["message"]["content"].strip()
106
+
107
+ except httpx.TimeoutException:
108
+ raise AIError("Groq API request timed out")
109
+ except httpx.RequestError as e:
110
+ raise AIError(f"Network error: {e}")
111
+
112
+
113
+ class GeminiProvider(AIProvider):
114
+ """Google Gemini API provider."""
115
+
116
+ BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
117
+ DEFAULT_MODEL = "gemini-1.5-flash"
118
+
119
+ def __init__(self, api_key: str, model: str | None = None):
120
+ self.api_key = api_key
121
+ self.model = model or self.DEFAULT_MODEL
122
+
123
+ @property
124
+ def name(self) -> str:
125
+ return "Gemini"
126
+
127
+ async def generate(self, system_prompt: str, user_prompt: str) -> str:
128
+ """Generate a response using Gemini API."""
129
+ url = f"{self.BASE_URL}/{self.model}:generateContent"
130
+
131
+ payload = {
132
+ "contents": [
133
+ {
134
+ "parts": [
135
+ {"text": f"{system_prompt}\n\n{user_prompt}"}
136
+ ]
137
+ }
138
+ ],
139
+ "generationConfig": {
140
+ "temperature": 0.3,
141
+ "maxOutputTokens": 256,
142
+ }
143
+ }
144
+
145
+ async with httpx.AsyncClient(timeout=30.0) as client:
146
+ try:
147
+ response = await client.post(
148
+ url,
149
+ params={"key": self.api_key},
150
+ json=payload,
151
+ )
152
+
153
+ if response.status_code == 401 or response.status_code == 403:
154
+ raise AuthenticationError(
155
+ "Invalid Gemini API key. Get one at https://aistudio.google.com"
156
+ )
157
+ elif response.status_code == 429:
158
+ raise RateLimitError(
159
+ "Gemini rate limit exceeded. Wait a moment and try again."
160
+ )
161
+ elif response.status_code != 200:
162
+ raise AIError(
163
+ f"Gemini API error ({response.status_code}): {response.text}"
164
+ )
165
+
166
+ data = response.json()
167
+ return data["candidates"][0]["content"]["parts"][0]["text"].strip()
168
+
169
+ except httpx.TimeoutException:
170
+ raise AIError("Gemini API request timed out")
171
+ except httpx.RequestError as e:
172
+ raise AIError(f"Network error: {e}")
173
+
174
+
175
+ class OllamaProvider(AIProvider):
176
+ """Ollama local API provider."""
177
+
178
+ BASE_URL = "http://localhost:11434/api/generate"
179
+ DEFAULT_MODEL = "llama3.2"
180
+
181
+ def __init__(self, model: str | None = None, base_url: str | None = None):
182
+ self.model = model or self.DEFAULT_MODEL
183
+ self.base_url = base_url or self.BASE_URL
184
+
185
+ @property
186
+ def name(self) -> str:
187
+ return "Ollama"
188
+
189
+ async def generate(self, system_prompt: str, user_prompt: str) -> str:
190
+ """Generate a response using Ollama API."""
191
+ payload = {
192
+ "model": self.model,
193
+ "prompt": f"{system_prompt}\n\n{user_prompt}",
194
+ "stream": False,
195
+ "options": {
196
+ "temperature": 0.3,
197
+ "num_predict": 256,
198
+ }
199
+ }
200
+
201
+ async with httpx.AsyncClient(timeout=60.0) as client:
202
+ try:
203
+ response = await client.post(
204
+ self.base_url,
205
+ json=payload,
206
+ )
207
+
208
+ if response.status_code == 404:
209
+ raise AIError(
210
+ f"Model '{self.model}' not found. "
211
+ f"Run 'ollama pull {self.model}' first."
212
+ )
213
+ elif response.status_code != 200:
214
+ raise AIError(
215
+ f"Ollama API error ({response.status_code}): {response.text}"
216
+ )
217
+
218
+ data = response.json()
219
+ return data["response"].strip()
220
+
221
+ except httpx.ConnectError:
222
+ raise AIError(
223
+ "Cannot connect to Ollama. "
224
+ "Make sure Ollama is running: https://ollama.ai"
225
+ )
226
+ except httpx.TimeoutException:
227
+ raise AIError("Ollama request timed out (model may be loading)")
228
+ except httpx.RequestError as e:
229
+ raise AIError(f"Network error: {e}")
230
+
231
+
232
+ def get_ai_provider() -> AIProvider:
233
+ """
234
+ Get the configured AI provider instance.
235
+
236
+ Returns:
237
+ An AI provider instance based on configuration.
238
+
239
+ Raises:
240
+ AIError: If configuration is invalid.
241
+ """
242
+ provider_name = get_provider()
243
+ api_key = get_api_key()
244
+ model = get_model()
245
+
246
+ if provider_name == "groq":
247
+ if not api_key:
248
+ raise AuthenticationError(
249
+ "Groq API key not configured. Run:\n"
250
+ " gitwit config set api-key YOUR_GROQ_KEY\n\n"
251
+ "Get a free key at: https://console.groq.com"
252
+ )
253
+ return GroqProvider(api_key, model or None)
254
+
255
+ elif provider_name == "gemini":
256
+ if not api_key:
257
+ raise AuthenticationError(
258
+ "Gemini API key not configured. Run:\n"
259
+ " gitwit config set api-key YOUR_GEMINI_KEY\n\n"
260
+ "Get a free key at: https://aistudio.google.com"
261
+ )
262
+ return GeminiProvider(api_key, model or None)
263
+
264
+ elif provider_name == "ollama":
265
+ return OllamaProvider(model or None)
266
+
267
+ else:
268
+ raise AIError(
269
+ f"Unknown provider: {provider_name}. "
270
+ f"Supported: groq, gemini, ollama"
271
+ )
272
+
273
+
274
+ async def generate_with_retry(
275
+ provider: AIProvider,
276
+ system_prompt: str,
277
+ user_prompt: str,
278
+ max_retries: int = 3,
279
+ initial_delay: float = 1.0,
280
+ ) -> str:
281
+ """
282
+ Generate a response with automatic retry on rate limits.
283
+
284
+ Args:
285
+ provider: The AI provider to use.
286
+ system_prompt: The system prompt.
287
+ user_prompt: The user prompt.
288
+ max_retries: Maximum number of retries.
289
+ initial_delay: Initial delay between retries (exponential backoff).
290
+
291
+ Returns:
292
+ The generated text response.
293
+ """
294
+ last_error = None
295
+ delay = initial_delay
296
+
297
+ for attempt in range(max_retries + 1):
298
+ try:
299
+ return await provider.generate(system_prompt, user_prompt)
300
+ except RateLimitError as e:
301
+ last_error = e
302
+ if attempt < max_retries:
303
+ await asyncio.sleep(delay)
304
+ delay *= 2 # Exponential backoff
305
+ continue
306
+ except (AuthenticationError, AIError):
307
+ raise
308
+
309
+ raise last_error or AIError("Failed after retries")
310
+
311
+
312
+ def generate_sync(system_prompt: str, user_prompt: str) -> str:
313
+ """
314
+ Synchronous wrapper for AI generation.
315
+
316
+ Args:
317
+ system_prompt: The system prompt.
318
+ user_prompt: The user prompt.
319
+
320
+ Returns:
321
+ The generated text response.
322
+ """
323
+ provider = get_ai_provider()
324
+ return asyncio.run(generate_with_retry(provider, system_prompt, user_prompt))