cat-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cat_stack/__about__.py +10 -0
- cat_stack/__init__.py +128 -0
- cat_stack/_batch.py +1388 -0
- cat_stack/_category_analysis.py +348 -0
- cat_stack/_chunked.py +424 -0
- cat_stack/_embeddings.py +189 -0
- cat_stack/_formatter.py +169 -0
- cat_stack/_providers.py +1048 -0
- cat_stack/_tiebreaker.py +277 -0
- cat_stack/_utils.py +512 -0
- cat_stack/_web_fetch.py +194 -0
- cat_stack/calls/CoVe.py +287 -0
- cat_stack/calls/__init__.py +25 -0
- cat_stack/calls/all_calls.py +622 -0
- cat_stack/calls/image_CoVe.py +386 -0
- cat_stack/calls/image_stepback.py +210 -0
- cat_stack/calls/pdf_CoVe.py +386 -0
- cat_stack/calls/pdf_stepback.py +210 -0
- cat_stack/calls/stepback.py +180 -0
- cat_stack/calls/top_n.py +217 -0
- cat_stack/classify.py +682 -0
- cat_stack/explore.py +111 -0
- cat_stack/extract.py +218 -0
- cat_stack/image_functions.py +2078 -0
- cat_stack/images/circle.png +0 -0
- cat_stack/images/cube.png +0 -0
- cat_stack/images/diamond.png +0 -0
- cat_stack/images/overlapping_pentagons.png +0 -0
- cat_stack/images/rectangles.png +0 -0
- cat_stack/model_reference_list.py +94 -0
- cat_stack/pdf_functions.py +2087 -0
- cat_stack/summarize.py +290 -0
- cat_stack/text_functions.py +1358 -0
- cat_stack/text_functions_ensemble.py +3644 -0
- cat_stack-0.1.0.dist-info/METADATA +150 -0
- cat_stack-0.1.0.dist-info/RECORD +38 -0
- cat_stack-0.1.0.dist-info/WHEEL +4 -0
- cat_stack-0.1.0.dist-info/licenses/LICENSE +672 -0
cat_stack/_providers.py
ADDED
|
@@ -0,0 +1,1048 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified LLM provider infrastructure for CatLLM.
|
|
3
|
+
|
|
4
|
+
This module provides a unified HTTP-based approach for calling multiple LLM providers
|
|
5
|
+
(OpenAI, Anthropic, Google, Mistral, xAI, Perplexity, HuggingFace, and Ollama)
|
|
6
|
+
without requiring provider-specific SDKs.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import time
|
|
11
|
+
import requests
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
# Main client
|
|
15
|
+
"UnifiedLLMClient",
|
|
16
|
+
"PROVIDER_CONFIG",
|
|
17
|
+
# Provider detection
|
|
18
|
+
"detect_provider",
|
|
19
|
+
"_detect_model_source",
|
|
20
|
+
"_detect_huggingface_endpoint",
|
|
21
|
+
# Ollama utilities
|
|
22
|
+
"set_ollama_endpoint",
|
|
23
|
+
"check_ollama_running",
|
|
24
|
+
"list_ollama_models",
|
|
25
|
+
"check_ollama_model",
|
|
26
|
+
"check_system_resources",
|
|
27
|
+
"get_ollama_model_size_estimate",
|
|
28
|
+
"pull_ollama_model",
|
|
29
|
+
"OLLAMA_MODEL_SIZES",
|
|
30
|
+
# Claude Code utilities
|
|
31
|
+
"check_claude_cli_available",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# =============================================================================
|
|
36
|
+
# HuggingFace Endpoint Auto-Detection
|
|
37
|
+
# =============================================================================
|
|
38
|
+
|
|
39
|
+
def _detect_huggingface_endpoint(api_key: str, model: str) -> str:
|
|
40
|
+
"""
|
|
41
|
+
Test which HuggingFace endpoint works for this model.
|
|
42
|
+
Tries generic router first, then Together.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
api_key: HuggingFace API key
|
|
46
|
+
model: Model name to test
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Base URL for the working endpoint (without /chat/completions)
|
|
50
|
+
"""
|
|
51
|
+
endpoints = [
|
|
52
|
+
"https://router.huggingface.co/v1/chat/completions",
|
|
53
|
+
"https://router.huggingface.co/together/v1/chat/completions",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
headers = {
|
|
57
|
+
"Content-Type": "application/json",
|
|
58
|
+
"Authorization": f"Bearer {api_key}"
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
payload = {
|
|
62
|
+
"model": model,
|
|
63
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
64
|
+
"max_tokens": 5
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
for endpoint in endpoints:
|
|
68
|
+
try:
|
|
69
|
+
response = requests.post(endpoint, headers=headers, json=payload, timeout=30)
|
|
70
|
+
if response.status_code == 200:
|
|
71
|
+
# Return the base URL (without /chat/completions)
|
|
72
|
+
return endpoint.replace("/chat/completions", "")
|
|
73
|
+
except Exception:
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
# Default to generic (will fail with informative error)
|
|
77
|
+
return "https://router.huggingface.co/v1"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# =============================================================================
|
|
81
|
+
# Provider Configuration
|
|
82
|
+
# =============================================================================
|
|
83
|
+
|
|
84
|
+
PROVIDER_CONFIG = {
|
|
85
|
+
"openai": {
|
|
86
|
+
"endpoint": "https://api.openai.com/v1/chat/completions",
|
|
87
|
+
"auth_header": "Authorization",
|
|
88
|
+
"auth_prefix": "Bearer ",
|
|
89
|
+
},
|
|
90
|
+
"anthropic": {
|
|
91
|
+
"endpoint": "https://api.anthropic.com/v1/messages",
|
|
92
|
+
"auth_header": "x-api-key",
|
|
93
|
+
"auth_prefix": "",
|
|
94
|
+
},
|
|
95
|
+
"google": {
|
|
96
|
+
"endpoint": "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent",
|
|
97
|
+
"auth_header": "x-goog-api-key",
|
|
98
|
+
"auth_prefix": "",
|
|
99
|
+
},
|
|
100
|
+
"mistral": {
|
|
101
|
+
"endpoint": "https://api.mistral.ai/v1/chat/completions",
|
|
102
|
+
"auth_header": "Authorization",
|
|
103
|
+
"auth_prefix": "Bearer ",
|
|
104
|
+
},
|
|
105
|
+
"perplexity": {
|
|
106
|
+
"endpoint": "https://api.perplexity.ai/chat/completions",
|
|
107
|
+
"auth_header": "Authorization",
|
|
108
|
+
"auth_prefix": "Bearer ",
|
|
109
|
+
},
|
|
110
|
+
"xai": {
|
|
111
|
+
"endpoint": "https://api.x.ai/v1/chat/completions",
|
|
112
|
+
"auth_header": "Authorization",
|
|
113
|
+
"auth_prefix": "Bearer ",
|
|
114
|
+
},
|
|
115
|
+
"huggingface": {
|
|
116
|
+
"endpoint": "https://router.huggingface.co/v1/chat/completions",
|
|
117
|
+
"auth_header": "Authorization",
|
|
118
|
+
"auth_prefix": "Bearer ",
|
|
119
|
+
},
|
|
120
|
+
"huggingface-together": {
|
|
121
|
+
"endpoint": "https://router.huggingface.co/together/v1/chat/completions",
|
|
122
|
+
"auth_header": "Authorization",
|
|
123
|
+
"auth_prefix": "Bearer ",
|
|
124
|
+
},
|
|
125
|
+
"ollama": {
|
|
126
|
+
"endpoint": "http://localhost:11434/v1/chat/completions",
|
|
127
|
+
"auth_header": None, # No auth required for local Ollama
|
|
128
|
+
"auth_prefix": "",
|
|
129
|
+
},
|
|
130
|
+
"claude-code": {
|
|
131
|
+
"endpoint": None, # Uses CLI subprocess, not HTTP
|
|
132
|
+
"auth_header": None,
|
|
133
|
+
"auth_prefix": "",
|
|
134
|
+
},
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# =============================================================================
|
|
139
|
+
# Unified API Client
|
|
140
|
+
# =============================================================================
|
|
141
|
+
|
|
142
|
+
class UnifiedLLMClient:
|
|
143
|
+
"""A unified client for calling various LLM providers via HTTP."""
|
|
144
|
+
|
|
145
|
+
def __init__(self, provider: str, api_key: str, model: str):
|
|
146
|
+
self.provider = provider.lower()
|
|
147
|
+
self.api_key = api_key
|
|
148
|
+
self.model = model
|
|
149
|
+
|
|
150
|
+
# Auto-detect HuggingFace endpoint
|
|
151
|
+
if self.provider == "huggingface":
|
|
152
|
+
detected_url = _detect_huggingface_endpoint(api_key, model)
|
|
153
|
+
if "together" in detected_url:
|
|
154
|
+
self.provider = "huggingface-together"
|
|
155
|
+
|
|
156
|
+
if self.provider not in PROVIDER_CONFIG:
|
|
157
|
+
raise ValueError(f"Unsupported provider: {provider}. "
|
|
158
|
+
f"Supported: {list(PROVIDER_CONFIG.keys())}")
|
|
159
|
+
|
|
160
|
+
self.config = PROVIDER_CONFIG[self.provider]
|
|
161
|
+
|
|
162
|
+
def _get_endpoint(self) -> str:
|
|
163
|
+
"""Get the API endpoint, substituting model if needed."""
|
|
164
|
+
endpoint = self.config["endpoint"]
|
|
165
|
+
if "{model}" in endpoint:
|
|
166
|
+
endpoint = endpoint.format(model=self.model)
|
|
167
|
+
return endpoint
|
|
168
|
+
|
|
169
|
+
def _get_headers(self) -> dict:
|
|
170
|
+
"""Build request headers for the provider."""
|
|
171
|
+
headers = {"Content-Type": "application/json"}
|
|
172
|
+
auth_header = self.config["auth_header"]
|
|
173
|
+
auth_prefix = self.config["auth_prefix"]
|
|
174
|
+
|
|
175
|
+
# Some providers (like Ollama) don't require auth
|
|
176
|
+
if auth_header is not None:
|
|
177
|
+
headers[auth_header] = f"{auth_prefix}{self.api_key}"
|
|
178
|
+
|
|
179
|
+
# Anthropic requires additional headers
|
|
180
|
+
if self.provider == "anthropic":
|
|
181
|
+
headers["anthropic-version"] = "2023-06-01"
|
|
182
|
+
|
|
183
|
+
return headers
|
|
184
|
+
|
|
185
|
+
def _build_payload(
|
|
186
|
+
self,
|
|
187
|
+
messages: list,
|
|
188
|
+
json_schema: dict = None,
|
|
189
|
+
creativity: float = None,
|
|
190
|
+
max_tokens: int = 4096,
|
|
191
|
+
thinking_budget: int = None,
|
|
192
|
+
force_json: bool = True,
|
|
193
|
+
) -> dict:
|
|
194
|
+
"""Build the request payload for the specific provider."""
|
|
195
|
+
|
|
196
|
+
if self.provider == "anthropic":
|
|
197
|
+
return self._build_anthropic_payload(messages, json_schema, creativity, max_tokens, thinking_budget)
|
|
198
|
+
elif self.provider == "google":
|
|
199
|
+
return self._build_google_payload(messages, json_schema, creativity, thinking_budget, force_json)
|
|
200
|
+
elif self.provider == "openai":
|
|
201
|
+
return self._build_openai_payload(messages, json_schema, creativity, force_json, thinking_budget)
|
|
202
|
+
elif self.provider in ("huggingface", "huggingface-together"):
|
|
203
|
+
# HuggingFace needs thinking_budget to disable thinking on models that reason by default
|
|
204
|
+
return self._build_openai_payload(messages, json_schema, creativity, force_json, thinking_budget)
|
|
205
|
+
else:
|
|
206
|
+
# Other OpenAI-compatible providers (xai, mistral, etc.)
|
|
207
|
+
return self._build_openai_payload(messages, json_schema, creativity, force_json)
|
|
208
|
+
|
|
209
|
+
def _build_openai_payload(
|
|
210
|
+
self,
|
|
211
|
+
messages: list,
|
|
212
|
+
json_schema: dict = None,
|
|
213
|
+
creativity: float = None,
|
|
214
|
+
force_json: bool = True,
|
|
215
|
+
thinking_budget: int = None,
|
|
216
|
+
) -> dict:
|
|
217
|
+
"""Build payload for OpenAI-compatible APIs.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
force_json: If False and no json_schema, don't set response_format (for text responses)
|
|
221
|
+
thinking_budget: For OpenAI models, maps to reasoning_effort:
|
|
222
|
+
0 or None → "minimal", >0 → "high"
|
|
223
|
+
"""
|
|
224
|
+
payload = {
|
|
225
|
+
"model": self.model,
|
|
226
|
+
"messages": messages,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
# Structured output
|
|
230
|
+
# Ollama and HuggingFace only support json_object mode, not strict json_schema
|
|
231
|
+
if json_schema and self.provider not in ["ollama", "huggingface", "huggingface-together"]:
|
|
232
|
+
payload["response_format"] = {
|
|
233
|
+
"type": "json_schema",
|
|
234
|
+
"json_schema": {
|
|
235
|
+
"name": "classification_result",
|
|
236
|
+
"strict": True,
|
|
237
|
+
"schema": json_schema,
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
elif json_schema:
|
|
241
|
+
# Ollama/HuggingFace - use json_object mode
|
|
242
|
+
payload["response_format"] = {"type": "json_object"}
|
|
243
|
+
elif force_json:
|
|
244
|
+
# No schema but force JSON output
|
|
245
|
+
payload["response_format"] = {"type": "json_object"}
|
|
246
|
+
# else: no response_format - allow text responses
|
|
247
|
+
|
|
248
|
+
# OpenAI reasoning models (o-series, GPT-5) only accept temperature=1.
|
|
249
|
+
# Use reasoning_effort to control reasoning depth instead.
|
|
250
|
+
_is_reasoning_model = self.provider == "openai" and any(
|
|
251
|
+
self.model.startswith(p) for p in ("o1", "o3", "o4", "gpt-5")
|
|
252
|
+
)
|
|
253
|
+
if _is_reasoning_model:
|
|
254
|
+
# Never set temperature for reasoning models (only default=1 is valid)
|
|
255
|
+
if thinking_budget is not None:
|
|
256
|
+
if thinking_budget > 0:
|
|
257
|
+
payload["reasoning_effort"] = "high"
|
|
258
|
+
else:
|
|
259
|
+
payload["reasoning_effort"] = "minimal"
|
|
260
|
+
elif creativity is not None:
|
|
261
|
+
payload["temperature"] = creativity
|
|
262
|
+
|
|
263
|
+
# HuggingFace: disable thinking for models that reason by default (e.g., Qwen3)
|
|
264
|
+
# when thinking_budget is explicitly set to 0
|
|
265
|
+
if self.provider in ("huggingface", "huggingface-together") and thinking_budget is not None and thinking_budget == 0:
|
|
266
|
+
payload["chat_template_kwargs"] = {"enable_thinking": False}
|
|
267
|
+
|
|
268
|
+
return payload
|
|
269
|
+
|
|
270
|
+
def _build_anthropic_payload(
|
|
271
|
+
self,
|
|
272
|
+
messages: list,
|
|
273
|
+
json_schema: dict = None,
|
|
274
|
+
creativity: float = None,
|
|
275
|
+
max_tokens: int = 4096,
|
|
276
|
+
thinking_budget: int = None,
|
|
277
|
+
) -> dict:
|
|
278
|
+
"""Build payload for Anthropic API.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
thinking_budget: Controls extended thinking for Anthropic models.
|
|
282
|
+
0 or None → thinking disabled.
|
|
283
|
+
>0 → thinking enabled with budget_tokens = max(thinking_budget, 1024).
|
|
284
|
+
"""
|
|
285
|
+
# Extract system message if present
|
|
286
|
+
system_content = None
|
|
287
|
+
user_messages = []
|
|
288
|
+
for msg in messages:
|
|
289
|
+
if msg["role"] == "system":
|
|
290
|
+
system_content = msg["content"]
|
|
291
|
+
else:
|
|
292
|
+
user_messages.append(msg)
|
|
293
|
+
|
|
294
|
+
payload = {
|
|
295
|
+
"model": self.model,
|
|
296
|
+
"max_tokens": max_tokens,
|
|
297
|
+
"messages": user_messages,
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
if system_content:
|
|
301
|
+
payload["system"] = system_content
|
|
302
|
+
|
|
303
|
+
# Extended thinking for Anthropic (minimum 1024 tokens)
|
|
304
|
+
# When thinking is enabled, temperature must be 1 (Anthropic requirement),
|
|
305
|
+
# so we skip setting temperature from creativity in that case
|
|
306
|
+
if thinking_budget and thinking_budget > 0:
|
|
307
|
+
budget = max(thinking_budget, 1024)
|
|
308
|
+
payload["thinking"] = {
|
|
309
|
+
"type": "enabled",
|
|
310
|
+
"budget_tokens": budget,
|
|
311
|
+
}
|
|
312
|
+
payload["temperature"] = 1
|
|
313
|
+
# When thinking is enabled, max_tokens must be larger than budget_tokens
|
|
314
|
+
if payload["max_tokens"] <= budget:
|
|
315
|
+
payload["max_tokens"] = budget + 4096
|
|
316
|
+
elif creativity is not None:
|
|
317
|
+
payload["temperature"] = creativity
|
|
318
|
+
|
|
319
|
+
# Use tool calling for structured output (most reliable for Anthropic)
|
|
320
|
+
# When thinking is enabled, forced tool_choice is not allowed — use "auto"
|
|
321
|
+
if json_schema:
|
|
322
|
+
payload["tools"] = [{
|
|
323
|
+
"name": "return_categories",
|
|
324
|
+
"description": "Return categorization results",
|
|
325
|
+
"input_schema": json_schema,
|
|
326
|
+
}]
|
|
327
|
+
if thinking_budget and thinking_budget > 0:
|
|
328
|
+
payload["tool_choice"] = {"type": "auto"}
|
|
329
|
+
else:
|
|
330
|
+
payload["tool_choice"] = {"type": "tool", "name": "return_categories"}
|
|
331
|
+
|
|
332
|
+
return payload
|
|
333
|
+
|
|
334
|
+
def _build_google_payload(
|
|
335
|
+
self,
|
|
336
|
+
messages: list,
|
|
337
|
+
json_schema: dict = None,
|
|
338
|
+
creativity: float = None,
|
|
339
|
+
thinking_budget: int = None,
|
|
340
|
+
force_json: bool = True,
|
|
341
|
+
) -> dict:
|
|
342
|
+
"""Build payload for Google Gemini API."""
|
|
343
|
+
# Convert messages to Google format
|
|
344
|
+
# Combine system + user messages into a single prompt
|
|
345
|
+
combined_text = ""
|
|
346
|
+
for msg in messages:
|
|
347
|
+
if msg["role"] == "system":
|
|
348
|
+
combined_text += msg["content"] + "\n\n"
|
|
349
|
+
elif msg["role"] == "user":
|
|
350
|
+
combined_text += msg["content"]
|
|
351
|
+
elif msg["role"] == "assistant":
|
|
352
|
+
combined_text += "\n\nAssistant: " + msg["content"] + "\n\n"
|
|
353
|
+
|
|
354
|
+
payload = {
|
|
355
|
+
"contents": [{"parts": [{"text": combined_text}]}],
|
|
356
|
+
"generationConfig": {}
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
if json_schema:
|
|
360
|
+
payload["generationConfig"]["responseMimeType"] = "application/json"
|
|
361
|
+
payload["generationConfig"]["responseSchema"] = json_schema
|
|
362
|
+
elif force_json:
|
|
363
|
+
payload["generationConfig"]["responseMimeType"] = "application/json"
|
|
364
|
+
# else: no mime type - allow text responses
|
|
365
|
+
|
|
366
|
+
if creativity is not None:
|
|
367
|
+
payload["generationConfig"]["temperature"] = creativity
|
|
368
|
+
|
|
369
|
+
# Add thinking budget for extended thinking (Google-specific)
|
|
370
|
+
# Must be inside generationConfig, not at top level
|
|
371
|
+
# Google requires a reasonable minimum budget (enforce 128 tokens minimum)
|
|
372
|
+
if thinking_budget and thinking_budget > 0:
|
|
373
|
+
budget = max(thinking_budget, 128)
|
|
374
|
+
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": budget}
|
|
375
|
+
|
|
376
|
+
return payload
|
|
377
|
+
|
|
378
|
+
def _parse_response(self, response_json: dict) -> str:
|
|
379
|
+
"""Parse the response based on provider format."""
|
|
380
|
+
if self.provider == "anthropic":
|
|
381
|
+
return self._parse_anthropic_response(response_json)
|
|
382
|
+
elif self.provider == "google":
|
|
383
|
+
return self._parse_google_response(response_json)
|
|
384
|
+
else:
|
|
385
|
+
# OpenAI-compatible
|
|
386
|
+
return self._parse_openai_response(response_json)
|
|
387
|
+
|
|
388
|
+
def _parse_openai_response(self, response_json: dict) -> str:
|
|
389
|
+
"""Parse OpenAI-compatible response."""
|
|
390
|
+
return response_json["choices"][0]["message"]["content"]
|
|
391
|
+
|
|
392
|
+
def _parse_anthropic_response(self, response_json: dict) -> str:
|
|
393
|
+
"""Parse Anthropic response (handles both text and tool use)."""
|
|
394
|
+
content = response_json.get("content", [])
|
|
395
|
+
for block in content:
|
|
396
|
+
if block.get("type") == "tool_use":
|
|
397
|
+
# Return the tool input as JSON string
|
|
398
|
+
return json.dumps(block.get("input", {}))
|
|
399
|
+
elif block.get("type") == "text":
|
|
400
|
+
return block.get("text", "")
|
|
401
|
+
return ""
|
|
402
|
+
|
|
403
|
+
def _parse_google_response(self, response_json: dict) -> str:
|
|
404
|
+
"""Parse Google Gemini response."""
|
|
405
|
+
candidates = response_json.get("candidates", [])
|
|
406
|
+
if candidates:
|
|
407
|
+
parts = candidates[0].get("content", {}).get("parts", [])
|
|
408
|
+
if parts:
|
|
409
|
+
return parts[0].get("text", "")
|
|
410
|
+
return ""
|
|
411
|
+
|
|
412
|
+
def _call_claude_cli(
|
|
413
|
+
self,
|
|
414
|
+
messages: list,
|
|
415
|
+
max_retries: int = 3,
|
|
416
|
+
initial_delay: float = 2.0,
|
|
417
|
+
) -> tuple[str, str | None]:
|
|
418
|
+
"""
|
|
419
|
+
Call the Claude CLI (claude -p) as a subprocess.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
messages: List of message dicts with 'role' and 'content'
|
|
423
|
+
max_retries: Maximum retry attempts
|
|
424
|
+
initial_delay: Initial delay for exponential backoff
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
tuple: (response_text, error_message)
|
|
428
|
+
"""
|
|
429
|
+
import subprocess
|
|
430
|
+
|
|
431
|
+
# Extract system and user messages
|
|
432
|
+
system_parts = []
|
|
433
|
+
user_parts = []
|
|
434
|
+
for msg in messages:
|
|
435
|
+
if msg["role"] == "system":
|
|
436
|
+
system_parts.append(msg["content"])
|
|
437
|
+
elif msg["role"] in ("user", "assistant"):
|
|
438
|
+
user_parts.append(msg["content"])
|
|
439
|
+
|
|
440
|
+
system_prompt = "\n\n".join(system_parts) if system_parts else None
|
|
441
|
+
user_prompt = "\n\n".join(user_parts)
|
|
442
|
+
|
|
443
|
+
# Build command
|
|
444
|
+
cmd = ["claude", "-p", "--output-format", "text", "--model", self.model]
|
|
445
|
+
if system_prompt:
|
|
446
|
+
cmd.extend(["--system-prompt", system_prompt])
|
|
447
|
+
cmd.append(user_prompt)
|
|
448
|
+
|
|
449
|
+
for attempt in range(max_retries):
|
|
450
|
+
try:
|
|
451
|
+
result = subprocess.run(
|
|
452
|
+
cmd,
|
|
453
|
+
capture_output=True,
|
|
454
|
+
text=True,
|
|
455
|
+
timeout=120,
|
|
456
|
+
)
|
|
457
|
+
if result.returncode == 0:
|
|
458
|
+
return result.stdout.strip(), None
|
|
459
|
+
else:
|
|
460
|
+
error_msg = result.stderr.strip() or f"CLI exited with code {result.returncode}"
|
|
461
|
+
if attempt < max_retries - 1:
|
|
462
|
+
wait_time = initial_delay * (2 ** attempt)
|
|
463
|
+
print(f"Claude CLI error: {error_msg}. Retrying in {wait_time}s...")
|
|
464
|
+
time.sleep(wait_time)
|
|
465
|
+
else:
|
|
466
|
+
return None, f"Claude CLI failed after {max_retries} attempts: {error_msg}"
|
|
467
|
+
except subprocess.TimeoutExpired:
|
|
468
|
+
if attempt < max_retries - 1:
|
|
469
|
+
wait_time = initial_delay * (2 ** attempt)
|
|
470
|
+
print(f"Claude CLI timeout. Retrying in {wait_time}s...")
|
|
471
|
+
time.sleep(wait_time)
|
|
472
|
+
else:
|
|
473
|
+
return None, "Claude CLI timeout after retries"
|
|
474
|
+
except FileNotFoundError:
|
|
475
|
+
return None, (
|
|
476
|
+
"Claude CLI not found. Install it: "
|
|
477
|
+
"https://docs.anthropic.com/en/docs/claude-code"
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
return None, "Max retries exceeded"
|
|
481
|
+
|
|
482
|
+
def complete(
|
|
483
|
+
self,
|
|
484
|
+
messages: list,
|
|
485
|
+
json_schema: dict = None,
|
|
486
|
+
creativity: float = None,
|
|
487
|
+
thinking_budget: int = None,
|
|
488
|
+
force_json: bool = True,
|
|
489
|
+
max_retries: int = 5,
|
|
490
|
+
initial_delay: float = 2.0,
|
|
491
|
+
) -> tuple[str, str | None]:
|
|
492
|
+
"""
|
|
493
|
+
Make a completion request to the LLM provider.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
messages: List of message dicts with 'role' and 'content'
|
|
497
|
+
json_schema: Optional JSON schema for structured output
|
|
498
|
+
creativity: Temperature setting (None for default)
|
|
499
|
+
thinking_budget: Controls reasoning behavior per provider:
|
|
500
|
+
- Google: Token budget for extended thinking (0 to disable, >0 to enable)
|
|
501
|
+
- OpenAI: Maps to reasoning_effort (0 → "minimal", >0 → "high")
|
|
502
|
+
- Anthropic: Enables extended thinking (0 to disable, >0 to enable with min 1024)
|
|
503
|
+
force_json: If True and no json_schema, still request JSON output.
|
|
504
|
+
Set to False for text-only responses (e.g., CoVe intermediate steps)
|
|
505
|
+
max_retries: Maximum retry attempts
|
|
506
|
+
initial_delay: Initial delay for exponential backoff
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
tuple: (response_text, error_message)
|
|
510
|
+
error_message is None on success
|
|
511
|
+
"""
|
|
512
|
+
if self.provider == "claude-code":
|
|
513
|
+
return self._call_claude_cli(messages, max_retries=max_retries, initial_delay=initial_delay)
|
|
514
|
+
|
|
515
|
+
endpoint = self._get_endpoint()
|
|
516
|
+
headers = self._get_headers()
|
|
517
|
+
payload = self._build_payload(messages, json_schema, creativity, thinking_budget=thinking_budget, force_json=force_json)
|
|
518
|
+
|
|
519
|
+
for attempt in range(max_retries):
|
|
520
|
+
try:
|
|
521
|
+
response = requests.post(
|
|
522
|
+
endpoint,
|
|
523
|
+
headers=headers,
|
|
524
|
+
json=payload,
|
|
525
|
+
timeout=120,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# Check for HTTP errors
|
|
529
|
+
if response.status_code == 404:
|
|
530
|
+
return None, f"Model '{self.model}' not found for {self.provider}"
|
|
531
|
+
elif response.status_code in [401, 403]:
|
|
532
|
+
return None, f"Authentication failed for {self.provider}"
|
|
533
|
+
elif response.status_code == 429:
|
|
534
|
+
# Rate limited - retry with backoff
|
|
535
|
+
if attempt < max_retries - 1:
|
|
536
|
+
wait_time = initial_delay * (2 ** attempt) * 5 # Longer wait for rate limits
|
|
537
|
+
print(f"Rate limited. Waiting {wait_time}s...")
|
|
538
|
+
time.sleep(wait_time)
|
|
539
|
+
continue
|
|
540
|
+
else:
|
|
541
|
+
return None, "Rate limit exceeded after retries"
|
|
542
|
+
elif response.status_code >= 500:
|
|
543
|
+
# Server error - retry
|
|
544
|
+
if attempt < max_retries - 1:
|
|
545
|
+
wait_time = initial_delay * (2 ** attempt)
|
|
546
|
+
print(f"Server error {response.status_code}. Retrying in {wait_time}s...")
|
|
547
|
+
time.sleep(wait_time)
|
|
548
|
+
continue
|
|
549
|
+
else:
|
|
550
|
+
return None, f"Server error {response.status_code} after retries"
|
|
551
|
+
|
|
552
|
+
response.raise_for_status()
|
|
553
|
+
response_json = response.json()
|
|
554
|
+
result = self._parse_response(response_json)
|
|
555
|
+
return result, None
|
|
556
|
+
|
|
557
|
+
except requests.exceptions.Timeout:
|
|
558
|
+
if attempt < max_retries - 1:
|
|
559
|
+
wait_time = initial_delay * (2 ** attempt)
|
|
560
|
+
print(f"Request timeout. Retrying in {wait_time}s...")
|
|
561
|
+
time.sleep(wait_time)
|
|
562
|
+
else:
|
|
563
|
+
return None, "Request timeout after retries"
|
|
564
|
+
|
|
565
|
+
except requests.exceptions.RequestException as e:
|
|
566
|
+
if attempt < max_retries - 1:
|
|
567
|
+
wait_time = initial_delay * (2 ** attempt)
|
|
568
|
+
print(f"Request error: {e}. Retrying in {wait_time}s...")
|
|
569
|
+
time.sleep(wait_time)
|
|
570
|
+
else:
|
|
571
|
+
return None, f"Request failed: {e}"
|
|
572
|
+
|
|
573
|
+
except json.JSONDecodeError as e:
|
|
574
|
+
return None, f"Failed to parse response JSON: {e}"
|
|
575
|
+
|
|
576
|
+
return None, "Max retries exceeded"
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
# =============================================================================
|
|
580
|
+
# Provider Detection
|
|
581
|
+
# =============================================================================
|
|
582
|
+
|
|
583
|
+
def _detect_model_source(user_model, model_source):
|
|
584
|
+
"""Auto-detect model source from model name if not explicitly provided."""
|
|
585
|
+
model_source = model_source.lower()
|
|
586
|
+
|
|
587
|
+
# Explicit provider pass-through (no auto-detection needed)
|
|
588
|
+
if model_source == "claude-code":
|
|
589
|
+
return "claude-code"
|
|
590
|
+
|
|
591
|
+
if model_source is None or model_source == "auto":
|
|
592
|
+
user_model_lower = user_model.lower()
|
|
593
|
+
|
|
594
|
+
if "gpt" in user_model_lower:
|
|
595
|
+
return "openai"
|
|
596
|
+
elif "claude" in user_model_lower:
|
|
597
|
+
return "anthropic"
|
|
598
|
+
elif "gemini" in user_model_lower or "gemma" in user_model_lower:
|
|
599
|
+
return "google"
|
|
600
|
+
elif "llama" in user_model_lower or "meta" in user_model_lower:
|
|
601
|
+
return "huggingface"
|
|
602
|
+
elif "mistral" in user_model_lower or "mixtral" in user_model_lower:
|
|
603
|
+
return "mistral"
|
|
604
|
+
elif "sonar" in user_model_lower or "pplx" in user_model_lower:
|
|
605
|
+
return "perplexity"
|
|
606
|
+
elif "deepseek" in user_model_lower or "qwen" in user_model_lower:
|
|
607
|
+
return "huggingface"
|
|
608
|
+
elif "grok" in user_model_lower:
|
|
609
|
+
return "xai"
|
|
610
|
+
else:
|
|
611
|
+
raise ValueError(
|
|
612
|
+
f"Could not auto-detect model source from '{user_model}'. "
|
|
613
|
+
"Please specify model_source explicitly: OpenAI, Anthropic, Perplexity, Google, xAI, Huggingface, or Mistral"
|
|
614
|
+
)
|
|
615
|
+
return model_source
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def detect_provider(model_name: str, provider: str = "auto") -> str:
|
|
619
|
+
"""Auto-detect provider from model name if not explicitly provided."""
|
|
620
|
+
if provider and provider.lower() != "auto":
|
|
621
|
+
return provider.lower()
|
|
622
|
+
|
|
623
|
+
model_lower = model_name.lower()
|
|
624
|
+
|
|
625
|
+
if "gpt" in model_lower or "o1" in model_lower or "o3" in model_lower:
|
|
626
|
+
return "openai"
|
|
627
|
+
elif "claude" in model_lower:
|
|
628
|
+
return "anthropic"
|
|
629
|
+
elif "gemini" in model_lower or "gemma" in model_lower:
|
|
630
|
+
return "google"
|
|
631
|
+
elif "mistral" in model_lower or "mixtral" in model_lower:
|
|
632
|
+
return "mistral"
|
|
633
|
+
elif "sonar" in model_lower or "pplx" in model_lower:
|
|
634
|
+
return "perplexity"
|
|
635
|
+
elif "grok" in model_lower:
|
|
636
|
+
return "xai"
|
|
637
|
+
elif "llama" in model_lower or "meta" in model_lower or "deepseek" in model_lower or "qwen" in model_lower:
|
|
638
|
+
return "huggingface"
|
|
639
|
+
else:
|
|
640
|
+
raise ValueError(
|
|
641
|
+
f"Could not auto-detect provider from '{model_name}'. "
|
|
642
|
+
"Please specify provider explicitly: openai, anthropic, google, mistral, "
|
|
643
|
+
"perplexity, xai, huggingface, or ollama."
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
# =============================================================================
|
|
648
|
+
# Ollama Functions
|
|
649
|
+
# =============================================================================
|
|
650
|
+
|
|
651
|
+
def set_ollama_endpoint(host: str = "localhost", port: int = 11434):
|
|
652
|
+
"""
|
|
653
|
+
Configure a custom Ollama endpoint.
|
|
654
|
+
|
|
655
|
+
Useful if Ollama is running on a different host or port.
|
|
656
|
+
|
|
657
|
+
Args:
|
|
658
|
+
host: Hostname where Ollama is running (default: localhost)
|
|
659
|
+
port: Port number (default: 11434)
|
|
660
|
+
|
|
661
|
+
Example:
|
|
662
|
+
set_ollama_endpoint("192.168.1.100", 11434)
|
|
663
|
+
"""
|
|
664
|
+
PROVIDER_CONFIG["ollama"]["endpoint"] = f"http://{host}:{port}/v1/chat/completions"
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def check_ollama_running(host: str = "localhost", port: int = 11434) -> bool:
|
|
668
|
+
"""
|
|
669
|
+
Check if Ollama is running and accessible.
|
|
670
|
+
|
|
671
|
+
Args:
|
|
672
|
+
host: Hostname where Ollama should be running
|
|
673
|
+
port: Port number
|
|
674
|
+
|
|
675
|
+
Returns:
|
|
676
|
+
True if Ollama is running, False otherwise
|
|
677
|
+
"""
|
|
678
|
+
try:
|
|
679
|
+
response = requests.get(f"http://{host}:{port}/api/tags", timeout=5)
|
|
680
|
+
return response.status_code == 200
|
|
681
|
+
except requests.exceptions.RequestException:
|
|
682
|
+
return False
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def list_ollama_models(host: str = "localhost", port: int = 11434) -> list:
|
|
686
|
+
"""
|
|
687
|
+
List all models available in the local Ollama installation.
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
host: Hostname where Ollama is running
|
|
691
|
+
port: Port number
|
|
692
|
+
|
|
693
|
+
Returns:
|
|
694
|
+
List of model names, or empty list if Ollama is not running
|
|
695
|
+
"""
|
|
696
|
+
try:
|
|
697
|
+
response = requests.get(f"http://{host}:{port}/api/tags", timeout=5)
|
|
698
|
+
if response.status_code == 200:
|
|
699
|
+
data = response.json()
|
|
700
|
+
return [model["name"] for model in data.get("models", [])]
|
|
701
|
+
return []
|
|
702
|
+
except requests.exceptions.RequestException:
|
|
703
|
+
return []
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def check_ollama_model(model: str, host: str = "localhost", port: int = 11434) -> bool:
|
|
707
|
+
"""
|
|
708
|
+
Check if a specific model is available in Ollama.
|
|
709
|
+
|
|
710
|
+
Args:
|
|
711
|
+
model: Model name to check (e.g., "llama3.2", "mistral")
|
|
712
|
+
host: Hostname where Ollama is running
|
|
713
|
+
port: Port number
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
True if model is available, False otherwise
|
|
717
|
+
"""
|
|
718
|
+
available_models = list_ollama_models(host, port)
|
|
719
|
+
# Check for exact match or partial match (e.g., "llama3.2" matches "llama3.2:latest")
|
|
720
|
+
model_lower = model.lower()
|
|
721
|
+
return any(
|
|
722
|
+
model_lower == m.lower() or
|
|
723
|
+
m.lower().startswith(f"{model_lower}:") or
|
|
724
|
+
model_lower.startswith(m.lower().split(":")[0])
|
|
725
|
+
for m in available_models
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
def _format_bytes(size_bytes: int) -> str:
|
|
730
|
+
"""Format bytes into human-readable string."""
|
|
731
|
+
if size_bytes < 1024:
|
|
732
|
+
return f"{size_bytes} B"
|
|
733
|
+
elif size_bytes < 1024 ** 2:
|
|
734
|
+
return f"{size_bytes / 1024:.1f} KB"
|
|
735
|
+
elif size_bytes < 1024 ** 3:
|
|
736
|
+
return f"{size_bytes / (1024 ** 2):.1f} MB"
|
|
737
|
+
else:
|
|
738
|
+
return f"{size_bytes / (1024 ** 3):.2f} GB"
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
def _parse_size_string(size_str: str) -> int:
|
|
742
|
+
"""Parse a size string like '2.0 GB' into bytes."""
|
|
743
|
+
if size_str == "unknown":
|
|
744
|
+
return 0
|
|
745
|
+
|
|
746
|
+
size_str = size_str.strip().upper()
|
|
747
|
+
try:
|
|
748
|
+
if "GB" in size_str:
|
|
749
|
+
return int(float(size_str.replace("GB", "").strip()) * 1024 ** 3)
|
|
750
|
+
elif "MB" in size_str:
|
|
751
|
+
return int(float(size_str.replace("MB", "").strip()) * 1024 ** 2)
|
|
752
|
+
elif "KB" in size_str:
|
|
753
|
+
return int(float(size_str.replace("KB", "").strip()) * 1024)
|
|
754
|
+
else:
|
|
755
|
+
return int(float(size_str.replace("B", "").strip()))
|
|
756
|
+
except ValueError:
|
|
757
|
+
return 0
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
# Common model sizes (approximate) for user reference
|
|
761
|
+
OLLAMA_MODEL_SIZES = {
|
|
762
|
+
"llama3.2": "2.0 GB",
|
|
763
|
+
"llama3.2:1b": "1.3 GB",
|
|
764
|
+
"llama3.2:3b": "2.0 GB",
|
|
765
|
+
"llama3.1": "4.7 GB",
|
|
766
|
+
"llama3.1:8b": "4.7 GB",
|
|
767
|
+
"llama3.1:70b": "40 GB",
|
|
768
|
+
"llama3": "4.7 GB",
|
|
769
|
+
"llama2": "3.8 GB",
|
|
770
|
+
"mistral": "4.1 GB",
|
|
771
|
+
"mixtral": "26 GB",
|
|
772
|
+
"phi3": "2.2 GB",
|
|
773
|
+
"phi3:mini": "2.2 GB",
|
|
774
|
+
"gemma": "5.0 GB",
|
|
775
|
+
"gemma:2b": "1.7 GB",
|
|
776
|
+
"gemma:7b": "5.0 GB",
|
|
777
|
+
"gemma2": "5.4 GB",
|
|
778
|
+
"gemma2:2b": "1.6 GB",
|
|
779
|
+
"gemma2:9b": "5.4 GB",
|
|
780
|
+
"gemma2:27b": "16 GB",
|
|
781
|
+
"qwen2.5": "4.7 GB",
|
|
782
|
+
"qwen2.5:0.5b": "397 MB",
|
|
783
|
+
"qwen2.5:1.5b": "986 MB",
|
|
784
|
+
"qwen2.5:3b": "1.9 GB",
|
|
785
|
+
"qwen2.5:7b": "4.7 GB",
|
|
786
|
+
"deepseek-r1": "4.7 GB",
|
|
787
|
+
"codellama": "3.8 GB",
|
|
788
|
+
"codegemma": "5.0 GB",
|
|
789
|
+
"nomic-embed-text": "274 MB",
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
def get_ollama_model_size_estimate(model: str) -> str:
|
|
794
|
+
"""
|
|
795
|
+
Get estimated download size for an Ollama model.
|
|
796
|
+
|
|
797
|
+
Args:
|
|
798
|
+
model: Model name
|
|
799
|
+
|
|
800
|
+
Returns:
|
|
801
|
+
Human-readable size estimate or "unknown"
|
|
802
|
+
"""
|
|
803
|
+
model_lower = model.lower()
|
|
804
|
+
|
|
805
|
+
# Check exact match first
|
|
806
|
+
if model_lower in OLLAMA_MODEL_SIZES:
|
|
807
|
+
return OLLAMA_MODEL_SIZES[model_lower]
|
|
808
|
+
|
|
809
|
+
# Check base model name (without tag)
|
|
810
|
+
base_model = model_lower.split(":")[0]
|
|
811
|
+
if base_model in OLLAMA_MODEL_SIZES:
|
|
812
|
+
return OLLAMA_MODEL_SIZES[base_model]
|
|
813
|
+
|
|
814
|
+
return "unknown"
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
def check_system_resources(model: str) -> dict:
|
|
818
|
+
"""
|
|
819
|
+
Check if system has enough resources to download and run a model.
|
|
820
|
+
|
|
821
|
+
Args:
|
|
822
|
+
model: Model name to check
|
|
823
|
+
|
|
824
|
+
Returns:
|
|
825
|
+
dict with 'can_download', 'can_run', 'warnings', and 'details'
|
|
826
|
+
"""
|
|
827
|
+
import shutil
|
|
828
|
+
import os
|
|
829
|
+
|
|
830
|
+
result = {
|
|
831
|
+
"can_download": True,
|
|
832
|
+
"can_run": True,
|
|
833
|
+
"warnings": [],
|
|
834
|
+
"details": {}
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
size_estimate = get_ollama_model_size_estimate(model)
|
|
838
|
+
model_size_bytes = _parse_size_string(size_estimate)
|
|
839
|
+
|
|
840
|
+
# Check disk space (Ollama typically stores models in ~/.ollama)
|
|
841
|
+
ollama_dir = os.path.expanduser("~/.ollama")
|
|
842
|
+
if not os.path.exists(ollama_dir):
|
|
843
|
+
ollama_dir = os.path.expanduser("~")
|
|
844
|
+
|
|
845
|
+
try:
|
|
846
|
+
disk_usage = shutil.disk_usage(ollama_dir)
|
|
847
|
+
free_space = disk_usage.free
|
|
848
|
+
result["details"]["free_disk_space"] = _format_bytes(free_space)
|
|
849
|
+
result["details"]["model_size"] = size_estimate
|
|
850
|
+
|
|
851
|
+
# Need at least 1.5x model size for download + extraction
|
|
852
|
+
required_space = int(model_size_bytes * 1.5) if model_size_bytes > 0 else 0
|
|
853
|
+
|
|
854
|
+
if required_space > 0 and free_space < required_space:
|
|
855
|
+
result["can_download"] = False
|
|
856
|
+
result["warnings"].append(
|
|
857
|
+
f"Insufficient disk space. Need ~{_format_bytes(required_space)}, "
|
|
858
|
+
f"but only {_format_bytes(free_space)} available."
|
|
859
|
+
)
|
|
860
|
+
elif required_space > 0 and free_space < required_space * 2:
|
|
861
|
+
result["warnings"].append(
|
|
862
|
+
f"Low disk space warning: {_format_bytes(free_space)} available."
|
|
863
|
+
)
|
|
864
|
+
except Exception:
|
|
865
|
+
result["details"]["free_disk_space"] = "unknown"
|
|
866
|
+
|
|
867
|
+
# Estimate RAM requirements (rough guide: model size * 1.2 for inference)
|
|
868
|
+
# This is approximate - actual requirements vary by quantization
|
|
869
|
+
if model_size_bytes > 0:
|
|
870
|
+
estimated_ram = model_size_bytes * 1.2
|
|
871
|
+
result["details"]["estimated_ram"] = _format_bytes(int(estimated_ram))
|
|
872
|
+
|
|
873
|
+
# Try to get system RAM (works on most systems)
|
|
874
|
+
try:
|
|
875
|
+
import subprocess
|
|
876
|
+
if os.name == 'posix': # Linux/macOS
|
|
877
|
+
if os.path.exists('/proc/meminfo'): # Linux
|
|
878
|
+
with open('/proc/meminfo', 'r') as f:
|
|
879
|
+
for line in f:
|
|
880
|
+
if line.startswith('MemTotal:'):
|
|
881
|
+
total_ram = int(line.split()[1]) * 1024 # Convert KB to bytes
|
|
882
|
+
break
|
|
883
|
+
else: # macOS
|
|
884
|
+
output = subprocess.check_output(['sysctl', '-n', 'hw.memsize'], text=True)
|
|
885
|
+
total_ram = int(output.strip())
|
|
886
|
+
|
|
887
|
+
result["details"]["total_ram"] = _format_bytes(total_ram)
|
|
888
|
+
|
|
889
|
+
if estimated_ram > total_ram * 0.8:
|
|
890
|
+
result["can_run"] = False
|
|
891
|
+
result["warnings"].append(
|
|
892
|
+
f"Model may be too large for your system. "
|
|
893
|
+
f"Requires ~{_format_bytes(int(estimated_ram))} RAM, "
|
|
894
|
+
f"but system has {_format_bytes(total_ram)}."
|
|
895
|
+
)
|
|
896
|
+
elif estimated_ram > total_ram * 0.5:
|
|
897
|
+
result["warnings"].append(
|
|
898
|
+
f"Model will use significant RAM (~{_format_bytes(int(estimated_ram))})."
|
|
899
|
+
)
|
|
900
|
+
except Exception:
|
|
901
|
+
result["details"]["total_ram"] = "unknown"
|
|
902
|
+
# If we can't check RAM, warn for large models
|
|
903
|
+
if model_size_bytes > 8 * 1024 ** 3: # > 8GB models
|
|
904
|
+
result["warnings"].append(
|
|
905
|
+
f"Large model (~{size_estimate}). Ensure you have sufficient RAM."
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
return result
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
def pull_ollama_model(model: str, host: str = "localhost", port: int = 11434, auto_confirm: bool = False) -> bool:
|
|
912
|
+
"""
|
|
913
|
+
Pull/download a model in Ollama.
|
|
914
|
+
|
|
915
|
+
Args:
|
|
916
|
+
model: Model name to pull (e.g., "llama3.2", "mistral")
|
|
917
|
+
host: Hostname where Ollama is running
|
|
918
|
+
port: Port number
|
|
919
|
+
auto_confirm: If True, skip confirmation prompt
|
|
920
|
+
|
|
921
|
+
Returns:
|
|
922
|
+
True if model was pulled successfully, False otherwise
|
|
923
|
+
"""
|
|
924
|
+
# Get size estimate and check system resources
|
|
925
|
+
size_estimate = get_ollama_model_size_estimate(model)
|
|
926
|
+
resources = check_system_resources(model)
|
|
927
|
+
|
|
928
|
+
print(f"\n{'='*60}")
|
|
929
|
+
print(f" Model '{model}' not found locally")
|
|
930
|
+
print(f"{'='*60}")
|
|
931
|
+
print(f" Model size: {size_estimate}")
|
|
932
|
+
if resources["details"].get("estimated_ram"):
|
|
933
|
+
print(f" RAM required: ~{resources['details']['estimated_ram']}")
|
|
934
|
+
if resources["details"].get("free_disk_space"):
|
|
935
|
+
print(f" Free disk space: {resources['details']['free_disk_space']}")
|
|
936
|
+
if resources["details"].get("total_ram"):
|
|
937
|
+
print(f" System RAM: {resources['details']['total_ram']}")
|
|
938
|
+
|
|
939
|
+
# Show warnings
|
|
940
|
+
if resources["warnings"]:
|
|
941
|
+
print(f"\n {'!'*50}")
|
|
942
|
+
for warning in resources["warnings"]:
|
|
943
|
+
print(f" Warning: {warning}")
|
|
944
|
+
print(f" {'!'*50}")
|
|
945
|
+
|
|
946
|
+
# Block if can't download
|
|
947
|
+
if not resources["can_download"]:
|
|
948
|
+
print(f"\n Cannot download: insufficient disk space.")
|
|
949
|
+
print(f" Free up disk space and try again.")
|
|
950
|
+
return False
|
|
951
|
+
|
|
952
|
+
# Warn but allow if can't run (user might want to try anyway)
|
|
953
|
+
if not resources["can_run"]:
|
|
954
|
+
print(f"\n Warning: Model may not run on this system.")
|
|
955
|
+
print(f" Consider a smaller model variant (e.g., '{model}:1b' or '{model}:3b').")
|
|
956
|
+
|
|
957
|
+
print(f"{'='*60}")
|
|
958
|
+
|
|
959
|
+
# Ask for confirmation
|
|
960
|
+
if not auto_confirm:
|
|
961
|
+
try:
|
|
962
|
+
if not resources["can_run"]:
|
|
963
|
+
prompt = f"\n Download anyway? [y/N]: "
|
|
964
|
+
else:
|
|
965
|
+
prompt = f"\n Download '{model}'? [y/N]: "
|
|
966
|
+
response = input(prompt).strip().lower()
|
|
967
|
+
if response not in ['y', 'yes']:
|
|
968
|
+
print(" Download cancelled.")
|
|
969
|
+
return False
|
|
970
|
+
except (EOFError, KeyboardInterrupt):
|
|
971
|
+
print("\n Download cancelled.")
|
|
972
|
+
return False
|
|
973
|
+
|
|
974
|
+
print(f"\n Downloading from Ollama registry...")
|
|
975
|
+
print(f" (Press Ctrl+C to cancel)\n")
|
|
976
|
+
|
|
977
|
+
try:
|
|
978
|
+
# Ollama pull endpoint streams the response
|
|
979
|
+
response = requests.post(
|
|
980
|
+
f"http://{host}:{port}/api/pull",
|
|
981
|
+
json={"name": model},
|
|
982
|
+
stream=True,
|
|
983
|
+
timeout=None # No timeout - large models can take a while
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
if response.status_code != 200:
|
|
987
|
+
print(f"Failed to pull model: HTTP {response.status_code}")
|
|
988
|
+
return False
|
|
989
|
+
|
|
990
|
+
# Process streaming response to show progress
|
|
991
|
+
last_status = ""
|
|
992
|
+
total_size_shown = False
|
|
993
|
+
|
|
994
|
+
for line in response.iter_lines():
|
|
995
|
+
if line:
|
|
996
|
+
try:
|
|
997
|
+
data = json.loads(line)
|
|
998
|
+
status = data.get("status", "")
|
|
999
|
+
|
|
1000
|
+
# Show progress for downloads
|
|
1001
|
+
if "completed" in data and "total" in data:
|
|
1002
|
+
completed = data["completed"]
|
|
1003
|
+
total = data["total"]
|
|
1004
|
+
pct = (completed / total * 100) if total > 0 else 0
|
|
1005
|
+
|
|
1006
|
+
# Show actual total size on first progress update
|
|
1007
|
+
if not total_size_shown and total > 0:
|
|
1008
|
+
print(f" Actual size: {_format_bytes(total)}")
|
|
1009
|
+
total_size_shown = True
|
|
1010
|
+
|
|
1011
|
+
print(f"\r {status}: {pct:.1f}% ({_format_bytes(completed)}/{_format_bytes(total)})", end="", flush=True)
|
|
1012
|
+
elif status != last_status:
|
|
1013
|
+
if last_status and "completed" in str(last_status):
|
|
1014
|
+
print() # newline after progress bar
|
|
1015
|
+
print(f" {status}")
|
|
1016
|
+
last_status = status
|
|
1017
|
+
|
|
1018
|
+
# Check for errors
|
|
1019
|
+
if "error" in data:
|
|
1020
|
+
print(f"\n Error: {data['error']}")
|
|
1021
|
+
return False
|
|
1022
|
+
|
|
1023
|
+
except json.JSONDecodeError:
|
|
1024
|
+
continue
|
|
1025
|
+
|
|
1026
|
+
print(f"\n Model '{model}' downloaded successfully!")
|
|
1027
|
+
return True
|
|
1028
|
+
|
|
1029
|
+
except KeyboardInterrupt:
|
|
1030
|
+
print(f"\n\n Download cancelled by user.")
|
|
1031
|
+
return False
|
|
1032
|
+
except requests.exceptions.Timeout:
|
|
1033
|
+
print(f"\n Timeout while downloading model '{model}'.")
|
|
1034
|
+
print(f" Try again or download manually: ollama pull {model}")
|
|
1035
|
+
return False
|
|
1036
|
+
except requests.exceptions.RequestException as e:
|
|
1037
|
+
print(f"\n Error pulling model: {e}")
|
|
1038
|
+
return False
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
# =============================================================================
|
|
1042
|
+
# Claude Code CLI Functions
|
|
1043
|
+
# =============================================================================
|
|
1044
|
+
|
|
1045
|
+
def check_claude_cli_available():
|
|
1046
|
+
"""Check if the Claude CLI (claude) is installed and available on PATH."""
|
|
1047
|
+
import shutil
|
|
1048
|
+
return shutil.which("claude") is not None
|