claude-dev-cli 0.13.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of claude-dev-cli might be problematic. Click here for more details.
- claude_dev_cli/__init__.py +1 -1
- claude_dev_cli/cli.py +234 -16
- claude_dev_cli/config.py +95 -9
- claude_dev_cli/core.py +48 -53
- claude_dev_cli/multi_file_handler.py +400 -16
- claude_dev_cli/providers/__init__.py +28 -0
- claude_dev_cli/providers/anthropic.py +216 -0
- claude_dev_cli/providers/base.py +168 -0
- claude_dev_cli/providers/factory.py +114 -0
- claude_dev_cli/providers/ollama.py +283 -0
- claude_dev_cli/providers/openai.py +268 -0
- {claude_dev_cli-0.13.0.dist-info → claude_dev_cli-0.16.0.dist-info}/METADATA +297 -15
- {claude_dev_cli-0.13.0.dist-info → claude_dev_cli-0.16.0.dist-info}/RECORD +17 -11
- {claude_dev_cli-0.13.0.dist-info → claude_dev_cli-0.16.0.dist-info}/WHEEL +0 -0
- {claude_dev_cli-0.13.0.dist-info → claude_dev_cli-0.16.0.dist-info}/entry_points.txt +0 -0
- {claude_dev_cli-0.13.0.dist-info → claude_dev_cli-0.16.0.dist-info}/licenses/LICENSE +0 -0
- {claude_dev_cli-0.13.0.dist-info → claude_dev_cli-0.16.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""Ollama local AI provider implementation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Iterator, Optional, List, Dict, Any
|
|
6
|
+
|
|
7
|
+
from claude_dev_cli.providers.base import (
|
|
8
|
+
AIProvider,
|
|
9
|
+
ModelInfo,
|
|
10
|
+
UsageInfo,
|
|
11
|
+
ProviderConnectionError,
|
|
12
|
+
ModelNotFoundError,
|
|
13
|
+
ProviderError,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# Try to import requests, handle gracefully if not installed
|
|
17
|
+
try:
|
|
18
|
+
import requests
|
|
19
|
+
REQUESTS_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
REQUESTS_AVAILABLE = False
|
|
22
|
+
requests = None # type: ignore
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OllamaProvider(AIProvider):
|
|
26
|
+
"""Ollama local model provider implementation.
|
|
27
|
+
|
|
28
|
+
Provides zero-cost local inference with models like:
|
|
29
|
+
- llama2, mistral, codellama, phi, deepseek-coder, mixtral, etc.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
# Known model templates for common models
|
|
33
|
+
KNOWN_MODELS = {
|
|
34
|
+
"mistral": {
|
|
35
|
+
"display_name": "Mistral 7B",
|
|
36
|
+
"context_window": 8192,
|
|
37
|
+
"capabilities": ["chat", "code"]
|
|
38
|
+
},
|
|
39
|
+
"llama2": {
|
|
40
|
+
"display_name": "Llama 2",
|
|
41
|
+
"context_window": 4096,
|
|
42
|
+
"capabilities": ["chat", "general"]
|
|
43
|
+
},
|
|
44
|
+
"codellama": {
|
|
45
|
+
"display_name": "Code Llama",
|
|
46
|
+
"context_window": 16384,
|
|
47
|
+
"capabilities": ["code", "chat"]
|
|
48
|
+
},
|
|
49
|
+
"phi": {
|
|
50
|
+
"display_name": "Phi",
|
|
51
|
+
"context_window": 2048,
|
|
52
|
+
"capabilities": ["chat", "reasoning"]
|
|
53
|
+
},
|
|
54
|
+
"deepseek-coder": {
|
|
55
|
+
"display_name": "DeepSeek Coder",
|
|
56
|
+
"context_window": 16384,
|
|
57
|
+
"capabilities": ["code"]
|
|
58
|
+
},
|
|
59
|
+
"mixtral": {
|
|
60
|
+
"display_name": "Mixtral 8x7B",
|
|
61
|
+
"context_window": 32768,
|
|
62
|
+
"capabilities": ["chat", "code", "analysis"]
|
|
63
|
+
},
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
def __init__(self, config: Any) -> None:
|
|
67
|
+
"""Initialize Ollama provider.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
config: ProviderConfig with optional base_url
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
RuntimeError: If requests library is not installed
|
|
74
|
+
"""
|
|
75
|
+
super().__init__(config)
|
|
76
|
+
|
|
77
|
+
if not REQUESTS_AVAILABLE:
|
|
78
|
+
raise RuntimeError(
|
|
79
|
+
"Ollama provider requires the requests package. "
|
|
80
|
+
"Install it with: pip install 'claude-dev-cli[ollama]'"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# No API key needed for local!
|
|
84
|
+
self.base_url = getattr(config, 'base_url', None) or "http://localhost:11434"
|
|
85
|
+
self.timeout = 120 # Local inference can be slow
|
|
86
|
+
self.last_usage: Optional[UsageInfo] = None
|
|
87
|
+
|
|
88
|
+
def call(
|
|
89
|
+
self,
|
|
90
|
+
prompt: str,
|
|
91
|
+
system_prompt: Optional[str] = None,
|
|
92
|
+
model: Optional[str] = None,
|
|
93
|
+
max_tokens: Optional[int] = None,
|
|
94
|
+
temperature: float = 1.0,
|
|
95
|
+
) -> str:
|
|
96
|
+
"""Make a synchronous call to Ollama API."""
|
|
97
|
+
model = model or "mistral"
|
|
98
|
+
|
|
99
|
+
# Build messages for chat endpoint
|
|
100
|
+
messages = []
|
|
101
|
+
if system_prompt:
|
|
102
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
103
|
+
messages.append({"role": "user", "content": prompt})
|
|
104
|
+
|
|
105
|
+
start_time = datetime.utcnow()
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
# Use chat endpoint (preferred for conversational use)
|
|
109
|
+
response = requests.post(
|
|
110
|
+
f"{self.base_url}/api/chat",
|
|
111
|
+
json={
|
|
112
|
+
"model": model,
|
|
113
|
+
"messages": messages,
|
|
114
|
+
"stream": False,
|
|
115
|
+
"options": {
|
|
116
|
+
"temperature": temperature,
|
|
117
|
+
"num_predict": max_tokens or 4096,
|
|
118
|
+
}
|
|
119
|
+
},
|
|
120
|
+
timeout=self.timeout
|
|
121
|
+
)
|
|
122
|
+
response.raise_for_status()
|
|
123
|
+
data = response.json()
|
|
124
|
+
|
|
125
|
+
except requests.ConnectionError:
|
|
126
|
+
raise ProviderConnectionError(
|
|
127
|
+
"Cannot connect to Ollama. Is it running? Start with: ollama serve",
|
|
128
|
+
provider="ollama"
|
|
129
|
+
)
|
|
130
|
+
except requests.Timeout:
|
|
131
|
+
raise ProviderError(
|
|
132
|
+
f"Ollama request timed out after {self.timeout}s. "
|
|
133
|
+
"Local models can be slow - consider using a smaller model."
|
|
134
|
+
)
|
|
135
|
+
except requests.HTTPError as e:
|
|
136
|
+
if e.response.status_code == 404:
|
|
137
|
+
raise ModelNotFoundError(
|
|
138
|
+
f"Model '{model}' not found. Pull it with: ollama pull {model}",
|
|
139
|
+
model=model,
|
|
140
|
+
provider="ollama"
|
|
141
|
+
)
|
|
142
|
+
raise ProviderError(f"Ollama API error: {e}")
|
|
143
|
+
|
|
144
|
+
end_time = datetime.utcnow()
|
|
145
|
+
duration_ms = int((end_time - start_time).total_seconds() * 1000)
|
|
146
|
+
|
|
147
|
+
# Extract response
|
|
148
|
+
response_text = data.get("message", {}).get("content", "")
|
|
149
|
+
|
|
150
|
+
# Get token counts if available
|
|
151
|
+
prompt_tokens = data.get("prompt_eval_count", 0)
|
|
152
|
+
completion_tokens = data.get("eval_count", 0)
|
|
153
|
+
|
|
154
|
+
# Store usage info (always zero cost!)
|
|
155
|
+
self.last_usage = UsageInfo(
|
|
156
|
+
input_tokens=prompt_tokens,
|
|
157
|
+
output_tokens=completion_tokens,
|
|
158
|
+
duration_ms=duration_ms,
|
|
159
|
+
model=model,
|
|
160
|
+
timestamp=end_time,
|
|
161
|
+
cost_usd=0.0 # Free!
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return response_text
|
|
165
|
+
|
|
166
|
+
def call_streaming(
|
|
167
|
+
self,
|
|
168
|
+
prompt: str,
|
|
169
|
+
system_prompt: Optional[str] = None,
|
|
170
|
+
model: Optional[str] = None,
|
|
171
|
+
max_tokens: Optional[int] = None,
|
|
172
|
+
temperature: float = 1.0,
|
|
173
|
+
) -> Iterator[str]:
|
|
174
|
+
"""Make a streaming call to Ollama API."""
|
|
175
|
+
model = model or "mistral"
|
|
176
|
+
|
|
177
|
+
# Build messages
|
|
178
|
+
messages = []
|
|
179
|
+
if system_prompt:
|
|
180
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
181
|
+
messages.append({"role": "user", "content": prompt})
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
response = requests.post(
|
|
185
|
+
f"{self.base_url}/api/chat",
|
|
186
|
+
json={
|
|
187
|
+
"model": model,
|
|
188
|
+
"messages": messages,
|
|
189
|
+
"stream": True,
|
|
190
|
+
"options": {
|
|
191
|
+
"temperature": temperature,
|
|
192
|
+
"num_predict": max_tokens or 4096,
|
|
193
|
+
}
|
|
194
|
+
},
|
|
195
|
+
stream=True,
|
|
196
|
+
timeout=self.timeout
|
|
197
|
+
)
|
|
198
|
+
response.raise_for_status()
|
|
199
|
+
|
|
200
|
+
# Stream chunks
|
|
201
|
+
for line in response.iter_lines():
|
|
202
|
+
if line:
|
|
203
|
+
chunk = json.loads(line)
|
|
204
|
+
if "message" in chunk:
|
|
205
|
+
content = chunk["message"].get("content", "")
|
|
206
|
+
if content:
|
|
207
|
+
yield content
|
|
208
|
+
|
|
209
|
+
except requests.ConnectionError:
|
|
210
|
+
raise ProviderConnectionError(
|
|
211
|
+
"Cannot connect to Ollama. Is it running? Start with: ollama serve",
|
|
212
|
+
provider="ollama"
|
|
213
|
+
)
|
|
214
|
+
except requests.HTTPError as e:
|
|
215
|
+
if e.response.status_code == 404:
|
|
216
|
+
raise ModelNotFoundError(
|
|
217
|
+
f"Model '{model}' not found. Pull it with: ollama pull {model}",
|
|
218
|
+
model=model,
|
|
219
|
+
provider="ollama"
|
|
220
|
+
)
|
|
221
|
+
raise ProviderError(f"Ollama API error: {e}")
|
|
222
|
+
|
|
223
|
+
def list_models(self) -> List[ModelInfo]:
|
|
224
|
+
"""List available Ollama models."""
|
|
225
|
+
try:
|
|
226
|
+
response = requests.get(
|
|
227
|
+
f"{self.base_url}/api/tags",
|
|
228
|
+
timeout=10
|
|
229
|
+
)
|
|
230
|
+
response.raise_for_status()
|
|
231
|
+
data = response.json()
|
|
232
|
+
|
|
233
|
+
models = []
|
|
234
|
+
for model_data in data.get("models", []):
|
|
235
|
+
model_name = model_data.get("name", "")
|
|
236
|
+
base_name = model_name.split(":")[0] # Remove tag (e.g., "mistral:7b" -> "mistral")
|
|
237
|
+
|
|
238
|
+
# Get info from known models or use defaults
|
|
239
|
+
info = self.KNOWN_MODELS.get(base_name, {
|
|
240
|
+
"display_name": model_name,
|
|
241
|
+
"context_window": 4096,
|
|
242
|
+
"capabilities": ["chat"]
|
|
243
|
+
})
|
|
244
|
+
|
|
245
|
+
models.append(ModelInfo(
|
|
246
|
+
model_id=model_name,
|
|
247
|
+
display_name=info.get("display_name", model_name),
|
|
248
|
+
provider="ollama",
|
|
249
|
+
context_window=info.get("context_window", 4096),
|
|
250
|
+
input_price_per_mtok=0.0, # Free!
|
|
251
|
+
output_price_per_mtok=0.0, # Free!
|
|
252
|
+
capabilities=info.get("capabilities", ["chat"])
|
|
253
|
+
))
|
|
254
|
+
|
|
255
|
+
return models
|
|
256
|
+
|
|
257
|
+
except requests.ConnectionError:
|
|
258
|
+
raise ProviderConnectionError(
|
|
259
|
+
"Cannot connect to Ollama. Is it running? Start with: ollama serve",
|
|
260
|
+
provider="ollama"
|
|
261
|
+
)
|
|
262
|
+
except Exception as e:
|
|
263
|
+
raise ProviderError(f"Failed to list Ollama models: {e}")
|
|
264
|
+
|
|
265
|
+
def get_last_usage(self) -> Optional[UsageInfo]:
|
|
266
|
+
"""Get usage information from the last API call."""
|
|
267
|
+
return self.last_usage
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def provider_name(self) -> str:
|
|
271
|
+
"""Get the provider's name."""
|
|
272
|
+
return "ollama"
|
|
273
|
+
|
|
274
|
+
def test_connection(self) -> bool:
|
|
275
|
+
"""Test if Ollama is accessible."""
|
|
276
|
+
try:
|
|
277
|
+
response = requests.get(
|
|
278
|
+
f"{self.base_url}/api/version",
|
|
279
|
+
timeout=5
|
|
280
|
+
)
|
|
281
|
+
return response.status_code == 200
|
|
282
|
+
except Exception:
|
|
283
|
+
return False
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""OpenAI (GPT-4, GPT-3.5) AI provider implementation."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Iterator, Optional, List, Dict, Any
|
|
5
|
+
|
|
6
|
+
from claude_dev_cli.providers.base import (
|
|
7
|
+
AIProvider,
|
|
8
|
+
ModelInfo,
|
|
9
|
+
UsageInfo,
|
|
10
|
+
InsufficientCreditsError,
|
|
11
|
+
ProviderConnectionError,
|
|
12
|
+
ModelNotFoundError,
|
|
13
|
+
ProviderError,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# Try to import openai, handle gracefully if not installed
|
|
17
|
+
try:
|
|
18
|
+
from openai import OpenAI, APIError, AuthenticationError, RateLimitError, NotFoundError
|
|
19
|
+
OPENAI_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
OPENAI_AVAILABLE = False
|
|
22
|
+
OpenAI = None # type: ignore
|
|
23
|
+
APIError = Exception # type: ignore
|
|
24
|
+
AuthenticationError = Exception # type: ignore
|
|
25
|
+
RateLimitError = Exception # type: ignore
|
|
26
|
+
NotFoundError = Exception # type: ignore
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OpenAIProvider(AIProvider):
|
|
30
|
+
"""OpenAI GPT API provider implementation."""
|
|
31
|
+
|
|
32
|
+
# Known OpenAI models with their capabilities
|
|
33
|
+
KNOWN_MODELS = {
|
|
34
|
+
"gpt-4-turbo-preview": {
|
|
35
|
+
"display_name": "GPT-4 Turbo",
|
|
36
|
+
"context_window": 128000,
|
|
37
|
+
"input_price": 10.00,
|
|
38
|
+
"output_price": 30.00,
|
|
39
|
+
"capabilities": ["chat", "code", "analysis", "vision"]
|
|
40
|
+
},
|
|
41
|
+
"gpt-4-turbo": {
|
|
42
|
+
"display_name": "GPT-4 Turbo",
|
|
43
|
+
"context_window": 128000,
|
|
44
|
+
"input_price": 10.00,
|
|
45
|
+
"output_price": 30.00,
|
|
46
|
+
"capabilities": ["chat", "code", "analysis", "vision"]
|
|
47
|
+
},
|
|
48
|
+
"gpt-4": {
|
|
49
|
+
"display_name": "GPT-4",
|
|
50
|
+
"context_window": 8192,
|
|
51
|
+
"input_price": 30.00,
|
|
52
|
+
"output_price": 60.00,
|
|
53
|
+
"capabilities": ["chat", "code", "analysis"]
|
|
54
|
+
},
|
|
55
|
+
"gpt-3.5-turbo": {
|
|
56
|
+
"display_name": "GPT-3.5 Turbo",
|
|
57
|
+
"context_window": 16385,
|
|
58
|
+
"input_price": 0.50,
|
|
59
|
+
"output_price": 1.50,
|
|
60
|
+
"capabilities": ["chat", "code"]
|
|
61
|
+
},
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
def __init__(self, config: Any) -> None:
|
|
65
|
+
"""Initialize OpenAI provider.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
config: ProviderConfig with api_key and optional base_url
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
RuntimeError: If openai SDK is not installed
|
|
72
|
+
"""
|
|
73
|
+
super().__init__(config)
|
|
74
|
+
|
|
75
|
+
if not OPENAI_AVAILABLE:
|
|
76
|
+
raise RuntimeError(
|
|
77
|
+
"OpenAI provider requires the openai package. "
|
|
78
|
+
"Install it with: pip install 'claude-dev-cli[openai]'"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Extract API key from config
|
|
82
|
+
api_key = getattr(config, 'api_key', None)
|
|
83
|
+
if not api_key:
|
|
84
|
+
raise ValueError("OpenAI provider requires api_key in config")
|
|
85
|
+
|
|
86
|
+
# Get optional base_url for custom endpoints (Azure, proxies, etc.)
|
|
87
|
+
base_url = getattr(config, 'base_url', None)
|
|
88
|
+
|
|
89
|
+
# Initialize OpenAI client
|
|
90
|
+
client_kwargs: Dict[str, Any] = {"api_key": api_key}
|
|
91
|
+
if base_url:
|
|
92
|
+
client_kwargs["base_url"] = base_url
|
|
93
|
+
|
|
94
|
+
self.client = OpenAI(**client_kwargs)
|
|
95
|
+
self.last_usage: Optional[UsageInfo] = None
|
|
96
|
+
|
|
97
|
+
def call(
|
|
98
|
+
self,
|
|
99
|
+
prompt: str,
|
|
100
|
+
system_prompt: Optional[str] = None,
|
|
101
|
+
model: Optional[str] = None,
|
|
102
|
+
max_tokens: Optional[int] = None,
|
|
103
|
+
temperature: float = 1.0,
|
|
104
|
+
) -> str:
|
|
105
|
+
"""Make a synchronous call to OpenAI API."""
|
|
106
|
+
model = model or "gpt-4-turbo-preview"
|
|
107
|
+
max_tokens = max_tokens or 4096
|
|
108
|
+
|
|
109
|
+
# Build messages array (OpenAI format)
|
|
110
|
+
messages: List[Dict[str, str]] = []
|
|
111
|
+
if system_prompt:
|
|
112
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
113
|
+
messages.append({"role": "user", "content": prompt})
|
|
114
|
+
|
|
115
|
+
start_time = datetime.utcnow()
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
response = self.client.chat.completions.create(
|
|
119
|
+
model=model,
|
|
120
|
+
messages=messages, # type: ignore
|
|
121
|
+
max_tokens=max_tokens,
|
|
122
|
+
temperature=temperature
|
|
123
|
+
)
|
|
124
|
+
except AuthenticationError as e:
|
|
125
|
+
raise ProviderConnectionError(
|
|
126
|
+
f"OpenAI authentication failed: {e}",
|
|
127
|
+
provider="openai"
|
|
128
|
+
)
|
|
129
|
+
except RateLimitError as e:
|
|
130
|
+
raise ProviderError(f"OpenAI rate limit exceeded: {e}")
|
|
131
|
+
except NotFoundError as e:
|
|
132
|
+
raise ModelNotFoundError(
|
|
133
|
+
f"Model not found: {model}",
|
|
134
|
+
model=model,
|
|
135
|
+
provider="openai"
|
|
136
|
+
)
|
|
137
|
+
except APIError as e:
|
|
138
|
+
# Check for quota/billing issues
|
|
139
|
+
error_message = str(e).lower()
|
|
140
|
+
if "quota" in error_message or "billing" in error_message or "insufficient" in error_message:
|
|
141
|
+
raise InsufficientCreditsError(
|
|
142
|
+
f"Insufficient OpenAI credits: {e}",
|
|
143
|
+
provider="openai"
|
|
144
|
+
)
|
|
145
|
+
raise ProviderConnectionError(
|
|
146
|
+
f"OpenAI API error: {e}",
|
|
147
|
+
provider="openai"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
end_time = datetime.utcnow()
|
|
151
|
+
duration_ms = int((end_time - start_time).total_seconds() * 1000)
|
|
152
|
+
|
|
153
|
+
# Calculate cost
|
|
154
|
+
model_info = self.KNOWN_MODELS.get(model, {})
|
|
155
|
+
input_price = model_info.get("input_price", 0.0)
|
|
156
|
+
output_price = model_info.get("output_price", 0.0)
|
|
157
|
+
|
|
158
|
+
input_tokens = response.usage.prompt_tokens if response.usage else 0
|
|
159
|
+
output_tokens = response.usage.completion_tokens if response.usage else 0
|
|
160
|
+
|
|
161
|
+
input_cost = (input_tokens / 1_000_000) * input_price
|
|
162
|
+
output_cost = (output_tokens / 1_000_000) * output_price
|
|
163
|
+
total_cost = input_cost + output_cost
|
|
164
|
+
|
|
165
|
+
# Store usage info
|
|
166
|
+
self.last_usage = UsageInfo(
|
|
167
|
+
input_tokens=input_tokens,
|
|
168
|
+
output_tokens=output_tokens,
|
|
169
|
+
duration_ms=duration_ms,
|
|
170
|
+
model=model,
|
|
171
|
+
timestamp=end_time,
|
|
172
|
+
cost_usd=total_cost
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Extract text from response
|
|
176
|
+
if response.choices and len(response.choices) > 0:
|
|
177
|
+
message = response.choices[0].message
|
|
178
|
+
return message.content or ""
|
|
179
|
+
return ""
|
|
180
|
+
|
|
181
|
+
def call_streaming(
|
|
182
|
+
self,
|
|
183
|
+
prompt: str,
|
|
184
|
+
system_prompt: Optional[str] = None,
|
|
185
|
+
model: Optional[str] = None,
|
|
186
|
+
max_tokens: Optional[int] = None,
|
|
187
|
+
temperature: float = 1.0,
|
|
188
|
+
) -> Iterator[str]:
|
|
189
|
+
"""Make a streaming call to OpenAI API."""
|
|
190
|
+
model = model or "gpt-4-turbo-preview"
|
|
191
|
+
max_tokens = max_tokens or 4096
|
|
192
|
+
|
|
193
|
+
# Build messages array
|
|
194
|
+
messages: List[Dict[str, str]] = []
|
|
195
|
+
if system_prompt:
|
|
196
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
197
|
+
messages.append({"role": "user", "content": prompt})
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
stream = self.client.chat.completions.create(
|
|
201
|
+
model=model,
|
|
202
|
+
messages=messages, # type: ignore
|
|
203
|
+
max_tokens=max_tokens,
|
|
204
|
+
temperature=temperature,
|
|
205
|
+
stream=True
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
for chunk in stream:
|
|
209
|
+
if chunk.choices and len(chunk.choices) > 0:
|
|
210
|
+
delta = chunk.choices[0].delta
|
|
211
|
+
if delta.content:
|
|
212
|
+
yield delta.content
|
|
213
|
+
|
|
214
|
+
except AuthenticationError as e:
|
|
215
|
+
raise ProviderConnectionError(
|
|
216
|
+
f"OpenAI authentication failed: {e}",
|
|
217
|
+
provider="openai"
|
|
218
|
+
)
|
|
219
|
+
except RateLimitError as e:
|
|
220
|
+
raise ProviderError(f"OpenAI rate limit exceeded: {e}")
|
|
221
|
+
except APIError as e:
|
|
222
|
+
error_message = str(e).lower()
|
|
223
|
+
if "quota" in error_message or "billing" in error_message:
|
|
224
|
+
raise InsufficientCreditsError(
|
|
225
|
+
f"Insufficient OpenAI credits: {e}",
|
|
226
|
+
provider="openai"
|
|
227
|
+
)
|
|
228
|
+
raise ProviderConnectionError(
|
|
229
|
+
f"OpenAI API error: {e}",
|
|
230
|
+
provider="openai"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def list_models(self) -> List[ModelInfo]:
|
|
234
|
+
"""List available OpenAI models."""
|
|
235
|
+
models = []
|
|
236
|
+
for model_id, info in self.KNOWN_MODELS.items():
|
|
237
|
+
models.append(ModelInfo(
|
|
238
|
+
model_id=model_id,
|
|
239
|
+
display_name=info["display_name"],
|
|
240
|
+
provider="openai",
|
|
241
|
+
context_window=info["context_window"],
|
|
242
|
+
input_price_per_mtok=info["input_price"],
|
|
243
|
+
output_price_per_mtok=info["output_price"],
|
|
244
|
+
capabilities=info["capabilities"]
|
|
245
|
+
))
|
|
246
|
+
return models
|
|
247
|
+
|
|
248
|
+
def get_last_usage(self) -> Optional[UsageInfo]:
|
|
249
|
+
"""Get usage information from the last API call."""
|
|
250
|
+
return self.last_usage
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def provider_name(self) -> str:
|
|
254
|
+
"""Get the provider's name."""
|
|
255
|
+
return "openai"
|
|
256
|
+
|
|
257
|
+
def test_connection(self) -> bool:
|
|
258
|
+
"""Test if the OpenAI API is accessible."""
|
|
259
|
+
try:
|
|
260
|
+
# Make a minimal API call to test credentials
|
|
261
|
+
response = self.client.chat.completions.create(
|
|
262
|
+
model="gpt-3.5-turbo",
|
|
263
|
+
messages=[{"role": "user", "content": "test"}],
|
|
264
|
+
max_tokens=5
|
|
265
|
+
)
|
|
266
|
+
return True
|
|
267
|
+
except Exception:
|
|
268
|
+
return False
|