claude-dev-cli 0.13.3__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of claude-dev-cli might be problematic. Click here for more details.
- claude_dev_cli/__init__.py +1 -1
- claude_dev_cli/cli.py +231 -13
- claude_dev_cli/config.py +95 -9
- claude_dev_cli/core.py +48 -53
- claude_dev_cli/providers/__init__.py +28 -0
- claude_dev_cli/providers/anthropic.py +216 -0
- claude_dev_cli/providers/base.py +168 -0
- claude_dev_cli/providers/factory.py +114 -0
- claude_dev_cli/providers/ollama.py +283 -0
- claude_dev_cli/providers/openai.py +268 -0
- claude_dev_cli/workflows.py +49 -1
- {claude_dev_cli-0.13.3.dist-info → claude_dev_cli-0.16.1.dist-info}/METADATA +196 -15
- {claude_dev_cli-0.13.3.dist-info → claude_dev_cli-0.16.1.dist-info}/RECORD +17 -11
- {claude_dev_cli-0.13.3.dist-info → claude_dev_cli-0.16.1.dist-info}/WHEEL +0 -0
- {claude_dev_cli-0.13.3.dist-info → claude_dev_cli-0.16.1.dist-info}/entry_points.txt +0 -0
- {claude_dev_cli-0.13.3.dist-info → claude_dev_cli-0.16.1.dist-info}/licenses/LICENSE +0 -0
- {claude_dev_cli-0.13.3.dist-info → claude_dev_cli-0.16.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""Ollama local AI provider implementation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Iterator, Optional, List, Dict, Any
|
|
6
|
+
|
|
7
|
+
from claude_dev_cli.providers.base import (
|
|
8
|
+
AIProvider,
|
|
9
|
+
ModelInfo,
|
|
10
|
+
UsageInfo,
|
|
11
|
+
ProviderConnectionError,
|
|
12
|
+
ModelNotFoundError,
|
|
13
|
+
ProviderError,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# Try to import requests, handle gracefully if not installed
|
|
17
|
+
try:
|
|
18
|
+
import requests
|
|
19
|
+
REQUESTS_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
REQUESTS_AVAILABLE = False
|
|
22
|
+
requests = None # type: ignore
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OllamaProvider(AIProvider):
|
|
26
|
+
"""Ollama local model provider implementation.
|
|
27
|
+
|
|
28
|
+
Provides zero-cost local inference with models like:
|
|
29
|
+
- llama2, mistral, codellama, phi, deepseek-coder, mixtral, etc.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
# Known model templates for common models
|
|
33
|
+
KNOWN_MODELS = {
|
|
34
|
+
"mistral": {
|
|
35
|
+
"display_name": "Mistral 7B",
|
|
36
|
+
"context_window": 8192,
|
|
37
|
+
"capabilities": ["chat", "code"]
|
|
38
|
+
},
|
|
39
|
+
"llama2": {
|
|
40
|
+
"display_name": "Llama 2",
|
|
41
|
+
"context_window": 4096,
|
|
42
|
+
"capabilities": ["chat", "general"]
|
|
43
|
+
},
|
|
44
|
+
"codellama": {
|
|
45
|
+
"display_name": "Code Llama",
|
|
46
|
+
"context_window": 16384,
|
|
47
|
+
"capabilities": ["code", "chat"]
|
|
48
|
+
},
|
|
49
|
+
"phi": {
|
|
50
|
+
"display_name": "Phi",
|
|
51
|
+
"context_window": 2048,
|
|
52
|
+
"capabilities": ["chat", "reasoning"]
|
|
53
|
+
},
|
|
54
|
+
"deepseek-coder": {
|
|
55
|
+
"display_name": "DeepSeek Coder",
|
|
56
|
+
"context_window": 16384,
|
|
57
|
+
"capabilities": ["code"]
|
|
58
|
+
},
|
|
59
|
+
"mixtral": {
|
|
60
|
+
"display_name": "Mixtral 8x7B",
|
|
61
|
+
"context_window": 32768,
|
|
62
|
+
"capabilities": ["chat", "code", "analysis"]
|
|
63
|
+
},
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
def __init__(self, config: Any) -> None:
|
|
67
|
+
"""Initialize Ollama provider.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
config: ProviderConfig with optional base_url
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
RuntimeError: If requests library is not installed
|
|
74
|
+
"""
|
|
75
|
+
super().__init__(config)
|
|
76
|
+
|
|
77
|
+
if not REQUESTS_AVAILABLE:
|
|
78
|
+
raise RuntimeError(
|
|
79
|
+
"Ollama provider requires the requests package. "
|
|
80
|
+
"Install it with: pip install 'claude-dev-cli[ollama]'"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# No API key needed for local!
|
|
84
|
+
self.base_url = getattr(config, 'base_url', None) or "http://localhost:11434"
|
|
85
|
+
self.timeout = 120 # Local inference can be slow
|
|
86
|
+
self.last_usage: Optional[UsageInfo] = None
|
|
87
|
+
|
|
88
|
+
def call(
|
|
89
|
+
self,
|
|
90
|
+
prompt: str,
|
|
91
|
+
system_prompt: Optional[str] = None,
|
|
92
|
+
model: Optional[str] = None,
|
|
93
|
+
max_tokens: Optional[int] = None,
|
|
94
|
+
temperature: float = 1.0,
|
|
95
|
+
) -> str:
|
|
96
|
+
"""Make a synchronous call to Ollama API."""
|
|
97
|
+
model = model or "mistral"
|
|
98
|
+
|
|
99
|
+
# Build messages for chat endpoint
|
|
100
|
+
messages = []
|
|
101
|
+
if system_prompt:
|
|
102
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
103
|
+
messages.append({"role": "user", "content": prompt})
|
|
104
|
+
|
|
105
|
+
start_time = datetime.utcnow()
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
# Use chat endpoint (preferred for conversational use)
|
|
109
|
+
response = requests.post(
|
|
110
|
+
f"{self.base_url}/api/chat",
|
|
111
|
+
json={
|
|
112
|
+
"model": model,
|
|
113
|
+
"messages": messages,
|
|
114
|
+
"stream": False,
|
|
115
|
+
"options": {
|
|
116
|
+
"temperature": temperature,
|
|
117
|
+
"num_predict": max_tokens or 4096,
|
|
118
|
+
}
|
|
119
|
+
},
|
|
120
|
+
timeout=self.timeout
|
|
121
|
+
)
|
|
122
|
+
response.raise_for_status()
|
|
123
|
+
data = response.json()
|
|
124
|
+
|
|
125
|
+
except requests.ConnectionError:
|
|
126
|
+
raise ProviderConnectionError(
|
|
127
|
+
"Cannot connect to Ollama. Is it running? Start with: ollama serve",
|
|
128
|
+
provider="ollama"
|
|
129
|
+
)
|
|
130
|
+
except requests.Timeout:
|
|
131
|
+
raise ProviderError(
|
|
132
|
+
f"Ollama request timed out after {self.timeout}s. "
|
|
133
|
+
"Local models can be slow - consider using a smaller model."
|
|
134
|
+
)
|
|
135
|
+
except requests.HTTPError as e:
|
|
136
|
+
if e.response.status_code == 404:
|
|
137
|
+
raise ModelNotFoundError(
|
|
138
|
+
f"Model '{model}' not found. Pull it with: ollama pull {model}",
|
|
139
|
+
model=model,
|
|
140
|
+
provider="ollama"
|
|
141
|
+
)
|
|
142
|
+
raise ProviderError(f"Ollama API error: {e}")
|
|
143
|
+
|
|
144
|
+
end_time = datetime.utcnow()
|
|
145
|
+
duration_ms = int((end_time - start_time).total_seconds() * 1000)
|
|
146
|
+
|
|
147
|
+
# Extract response
|
|
148
|
+
response_text = data.get("message", {}).get("content", "")
|
|
149
|
+
|
|
150
|
+
# Get token counts if available
|
|
151
|
+
prompt_tokens = data.get("prompt_eval_count", 0)
|
|
152
|
+
completion_tokens = data.get("eval_count", 0)
|
|
153
|
+
|
|
154
|
+
# Store usage info (always zero cost!)
|
|
155
|
+
self.last_usage = UsageInfo(
|
|
156
|
+
input_tokens=prompt_tokens,
|
|
157
|
+
output_tokens=completion_tokens,
|
|
158
|
+
duration_ms=duration_ms,
|
|
159
|
+
model=model,
|
|
160
|
+
timestamp=end_time,
|
|
161
|
+
cost_usd=0.0 # Free!
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return response_text
|
|
165
|
+
|
|
166
|
+
def call_streaming(
|
|
167
|
+
self,
|
|
168
|
+
prompt: str,
|
|
169
|
+
system_prompt: Optional[str] = None,
|
|
170
|
+
model: Optional[str] = None,
|
|
171
|
+
max_tokens: Optional[int] = None,
|
|
172
|
+
temperature: float = 1.0,
|
|
173
|
+
) -> Iterator[str]:
|
|
174
|
+
"""Make a streaming call to Ollama API."""
|
|
175
|
+
model = model or "mistral"
|
|
176
|
+
|
|
177
|
+
# Build messages
|
|
178
|
+
messages = []
|
|
179
|
+
if system_prompt:
|
|
180
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
181
|
+
messages.append({"role": "user", "content": prompt})
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
response = requests.post(
|
|
185
|
+
f"{self.base_url}/api/chat",
|
|
186
|
+
json={
|
|
187
|
+
"model": model,
|
|
188
|
+
"messages": messages,
|
|
189
|
+
"stream": True,
|
|
190
|
+
"options": {
|
|
191
|
+
"temperature": temperature,
|
|
192
|
+
"num_predict": max_tokens or 4096,
|
|
193
|
+
}
|
|
194
|
+
},
|
|
195
|
+
stream=True,
|
|
196
|
+
timeout=self.timeout
|
|
197
|
+
)
|
|
198
|
+
response.raise_for_status()
|
|
199
|
+
|
|
200
|
+
# Stream chunks
|
|
201
|
+
for line in response.iter_lines():
|
|
202
|
+
if line:
|
|
203
|
+
chunk = json.loads(line)
|
|
204
|
+
if "message" in chunk:
|
|
205
|
+
content = chunk["message"].get("content", "")
|
|
206
|
+
if content:
|
|
207
|
+
yield content
|
|
208
|
+
|
|
209
|
+
except requests.ConnectionError:
|
|
210
|
+
raise ProviderConnectionError(
|
|
211
|
+
"Cannot connect to Ollama. Is it running? Start with: ollama serve",
|
|
212
|
+
provider="ollama"
|
|
213
|
+
)
|
|
214
|
+
except requests.HTTPError as e:
|
|
215
|
+
if e.response.status_code == 404:
|
|
216
|
+
raise ModelNotFoundError(
|
|
217
|
+
f"Model '{model}' not found. Pull it with: ollama pull {model}",
|
|
218
|
+
model=model,
|
|
219
|
+
provider="ollama"
|
|
220
|
+
)
|
|
221
|
+
raise ProviderError(f"Ollama API error: {e}")
|
|
222
|
+
|
|
223
|
+
def list_models(self) -> List[ModelInfo]:
|
|
224
|
+
"""List available Ollama models."""
|
|
225
|
+
try:
|
|
226
|
+
response = requests.get(
|
|
227
|
+
f"{self.base_url}/api/tags",
|
|
228
|
+
timeout=10
|
|
229
|
+
)
|
|
230
|
+
response.raise_for_status()
|
|
231
|
+
data = response.json()
|
|
232
|
+
|
|
233
|
+
models = []
|
|
234
|
+
for model_data in data.get("models", []):
|
|
235
|
+
model_name = model_data.get("name", "")
|
|
236
|
+
base_name = model_name.split(":")[0] # Remove tag (e.g., "mistral:7b" -> "mistral")
|
|
237
|
+
|
|
238
|
+
# Get info from known models or use defaults
|
|
239
|
+
info = self.KNOWN_MODELS.get(base_name, {
|
|
240
|
+
"display_name": model_name,
|
|
241
|
+
"context_window": 4096,
|
|
242
|
+
"capabilities": ["chat"]
|
|
243
|
+
})
|
|
244
|
+
|
|
245
|
+
models.append(ModelInfo(
|
|
246
|
+
model_id=model_name,
|
|
247
|
+
display_name=info.get("display_name", model_name),
|
|
248
|
+
provider="ollama",
|
|
249
|
+
context_window=info.get("context_window", 4096),
|
|
250
|
+
input_price_per_mtok=0.0, # Free!
|
|
251
|
+
output_price_per_mtok=0.0, # Free!
|
|
252
|
+
capabilities=info.get("capabilities", ["chat"])
|
|
253
|
+
))
|
|
254
|
+
|
|
255
|
+
return models
|
|
256
|
+
|
|
257
|
+
except requests.ConnectionError:
|
|
258
|
+
raise ProviderConnectionError(
|
|
259
|
+
"Cannot connect to Ollama. Is it running? Start with: ollama serve",
|
|
260
|
+
provider="ollama"
|
|
261
|
+
)
|
|
262
|
+
except Exception as e:
|
|
263
|
+
raise ProviderError(f"Failed to list Ollama models: {e}")
|
|
264
|
+
|
|
265
|
+
def get_last_usage(self) -> Optional[UsageInfo]:
|
|
266
|
+
"""Get usage information from the last API call."""
|
|
267
|
+
return self.last_usage
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def provider_name(self) -> str:
|
|
271
|
+
"""Get the provider's name."""
|
|
272
|
+
return "ollama"
|
|
273
|
+
|
|
274
|
+
def test_connection(self) -> bool:
|
|
275
|
+
"""Test if Ollama is accessible."""
|
|
276
|
+
try:
|
|
277
|
+
response = requests.get(
|
|
278
|
+
f"{self.base_url}/api/version",
|
|
279
|
+
timeout=5
|
|
280
|
+
)
|
|
281
|
+
return response.status_code == 200
|
|
282
|
+
except Exception:
|
|
283
|
+
return False
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""OpenAI (GPT-4, GPT-3.5) AI provider implementation."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Iterator, Optional, List, Dict, Any
|
|
5
|
+
|
|
6
|
+
from claude_dev_cli.providers.base import (
|
|
7
|
+
AIProvider,
|
|
8
|
+
ModelInfo,
|
|
9
|
+
UsageInfo,
|
|
10
|
+
InsufficientCreditsError,
|
|
11
|
+
ProviderConnectionError,
|
|
12
|
+
ModelNotFoundError,
|
|
13
|
+
ProviderError,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# Try to import openai, handle gracefully if not installed
|
|
17
|
+
try:
|
|
18
|
+
from openai import OpenAI, APIError, AuthenticationError, RateLimitError, NotFoundError
|
|
19
|
+
OPENAI_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
OPENAI_AVAILABLE = False
|
|
22
|
+
OpenAI = None # type: ignore
|
|
23
|
+
APIError = Exception # type: ignore
|
|
24
|
+
AuthenticationError = Exception # type: ignore
|
|
25
|
+
RateLimitError = Exception # type: ignore
|
|
26
|
+
NotFoundError = Exception # type: ignore
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OpenAIProvider(AIProvider):
|
|
30
|
+
"""OpenAI GPT API provider implementation."""
|
|
31
|
+
|
|
32
|
+
# Known OpenAI models with their capabilities
|
|
33
|
+
KNOWN_MODELS = {
|
|
34
|
+
"gpt-4-turbo-preview": {
|
|
35
|
+
"display_name": "GPT-4 Turbo",
|
|
36
|
+
"context_window": 128000,
|
|
37
|
+
"input_price": 10.00,
|
|
38
|
+
"output_price": 30.00,
|
|
39
|
+
"capabilities": ["chat", "code", "analysis", "vision"]
|
|
40
|
+
},
|
|
41
|
+
"gpt-4-turbo": {
|
|
42
|
+
"display_name": "GPT-4 Turbo",
|
|
43
|
+
"context_window": 128000,
|
|
44
|
+
"input_price": 10.00,
|
|
45
|
+
"output_price": 30.00,
|
|
46
|
+
"capabilities": ["chat", "code", "analysis", "vision"]
|
|
47
|
+
},
|
|
48
|
+
"gpt-4": {
|
|
49
|
+
"display_name": "GPT-4",
|
|
50
|
+
"context_window": 8192,
|
|
51
|
+
"input_price": 30.00,
|
|
52
|
+
"output_price": 60.00,
|
|
53
|
+
"capabilities": ["chat", "code", "analysis"]
|
|
54
|
+
},
|
|
55
|
+
"gpt-3.5-turbo": {
|
|
56
|
+
"display_name": "GPT-3.5 Turbo",
|
|
57
|
+
"context_window": 16385,
|
|
58
|
+
"input_price": 0.50,
|
|
59
|
+
"output_price": 1.50,
|
|
60
|
+
"capabilities": ["chat", "code"]
|
|
61
|
+
},
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
def __init__(self, config: Any) -> None:
|
|
65
|
+
"""Initialize OpenAI provider.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
config: ProviderConfig with api_key and optional base_url
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
RuntimeError: If openai SDK is not installed
|
|
72
|
+
"""
|
|
73
|
+
super().__init__(config)
|
|
74
|
+
|
|
75
|
+
if not OPENAI_AVAILABLE:
|
|
76
|
+
raise RuntimeError(
|
|
77
|
+
"OpenAI provider requires the openai package. "
|
|
78
|
+
"Install it with: pip install 'claude-dev-cli[openai]'"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Extract API key from config
|
|
82
|
+
api_key = getattr(config, 'api_key', None)
|
|
83
|
+
if not api_key:
|
|
84
|
+
raise ValueError("OpenAI provider requires api_key in config")
|
|
85
|
+
|
|
86
|
+
# Get optional base_url for custom endpoints (Azure, proxies, etc.)
|
|
87
|
+
base_url = getattr(config, 'base_url', None)
|
|
88
|
+
|
|
89
|
+
# Initialize OpenAI client
|
|
90
|
+
client_kwargs: Dict[str, Any] = {"api_key": api_key}
|
|
91
|
+
if base_url:
|
|
92
|
+
client_kwargs["base_url"] = base_url
|
|
93
|
+
|
|
94
|
+
self.client = OpenAI(**client_kwargs)
|
|
95
|
+
self.last_usage: Optional[UsageInfo] = None
|
|
96
|
+
|
|
97
|
+
def call(
|
|
98
|
+
self,
|
|
99
|
+
prompt: str,
|
|
100
|
+
system_prompt: Optional[str] = None,
|
|
101
|
+
model: Optional[str] = None,
|
|
102
|
+
max_tokens: Optional[int] = None,
|
|
103
|
+
temperature: float = 1.0,
|
|
104
|
+
) -> str:
|
|
105
|
+
"""Make a synchronous call to OpenAI API."""
|
|
106
|
+
model = model or "gpt-4-turbo-preview"
|
|
107
|
+
max_tokens = max_tokens or 4096
|
|
108
|
+
|
|
109
|
+
# Build messages array (OpenAI format)
|
|
110
|
+
messages: List[Dict[str, str]] = []
|
|
111
|
+
if system_prompt:
|
|
112
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
113
|
+
messages.append({"role": "user", "content": prompt})
|
|
114
|
+
|
|
115
|
+
start_time = datetime.utcnow()
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
response = self.client.chat.completions.create(
|
|
119
|
+
model=model,
|
|
120
|
+
messages=messages, # type: ignore
|
|
121
|
+
max_tokens=max_tokens,
|
|
122
|
+
temperature=temperature
|
|
123
|
+
)
|
|
124
|
+
except AuthenticationError as e:
|
|
125
|
+
raise ProviderConnectionError(
|
|
126
|
+
f"OpenAI authentication failed: {e}",
|
|
127
|
+
provider="openai"
|
|
128
|
+
)
|
|
129
|
+
except RateLimitError as e:
|
|
130
|
+
raise ProviderError(f"OpenAI rate limit exceeded: {e}")
|
|
131
|
+
except NotFoundError as e:
|
|
132
|
+
raise ModelNotFoundError(
|
|
133
|
+
f"Model not found: {model}",
|
|
134
|
+
model=model,
|
|
135
|
+
provider="openai"
|
|
136
|
+
)
|
|
137
|
+
except APIError as e:
|
|
138
|
+
# Check for quota/billing issues
|
|
139
|
+
error_message = str(e).lower()
|
|
140
|
+
if "quota" in error_message or "billing" in error_message or "insufficient" in error_message:
|
|
141
|
+
raise InsufficientCreditsError(
|
|
142
|
+
f"Insufficient OpenAI credits: {e}",
|
|
143
|
+
provider="openai"
|
|
144
|
+
)
|
|
145
|
+
raise ProviderConnectionError(
|
|
146
|
+
f"OpenAI API error: {e}",
|
|
147
|
+
provider="openai"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
end_time = datetime.utcnow()
|
|
151
|
+
duration_ms = int((end_time - start_time).total_seconds() * 1000)
|
|
152
|
+
|
|
153
|
+
# Calculate cost
|
|
154
|
+
model_info = self.KNOWN_MODELS.get(model, {})
|
|
155
|
+
input_price = model_info.get("input_price", 0.0)
|
|
156
|
+
output_price = model_info.get("output_price", 0.0)
|
|
157
|
+
|
|
158
|
+
input_tokens = response.usage.prompt_tokens if response.usage else 0
|
|
159
|
+
output_tokens = response.usage.completion_tokens if response.usage else 0
|
|
160
|
+
|
|
161
|
+
input_cost = (input_tokens / 1_000_000) * input_price
|
|
162
|
+
output_cost = (output_tokens / 1_000_000) * output_price
|
|
163
|
+
total_cost = input_cost + output_cost
|
|
164
|
+
|
|
165
|
+
# Store usage info
|
|
166
|
+
self.last_usage = UsageInfo(
|
|
167
|
+
input_tokens=input_tokens,
|
|
168
|
+
output_tokens=output_tokens,
|
|
169
|
+
duration_ms=duration_ms,
|
|
170
|
+
model=model,
|
|
171
|
+
timestamp=end_time,
|
|
172
|
+
cost_usd=total_cost
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Extract text from response
|
|
176
|
+
if response.choices and len(response.choices) > 0:
|
|
177
|
+
message = response.choices[0].message
|
|
178
|
+
return message.content or ""
|
|
179
|
+
return ""
|
|
180
|
+
|
|
181
|
+
def call_streaming(
|
|
182
|
+
self,
|
|
183
|
+
prompt: str,
|
|
184
|
+
system_prompt: Optional[str] = None,
|
|
185
|
+
model: Optional[str] = None,
|
|
186
|
+
max_tokens: Optional[int] = None,
|
|
187
|
+
temperature: float = 1.0,
|
|
188
|
+
) -> Iterator[str]:
|
|
189
|
+
"""Make a streaming call to OpenAI API."""
|
|
190
|
+
model = model or "gpt-4-turbo-preview"
|
|
191
|
+
max_tokens = max_tokens or 4096
|
|
192
|
+
|
|
193
|
+
# Build messages array
|
|
194
|
+
messages: List[Dict[str, str]] = []
|
|
195
|
+
if system_prompt:
|
|
196
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
197
|
+
messages.append({"role": "user", "content": prompt})
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
stream = self.client.chat.completions.create(
|
|
201
|
+
model=model,
|
|
202
|
+
messages=messages, # type: ignore
|
|
203
|
+
max_tokens=max_tokens,
|
|
204
|
+
temperature=temperature,
|
|
205
|
+
stream=True
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
for chunk in stream:
|
|
209
|
+
if chunk.choices and len(chunk.choices) > 0:
|
|
210
|
+
delta = chunk.choices[0].delta
|
|
211
|
+
if delta.content:
|
|
212
|
+
yield delta.content
|
|
213
|
+
|
|
214
|
+
except AuthenticationError as e:
|
|
215
|
+
raise ProviderConnectionError(
|
|
216
|
+
f"OpenAI authentication failed: {e}",
|
|
217
|
+
provider="openai"
|
|
218
|
+
)
|
|
219
|
+
except RateLimitError as e:
|
|
220
|
+
raise ProviderError(f"OpenAI rate limit exceeded: {e}")
|
|
221
|
+
except APIError as e:
|
|
222
|
+
error_message = str(e).lower()
|
|
223
|
+
if "quota" in error_message or "billing" in error_message:
|
|
224
|
+
raise InsufficientCreditsError(
|
|
225
|
+
f"Insufficient OpenAI credits: {e}",
|
|
226
|
+
provider="openai"
|
|
227
|
+
)
|
|
228
|
+
raise ProviderConnectionError(
|
|
229
|
+
f"OpenAI API error: {e}",
|
|
230
|
+
provider="openai"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def list_models(self) -> List[ModelInfo]:
|
|
234
|
+
"""List available OpenAI models."""
|
|
235
|
+
models = []
|
|
236
|
+
for model_id, info in self.KNOWN_MODELS.items():
|
|
237
|
+
models.append(ModelInfo(
|
|
238
|
+
model_id=model_id,
|
|
239
|
+
display_name=info["display_name"],
|
|
240
|
+
provider="openai",
|
|
241
|
+
context_window=info["context_window"],
|
|
242
|
+
input_price_per_mtok=info["input_price"],
|
|
243
|
+
output_price_per_mtok=info["output_price"],
|
|
244
|
+
capabilities=info["capabilities"]
|
|
245
|
+
))
|
|
246
|
+
return models
|
|
247
|
+
|
|
248
|
+
def get_last_usage(self) -> Optional[UsageInfo]:
|
|
249
|
+
"""Get usage information from the last API call."""
|
|
250
|
+
return self.last_usage
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def provider_name(self) -> str:
|
|
254
|
+
"""Get the provider's name."""
|
|
255
|
+
return "openai"
|
|
256
|
+
|
|
257
|
+
def test_connection(self) -> bool:
|
|
258
|
+
"""Test if the OpenAI API is accessible."""
|
|
259
|
+
try:
|
|
260
|
+
# Make a minimal API call to test credentials
|
|
261
|
+
response = self.client.chat.completions.create(
|
|
262
|
+
model="gpt-3.5-turbo",
|
|
263
|
+
messages=[{"role": "user", "content": "test"}],
|
|
264
|
+
max_tokens=5
|
|
265
|
+
)
|
|
266
|
+
return True
|
|
267
|
+
except Exception:
|
|
268
|
+
return False
|
claude_dev_cli/workflows.py
CHANGED
|
@@ -130,6 +130,7 @@ class WorkflowEngine:
|
|
|
130
130
|
generate_tests, code_review, debug_code,
|
|
131
131
|
generate_docs, refactor_code, git_commit_message
|
|
132
132
|
)
|
|
133
|
+
from claude_dev_cli.core import ClaudeClient
|
|
133
134
|
|
|
134
135
|
# Map commands to functions
|
|
135
136
|
command_map = {
|
|
@@ -141,11 +142,23 @@ class WorkflowEngine:
|
|
|
141
142
|
'git commit': git_commit_message,
|
|
142
143
|
}
|
|
143
144
|
|
|
145
|
+
# Handle special commands that need different execution
|
|
146
|
+
if command == 'ask':
|
|
147
|
+
return self._execute_ask_command(interpolated_args)
|
|
148
|
+
elif command in ['generate code', 'generate feature']:
|
|
149
|
+
# These are CLI-only commands that would need file generation logic
|
|
150
|
+
# For now, redirect to shell equivalent
|
|
151
|
+
return StepResult(
|
|
152
|
+
success=False,
|
|
153
|
+
output="",
|
|
154
|
+
error=f"Command '{command}' not yet supported in workflows. Use shell step with 'cdc {command}' instead."
|
|
155
|
+
)
|
|
156
|
+
|
|
144
157
|
if command not in command_map:
|
|
145
158
|
return StepResult(
|
|
146
159
|
success=False,
|
|
147
160
|
output="",
|
|
148
|
-
error=f"Unknown command: {command}"
|
|
161
|
+
error=f"Unknown command: {command}. Supported: {', '.join(command_map.keys())}, ask"
|
|
149
162
|
)
|
|
150
163
|
|
|
151
164
|
try:
|
|
@@ -159,6 +172,8 @@ class WorkflowEngine:
|
|
|
159
172
|
func_args['error_message'] = interpolated_args['error']
|
|
160
173
|
if 'api' in interpolated_args:
|
|
161
174
|
func_args['api_config_name'] = interpolated_args['api']
|
|
175
|
+
if 'model' in interpolated_args:
|
|
176
|
+
func_args['model'] = interpolated_args['model']
|
|
162
177
|
|
|
163
178
|
# Execute
|
|
164
179
|
result = func(**func_args)
|
|
@@ -180,6 +195,39 @@ class WorkflowEngine:
|
|
|
180
195
|
error=str(e)
|
|
181
196
|
)
|
|
182
197
|
|
|
198
|
+
def _execute_ask_command(self, args: Dict[str, Any]) -> StepResult:
|
|
199
|
+
"""Execute an ask command step."""
|
|
200
|
+
try:
|
|
201
|
+
from claude_dev_cli.core import ClaudeClient
|
|
202
|
+
|
|
203
|
+
prompt = args.get('prompt', args.get('question', ''))
|
|
204
|
+
if not prompt:
|
|
205
|
+
return StepResult(
|
|
206
|
+
success=False,
|
|
207
|
+
output="",
|
|
208
|
+
error="ask command requires 'prompt' or 'question' argument"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
api = args.get('api')
|
|
212
|
+
model = args.get('model')
|
|
213
|
+
system = args.get('system')
|
|
214
|
+
|
|
215
|
+
client = ClaudeClient(api_config_name=api)
|
|
216
|
+
result = client.call(prompt, system_prompt=system, model=model)
|
|
217
|
+
|
|
218
|
+
return StepResult(
|
|
219
|
+
success=True,
|
|
220
|
+
output=result,
|
|
221
|
+
metadata={'command': 'ask'}
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
except Exception as e:
|
|
225
|
+
return StepResult(
|
|
226
|
+
success=False,
|
|
227
|
+
output="",
|
|
228
|
+
error=str(e)
|
|
229
|
+
)
|
|
230
|
+
|
|
183
231
|
def _execute_shell_step(self, step: Dict[str, Any], context: WorkflowContext) -> StepResult:
|
|
184
232
|
"""Execute a shell command step."""
|
|
185
233
|
command = step['shell']
|