kite-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kite/__init__.py +46 -0
- kite/ab_testing.py +384 -0
- kite/agent.py +556 -0
- kite/agents/__init__.py +3 -0
- kite/agents/plan_execute.py +191 -0
- kite/agents/react_agent.py +509 -0
- kite/agents/reflective_agent.py +90 -0
- kite/agents/rewoo.py +119 -0
- kite/agents/tot.py +151 -0
- kite/conversation.py +125 -0
- kite/core.py +974 -0
- kite/data_loaders.py +111 -0
- kite/embedding_providers.py +372 -0
- kite/llm_providers.py +1278 -0
- kite/memory/__init__.py +6 -0
- kite/memory/advanced_rag.py +333 -0
- kite/memory/graph_rag.py +719 -0
- kite/memory/session_memory.py +423 -0
- kite/memory/vector_memory.py +579 -0
- kite/monitoring.py +611 -0
- kite/observers.py +107 -0
- kite/optimization/__init__.py +9 -0
- kite/optimization/resource_router.py +80 -0
- kite/persistence.py +42 -0
- kite/pipeline/__init__.py +5 -0
- kite/pipeline/deterministic_pipeline.py +323 -0
- kite/pipeline/reactive_pipeline.py +171 -0
- kite/pipeline_manager.py +15 -0
- kite/routing/__init__.py +6 -0
- kite/routing/aggregator_router.py +325 -0
- kite/routing/llm_router.py +149 -0
- kite/routing/semantic_router.py +228 -0
- kite/safety/__init__.py +6 -0
- kite/safety/circuit_breaker.py +360 -0
- kite/safety/guardrails.py +82 -0
- kite/safety/idempotency_manager.py +304 -0
- kite/safety/kill_switch.py +75 -0
- kite/tool.py +183 -0
- kite/tool_registry.py +87 -0
- kite/tools/__init__.py +21 -0
- kite/tools/code_execution.py +53 -0
- kite/tools/contrib/__init__.py +19 -0
- kite/tools/contrib/calculator.py +26 -0
- kite/tools/contrib/datetime_utils.py +20 -0
- kite/tools/contrib/linkedin.py +428 -0
- kite/tools/contrib/web_search.py +30 -0
- kite/tools/mcp/__init__.py +31 -0
- kite/tools/mcp/database_mcp.py +267 -0
- kite/tools/mcp/gdrive_mcp_server.py +503 -0
- kite/tools/mcp/gmail_mcp_server.py +601 -0
- kite/tools/mcp/postgres_mcp_server.py +490 -0
- kite/tools/mcp/slack_mcp_server.py +538 -0
- kite/tools/mcp/stripe_mcp_server.py +219 -0
- kite/tools/search.py +90 -0
- kite/tools/system_tools.py +54 -0
- kite/tools_manager.py +27 -0
- kite_agent-0.1.0.dist-info/METADATA +621 -0
- kite_agent-0.1.0.dist-info/RECORD +61 -0
- kite_agent-0.1.0.dist-info/WHEEL +5 -0
- kite_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
- kite_agent-0.1.0.dist-info/top_level.txt +1 -0
kite/llm_providers.py
ADDED
|
@@ -0,0 +1,1278 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM Provider Abstraction Layer
|
|
3
|
+
Supports multiple LLM providers with priority on open source.
|
|
4
|
+
|
|
5
|
+
Supported Providers:
|
|
6
|
+
- Ollama (Local, Free) - PRIORITY
|
|
7
|
+
- LM Studio (Local, Free)
|
|
8
|
+
- vLLM (Local, Free)
|
|
9
|
+
- Anthropic Claude
|
|
10
|
+
- OpenAI GPT
|
|
11
|
+
- Google Gemini
|
|
12
|
+
- Mistral AI
|
|
13
|
+
- Groq (Fast inference)
|
|
14
|
+
- Together AI
|
|
15
|
+
- Replicate
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from typing import Dict, List, Optional, Any
|
|
19
|
+
from abc import ABC, abstractmethod
|
|
20
|
+
import os
|
|
21
|
+
import logging
|
|
22
|
+
import json
|
|
23
|
+
import asyncio
|
|
24
|
+
import time
|
|
25
|
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BaseLLMProvider(ABC):
|
|
29
|
+
"""Base class for all LLM providers."""
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
33
|
+
"""
|
|
34
|
+
Generate completion.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
prompt: Input text.
|
|
38
|
+
response_schema: Optional Dict/Schema to enforce JSON output.
|
|
39
|
+
"""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def chat(self, messages: List[Dict], **kwargs) -> str:
|
|
44
|
+
"""Chat completion."""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def embed(self, text: str) -> List[float]:
|
|
49
|
+
"""Generate embeddings."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
async def stream_complete(self, prompt: str, **kwargs):
|
|
54
|
+
"""Stream completion."""
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
async def stream_chat(self, messages: List[Dict], **kwargs):
|
|
59
|
+
"""Stream chat completion."""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def name(self) -> str:
|
|
65
|
+
"""Provider name."""
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
async def complete_async(self, prompt: str, **kwargs) -> str:
|
|
69
|
+
"""Async version of complete."""
|
|
70
|
+
import asyncio
|
|
71
|
+
return await asyncio.to_thread(self.complete, prompt, **kwargs)
|
|
72
|
+
|
|
73
|
+
async def chat_async(self, messages: List[Dict], **kwargs) -> str:
|
|
74
|
+
"""Async version of chat."""
|
|
75
|
+
import asyncio
|
|
76
|
+
return await asyncio.to_thread(self.chat, messages, **kwargs)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ============================================================================
|
|
80
|
+
# LOCAL / OPENSOURCE PROVIDERS (PRIORITY)
|
|
81
|
+
# ============================================================================
|
|
82
|
+
|
|
83
|
+
class OllamaProvider(BaseLLMProvider):
|
|
84
|
+
"""
|
|
85
|
+
Ollama - Run LLMs locally (FREE, OPENSOURCE)
|
|
86
|
+
|
|
87
|
+
Models: llama3, mistral, codellama, phi, gemma, etc.
|
|
88
|
+
Installation: curl -fsSL https://ollama.com/install.sh | sh
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self,
|
|
92
|
+
model: str = "llama3",
|
|
93
|
+
base_url: str = "http://localhost:11434",
|
|
94
|
+
timeout: float = 600.0,
|
|
95
|
+
**kwargs):
|
|
96
|
+
self.model = model
|
|
97
|
+
self.base_url = base_url
|
|
98
|
+
self.timeout = timeout
|
|
99
|
+
self.logger = logging.getLogger("Ollama")
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
import httpx
|
|
103
|
+
self._test_connection()
|
|
104
|
+
except ImportError:
|
|
105
|
+
raise ImportError("pip install httpx")
|
|
106
|
+
|
|
107
|
+
def _test_connection(self):
|
|
108
|
+
"""Test Ollama connection."""
|
|
109
|
+
import httpx
|
|
110
|
+
try:
|
|
111
|
+
with httpx.Client(timeout=2.0) as client:
|
|
112
|
+
response = client.get(f"{self.base_url}/api/tags")
|
|
113
|
+
if response.status_code == 200:
|
|
114
|
+
self.logger.info(f"[OK] Ollama connected: {self.model}")
|
|
115
|
+
else:
|
|
116
|
+
raise ConnectionError(f"Ollama returned {response.status_code}")
|
|
117
|
+
except Exception as e:
|
|
118
|
+
raise ConnectionError(f"Ollama not running at {self.base_url}: {e}")
|
|
119
|
+
|
|
120
|
+
def _sanitize_ollama_params(self, kwargs: Dict) -> Dict:
|
|
121
|
+
"""Helper to ensure only valid Ollama parameters are sent."""
|
|
122
|
+
valid_top_level = {'model', 'prompt', 'messages', 'stream', 'format', 'options', 'keep_alive', 'tools'}
|
|
123
|
+
# Common model options that should go into 'options' dict
|
|
124
|
+
valid_options = {
|
|
125
|
+
'num_keep', 'seed', 'num_predict', 'top_k', 'top_p', 'tfs_z',
|
|
126
|
+
'typical_p', 'repeat_last_n', 'temperature', 'repeat_penalty',
|
|
127
|
+
'presence_penalty', 'frequency_penalty', 'mixtral_mi', 'mixtral_m',
|
|
128
|
+
'mixtral_s', 'num_ctx', 'num_batch', 'num_gqa', 'num_gpu',
|
|
129
|
+
'main_gpu', 'low_vram', 'f16_kv', 'logits_all', 'vocab_only',
|
|
130
|
+
'use_mmap', 'use_mlock', 'num_thread'
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
sanitized = {}
|
|
134
|
+
options = kwargs.get('options', {})
|
|
135
|
+
|
|
136
|
+
for k, v in kwargs.items():
|
|
137
|
+
if k in valid_top_level:
|
|
138
|
+
sanitized[k] = v
|
|
139
|
+
elif k in valid_options:
|
|
140
|
+
options[k] = v
|
|
141
|
+
|
|
142
|
+
if options:
|
|
143
|
+
sanitized['options'] = options
|
|
144
|
+
|
|
145
|
+
# Handle Structured Output (JSON Schema)
|
|
146
|
+
if 'response_schema' in kwargs:
|
|
147
|
+
sanitized['format'] = 'json'
|
|
148
|
+
# We trust the caller to put instructions in the prompt.
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
return sanitized
|
|
152
|
+
|
|
153
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
154
|
+
"""Generate completion (sync)."""
|
|
155
|
+
import httpx
|
|
156
|
+
import threading
|
|
157
|
+
|
|
158
|
+
params = self._sanitize_ollama_params({
|
|
159
|
+
"model": self.model,
|
|
160
|
+
"prompt": prompt,
|
|
161
|
+
"stream": False,
|
|
162
|
+
**kwargs
|
|
163
|
+
})
|
|
164
|
+
|
|
165
|
+
# Simple heartbeat thread
|
|
166
|
+
stop_heartbeat = threading.Event()
|
|
167
|
+
def heartbeat():
|
|
168
|
+
start = time.time()
|
|
169
|
+
while not stop_heartbeat.wait(30):
|
|
170
|
+
self.logger.info(f"Ollama is still thinking... ({int(time.time() - start)}s elapsed)")
|
|
171
|
+
|
|
172
|
+
h_thread = threading.Thread(target=heartbeat, daemon=True)
|
|
173
|
+
h_thread.start()
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
timeout = httpx.Timeout(self.timeout, read=None)
|
|
177
|
+
with httpx.Client(timeout=timeout) as client:
|
|
178
|
+
response = client.post(
|
|
179
|
+
f"{self.base_url}/api/generate",
|
|
180
|
+
json=params
|
|
181
|
+
)
|
|
182
|
+
data = response.json()
|
|
183
|
+
if "response" in data:
|
|
184
|
+
if kwargs.get("metrics"):
|
|
185
|
+
kwargs["metrics"].record_llm_usage(
|
|
186
|
+
provider="ollama",
|
|
187
|
+
model=self.model,
|
|
188
|
+
prompt_tokens=data.get("prompt_eval_count", 0),
|
|
189
|
+
completion_tokens=data.get("eval_count", 0)
|
|
190
|
+
)
|
|
191
|
+
return data["response"]
|
|
192
|
+
elif "message" in data and "content" in data["message"]:
|
|
193
|
+
return data["message"]["content"]
|
|
194
|
+
raise KeyError(f"Unexpected Ollama response format: {data}")
|
|
195
|
+
finally:
|
|
196
|
+
stop_heartbeat.set()
|
|
197
|
+
|
|
198
|
+
async def complete_async(self, prompt: str, **kwargs) -> str:
|
|
199
|
+
"""Native async complete."""
|
|
200
|
+
import httpx
|
|
201
|
+
|
|
202
|
+
params = self._sanitize_ollama_params({
|
|
203
|
+
"model": self.model,
|
|
204
|
+
"prompt": prompt,
|
|
205
|
+
"stream": False,
|
|
206
|
+
**kwargs
|
|
207
|
+
})
|
|
208
|
+
|
|
209
|
+
async def heartbeat(task):
|
|
210
|
+
start = time.time()
|
|
211
|
+
while not task.done():
|
|
212
|
+
await asyncio.sleep(30)
|
|
213
|
+
if not task.done():
|
|
214
|
+
self.logger.info(f"Ollama is still thinking... ({int(time.time() - start)}s elapsed)")
|
|
215
|
+
|
|
216
|
+
timeout = httpx.Timeout(self.timeout, read=None)
|
|
217
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
218
|
+
request_task = asyncio.create_task(client.post(
|
|
219
|
+
f"{self.base_url}/api/generate",
|
|
220
|
+
json=params
|
|
221
|
+
))
|
|
222
|
+
heartbeat_task = asyncio.create_task(heartbeat(request_task))
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
response = await request_task
|
|
226
|
+
response.raise_for_status()
|
|
227
|
+
data = response.json()
|
|
228
|
+
if "response" in data:
|
|
229
|
+
res = data["response"]
|
|
230
|
+
elif "message" in data and "content" in data["message"]:
|
|
231
|
+
res = data["message"]["content"]
|
|
232
|
+
else:
|
|
233
|
+
raise KeyError(f"Unexpected Ollama response format: {data}")
|
|
234
|
+
|
|
235
|
+
if not res and data.get("done") and data.get("done_reason") == "load":
|
|
236
|
+
self.logger.warning(f"Ollama returned 'load' reason. Retrying in 2s...")
|
|
237
|
+
await asyncio.sleep(2)
|
|
238
|
+
return await self.complete_async(prompt, **kwargs)
|
|
239
|
+
|
|
240
|
+
if not res or res.strip() == "":
|
|
241
|
+
raise ValueError(f"Ollama returned empty response for model {self.model}")
|
|
242
|
+
return res
|
|
243
|
+
finally:
|
|
244
|
+
heartbeat_task.cancel()
|
|
245
|
+
|
|
246
|
+
def chat(self, messages: List[Dict], **kwargs) -> str:
|
|
247
|
+
"""Chat completion (sync)."""
|
|
248
|
+
import httpx
|
|
249
|
+
import threading
|
|
250
|
+
|
|
251
|
+
params = self._sanitize_ollama_params({
|
|
252
|
+
"model": self.model,
|
|
253
|
+
"messages": messages,
|
|
254
|
+
"stream": False,
|
|
255
|
+
**kwargs
|
|
256
|
+
})
|
|
257
|
+
|
|
258
|
+
# Inject Schema into System Prompt if provided and format is JSON
|
|
259
|
+
if kwargs.get("response_schema") and params.get("format") == "json":
|
|
260
|
+
schema_str = json.dumps(kwargs["response_schema"], indent=2)
|
|
261
|
+
sys_msg = f"\n\nIMPORTANT: Output data MUST be valid JSON matching this schema:\n{schema_str}"
|
|
262
|
+
if params["messages"] and params["messages"][0]["role"] == "system":
|
|
263
|
+
params["messages"][0]["content"] += sys_msg
|
|
264
|
+
else:
|
|
265
|
+
params["messages"].insert(0, {"role": "system", "content": f"You are a helpful assistant.{sys_msg}"})
|
|
266
|
+
|
|
267
|
+
stop_heartbeat = threading.Event()
|
|
268
|
+
def heartbeat():
|
|
269
|
+
start = time.time()
|
|
270
|
+
while not stop_heartbeat.wait(30):
|
|
271
|
+
self.logger.info(f"Ollama is still thinking... ({int(time.time() - start)}s elapsed)")
|
|
272
|
+
|
|
273
|
+
h_thread = threading.Thread(target=heartbeat, daemon=True)
|
|
274
|
+
h_thread.start()
|
|
275
|
+
|
|
276
|
+
try:
|
|
277
|
+
timeout = httpx.Timeout(self.timeout, read=None)
|
|
278
|
+
with httpx.Client(timeout=timeout) as client:
|
|
279
|
+
response = client.post(
|
|
280
|
+
f"{self.base_url}/api/chat",
|
|
281
|
+
json=params
|
|
282
|
+
)
|
|
283
|
+
if response.status_code != 200:
|
|
284
|
+
self.logger.error(f"Ollama Chat Error {response.status_code}: {response.text}")
|
|
285
|
+
|
|
286
|
+
data = response.json()
|
|
287
|
+
if "message" in data:
|
|
288
|
+
if kwargs.get("metrics"):
|
|
289
|
+
kwargs["metrics"].record_llm_usage(
|
|
290
|
+
provider="ollama",
|
|
291
|
+
model=self.model,
|
|
292
|
+
prompt_tokens=data.get("prompt_eval_count", 0),
|
|
293
|
+
completion_tokens=data.get("eval_count", 0)
|
|
294
|
+
)
|
|
295
|
+
msg = data["message"]
|
|
296
|
+
if msg.get("tool_calls"):
|
|
297
|
+
return {"content": msg.get("content"), "tool_calls": msg["tool_calls"]}
|
|
298
|
+
return msg.get("content", "")
|
|
299
|
+
elif "response" in data:
|
|
300
|
+
return data["response"]
|
|
301
|
+
raise KeyError(f"Unexpected Ollama chat response format: {data}")
|
|
302
|
+
finally:
|
|
303
|
+
stop_heartbeat.set()
|
|
304
|
+
|
|
305
|
+
async def chat_async(self, messages: List[Dict], **kwargs) -> str:
|
|
306
|
+
"""Native async chat."""
|
|
307
|
+
import httpx
|
|
308
|
+
|
|
309
|
+
params = self._sanitize_ollama_params({
|
|
310
|
+
"model": self.model,
|
|
311
|
+
"messages": messages,
|
|
312
|
+
"stream": False,
|
|
313
|
+
**kwargs
|
|
314
|
+
})
|
|
315
|
+
|
|
316
|
+
if kwargs.get("response_schema") and params.get("format") == "json":
|
|
317
|
+
schema_str = json.dumps(kwargs["response_schema"], indent=2)
|
|
318
|
+
sys_msg = f"\n\nIMPORTANT: Output data MUST be valid JSON matching this schema:\n{schema_str}"
|
|
319
|
+
if params["messages"] and params["messages"][0]["role"] == "system":
|
|
320
|
+
params["messages"][0]["content"] += sys_msg
|
|
321
|
+
else:
|
|
322
|
+
params["messages"].insert(0, {"role": "system", "content": f"You are a helpful assistant.{sys_msg}"})
|
|
323
|
+
|
|
324
|
+
async def heartbeat(task):
|
|
325
|
+
start = time.time()
|
|
326
|
+
while not task.done():
|
|
327
|
+
await asyncio.sleep(30)
|
|
328
|
+
if not task.done():
|
|
329
|
+
self.logger.info(f"Ollama is still thinking... ({int(time.time() - start)}s elapsed)")
|
|
330
|
+
|
|
331
|
+
timeout = httpx.Timeout(self.timeout, read=None)
|
|
332
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
333
|
+
request_task = asyncio.create_task(client.post(
|
|
334
|
+
f"{self.base_url}/api/chat",
|
|
335
|
+
json=params
|
|
336
|
+
))
|
|
337
|
+
heartbeat_task = asyncio.create_task(heartbeat(request_task))
|
|
338
|
+
|
|
339
|
+
try:
|
|
340
|
+
response = await request_task
|
|
341
|
+
if response.status_code != 200:
|
|
342
|
+
error_text = response.text
|
|
343
|
+
self.logger.error(f"Ollama Chat Async Error {response.status_code}: {error_text}")
|
|
344
|
+
response.raise_for_status()
|
|
345
|
+
|
|
346
|
+
data = response.json()
|
|
347
|
+
res = None
|
|
348
|
+
if "message" in data:
|
|
349
|
+
if kwargs.get("metrics"):
|
|
350
|
+
kwargs["metrics"].record_llm_usage(
|
|
351
|
+
provider="ollama",
|
|
352
|
+
model=self.model,
|
|
353
|
+
prompt_tokens=data.get("prompt_eval_count", 0),
|
|
354
|
+
completion_tokens=data.get("eval_count", 0)
|
|
355
|
+
)
|
|
356
|
+
msg = data["message"]
|
|
357
|
+
if msg.get("tool_calls"):
|
|
358
|
+
res = {"content": msg.get("content"), "tool_calls": msg["tool_calls"]}
|
|
359
|
+
else:
|
|
360
|
+
res = msg.get("content")
|
|
361
|
+
elif "response" in data:
|
|
362
|
+
res = data["response"]
|
|
363
|
+
|
|
364
|
+
if not res and data.get("done") and data.get("done_reason") == "load":
|
|
365
|
+
self.logger.warning(f"Ollama returned 'load' reason in chat. Retrying in 2s...")
|
|
366
|
+
await asyncio.sleep(2)
|
|
367
|
+
return await self.chat_async(messages, **kwargs)
|
|
368
|
+
|
|
369
|
+
if res is None:
|
|
370
|
+
raise ValueError(f"Ollama returned empty response for chat model {self.model}")
|
|
371
|
+
return res
|
|
372
|
+
finally:
|
|
373
|
+
heartbeat_task.cancel()
|
|
374
|
+
|
|
375
|
+
def embed(self, text: str) -> List[float]:
|
|
376
|
+
"""Generate embeddings."""
|
|
377
|
+
import httpx
|
|
378
|
+
with httpx.Client(timeout=self.timeout) as client:
|
|
379
|
+
response = client.post(
|
|
380
|
+
f"{self.base_url}/api/embeddings",
|
|
381
|
+
json={
|
|
382
|
+
"model": self.model,
|
|
383
|
+
"prompt": text
|
|
384
|
+
}
|
|
385
|
+
)
|
|
386
|
+
return response.json()["embedding"]
|
|
387
|
+
|
|
388
|
+
async def stream_complete(self, prompt: str, **kwargs):
|
|
389
|
+
"""Stream completion."""
|
|
390
|
+
import httpx
|
|
391
|
+
import json
|
|
392
|
+
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
393
|
+
async with client.stream(
|
|
394
|
+
"POST",
|
|
395
|
+
f"{self.base_url}/api/generate",
|
|
396
|
+
json={
|
|
397
|
+
"model": self.model,
|
|
398
|
+
"prompt": prompt,
|
|
399
|
+
"stream": True,
|
|
400
|
+
**kwargs
|
|
401
|
+
}
|
|
402
|
+
) as response:
|
|
403
|
+
async for line in response.aiter_lines():
|
|
404
|
+
if line:
|
|
405
|
+
chunk = json.loads(line)
|
|
406
|
+
yield chunk.get("response", "")
|
|
407
|
+
if chunk.get("done"):
|
|
408
|
+
break
|
|
409
|
+
|
|
410
|
+
async def stream_chat(self, messages: List[Dict], **kwargs):
|
|
411
|
+
"""Stream chat completion."""
|
|
412
|
+
import httpx
|
|
413
|
+
import json
|
|
414
|
+
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
415
|
+
async with client.stream(
|
|
416
|
+
"POST",
|
|
417
|
+
f"{self.base_url}/api/chat",
|
|
418
|
+
json={
|
|
419
|
+
"model": self.model,
|
|
420
|
+
"messages": messages,
|
|
421
|
+
"stream": True,
|
|
422
|
+
**kwargs
|
|
423
|
+
}
|
|
424
|
+
) as response:
|
|
425
|
+
async for line in response.aiter_lines():
|
|
426
|
+
if line:
|
|
427
|
+
chunk = json.loads(line)
|
|
428
|
+
yield chunk.get("message", {}).get("content", "")
|
|
429
|
+
if chunk.get("done"):
|
|
430
|
+
break
|
|
431
|
+
|
|
432
|
+
@property
|
|
433
|
+
def name(self) -> str:
|
|
434
|
+
return f"Ollama/{self.model}"
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class LMStudioProvider(BaseLLMProvider):
|
|
438
|
+
"""
|
|
439
|
+
LM Studio - Local LLM with GUI (FREE, OPENSOURCE)
|
|
440
|
+
|
|
441
|
+
Download: https://lmstudio.ai/
|
|
442
|
+
Compatible with OpenAI API format.
|
|
443
|
+
"""
|
|
444
|
+
|
|
445
|
+
def __init__(self,
|
|
446
|
+
model: str = "local-model",
|
|
447
|
+
base_url: str = "http://localhost:1234/v1",
|
|
448
|
+
**kwargs):
|
|
449
|
+
self.model = model
|
|
450
|
+
self.base_url = base_url
|
|
451
|
+
self.logger = logging.getLogger("LMStudio")
|
|
452
|
+
|
|
453
|
+
try:
|
|
454
|
+
import openai
|
|
455
|
+
import requests
|
|
456
|
+
import httpx
|
|
457
|
+
self.requests = requests
|
|
458
|
+
self.client = openai.OpenAI(
|
|
459
|
+
base_url=base_url,
|
|
460
|
+
api_key="lm-studio"
|
|
461
|
+
)
|
|
462
|
+
self.async_client = openai.AsyncOpenAI(
|
|
463
|
+
base_url=base_url,
|
|
464
|
+
api_key="lm-studio"
|
|
465
|
+
)
|
|
466
|
+
self._test_connection()
|
|
467
|
+
self.logger.info(f"[OK] LM Studio connected: {model}")
|
|
468
|
+
except Exception as e:
|
|
469
|
+
raise ConnectionError(f"LM Studio not found: {e}")
|
|
470
|
+
|
|
471
|
+
def _test_connection(self):
|
|
472
|
+
"""Test LM Studio connection."""
|
|
473
|
+
try:
|
|
474
|
+
response = self.requests.get(f"{self.base_url}/models", timeout=1)
|
|
475
|
+
if response.status_code != 200:
|
|
476
|
+
raise ConnectionError(f"LM Studio returned {response.status_code}")
|
|
477
|
+
except Exception as e:
|
|
478
|
+
raise ConnectionError(f"LM Studio not running at {self.base_url}: {e}")
|
|
479
|
+
|
|
480
|
+
def _sanitize_params(self, kwargs: Dict) -> Dict:
|
|
481
|
+
"""Translate Kite params (format='json') to LM Studio (response_format)."""
|
|
482
|
+
clean = {}
|
|
483
|
+
# LM Studio is OpenAI compatible
|
|
484
|
+
valid = {'temperature', 'max_tokens', 'top_p', 'stream', 'stop', 'response_format', 'seed'}
|
|
485
|
+
for k, v in kwargs.items():
|
|
486
|
+
if k == 'format' and v == 'json':
|
|
487
|
+
clean['response_format'] = {"type": "json_object"}
|
|
488
|
+
elif k in valid:
|
|
489
|
+
clean[k] = v
|
|
490
|
+
return clean
|
|
491
|
+
|
|
492
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
493
|
+
"""Generate completion."""
|
|
494
|
+
response = self.client.completions.create(
|
|
495
|
+
model=self.model,
|
|
496
|
+
prompt=prompt,
|
|
497
|
+
**kwargs
|
|
498
|
+
)
|
|
499
|
+
return response.choices[0].text
|
|
500
|
+
|
|
501
|
+
async def complete_async(self, prompt: str, **kwargs) -> str:
|
|
502
|
+
"""Async completion."""
|
|
503
|
+
response = await self.async_client.completions.create(
|
|
504
|
+
model=self.model,
|
|
505
|
+
prompt=prompt,
|
|
506
|
+
**kwargs
|
|
507
|
+
)
|
|
508
|
+
return response.choices[0].text
|
|
509
|
+
|
|
510
|
+
async def chat_async(self, messages: List[Dict], **kwargs) -> str:
|
|
511
|
+
"""Async chat completion."""
|
|
512
|
+
response = await self.async_client.chat.completions.create(
|
|
513
|
+
model=self.model,
|
|
514
|
+
messages=messages,
|
|
515
|
+
**kwargs
|
|
516
|
+
)
|
|
517
|
+
return response.choices[0].message.content
|
|
518
|
+
|
|
519
|
+
def chat(self, messages: List[Dict], **kwargs) -> str:
|
|
520
|
+
"""Chat completion."""
|
|
521
|
+
params = self._sanitize_params(kwargs)
|
|
522
|
+
response = self.client.chat.completions.create(
|
|
523
|
+
model=self.model,
|
|
524
|
+
messages=messages,
|
|
525
|
+
**params
|
|
526
|
+
)
|
|
527
|
+
return response.choices[0].message.content
|
|
528
|
+
|
|
529
|
+
def embed(self, text: str) -> List[float]:
|
|
530
|
+
"""Generate embeddings."""
|
|
531
|
+
response = self.client.embeddings.create(
|
|
532
|
+
model=self.model,
|
|
533
|
+
input=text
|
|
534
|
+
)
|
|
535
|
+
return response.data[0].embedding
|
|
536
|
+
|
|
537
|
+
async def stream_complete(self, prompt: str, **kwargs):
|
|
538
|
+
"""Stream completion."""
|
|
539
|
+
params = self._sanitize_params(kwargs)
|
|
540
|
+
stream = await self.async_client.completions.create(
|
|
541
|
+
model=self.model,
|
|
542
|
+
prompt=prompt,
|
|
543
|
+
stream=True,
|
|
544
|
+
**params
|
|
545
|
+
)
|
|
546
|
+
async for chunk in stream:
|
|
547
|
+
if hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].text:
|
|
548
|
+
yield chunk.choices[0].text
|
|
549
|
+
|
|
550
|
+
async def stream_chat(self, messages: List[Dict], **kwargs):
|
|
551
|
+
"""Stream chat completion."""
|
|
552
|
+
params = self._sanitize_params(kwargs)
|
|
553
|
+
stream = await self.async_client.chat.completions.create(
|
|
554
|
+
model=self.model,
|
|
555
|
+
messages=messages,
|
|
556
|
+
stream=True,
|
|
557
|
+
**params
|
|
558
|
+
)
|
|
559
|
+
async for chunk in stream:
|
|
560
|
+
if hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].delta.content:
|
|
561
|
+
yield chunk.choices[0].delta.content
|
|
562
|
+
|
|
563
|
+
@property
|
|
564
|
+
def name(self) -> str:
|
|
565
|
+
return f"LMStudio/{self.model}"
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
class VLLMProvider(BaseLLMProvider):
|
|
569
|
+
"""
|
|
570
|
+
vLLM - Fast inference server (FREE, OPENSOURCE)
|
|
571
|
+
"""
|
|
572
|
+
|
|
573
|
+
def __init__(self,
|
|
574
|
+
model: str = "meta-llama/Llama-2-7b-hf",
|
|
575
|
+
base_url: str = "http://localhost:8000",
|
|
576
|
+
**kwargs):
|
|
577
|
+
self.model = model
|
|
578
|
+
self.base_url = base_url
|
|
579
|
+
self.logger = logging.getLogger("vLLM")
|
|
580
|
+
|
|
581
|
+
try:
|
|
582
|
+
import requests
|
|
583
|
+
self.requests = requests
|
|
584
|
+
self._test_connection()
|
|
585
|
+
self.logger.info(f"[OK] vLLM connected: {model}")
|
|
586
|
+
except ImportError:
|
|
587
|
+
raise ImportError("pip install requests")
|
|
588
|
+
except Exception as e:
|
|
589
|
+
raise ConnectionError(f"vLLM server not found at {base_url}: {e}")
|
|
590
|
+
|
|
591
|
+
def _test_connection(self):
|
|
592
|
+
try:
|
|
593
|
+
response = self.requests.get(f"{self.base_url}/models", timeout=1)
|
|
594
|
+
if response.status_code != 200:
|
|
595
|
+
response = self.requests.get(f"{self.base_url}/health", timeout=1)
|
|
596
|
+
if response.status_code != 200:
|
|
597
|
+
raise ConnectionError(f"vLLM returned {response.status_code}")
|
|
598
|
+
except Exception as e:
|
|
599
|
+
raise ConnectionError(f"vLLM not running at {self.base_url}")
|
|
600
|
+
|
|
601
|
+
async def complete_async(self, prompt: str, **kwargs) -> str:
|
|
602
|
+
import httpx
|
|
603
|
+
params = {
|
|
604
|
+
"prompt": prompt,
|
|
605
|
+
"max_tokens": kwargs.get("max_tokens", 512),
|
|
606
|
+
**kwargs
|
|
607
|
+
}
|
|
608
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
609
|
+
response = await client.post(
|
|
610
|
+
f"{self.base_url}/generate",
|
|
611
|
+
json=params
|
|
612
|
+
)
|
|
613
|
+
return response.json()["text"][0]
|
|
614
|
+
|
|
615
|
+
async def chat_async(self, messages: List[Dict], **kwargs) -> str:
|
|
616
|
+
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
|
617
|
+
return await self.complete_async(prompt, **kwargs)
|
|
618
|
+
|
|
619
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
620
|
+
response = self.requests.post(
|
|
621
|
+
f"{self.base_url}/generate",
|
|
622
|
+
json={
|
|
623
|
+
"prompt": prompt,
|
|
624
|
+
"max_tokens": kwargs.get("max_tokens", 512),
|
|
625
|
+
**kwargs
|
|
626
|
+
}
|
|
627
|
+
)
|
|
628
|
+
return response.json()["text"][0]
|
|
629
|
+
|
|
630
|
+
def chat(self, messages: List[Dict], **kwargs) -> str:
|
|
631
|
+
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
|
632
|
+
return self.complete(prompt, **kwargs)
|
|
633
|
+
|
|
634
|
+
async def stream_complete(self, prompt: str, **kwargs):
|
|
635
|
+
import httpx
|
|
636
|
+
import json
|
|
637
|
+
params = {
|
|
638
|
+
"prompt": prompt,
|
|
639
|
+
"stream": True,
|
|
640
|
+
"max_tokens": kwargs.get("max_tokens", 512),
|
|
641
|
+
**kwargs
|
|
642
|
+
}
|
|
643
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
644
|
+
async with client.stream("POST", f"{self.base_url}/generate", json=params) as response:
|
|
645
|
+
async for line in response.aiter_lines():
|
|
646
|
+
if line:
|
|
647
|
+
chunk = json.loads(line)
|
|
648
|
+
yield chunk.get("text", [""])[0]
|
|
649
|
+
|
|
650
|
+
async def stream_chat(self, messages: List[Dict], **kwargs):
|
|
651
|
+
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
|
652
|
+
async for chunk in self.stream_complete(prompt, **kwargs):
|
|
653
|
+
yield chunk
|
|
654
|
+
|
|
655
|
+
def embed(self, text: str) -> List[float]:
|
|
656
|
+
from .embedding_providers import EmbeddingFactory
|
|
657
|
+
if not hasattr(self, '_embedding_fallback'):
|
|
658
|
+
self._embedding_fallback = EmbeddingFactory.auto_detect()
|
|
659
|
+
return self._embedding_fallback.embed(text)
|
|
660
|
+
|
|
661
|
+
@property
|
|
662
|
+
def name(self) -> str:
|
|
663
|
+
return f"vLLM/{self.model}"
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
class MockLLMProvider(BaseLLMProvider):
|
|
667
|
+
"""Mock LLM Provider for testing."""
|
|
668
|
+
|
|
669
|
+
def __init__(self, model: str = "mock-model", **kwargs):
|
|
670
|
+
self.model = model
|
|
671
|
+
|
|
672
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
673
|
+
return f"Mock response to: {prompt[:50]}..."
|
|
674
|
+
|
|
675
|
+
def chat(self, messages: List[Dict], **kwargs) -> str:
|
|
676
|
+
last_msg = messages[-1]["content"]
|
|
677
|
+
if "ORD-001" in last_msg and kwargs.get("tools"):
|
|
678
|
+
return {
|
|
679
|
+
"content": "",
|
|
680
|
+
"tool_calls": [{
|
|
681
|
+
"function": {
|
|
682
|
+
"name": "search_order",
|
|
683
|
+
"arguments": '{"order_id": "ORD-001"}'
|
|
684
|
+
}
|
|
685
|
+
}]
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
return "Mock chat response: I'm here to help with your agentic tasks!"
|
|
689
|
+
|
|
690
|
+
def embed(self, text: str) -> List[float]:
|
|
691
|
+
import random
|
|
692
|
+
return [random.random() for _ in range(1536)]
|
|
693
|
+
|
|
694
|
+
async def chat_async(self, messages: List[Dict], **kwargs):
|
|
695
|
+
return self.chat(messages, **kwargs)
|
|
696
|
+
|
|
697
|
+
async def stream_complete(self, prompt: str, **kwargs):
|
|
698
|
+
yield f"Mock stream response to: {prompt[:20]}..."
|
|
699
|
+
|
|
700
|
+
async def stream_chat(self, messages: List[Dict], **kwargs):
|
|
701
|
+
yield "Mock chat stream: "
|
|
702
|
+
yield "I'm "
|
|
703
|
+
yield "helping!"
|
|
704
|
+
|
|
705
|
+
@property
|
|
706
|
+
def name(self) -> str:
|
|
707
|
+
return "Mock/LLM"
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
# ============================================================================
|
|
711
|
+
# CLOUD OPENSOURCE PROVIDERS
|
|
712
|
+
# ============================================================================
|
|
713
|
+
|
|
714
|
+
class GroqProvider(BaseLLMProvider):
|
|
715
|
+
"""
|
|
716
|
+
Groq - Ultra-fast inference (FREE tier, OPENSOURCE models)
|
|
717
|
+
"""
|
|
718
|
+
|
|
719
|
+
def __init__(self,
|
|
720
|
+
model: str = "llama3-70b-8192",
|
|
721
|
+
api_key: Optional[str] = None,
|
|
722
|
+
**kwargs):
|
|
723
|
+
self.model = model
|
|
724
|
+
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
|
725
|
+
self.logger = logging.getLogger("Groq")
|
|
726
|
+
|
|
727
|
+
if not self.api_key:
|
|
728
|
+
raise ValueError("GROQ_API_KEY required")
|
|
729
|
+
|
|
730
|
+
try:
|
|
731
|
+
from groq import Groq, AsyncGroq
|
|
732
|
+
self.client = Groq(api_key=self.api_key)
|
|
733
|
+
self.async_client = AsyncGroq(api_key=self.api_key)
|
|
734
|
+
self.logger.info(f"[OK] Groq connected: {model}")
|
|
735
|
+
except ImportError:
|
|
736
|
+
raise ImportError("pip install groq")
|
|
737
|
+
|
|
738
|
+
def _sanitize_params(self, kwargs: Dict) -> Dict:
|
|
739
|
+
"""Translate Kite params (format='json') to Groq/OpenAI (response_format)."""
|
|
740
|
+
clean = {}
|
|
741
|
+
# Valid Groq params
|
|
742
|
+
valid = {'temperature', 'max_tokens', 'top_p', 'stream', 'stop', 'response_format', 'seed', 'tools', 'tool_choice'}
|
|
743
|
+
|
|
744
|
+
for k, v in kwargs.items():
|
|
745
|
+
if k == 'format' and v == 'json':
|
|
746
|
+
clean['response_format'] = {"type": "json_object"}
|
|
747
|
+
elif k in valid:
|
|
748
|
+
clean[k] = v
|
|
749
|
+
return clean
|
|
750
|
+
|
|
751
|
+
async def complete_async(self, prompt: str, **kwargs) -> str:
|
|
752
|
+
"""Async completion."""
|
|
753
|
+
return await self.chat_async([{"role": "user", "content": prompt}], **kwargs)
|
|
754
|
+
|
|
755
|
+
@retry(wait=wait_exponential(multiplier=1, min=4, max=60), stop=stop_after_attempt(5), reraise=True)
|
|
756
|
+
async def chat_async(self, messages: List[Dict], **kwargs) -> str:
|
|
757
|
+
"""Async chat completion."""
|
|
758
|
+
params = self._sanitize_params(kwargs)
|
|
759
|
+
try:
|
|
760
|
+
response = await self.async_client.chat.completions.create(
|
|
761
|
+
model=self.model,
|
|
762
|
+
messages=messages,
|
|
763
|
+
**params
|
|
764
|
+
)
|
|
765
|
+
except Exception as e:
|
|
766
|
+
self.logger.error(f"Groq Chat Async Error: {e}")
|
|
767
|
+
raise
|
|
768
|
+
|
|
769
|
+
msg = response.choices[0].message
|
|
770
|
+
if msg.tool_calls:
|
|
771
|
+
# Convert to dicts same as sync version
|
|
772
|
+
tool_calls = []
|
|
773
|
+
for tc in msg.tool_calls:
|
|
774
|
+
tool_calls.append({
|
|
775
|
+
"id": tc.id,
|
|
776
|
+
"type": tc.type,
|
|
777
|
+
"function": {
|
|
778
|
+
"name": tc.function.name,
|
|
779
|
+
"arguments": tc.function.arguments
|
|
780
|
+
}
|
|
781
|
+
})
|
|
782
|
+
return {"content": msg.content, "tool_calls": tool_calls}
|
|
783
|
+
|
|
784
|
+
return msg.content
|
|
785
|
+
|
|
786
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
787
|
+
"""Generate completion."""
|
|
788
|
+
return self.chat([{"role": "user", "content": prompt}], **kwargs)
|
|
789
|
+
|
|
790
|
+
@retry(wait=wait_exponential(multiplier=1, min=4, max=60), stop=stop_after_attempt(5), reraise=True)
|
|
791
|
+
def chat(self, messages: List[Dict], **kwargs) -> str:
|
|
792
|
+
"""Chat completion."""
|
|
793
|
+
params = self._sanitize_params(kwargs)
|
|
794
|
+
try:
|
|
795
|
+
response = self.client.chat.completions.create(
|
|
796
|
+
model=self.model,
|
|
797
|
+
messages=messages,
|
|
798
|
+
**params
|
|
799
|
+
)
|
|
800
|
+
except Exception as e:
|
|
801
|
+
self.logger.error(f"Groq Chat Error: {e}")
|
|
802
|
+
raise
|
|
803
|
+
|
|
804
|
+
msg = response.choices[0].message
|
|
805
|
+
if msg.tool_calls:
|
|
806
|
+
# Convert objects to dicts for Agent.run
|
|
807
|
+
tool_calls = []
|
|
808
|
+
for tc in msg.tool_calls:
|
|
809
|
+
tool_calls.append({
|
|
810
|
+
"id": tc.id,
|
|
811
|
+
"type": tc.type,
|
|
812
|
+
"function": {
|
|
813
|
+
"name": tc.function.name,
|
|
814
|
+
"arguments": tc.function.arguments
|
|
815
|
+
}
|
|
816
|
+
})
|
|
817
|
+
return {"content": msg.content, "tool_calls": tool_calls}
|
|
818
|
+
|
|
819
|
+
return msg.content
|
|
820
|
+
|
|
821
|
+
async def stream_complete(self, prompt: str, **kwargs):
|
|
822
|
+
"""Stream completion."""
|
|
823
|
+
params = self._sanitize_params(kwargs)
|
|
824
|
+
stream = await self.async_client.chat.completions.create(
|
|
825
|
+
model=self.model,
|
|
826
|
+
messages=[{"role": "user", "content": prompt}],
|
|
827
|
+
stream=True,
|
|
828
|
+
**params
|
|
829
|
+
)
|
|
830
|
+
async for chunk in stream:
|
|
831
|
+
if chunk.choices[0].delta.content:
|
|
832
|
+
yield chunk.choices[0].delta.content
|
|
833
|
+
|
|
834
|
+
async def stream_chat(self, messages: List[Dict], **kwargs):
|
|
835
|
+
"""Stream chat completion."""
|
|
836
|
+
params = self._sanitize_params(kwargs)
|
|
837
|
+
stream = await self.async_client.chat.completions.create(
|
|
838
|
+
model=self.model,
|
|
839
|
+
messages=messages,
|
|
840
|
+
stream=True,
|
|
841
|
+
**params
|
|
842
|
+
)
|
|
843
|
+
async for chunk in stream:
|
|
844
|
+
if chunk.choices[0].delta.content:
|
|
845
|
+
yield chunk.choices[0].delta.content
|
|
846
|
+
|
|
847
|
+
def embed(self, text: str) -> List[float]:
|
|
848
|
+
"""Generate embeddings (fallback)."""
|
|
849
|
+
from .embedding_providers import EmbeddingFactory
|
|
850
|
+
if not hasattr(self, '_embedding_fallback'):
|
|
851
|
+
self._embedding_fallback = EmbeddingFactory.auto_detect()
|
|
852
|
+
return self._embedding_fallback.embed(text)
|
|
853
|
+
|
|
854
|
+
@property
|
|
855
|
+
def name(self) -> str:
|
|
856
|
+
return f"Groq/{self.model}"
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
class TogetherProvider(BaseLLMProvider):
|
|
860
|
+
"""
|
|
861
|
+
Together AI - Opensource models (PAID, but cheap)
|
|
862
|
+
"""
|
|
863
|
+
|
|
864
|
+
def __init__(self,
|
|
865
|
+
model: str = "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
866
|
+
api_key: Optional[str] = None,
|
|
867
|
+
**kwargs):
|
|
868
|
+
self.model = model
|
|
869
|
+
self.api_key = api_key or os.getenv("TOGETHER_API_KEY")
|
|
870
|
+
self.logger = logging.getLogger("Together")
|
|
871
|
+
|
|
872
|
+
if not self.api_key:
|
|
873
|
+
raise ValueError("TOGETHER_API_KEY required")
|
|
874
|
+
|
|
875
|
+
try:
|
|
876
|
+
import together
|
|
877
|
+
together.api_key = self.api_key
|
|
878
|
+
self.together = together
|
|
879
|
+
self.logger.info(f"[OK] Together AI connected: {model}")
|
|
880
|
+
except ImportError:
|
|
881
|
+
raise ImportError("pip install together")
|
|
882
|
+
|
|
883
|
+
async def complete_async(self, prompt: str, **kwargs) -> str:
|
|
884
|
+
import httpx
|
|
885
|
+
params = {
|
|
886
|
+
"model": self.model,
|
|
887
|
+
"prompt": prompt,
|
|
888
|
+
**kwargs
|
|
889
|
+
}
|
|
890
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
891
|
+
async with httpx.AsyncClient() as client:
|
|
892
|
+
response = await client.post(
|
|
893
|
+
"https://api.together.xyz/v1/completions",
|
|
894
|
+
json=params,
|
|
895
|
+
headers=headers
|
|
896
|
+
)
|
|
897
|
+
data = response.json()
|
|
898
|
+
return data['choices'][0]['text']
|
|
899
|
+
|
|
900
|
+
async def chat_async(self, messages: List[Dict], **kwargs) -> str:
|
|
901
|
+
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
|
902
|
+
return await self.complete_async(prompt, **kwargs)
|
|
903
|
+
|
|
904
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
905
|
+
response = self.together.Complete.create(
|
|
906
|
+
prompt=prompt,
|
|
907
|
+
model=self.model,
|
|
908
|
+
**kwargs
|
|
909
|
+
)
|
|
910
|
+
return response['output']['choices'][0]['text']
|
|
911
|
+
|
|
912
|
+
def chat(self, messages: List[Dict], **kwargs) -> str:
|
|
913
|
+
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
|
914
|
+
return self.complete(prompt, **kwargs)
|
|
915
|
+
|
|
916
|
+
def embed(self, text: str) -> List[float]:
|
|
917
|
+
response = self.together.Embeddings.create(
|
|
918
|
+
input=text,
|
|
919
|
+
model="togethercomputer/m2-bert-80M-8k-retrieval"
|
|
920
|
+
)
|
|
921
|
+
return response['data'][0]['embedding']
|
|
922
|
+
|
|
923
|
+
async def stream_complete(self, prompt: str, **kwargs):
|
|
924
|
+
import httpx
|
|
925
|
+
import json
|
|
926
|
+
params = {
|
|
927
|
+
"model": self.model,
|
|
928
|
+
"prompt": prompt,
|
|
929
|
+
"stream": True,
|
|
930
|
+
**kwargs
|
|
931
|
+
}
|
|
932
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
933
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
934
|
+
async with client.stream(
|
|
935
|
+
"POST",
|
|
936
|
+
"https://api.together.xyz/v1/completions",
|
|
937
|
+
json=params,
|
|
938
|
+
headers=headers
|
|
939
|
+
) as response:
|
|
940
|
+
async for line in response.aiter_lines():
|
|
941
|
+
if line.startswith("data: "):
|
|
942
|
+
line = line[6:]
|
|
943
|
+
if not line or line == "[DONE]":
|
|
944
|
+
continue
|
|
945
|
+
try:
|
|
946
|
+
chunk = json.loads(line)
|
|
947
|
+
yield chunk['choices'][0]['text']
|
|
948
|
+
except:
|
|
949
|
+
continue
|
|
950
|
+
|
|
951
|
+
async def stream_chat(self, messages: List[Dict], **kwargs):
|
|
952
|
+
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
|
953
|
+
async for chunk in self.stream_complete(prompt, **kwargs):
|
|
954
|
+
yield chunk
|
|
955
|
+
|
|
956
|
+
@property
|
|
957
|
+
def name(self) -> str:
|
|
958
|
+
return f"Together/{self.model}"
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
# ============================================================================
|
|
962
|
+
# COMMERCIAL PROVIDERS (Secondary)
|
|
963
|
+
# ============================================================================
|
|
964
|
+
|
|
965
|
+
class AnthropicProvider(BaseLLMProvider):
|
|
966
|
+
"""Anthropic Claude (PAID)."""
|
|
967
|
+
|
|
968
|
+
def __init__(self,
|
|
969
|
+
model: str = "claude-3-sonnet-20240229",
|
|
970
|
+
api_key: Optional[str] = None,
|
|
971
|
+
**kwargs):
|
|
972
|
+
self.model = model
|
|
973
|
+
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
974
|
+
|
|
975
|
+
try:
|
|
976
|
+
import anthropic
|
|
977
|
+
self.client = anthropic.Anthropic(api_key=self.api_key)
|
|
978
|
+
self.async_client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
979
|
+
except ImportError:
|
|
980
|
+
raise ImportError("pip install anthropic")
|
|
981
|
+
|
|
982
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
983
|
+
return self.chat([{"role": "user", "content": prompt}], **kwargs)
|
|
984
|
+
|
|
985
|
+
def chat(self, messages: List[Dict], **kwargs) -> str:
|
|
986
|
+
response = self.client.messages.create(
|
|
987
|
+
model=self.model,
|
|
988
|
+
messages=messages,
|
|
989
|
+
max_tokens=kwargs.get("max_tokens", 1024)
|
|
990
|
+
)
|
|
991
|
+
return response.content[0].text
|
|
992
|
+
|
|
993
|
+
async def complete_async(self, prompt: str, **kwargs) -> str:
|
|
994
|
+
return await self.chat_async([{"role": "user", "content": prompt}], **kwargs)
|
|
995
|
+
|
|
996
|
+
async def chat_async(self, messages: List[Dict], **kwargs) -> str:
|
|
997
|
+
response = await self.async_client.messages.create(
|
|
998
|
+
model=self.model,
|
|
999
|
+
messages=messages,
|
|
1000
|
+
max_tokens=kwargs.get("max_tokens", 1024)
|
|
1001
|
+
)
|
|
1002
|
+
return response.content[0].text
|
|
1003
|
+
|
|
1004
|
+
async def stream_complete(self, prompt: str, **kwargs):
|
|
1005
|
+
async for chunk in self.stream_chat([{"role": "user", "content": prompt}], **kwargs):
|
|
1006
|
+
yield chunk
|
|
1007
|
+
|
|
1008
|
+
async def stream_chat(self, messages: List[Dict], **kwargs):
|
|
1009
|
+
async with self.async_client.messages.stream(
|
|
1010
|
+
model=self.model,
|
|
1011
|
+
messages=messages,
|
|
1012
|
+
max_tokens=kwargs.get("max_tokens", 1024)
|
|
1013
|
+
) as stream:
|
|
1014
|
+
async for text in stream.text_stream:
|
|
1015
|
+
yield text
|
|
1016
|
+
|
|
1017
|
+
def embed(self, text: str) -> List[float]:
|
|
1018
|
+
from .embedding_providers import EmbeddingFactory
|
|
1019
|
+
if not hasattr(self, '_embedding_fallback'):
|
|
1020
|
+
self._embedding_fallback = EmbeddingFactory.auto_detect()
|
|
1021
|
+
return self._embedding_fallback.embed(text)
|
|
1022
|
+
|
|
1023
|
+
@property
|
|
1024
|
+
def name(self) -> str:
|
|
1025
|
+
return f"Anthropic/{self.model}"
|
|
1026
|
+
|
|
1027
|
+
class OpenAIProvider(BaseLLMProvider):
|
|
1028
|
+
"""OpenAI GPT (PAID)."""
|
|
1029
|
+
|
|
1030
|
+
def __init__(self,
|
|
1031
|
+
model: str = "gpt-4-turbo-preview",
|
|
1032
|
+
api_key: Optional[str] = None,
|
|
1033
|
+
**kwargs):
|
|
1034
|
+
self.model = model
|
|
1035
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
1036
|
+
|
|
1037
|
+
try:
|
|
1038
|
+
import openai
|
|
1039
|
+
self.client = openai.OpenAI(api_key=self.api_key)
|
|
1040
|
+
self._test_connection()
|
|
1041
|
+
except Exception as e:
|
|
1042
|
+
raise ConnectionError(f"OpenAI failed: {e}")
|
|
1043
|
+
|
|
1044
|
+
def _sanitize_params(self, kwargs: Dict) -> Dict:
|
|
1045
|
+
"""Translate Kite params (format='json') to OpenAI (response_format)."""
|
|
1046
|
+
clean = {}
|
|
1047
|
+
# Valid OpenAI params
|
|
1048
|
+
valid = {'temperature', 'max_tokens', 'top_p', 'stream', 'stop', 'response_format', 'seed', 'n', 'presence_penalty', 'frequency_penalty'}
|
|
1049
|
+
|
|
1050
|
+
for k, v in kwargs.items():
|
|
1051
|
+
if k == 'format' and v == 'json':
|
|
1052
|
+
clean['response_format'] = {"type": "json_object"}
|
|
1053
|
+
elif k in valid:
|
|
1054
|
+
clean[k] = v
|
|
1055
|
+
return clean
|
|
1056
|
+
|
|
1057
|
+
def _test_connection(self):
|
|
1058
|
+
"""Test OpenAI connection (auth check)."""
|
|
1059
|
+
if not self.api_key or self.api_key.startswith("sk-..."):
|
|
1060
|
+
raise ValueError("OpenAI API key is invalid or placeholder")
|
|
1061
|
+
try:
|
|
1062
|
+
self.client.models.list()
|
|
1063
|
+
except Exception as e:
|
|
1064
|
+
raise ConnectionError(f"OpenAI auth failed: {e}")
|
|
1065
|
+
|
|
1066
|
+
async def chat_async(self, messages: List[Dict], **kwargs) -> str:
|
|
1067
|
+
"""True async chat for OpenAI."""
|
|
1068
|
+
if not hasattr(self, 'async_client'):
|
|
1069
|
+
import openai
|
|
1070
|
+
self.async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
|
1071
|
+
|
|
1072
|
+
params = self._sanitize_params(kwargs)
|
|
1073
|
+
response = await self.async_client.chat.completions.create(
|
|
1074
|
+
model=self.model,
|
|
1075
|
+
messages=messages,
|
|
1076
|
+
**params
|
|
1077
|
+
)
|
|
1078
|
+
return response.choices[0].message.content
|
|
1079
|
+
|
|
1080
|
+
async def complete_async(self, prompt: str, **kwargs) -> str:
|
|
1081
|
+
"""True async complete for OpenAI."""
|
|
1082
|
+
return await self.chat_async([{"role": "user", "content": prompt}], **kwargs)
|
|
1083
|
+
|
|
1084
|
+
def complete(self, prompt: str, **kwargs) -> str:
|
|
1085
|
+
return self.chat([{"role": "user", "content": prompt}], **kwargs)
|
|
1086
|
+
|
|
1087
|
+
def chat(self, messages: List[Dict], **kwargs) -> str:
|
|
1088
|
+
params = self._sanitize_params(kwargs)
|
|
1089
|
+
response = self.client.chat.completions.create(
|
|
1090
|
+
model=self.model,
|
|
1091
|
+
messages=messages,
|
|
1092
|
+
**params
|
|
1093
|
+
)
|
|
1094
|
+
return response.choices[0].message.content
|
|
1095
|
+
|
|
1096
|
+
def embed(self, text: str) -> List[float]:
|
|
1097
|
+
response = self.client.embeddings.create(
|
|
1098
|
+
model="text-embedding-3-small",
|
|
1099
|
+
input=text
|
|
1100
|
+
)
|
|
1101
|
+
return response.data[0].embedding
|
|
1102
|
+
|
|
1103
|
+
async def stream_complete(self, prompt: str, **kwargs):
|
|
1104
|
+
"""Stream completion."""
|
|
1105
|
+
# Ensure async_client is initialized
|
|
1106
|
+
if not hasattr(self, 'async_client'):
|
|
1107
|
+
import openai
|
|
1108
|
+
self.async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
|
1109
|
+
|
|
1110
|
+
params = self._sanitize_params(kwargs)
|
|
1111
|
+
stream = await self.async_client.completions.create(
|
|
1112
|
+
model=self.model,
|
|
1113
|
+
prompt=prompt,
|
|
1114
|
+
stream=True,
|
|
1115
|
+
**params
|
|
1116
|
+
)
|
|
1117
|
+
async for chunk in stream:
|
|
1118
|
+
if chunk.choices[0].text:
|
|
1119
|
+
yield chunk.choices[0].text
|
|
1120
|
+
|
|
1121
|
+
async def stream_chat(self, messages: List[Dict], **kwargs):
|
|
1122
|
+
"""Stream chat completion."""
|
|
1123
|
+
# Ensure async_client is initialized
|
|
1124
|
+
if not hasattr(self, 'async_client'):
|
|
1125
|
+
import openai
|
|
1126
|
+
self.async_client = openai.AsyncOpenAI(api_key=self.api_key)
|
|
1127
|
+
|
|
1128
|
+
params = self._sanitize_params(kwargs)
|
|
1129
|
+
stream = await self.async_client.chat.completions.create(
|
|
1130
|
+
model=self.model,
|
|
1131
|
+
messages=messages,
|
|
1132
|
+
stream=True,
|
|
1133
|
+
**params
|
|
1134
|
+
)
|
|
1135
|
+
async for chunk in stream:
|
|
1136
|
+
if chunk.choices[0].delta.content:
|
|
1137
|
+
yield chunk.choices[0].delta.content
|
|
1138
|
+
|
|
1139
|
+
@property
|
|
1140
|
+
def name(self) -> str:
|
|
1141
|
+
return f"OpenAI/{self.model}"
|
|
1142
|
+
|
|
1143
|
+
|
|
1144
|
+
class LLMFactory:
|
|
1145
|
+
"""
|
|
1146
|
+
Factory for creating LLM providers.
|
|
1147
|
+
|
|
1148
|
+
Priority order:
|
|
1149
|
+
1. Local/Free (Ollama, LM Studio, vLLM)
|
|
1150
|
+
2. Cloud Free Tier (Groq)
|
|
1151
|
+
3. Cloud Opensource (Together)
|
|
1152
|
+
4. Commercial (Claude, GPT)
|
|
1153
|
+
"""
|
|
1154
|
+
|
|
1155
|
+
PROVIDERS = {
|
|
1156
|
+
# Local (Priority)
|
|
1157
|
+
'ollama': OllamaProvider,
|
|
1158
|
+
'lmstudio': LMStudioProvider,
|
|
1159
|
+
'vllm': VLLMProvider,
|
|
1160
|
+
|
|
1161
|
+
# Cloud Opensource
|
|
1162
|
+
'groq': GroqProvider,
|
|
1163
|
+
'together': TogetherProvider,
|
|
1164
|
+
|
|
1165
|
+
# Commercial
|
|
1166
|
+
'anthropic': AnthropicProvider,
|
|
1167
|
+
'openai': OpenAIProvider,
|
|
1168
|
+
|
|
1169
|
+
# Testing
|
|
1170
|
+
'mock': MockLLMProvider,
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
@classmethod
|
|
1174
|
+
def create(cls,
|
|
1175
|
+
provider: str = "ollama",
|
|
1176
|
+
model: Optional[str] = None,
|
|
1177
|
+
**kwargs) -> BaseLLMProvider:
|
|
1178
|
+
"""
|
|
1179
|
+
Create LLM provider.
|
|
1180
|
+
|
|
1181
|
+
Args:
|
|
1182
|
+
provider: Provider name
|
|
1183
|
+
model: Model name (optional, uses default)
|
|
1184
|
+
**kwargs: Provider-specific kwargs
|
|
1185
|
+
"""
|
|
1186
|
+
if provider not in cls.PROVIDERS:
|
|
1187
|
+
raise ValueError(f"Unknown provider: {provider}. "
|
|
1188
|
+
f"Available: {list(cls.PROVIDERS.keys())}")
|
|
1189
|
+
|
|
1190
|
+
provider_class = cls.PROVIDERS[provider]
|
|
1191
|
+
|
|
1192
|
+
# Create with model if specified
|
|
1193
|
+
if model:
|
|
1194
|
+
return provider_class(model=model, **kwargs)
|
|
1195
|
+
else:
|
|
1196
|
+
return provider_class(**kwargs)
|
|
1197
|
+
|
|
1198
|
+
@classmethod
|
|
1199
|
+
def auto_detect(cls, timeout: float = 600.0) -> BaseLLMProvider:
|
|
1200
|
+
"""
|
|
1201
|
+
Auto-detect best available provider.
|
|
1202
|
+
|
|
1203
|
+
Priority:
|
|
1204
|
+
1. Try Ollama (local)
|
|
1205
|
+
2. Try LM Studio (local)
|
|
1206
|
+
3. Try Groq (cloud, free)
|
|
1207
|
+
4. Try OpenAI (fallback)
|
|
1208
|
+
"""
|
|
1209
|
+
logger = logging.getLogger("LLMFactory")
|
|
1210
|
+
|
|
1211
|
+
# Try Ollama
|
|
1212
|
+
try:
|
|
1213
|
+
provider = cls.create("ollama", timeout=timeout)
|
|
1214
|
+
logger.info("[OK] Using Ollama (local, free)")
|
|
1215
|
+
return provider
|
|
1216
|
+
except:
|
|
1217
|
+
pass
|
|
1218
|
+
|
|
1219
|
+
# Try LM Studio
|
|
1220
|
+
try:
|
|
1221
|
+
provider = cls.create("lmstudio", timeout=timeout)
|
|
1222
|
+
logger.info("[OK] Using LM Studio (local, free)")
|
|
1223
|
+
return provider
|
|
1224
|
+
except:
|
|
1225
|
+
pass
|
|
1226
|
+
|
|
1227
|
+
# Try Groq
|
|
1228
|
+
if os.getenv("GROQ_API_KEY"):
|
|
1229
|
+
try:
|
|
1230
|
+
provider = cls.create("groq", timeout=timeout)
|
|
1231
|
+
logger.info("[OK] Using Groq (cloud, free tier)")
|
|
1232
|
+
return provider
|
|
1233
|
+
except:
|
|
1234
|
+
pass
|
|
1235
|
+
|
|
1236
|
+
# Fallback to OpenAI
|
|
1237
|
+
if os.getenv("OPENAI_API_KEY"):
|
|
1238
|
+
try:
|
|
1239
|
+
provider = cls.create("openai", timeout=timeout)
|
|
1240
|
+
logger.warning(" Using OpenAI (commercial, paid)")
|
|
1241
|
+
return provider
|
|
1242
|
+
except:
|
|
1243
|
+
pass
|
|
1244
|
+
|
|
1245
|
+
# Try mock as ultimate fallback
|
|
1246
|
+
try:
|
|
1247
|
+
return cls.create("mock", timeout=timeout)
|
|
1248
|
+
except:
|
|
1249
|
+
pass
|
|
1250
|
+
|
|
1251
|
+
raise RuntimeError(
|
|
1252
|
+
"No LLM provider available. Install Ollama or set API keys."
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
|
|
1256
|
+
if __name__ == "__main__":
|
|
1257
|
+
print("LLM Provider Examples:\n")
|
|
1258
|
+
|
|
1259
|
+
# Auto-detect
|
|
1260
|
+
print("1. Auto-detect:")
|
|
1261
|
+
llm = LLMFactory.auto_detect()
|
|
1262
|
+
print(f" Using: {llm.name}\n")
|
|
1263
|
+
|
|
1264
|
+
# Specific providers
|
|
1265
|
+
print("2. Specific providers:")
|
|
1266
|
+
|
|
1267
|
+
providers = [
|
|
1268
|
+
("ollama", "llama3"),
|
|
1269
|
+
("groq", "llama3-70b-8192"),
|
|
1270
|
+
("together", "mistralai/Mixtral-8x7B-Instruct-v0.1"),
|
|
1271
|
+
]
|
|
1272
|
+
|
|
1273
|
+
for provider, model in providers:
|
|
1274
|
+
try:
|
|
1275
|
+
llm = LLMFactory.create(provider, model)
|
|
1276
|
+
print(f" [OK] {llm.name}")
|
|
1277
|
+
except Exception as e:
|
|
1278
|
+
print(f" {provider}: {e}")
|