kailash 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/nodes/__init__.py +2 -1
- kailash/nodes/ai/__init__.py +26 -0
- kailash/nodes/ai/ai_providers.py +1272 -0
- kailash/nodes/ai/embedding_generator.py +853 -0
- kailash/nodes/ai/llm_agent.py +1166 -0
- kailash/nodes/api/auth.py +3 -3
- kailash/nodes/api/graphql.py +2 -2
- kailash/nodes/api/http.py +391 -44
- kailash/nodes/api/rate_limiting.py +2 -2
- kailash/nodes/api/rest.py +464 -56
- kailash/nodes/base.py +71 -12
- kailash/nodes/code/python.py +2 -1
- kailash/nodes/data/__init__.py +7 -0
- kailash/nodes/data/readers.py +28 -26
- kailash/nodes/data/retrieval.py +178 -0
- kailash/nodes/data/sharepoint_graph.py +7 -7
- kailash/nodes/data/sources.py +65 -0
- kailash/nodes/data/sql.py +4 -2
- kailash/nodes/data/writers.py +6 -3
- kailash/nodes/logic/operations.py +2 -1
- kailash/nodes/mcp/__init__.py +11 -0
- kailash/nodes/mcp/client.py +558 -0
- kailash/nodes/mcp/resource.py +682 -0
- kailash/nodes/mcp/server.py +571 -0
- kailash/nodes/transform/__init__.py +16 -1
- kailash/nodes/transform/chunkers.py +78 -0
- kailash/nodes/transform/formatters.py +96 -0
- kailash/runtime/docker.py +6 -6
- kailash/sdk_exceptions.py +24 -10
- kailash/tracking/metrics_collector.py +2 -1
- kailash/utils/templates.py +6 -6
- {kailash-0.1.1.dist-info → kailash-0.1.2.dist-info}/METADATA +344 -46
- {kailash-0.1.1.dist-info → kailash-0.1.2.dist-info}/RECORD +37 -26
- {kailash-0.1.1.dist-info → kailash-0.1.2.dist-info}/WHEEL +0 -0
- {kailash-0.1.1.dist-info → kailash-0.1.2.dist-info}/entry_points.txt +0 -0
- {kailash-0.1.1.dist-info → kailash-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.1.1.dist-info → kailash-0.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1272 @@
|
|
1
|
+
"""Unified AI provider implementations for LLM and embedding operations.
|
2
|
+
|
3
|
+
This module provides a unified interface for AI providers that support both
|
4
|
+
language model chat operations and text embedding generation. It reduces
|
5
|
+
redundancy by consolidating common functionality while maintaining clean
|
6
|
+
separation between LLM and embedding capabilities.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import hashlib
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
from typing import Any, Dict, List, Optional, Union
|
12
|
+
|
13
|
+
|
14
|
+
class BaseAIProvider(ABC):
|
15
|
+
"""
|
16
|
+
Base class for all AI provider implementations.
|
17
|
+
|
18
|
+
This abstract class defines the common interface and shared functionality
|
19
|
+
for providers that may support LLM operations, embedding operations, or both.
|
20
|
+
|
21
|
+
Design Philosophy:
|
22
|
+
- Single source of truth for provider availability
|
23
|
+
- Shared client management and initialization
|
24
|
+
- Common error handling patterns
|
25
|
+
- Flexible support for providers with different capabilities
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self):
|
29
|
+
"""Initialize base provider state."""
|
30
|
+
self._client = None
|
31
|
+
self._available = None
|
32
|
+
self._capabilities = {"chat": False, "embeddings": False}
|
33
|
+
|
34
|
+
@abstractmethod
|
35
|
+
def is_available(self) -> bool:
|
36
|
+
"""
|
37
|
+
Check if the provider is available and properly configured.
|
38
|
+
|
39
|
+
This method should verify:
|
40
|
+
- Required dependencies are installed
|
41
|
+
- API keys or credentials are configured
|
42
|
+
- Services are accessible (for local services)
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
bool: True if the provider can be used, False otherwise
|
46
|
+
"""
|
47
|
+
pass
|
48
|
+
|
49
|
+
def get_capabilities(self) -> Dict[str, bool]:
|
50
|
+
"""
|
51
|
+
Get the capabilities supported by this provider.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
Dict[str, bool]: Dictionary indicating support for:
|
55
|
+
- chat: LLM chat operations
|
56
|
+
- embeddings: Text embedding generation
|
57
|
+
"""
|
58
|
+
return self._capabilities.copy()
|
59
|
+
|
60
|
+
def supports_chat(self) -> bool:
|
61
|
+
"""Check if this provider supports LLM chat operations."""
|
62
|
+
return self._capabilities.get("chat", False)
|
63
|
+
|
64
|
+
def supports_embeddings(self) -> bool:
|
65
|
+
"""Check if this provider supports embedding generation."""
|
66
|
+
return self._capabilities.get("embeddings", False)
|
67
|
+
|
68
|
+
|
69
|
+
class LLMProvider(BaseAIProvider):
|
70
|
+
"""
|
71
|
+
Abstract base class for providers that support LLM chat operations.
|
72
|
+
|
73
|
+
Providers that support chat operations should inherit from this class
|
74
|
+
and implement the chat() method.
|
75
|
+
"""
|
76
|
+
|
77
|
+
def __init__(self):
|
78
|
+
super().__init__()
|
79
|
+
self._capabilities["chat"] = True
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
83
|
+
"""
|
84
|
+
Generate a chat completion using the provider's LLM.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
messages: Conversation messages in OpenAI format
|
88
|
+
**kwargs: Provider-specific parameters
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
Dict containing the standardized response
|
92
|
+
"""
|
93
|
+
pass
|
94
|
+
|
95
|
+
|
96
|
+
class EmbeddingProvider(BaseAIProvider):
|
97
|
+
"""
|
98
|
+
Abstract base class for providers that support embedding generation.
|
99
|
+
|
100
|
+
Providers that support embedding operations should inherit from this class
|
101
|
+
and implement the embed() and get_model_info() methods.
|
102
|
+
"""
|
103
|
+
|
104
|
+
def __init__(self):
|
105
|
+
super().__init__()
|
106
|
+
self._capabilities["embeddings"] = True
|
107
|
+
|
108
|
+
@abstractmethod
|
109
|
+
def embed(self, texts: List[str], **kwargs) -> List[List[float]]:
|
110
|
+
"""
|
111
|
+
Generate embeddings for a list of texts.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
texts: List of texts to embed
|
115
|
+
**kwargs: Provider-specific parameters
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
List of embedding vectors
|
119
|
+
"""
|
120
|
+
pass
|
121
|
+
|
122
|
+
@abstractmethod
|
123
|
+
def get_model_info(self, model: str) -> Dict[str, Any]:
|
124
|
+
"""
|
125
|
+
Get information about a specific embedding model.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
model: Model identifier
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
Dict containing model information
|
132
|
+
"""
|
133
|
+
pass
|
134
|
+
|
135
|
+
|
136
|
+
class UnifiedAIProvider(LLMProvider, EmbeddingProvider):
|
137
|
+
"""
|
138
|
+
Abstract base class for providers that support both LLM and embedding operations.
|
139
|
+
|
140
|
+
Providers like OpenAI and Ollama that support both capabilities should
|
141
|
+
inherit from this class.
|
142
|
+
"""
|
143
|
+
|
144
|
+
def __init__(self):
|
145
|
+
super().__init__()
|
146
|
+
self._capabilities = {"chat": True, "embeddings": True}
|
147
|
+
|
148
|
+
|
149
|
+
# ============================================================================
|
150
|
+
# Unified Provider Implementations
|
151
|
+
# ============================================================================
|
152
|
+
|
153
|
+
|
154
|
+
class OllamaProvider(UnifiedAIProvider):
|
155
|
+
"""Ollama provider for both LLM and embedding operations.
|
156
|
+
|
157
|
+
Ollama runs models locally on your machine, supporting both chat and
|
158
|
+
embedding operations with various open-source models.
|
159
|
+
|
160
|
+
Prerequisites:
|
161
|
+
* Install Ollama: https://ollama.ai
|
162
|
+
* Pull models:
|
163
|
+
* LLM: ``ollama pull llama3.1:8b-instruct-q8_0``
|
164
|
+
* Embeddings: ``ollama pull snowflake-arctic-embed2``
|
165
|
+
* Ensure Ollama service is running
|
166
|
+
|
167
|
+
Supported LLM models:
|
168
|
+
* llama3.1:* (various quantizations)
|
169
|
+
* mixtral:* (various quantizations)
|
170
|
+
* mistral:* (various quantizations)
|
171
|
+
* qwen2.5:* (various sizes and quantizations)
|
172
|
+
|
173
|
+
Supported embedding models:
|
174
|
+
* snowflake-arctic-embed2 (1024 dimensions)
|
175
|
+
* avr/sfr-embedding-mistral (4096 dimensions)
|
176
|
+
* nomic-embed-text (768 dimensions)
|
177
|
+
* mxbai-embed-large (1024 dimensions)
|
178
|
+
"""
|
179
|
+
|
180
|
+
def __init__(self):
|
181
|
+
super().__init__()
|
182
|
+
self._model_cache = {}
|
183
|
+
|
184
|
+
def is_available(self) -> bool:
|
185
|
+
"""Check if Ollama is available."""
|
186
|
+
if self._available is not None:
|
187
|
+
return self._available
|
188
|
+
|
189
|
+
try:
|
190
|
+
import ollama
|
191
|
+
|
192
|
+
# Check if Ollama is running
|
193
|
+
ollama.list()
|
194
|
+
self._available = True
|
195
|
+
except Exception:
|
196
|
+
self._available = False
|
197
|
+
|
198
|
+
return self._available
|
199
|
+
|
200
|
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
201
|
+
"""Generate a chat completion using Ollama.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
messages: Conversation messages in OpenAI format.
|
205
|
+
**kwargs: Additional arguments including:
|
206
|
+
model (str): Ollama model name (default: "llama3.1:8b-instruct-q8_0")
|
207
|
+
generation_config (dict): Generation parameters including:
|
208
|
+
* temperature, max_tokens, top_p, top_k, repeat_penalty
|
209
|
+
* seed, stop, num_ctx, num_batch, num_thread
|
210
|
+
* tfs_z, typical_p, mirostat, mirostat_tau, mirostat_eta
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
Dict containing the standardized response.
|
214
|
+
"""
|
215
|
+
try:
|
216
|
+
import ollama
|
217
|
+
|
218
|
+
model = kwargs.get("model", "llama3.1:8b-instruct-q8_0")
|
219
|
+
generation_config = kwargs.get("generation_config", {})
|
220
|
+
|
221
|
+
# Map generation_config to Ollama options
|
222
|
+
options = {
|
223
|
+
"temperature": generation_config.get("temperature", 0.7),
|
224
|
+
"top_p": generation_config.get("top_p", 0.9),
|
225
|
+
"top_k": generation_config.get("top_k"),
|
226
|
+
"repeat_penalty": generation_config.get("repeat_penalty"),
|
227
|
+
"seed": generation_config.get("seed"),
|
228
|
+
"stop": generation_config.get("stop"),
|
229
|
+
"tfs_z": generation_config.get("tfs_z", 1.0),
|
230
|
+
"num_predict": generation_config.get("max_tokens", 500),
|
231
|
+
"num_ctx": generation_config.get("num_ctx"),
|
232
|
+
"num_batch": generation_config.get("num_batch"),
|
233
|
+
"num_thread": generation_config.get("num_thread"),
|
234
|
+
"typical_p": generation_config.get("typical_p"),
|
235
|
+
"mirostat": generation_config.get("mirostat"),
|
236
|
+
"mirostat_tau": generation_config.get("mirostat_tau"),
|
237
|
+
"mirostat_eta": generation_config.get("mirostat_eta"),
|
238
|
+
}
|
239
|
+
|
240
|
+
# Remove None values
|
241
|
+
options = {k: v for k, v in options.items() if v is not None}
|
242
|
+
|
243
|
+
# Call Ollama
|
244
|
+
response = ollama.chat(model=model, messages=messages, options=options)
|
245
|
+
|
246
|
+
# Format response to match standard structure
|
247
|
+
return {
|
248
|
+
"id": f"ollama_{hash(str(messages))}",
|
249
|
+
"content": response["message"]["content"],
|
250
|
+
"role": "assistant",
|
251
|
+
"model": model,
|
252
|
+
"created": response.get("created_at"),
|
253
|
+
"tool_calls": [],
|
254
|
+
"finish_reason": "stop",
|
255
|
+
"usage": {
|
256
|
+
"prompt_tokens": response.get("prompt_eval_count", 0),
|
257
|
+
"completion_tokens": response.get("eval_count", 0),
|
258
|
+
"total_tokens": response.get("prompt_eval_count", 0)
|
259
|
+
+ response.get("eval_count", 0),
|
260
|
+
},
|
261
|
+
"metadata": {
|
262
|
+
"duration_ms": response.get("total_duration", 0) / 1e6,
|
263
|
+
"load_duration_ms": response.get("load_duration", 0) / 1e6,
|
264
|
+
"eval_duration_ms": response.get("eval_duration", 0) / 1e6,
|
265
|
+
},
|
266
|
+
}
|
267
|
+
|
268
|
+
except ImportError:
|
269
|
+
raise RuntimeError(
|
270
|
+
"Ollama library not installed. Install with: pip install ollama"
|
271
|
+
)
|
272
|
+
except Exception as e:
|
273
|
+
raise RuntimeError(f"Ollama error: {str(e)}")
|
274
|
+
|
275
|
+
def embed(self, texts: List[str], **kwargs) -> List[List[float]]:
|
276
|
+
"""
|
277
|
+
Generate embeddings using Ollama.
|
278
|
+
|
279
|
+
Supported kwargs:
|
280
|
+
- model (str): Ollama model name (default: "snowflake-arctic-embed2")
|
281
|
+
- normalize (bool): Normalize embeddings to unit length
|
282
|
+
"""
|
283
|
+
try:
|
284
|
+
import ollama
|
285
|
+
|
286
|
+
model = kwargs.get("model", "snowflake-arctic-embed2")
|
287
|
+
normalize = kwargs.get("normalize", False)
|
288
|
+
|
289
|
+
embeddings = []
|
290
|
+
for text in texts:
|
291
|
+
response = ollama.embeddings(model=model, prompt=text)
|
292
|
+
embedding = response.get("embedding", [])
|
293
|
+
|
294
|
+
if normalize and embedding:
|
295
|
+
# Normalize to unit length
|
296
|
+
magnitude = sum(x * x for x in embedding) ** 0.5
|
297
|
+
if magnitude > 0:
|
298
|
+
embedding = [x / magnitude for x in embedding]
|
299
|
+
|
300
|
+
embeddings.append(embedding)
|
301
|
+
|
302
|
+
return embeddings
|
303
|
+
|
304
|
+
except ImportError:
|
305
|
+
raise RuntimeError(
|
306
|
+
"Ollama library not installed. Install with: pip install ollama"
|
307
|
+
)
|
308
|
+
except Exception as e:
|
309
|
+
raise RuntimeError(f"Ollama embedding error: {str(e)}")
|
310
|
+
|
311
|
+
def get_model_info(self, model: str) -> Dict[str, Any]:
|
312
|
+
"""Get information about an Ollama embedding model."""
|
313
|
+
if model in self._model_cache:
|
314
|
+
return self._model_cache[model]
|
315
|
+
|
316
|
+
# Known embedding model dimensions
|
317
|
+
known_models = {
|
318
|
+
"snowflake-arctic-embed2": {"dimensions": 1024, "max_tokens": 512},
|
319
|
+
"avr/sfr-embedding-mistral": {"dimensions": 4096, "max_tokens": 512},
|
320
|
+
"nomic-embed-text": {"dimensions": 768, "max_tokens": 8192},
|
321
|
+
"mxbai-embed-large": {"dimensions": 1024, "max_tokens": 512},
|
322
|
+
}
|
323
|
+
|
324
|
+
if model in known_models:
|
325
|
+
info = known_models[model].copy()
|
326
|
+
info["description"] = f"Ollama embedding model: {model}"
|
327
|
+
info["capabilities"] = {
|
328
|
+
"batch_processing": True,
|
329
|
+
"gpu_acceleration": True,
|
330
|
+
"normalize": True,
|
331
|
+
}
|
332
|
+
self._model_cache[model] = info
|
333
|
+
return info
|
334
|
+
|
335
|
+
# Default for unknown models
|
336
|
+
return {
|
337
|
+
"dimensions": 1536,
|
338
|
+
"max_tokens": 512,
|
339
|
+
"description": f"Ollama model: {model}",
|
340
|
+
"capabilities": {"batch_processing": True},
|
341
|
+
}
|
342
|
+
|
343
|
+
|
344
|
+
class OpenAIProvider(UnifiedAIProvider):
|
345
|
+
"""
|
346
|
+
OpenAI provider for both LLM and embedding operations.
|
347
|
+
|
348
|
+
Prerequisites:
|
349
|
+
- Set OPENAI_API_KEY environment variable
|
350
|
+
- Install openai package: `pip install openai`
|
351
|
+
|
352
|
+
Supported LLM models:
|
353
|
+
- gpt-4-turbo (latest GPT-4 Turbo)
|
354
|
+
- gpt-4 (standard GPT-4)
|
355
|
+
- gpt-4-32k (32k context window)
|
356
|
+
- gpt-3.5-turbo (latest GPT-3.5)
|
357
|
+
- gpt-3.5-turbo-16k (16k context window)
|
358
|
+
|
359
|
+
Supported embedding models:
|
360
|
+
- text-embedding-3-large (3072 dimensions, configurable)
|
361
|
+
- text-embedding-3-small (1536 dimensions, configurable)
|
362
|
+
- text-embedding-ada-002 (1536 dimensions, legacy)
|
363
|
+
"""
|
364
|
+
|
365
|
+
def is_available(self) -> bool:
|
366
|
+
"""Check if OpenAI is available."""
|
367
|
+
if self._available is not None:
|
368
|
+
return self._available
|
369
|
+
|
370
|
+
try:
|
371
|
+
import os
|
372
|
+
|
373
|
+
# Check for API key
|
374
|
+
self._available = bool(os.getenv("OPENAI_API_KEY"))
|
375
|
+
except ImportError:
|
376
|
+
self._available = False
|
377
|
+
|
378
|
+
return self._available
|
379
|
+
|
380
|
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
381
|
+
"""
|
382
|
+
Generate a chat completion using OpenAI.
|
383
|
+
|
384
|
+
Supported kwargs:
|
385
|
+
- model (str): OpenAI model name (default: "gpt-4")
|
386
|
+
- generation_config (dict): Generation parameters
|
387
|
+
- tools (List[Dict]): Function/tool definitions for function calling
|
388
|
+
"""
|
389
|
+
try:
|
390
|
+
import openai
|
391
|
+
|
392
|
+
model = kwargs.get("model", "gpt-4")
|
393
|
+
generation_config = kwargs.get("generation_config", {})
|
394
|
+
tools = kwargs.get("tools", [])
|
395
|
+
|
396
|
+
# Initialize client if needed
|
397
|
+
if self._client is None:
|
398
|
+
self._client = openai.OpenAI()
|
399
|
+
|
400
|
+
# Prepare request
|
401
|
+
request_params = {
|
402
|
+
"model": model,
|
403
|
+
"messages": messages,
|
404
|
+
"temperature": generation_config.get("temperature", 0.7),
|
405
|
+
"max_tokens": generation_config.get("max_tokens", 500),
|
406
|
+
"top_p": generation_config.get("top_p", 0.9),
|
407
|
+
"frequency_penalty": generation_config.get("frequency_penalty"),
|
408
|
+
"presence_penalty": generation_config.get("presence_penalty"),
|
409
|
+
"stop": generation_config.get("stop"),
|
410
|
+
"n": generation_config.get("n", 1),
|
411
|
+
"stream": kwargs.get("stream", False),
|
412
|
+
"logit_bias": generation_config.get("logit_bias"),
|
413
|
+
"user": generation_config.get("user"),
|
414
|
+
"response_format": generation_config.get("response_format"),
|
415
|
+
"seed": generation_config.get("seed"),
|
416
|
+
}
|
417
|
+
|
418
|
+
# Remove None values
|
419
|
+
request_params = {k: v for k, v in request_params.items() if v is not None}
|
420
|
+
|
421
|
+
# Add tools if provided
|
422
|
+
if tools:
|
423
|
+
request_params["tools"] = tools
|
424
|
+
request_params["tool_choice"] = generation_config.get(
|
425
|
+
"tool_choice", "auto"
|
426
|
+
)
|
427
|
+
|
428
|
+
# Call OpenAI
|
429
|
+
response = self._client.chat.completions.create(**request_params)
|
430
|
+
|
431
|
+
# Format response
|
432
|
+
choice = response.choices[0]
|
433
|
+
return {
|
434
|
+
"id": response.id,
|
435
|
+
"content": choice.message.content,
|
436
|
+
"role": choice.message.role,
|
437
|
+
"model": response.model,
|
438
|
+
"created": response.created,
|
439
|
+
"tool_calls": (
|
440
|
+
choice.message.tool_calls
|
441
|
+
if hasattr(choice.message, "tool_calls")
|
442
|
+
else []
|
443
|
+
),
|
444
|
+
"finish_reason": choice.finish_reason,
|
445
|
+
"usage": {
|
446
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
447
|
+
"completion_tokens": response.usage.completion_tokens,
|
448
|
+
"total_tokens": response.usage.total_tokens,
|
449
|
+
},
|
450
|
+
"metadata": {},
|
451
|
+
}
|
452
|
+
|
453
|
+
except ImportError:
|
454
|
+
raise RuntimeError(
|
455
|
+
"OpenAI library not installed. Install with: pip install openai"
|
456
|
+
)
|
457
|
+
except Exception as e:
|
458
|
+
raise RuntimeError(f"OpenAI error: {str(e)}")
|
459
|
+
|
460
|
+
def embed(self, texts: List[str], **kwargs) -> List[List[float]]:
|
461
|
+
"""
|
462
|
+
Generate embeddings using OpenAI.
|
463
|
+
|
464
|
+
Supported kwargs:
|
465
|
+
- model (str): OpenAI model name (default: "text-embedding-3-small")
|
466
|
+
- dimensions (int): Desired dimensions (only for v3 models)
|
467
|
+
- user (str): Unique user identifier for tracking
|
468
|
+
"""
|
469
|
+
try:
|
470
|
+
import openai
|
471
|
+
|
472
|
+
model = kwargs.get("model", "text-embedding-3-small")
|
473
|
+
dimensions = kwargs.get("dimensions")
|
474
|
+
user = kwargs.get("user")
|
475
|
+
|
476
|
+
# Initialize client if needed
|
477
|
+
if self._client is None:
|
478
|
+
self._client = openai.OpenAI()
|
479
|
+
|
480
|
+
# Prepare request
|
481
|
+
request_params = {"model": model, "input": texts}
|
482
|
+
|
483
|
+
# Add optional parameters
|
484
|
+
if dimensions and "embedding-3" in model:
|
485
|
+
request_params["dimensions"] = dimensions
|
486
|
+
if user:
|
487
|
+
request_params["user"] = user
|
488
|
+
|
489
|
+
# Call OpenAI
|
490
|
+
response = self._client.embeddings.create(**request_params)
|
491
|
+
|
492
|
+
# Extract embeddings
|
493
|
+
embeddings = [item.embedding for item in response.data]
|
494
|
+
|
495
|
+
return embeddings
|
496
|
+
|
497
|
+
except ImportError:
|
498
|
+
raise RuntimeError(
|
499
|
+
"OpenAI library not installed. Install with: pip install openai"
|
500
|
+
)
|
501
|
+
except Exception as e:
|
502
|
+
raise RuntimeError(f"OpenAI embedding error: {str(e)}")
|
503
|
+
|
504
|
+
def get_model_info(self, model: str) -> Dict[str, Any]:
|
505
|
+
"""Get information about an OpenAI embedding model."""
|
506
|
+
models = {
|
507
|
+
"text-embedding-3-large": {
|
508
|
+
"dimensions": 3072,
|
509
|
+
"max_tokens": 8191,
|
510
|
+
"description": "Most capable embedding model, supports dimensions",
|
511
|
+
"capabilities": {
|
512
|
+
"variable_dimensions": True,
|
513
|
+
"min_dimensions": 256,
|
514
|
+
"max_dimensions": 3072,
|
515
|
+
},
|
516
|
+
},
|
517
|
+
"text-embedding-3-small": {
|
518
|
+
"dimensions": 1536,
|
519
|
+
"max_tokens": 8191,
|
520
|
+
"description": "Efficient embedding model, supports dimensions",
|
521
|
+
"capabilities": {
|
522
|
+
"variable_dimensions": True,
|
523
|
+
"min_dimensions": 256,
|
524
|
+
"max_dimensions": 1536,
|
525
|
+
},
|
526
|
+
},
|
527
|
+
"text-embedding-ada-002": {
|
528
|
+
"dimensions": 1536,
|
529
|
+
"max_tokens": 8191,
|
530
|
+
"description": "Legacy embedding model",
|
531
|
+
"capabilities": {"variable_dimensions": False},
|
532
|
+
},
|
533
|
+
}
|
534
|
+
|
535
|
+
return models.get(
|
536
|
+
model,
|
537
|
+
{
|
538
|
+
"dimensions": 1536,
|
539
|
+
"max_tokens": 8191,
|
540
|
+
"description": f"OpenAI model: {model}",
|
541
|
+
"capabilities": {},
|
542
|
+
},
|
543
|
+
)
|
544
|
+
|
545
|
+
|
546
|
+
class AnthropicProvider(LLMProvider):
|
547
|
+
"""
|
548
|
+
Anthropic provider for Claude LLM models.
|
549
|
+
|
550
|
+
Note: Anthropic currently only provides LLM capabilities, not embeddings.
|
551
|
+
|
552
|
+
Prerequisites:
|
553
|
+
- Set ANTHROPIC_API_KEY environment variable
|
554
|
+
- Install anthropic package: `pip install anthropic`
|
555
|
+
|
556
|
+
Supported models:
|
557
|
+
- claude-3-opus-20240229 (Most capable, slower)
|
558
|
+
- claude-3-sonnet-20240229 (Balanced performance)
|
559
|
+
- claude-3-haiku-20240307 (Fastest, most affordable)
|
560
|
+
- claude-2.1 (Previous generation)
|
561
|
+
- claude-2.0
|
562
|
+
- claude-instant-1.2
|
563
|
+
"""
|
564
|
+
|
565
|
+
def is_available(self) -> bool:
|
566
|
+
"""Check if Anthropic is available."""
|
567
|
+
if self._available is not None:
|
568
|
+
return self._available
|
569
|
+
|
570
|
+
try:
|
571
|
+
import os
|
572
|
+
|
573
|
+
# Check for API key
|
574
|
+
self._available = bool(os.getenv("ANTHROPIC_API_KEY"))
|
575
|
+
except ImportError:
|
576
|
+
self._available = False
|
577
|
+
|
578
|
+
return self._available
|
579
|
+
|
580
|
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
581
|
+
"""Generate a chat completion using Anthropic."""
|
582
|
+
try:
|
583
|
+
import anthropic
|
584
|
+
|
585
|
+
model = kwargs.get("model", "claude-3-sonnet-20240229")
|
586
|
+
generation_config = kwargs.get("generation_config", {})
|
587
|
+
|
588
|
+
# Initialize client if needed
|
589
|
+
if self._client is None:
|
590
|
+
self._client = anthropic.Anthropic()
|
591
|
+
|
592
|
+
# Convert messages to Anthropic format
|
593
|
+
system_message = None
|
594
|
+
user_messages = []
|
595
|
+
|
596
|
+
for msg in messages:
|
597
|
+
if msg["role"] == "system":
|
598
|
+
system_message = msg["content"]
|
599
|
+
else:
|
600
|
+
user_messages.append(msg)
|
601
|
+
|
602
|
+
# Call Anthropic
|
603
|
+
response = self._client.messages.create(
|
604
|
+
model=model,
|
605
|
+
messages=user_messages,
|
606
|
+
system=system_message,
|
607
|
+
max_tokens=generation_config.get("max_tokens", 500),
|
608
|
+
temperature=generation_config.get("temperature", 0.7),
|
609
|
+
top_p=generation_config.get("top_p"),
|
610
|
+
top_k=generation_config.get("top_k"),
|
611
|
+
stop_sequences=generation_config.get("stop_sequences"),
|
612
|
+
metadata=generation_config.get("metadata"),
|
613
|
+
)
|
614
|
+
|
615
|
+
# Format response
|
616
|
+
return {
|
617
|
+
"id": response.id,
|
618
|
+
"content": response.content[0].text,
|
619
|
+
"role": "assistant",
|
620
|
+
"model": response.model,
|
621
|
+
"created": None, # Anthropic doesn't provide this
|
622
|
+
"tool_calls": [], # Handle tool use if needed
|
623
|
+
"finish_reason": response.stop_reason,
|
624
|
+
"usage": {
|
625
|
+
"prompt_tokens": response.usage.input_tokens,
|
626
|
+
"completion_tokens": response.usage.output_tokens,
|
627
|
+
"total_tokens": response.usage.input_tokens
|
628
|
+
+ response.usage.output_tokens,
|
629
|
+
},
|
630
|
+
"metadata": {},
|
631
|
+
}
|
632
|
+
|
633
|
+
except ImportError:
|
634
|
+
raise RuntimeError(
|
635
|
+
"Anthropic library not installed. Install with: pip install anthropic"
|
636
|
+
)
|
637
|
+
except Exception as e:
|
638
|
+
raise RuntimeError(f"Anthropic error: {str(e)}")
|
639
|
+
|
640
|
+
|
641
|
+
class CohereProvider(EmbeddingProvider):
|
642
|
+
"""
|
643
|
+
Cohere provider for embedding operations.
|
644
|
+
|
645
|
+
Note: This implementation focuses on embeddings. Cohere also provides
|
646
|
+
LLM capabilities which could be added in the future.
|
647
|
+
|
648
|
+
Prerequisites:
|
649
|
+
- Set COHERE_API_KEY environment variable
|
650
|
+
- Install cohere package: `pip install cohere`
|
651
|
+
|
652
|
+
Supported embedding models:
|
653
|
+
- embed-english-v3.0 (1024 dimensions)
|
654
|
+
- embed-multilingual-v3.0 (1024 dimensions)
|
655
|
+
- embed-english-light-v3.0 (384 dimensions)
|
656
|
+
- embed-multilingual-light-v3.0 (384 dimensions)
|
657
|
+
"""
|
658
|
+
|
659
|
+
def is_available(self) -> bool:
|
660
|
+
"""Check if Cohere is available."""
|
661
|
+
if self._available is not None:
|
662
|
+
return self._available
|
663
|
+
|
664
|
+
try:
|
665
|
+
import os
|
666
|
+
|
667
|
+
# Check for API key
|
668
|
+
self._available = bool(os.getenv("COHERE_API_KEY"))
|
669
|
+
except ImportError:
|
670
|
+
self._available = False
|
671
|
+
|
672
|
+
return self._available
|
673
|
+
|
674
|
+
def embed(self, texts: List[str], **kwargs) -> List[List[float]]:
|
675
|
+
"""Generate embeddings using Cohere."""
|
676
|
+
try:
|
677
|
+
import cohere
|
678
|
+
|
679
|
+
model = kwargs.get("model", "embed-english-v3.0")
|
680
|
+
input_type = kwargs.get("input_type", "search_document")
|
681
|
+
truncate = kwargs.get("truncate", "END")
|
682
|
+
|
683
|
+
# Initialize client if needed
|
684
|
+
if self._client is None:
|
685
|
+
self._client = cohere.Client()
|
686
|
+
|
687
|
+
# Call Cohere
|
688
|
+
response = self._client.embed(
|
689
|
+
texts=texts, model=model, input_type=input_type, truncate=truncate
|
690
|
+
)
|
691
|
+
|
692
|
+
# Extract embeddings
|
693
|
+
embeddings = response.embeddings
|
694
|
+
|
695
|
+
return embeddings
|
696
|
+
|
697
|
+
except ImportError:
|
698
|
+
raise RuntimeError(
|
699
|
+
"Cohere library not installed. Install with: pip install cohere"
|
700
|
+
)
|
701
|
+
except Exception as e:
|
702
|
+
raise RuntimeError(f"Cohere embedding error: {str(e)}")
|
703
|
+
|
704
|
+
def get_model_info(self, model: str) -> Dict[str, Any]:
|
705
|
+
"""Get information about a Cohere embedding model."""
|
706
|
+
models = {
|
707
|
+
"embed-english-v3.0": {
|
708
|
+
"dimensions": 1024,
|
709
|
+
"max_tokens": 512,
|
710
|
+
"description": "English embedding model v3",
|
711
|
+
"capabilities": {
|
712
|
+
"input_types": [
|
713
|
+
"search_query",
|
714
|
+
"search_document",
|
715
|
+
"classification",
|
716
|
+
"clustering",
|
717
|
+
],
|
718
|
+
"languages": ["en"],
|
719
|
+
},
|
720
|
+
},
|
721
|
+
"embed-multilingual-v3.0": {
|
722
|
+
"dimensions": 1024,
|
723
|
+
"max_tokens": 512,
|
724
|
+
"description": "Multilingual embedding model v3",
|
725
|
+
"capabilities": {
|
726
|
+
"input_types": [
|
727
|
+
"search_query",
|
728
|
+
"search_document",
|
729
|
+
"classification",
|
730
|
+
"clustering",
|
731
|
+
],
|
732
|
+
"languages": [
|
733
|
+
"en",
|
734
|
+
"es",
|
735
|
+
"fr",
|
736
|
+
"de",
|
737
|
+
"it",
|
738
|
+
"pt",
|
739
|
+
"ja",
|
740
|
+
"ko",
|
741
|
+
"zh",
|
742
|
+
"ar",
|
743
|
+
"hi",
|
744
|
+
"tr",
|
745
|
+
],
|
746
|
+
},
|
747
|
+
},
|
748
|
+
"embed-english-light-v3.0": {
|
749
|
+
"dimensions": 384,
|
750
|
+
"max_tokens": 512,
|
751
|
+
"description": "Lightweight English embedding model v3",
|
752
|
+
"capabilities": {
|
753
|
+
"input_types": [
|
754
|
+
"search_query",
|
755
|
+
"search_document",
|
756
|
+
"classification",
|
757
|
+
"clustering",
|
758
|
+
],
|
759
|
+
"languages": ["en"],
|
760
|
+
},
|
761
|
+
},
|
762
|
+
"embed-multilingual-light-v3.0": {
|
763
|
+
"dimensions": 384,
|
764
|
+
"max_tokens": 512,
|
765
|
+
"description": "Lightweight multilingual embedding model v3",
|
766
|
+
"capabilities": {
|
767
|
+
"input_types": [
|
768
|
+
"search_query",
|
769
|
+
"search_document",
|
770
|
+
"classification",
|
771
|
+
"clustering",
|
772
|
+
],
|
773
|
+
"languages": [
|
774
|
+
"en",
|
775
|
+
"es",
|
776
|
+
"fr",
|
777
|
+
"de",
|
778
|
+
"it",
|
779
|
+
"pt",
|
780
|
+
"ja",
|
781
|
+
"ko",
|
782
|
+
"zh",
|
783
|
+
"ar",
|
784
|
+
"hi",
|
785
|
+
"tr",
|
786
|
+
],
|
787
|
+
},
|
788
|
+
},
|
789
|
+
}
|
790
|
+
|
791
|
+
return models.get(
|
792
|
+
model,
|
793
|
+
{
|
794
|
+
"dimensions": 1024,
|
795
|
+
"max_tokens": 512,
|
796
|
+
"description": f"Cohere embedding model: {model}",
|
797
|
+
"capabilities": {},
|
798
|
+
},
|
799
|
+
)
|
800
|
+
|
801
|
+
|
802
|
+
class HuggingFaceProvider(EmbeddingProvider):
|
803
|
+
"""
|
804
|
+
HuggingFace provider for embedding operations.
|
805
|
+
|
806
|
+
This provider can use both the HuggingFace Inference API and local models.
|
807
|
+
|
808
|
+
Prerequisites for API:
|
809
|
+
- Set HUGGINGFACE_API_KEY environment variable
|
810
|
+
- Install requests: `pip install requests`
|
811
|
+
|
812
|
+
Prerequisites for local:
|
813
|
+
- Install transformers: `pip install transformers torch`
|
814
|
+
|
815
|
+
Supported embedding models:
|
816
|
+
- sentence-transformers/all-MiniLM-L6-v2 (384 dimensions)
|
817
|
+
- sentence-transformers/all-mpnet-base-v2 (768 dimensions)
|
818
|
+
- BAAI/bge-large-en-v1.5 (1024 dimensions)
|
819
|
+
- thenlper/gte-large (1024 dimensions)
|
820
|
+
"""
|
821
|
+
|
822
|
+
def __init__(self):
|
823
|
+
super().__init__()
|
824
|
+
self._models = {}
|
825
|
+
self._available_api = None
|
826
|
+
self._available_local = None
|
827
|
+
|
828
|
+
def is_available(self) -> bool:
|
829
|
+
"""Check if HuggingFace is available (either API or local)."""
|
830
|
+
# Check API availability
|
831
|
+
if self._available_api is None:
|
832
|
+
try:
|
833
|
+
import os
|
834
|
+
|
835
|
+
self._available_api = bool(os.getenv("HUGGINGFACE_API_KEY"))
|
836
|
+
except Exception:
|
837
|
+
self._available_api = False
|
838
|
+
|
839
|
+
# Check local availability
|
840
|
+
if self._available_local is None:
|
841
|
+
try:
|
842
|
+
# Check if torch and transformers are available
|
843
|
+
import importlib.util
|
844
|
+
|
845
|
+
torch_spec = importlib.util.find_spec("torch")
|
846
|
+
transformers_spec = importlib.util.find_spec("transformers")
|
847
|
+
self._available_local = (
|
848
|
+
torch_spec is not None and transformers_spec is not None
|
849
|
+
)
|
850
|
+
except ImportError:
|
851
|
+
self._available_local = False
|
852
|
+
|
853
|
+
return self._available_api or self._available_local
|
854
|
+
|
855
|
+
def embed(self, texts: List[str], **kwargs) -> List[List[float]]:
|
856
|
+
"""Generate embeddings using HuggingFace."""
|
857
|
+
model = kwargs.get("model", "sentence-transformers/all-MiniLM-L6-v2")
|
858
|
+
use_api = kwargs.get("use_api", self._available_api)
|
859
|
+
normalize = kwargs.get("normalize", True)
|
860
|
+
|
861
|
+
if use_api and self._available_api:
|
862
|
+
return self._embed_api(texts, model, normalize)
|
863
|
+
elif self._available_local:
|
864
|
+
device = kwargs.get("device", "cpu")
|
865
|
+
return self._embed_local(texts, model, device, normalize)
|
866
|
+
else:
|
867
|
+
raise RuntimeError(
|
868
|
+
"Neither HuggingFace API nor local transformers available"
|
869
|
+
)
|
870
|
+
|
871
|
+
def _embed_api(
|
872
|
+
self, texts: List[str], model: str, normalize: bool
|
873
|
+
) -> List[List[float]]:
|
874
|
+
"""Generate embeddings using HuggingFace Inference API."""
|
875
|
+
try:
|
876
|
+
import os
|
877
|
+
|
878
|
+
import requests
|
879
|
+
|
880
|
+
api_key = os.getenv("HUGGINGFACE_API_KEY")
|
881
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
882
|
+
|
883
|
+
api_url = f"https://api-inference.huggingface.co/models/{model}"
|
884
|
+
|
885
|
+
embeddings = []
|
886
|
+
for text in texts:
|
887
|
+
response = requests.post(
|
888
|
+
api_url, headers=headers, json={"inputs": text}
|
889
|
+
)
|
890
|
+
|
891
|
+
if response.status_code != 200:
|
892
|
+
raise RuntimeError(f"API error: {response.text}")
|
893
|
+
|
894
|
+
embedding = response.json()
|
895
|
+
if isinstance(embedding, list) and isinstance(embedding[0], list):
|
896
|
+
embedding = embedding[0] # Extract from nested list
|
897
|
+
|
898
|
+
if normalize:
|
899
|
+
magnitude = sum(x * x for x in embedding) ** 0.5
|
900
|
+
if magnitude > 0:
|
901
|
+
embedding = [x / magnitude for x in embedding]
|
902
|
+
|
903
|
+
embeddings.append(embedding)
|
904
|
+
|
905
|
+
return embeddings
|
906
|
+
|
907
|
+
except Exception as e:
|
908
|
+
raise RuntimeError(f"HuggingFace API error: {str(e)}")
|
909
|
+
|
910
|
+
def _embed_local(
|
911
|
+
self, texts: List[str], model: str, device: str, normalize: bool
|
912
|
+
) -> List[List[float]]:
|
913
|
+
"""Generate embeddings using local HuggingFace model."""
|
914
|
+
try:
|
915
|
+
import torch
|
916
|
+
from transformers import AutoModel, AutoTokenizer
|
917
|
+
|
918
|
+
# Load model if not cached
|
919
|
+
if model not in self._models:
|
920
|
+
tokenizer = AutoTokenizer.from_pretrained(model)
|
921
|
+
model_obj = AutoModel.from_pretrained(model)
|
922
|
+
model_obj.to(device)
|
923
|
+
model_obj.eval() # noqa: PGH001
|
924
|
+
self._models[model] = (tokenizer, model_obj)
|
925
|
+
|
926
|
+
tokenizer, model_obj = self._models[model]
|
927
|
+
|
928
|
+
embeddings = []
|
929
|
+
with torch.no_grad():
|
930
|
+
for text in texts:
|
931
|
+
# Tokenize
|
932
|
+
inputs = tokenizer(
|
933
|
+
text, padding=True, truncation=True, return_tensors="pt"
|
934
|
+
).to(device)
|
935
|
+
|
936
|
+
# Generate embeddings
|
937
|
+
outputs = model_obj(**inputs)
|
938
|
+
|
939
|
+
# Mean pooling
|
940
|
+
attention_mask = inputs["attention_mask"]
|
941
|
+
token_embeddings = outputs.last_hidden_state
|
942
|
+
input_mask_expanded = (
|
943
|
+
attention_mask.unsqueeze(-1)
|
944
|
+
.expand(token_embeddings.size())
|
945
|
+
.float()
|
946
|
+
)
|
947
|
+
embedding = torch.sum(
|
948
|
+
token_embeddings * input_mask_expanded, 1
|
949
|
+
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
950
|
+
|
951
|
+
# Convert to list
|
952
|
+
embedding = embedding.squeeze().cpu().numpy().tolist()
|
953
|
+
|
954
|
+
if normalize:
|
955
|
+
magnitude = sum(x * x for x in embedding) ** 0.5
|
956
|
+
if magnitude > 0:
|
957
|
+
embedding = [x / magnitude for x in embedding]
|
958
|
+
|
959
|
+
embeddings.append(embedding)
|
960
|
+
|
961
|
+
return embeddings
|
962
|
+
|
963
|
+
except ImportError:
|
964
|
+
raise RuntimeError(
|
965
|
+
"Transformers library not installed. Install with: pip install transformers torch"
|
966
|
+
)
|
967
|
+
except Exception as e:
|
968
|
+
raise RuntimeError(f"HuggingFace local error: {str(e)}")
|
969
|
+
|
970
|
+
def get_model_info(self, model: str) -> Dict[str, Any]:
|
971
|
+
"""Get information about a HuggingFace embedding model."""
|
972
|
+
models = {
|
973
|
+
"sentence-transformers/all-MiniLM-L6-v2": {
|
974
|
+
"dimensions": 384,
|
975
|
+
"max_tokens": 256,
|
976
|
+
"description": "Efficient sentence transformer model",
|
977
|
+
"capabilities": {
|
978
|
+
"languages": ["en"],
|
979
|
+
"use_cases": ["semantic_search", "clustering", "classification"],
|
980
|
+
},
|
981
|
+
},
|
982
|
+
"sentence-transformers/all-mpnet-base-v2": {
|
983
|
+
"dimensions": 768,
|
984
|
+
"max_tokens": 384,
|
985
|
+
"description": "High-quality sentence transformer model",
|
986
|
+
"capabilities": {
|
987
|
+
"languages": ["en"],
|
988
|
+
"use_cases": ["semantic_search", "clustering", "classification"],
|
989
|
+
},
|
990
|
+
},
|
991
|
+
"BAAI/bge-large-en-v1.5": {
|
992
|
+
"dimensions": 1024,
|
993
|
+
"max_tokens": 512,
|
994
|
+
"description": "BAAI General Embedding model",
|
995
|
+
"capabilities": {
|
996
|
+
"languages": ["en"],
|
997
|
+
"use_cases": ["retrieval", "reranking", "classification"],
|
998
|
+
},
|
999
|
+
},
|
1000
|
+
"thenlper/gte-large": {
|
1001
|
+
"dimensions": 1024,
|
1002
|
+
"max_tokens": 512,
|
1003
|
+
"description": "General Text Embeddings model",
|
1004
|
+
"capabilities": {
|
1005
|
+
"languages": ["en"],
|
1006
|
+
"use_cases": ["retrieval", "similarity", "clustering"],
|
1007
|
+
},
|
1008
|
+
},
|
1009
|
+
}
|
1010
|
+
|
1011
|
+
return models.get(
|
1012
|
+
model,
|
1013
|
+
{
|
1014
|
+
"dimensions": 768, # Common default
|
1015
|
+
"max_tokens": 512,
|
1016
|
+
"description": f"HuggingFace model: {model}",
|
1017
|
+
"capabilities": {},
|
1018
|
+
},
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
|
1022
|
+
class MockProvider(UnifiedAIProvider):
|
1023
|
+
"""
|
1024
|
+
Mock provider for testing and development.
|
1025
|
+
|
1026
|
+
This provider generates deterministic mock responses for both LLM and
|
1027
|
+
embedding operations without making actual API calls.
|
1028
|
+
|
1029
|
+
Features:
|
1030
|
+
- Always available (no dependencies)
|
1031
|
+
- Generates consistent responses based on input
|
1032
|
+
- Zero latency
|
1033
|
+
- Supports both chat and embedding operations
|
1034
|
+
"""
|
1035
|
+
|
1036
|
+
def is_available(self) -> bool:
|
1037
|
+
"""Mock provider is always available."""
|
1038
|
+
return True
|
1039
|
+
|
1040
|
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
1041
|
+
"""Generate mock LLM response."""
|
1042
|
+
last_user_message = ""
|
1043
|
+
for msg in reversed(messages):
|
1044
|
+
if msg.get("role") == "user":
|
1045
|
+
last_user_message = msg.get("content", "")
|
1046
|
+
break
|
1047
|
+
|
1048
|
+
# Generate contextual mock response
|
1049
|
+
if "analyze" in last_user_message.lower():
|
1050
|
+
response_content = "Based on the provided data and context, I can see several key patterns..."
|
1051
|
+
elif "create" in last_user_message.lower():
|
1052
|
+
response_content = "I'll help you create that. Based on the requirements..."
|
1053
|
+
elif "?" in last_user_message:
|
1054
|
+
response_content = f"Regarding your question about '{last_user_message[:50]}...', here's what I found..."
|
1055
|
+
else:
|
1056
|
+
response_content = f"I understand you want me to work with: '{last_user_message[:100]}...'."
|
1057
|
+
|
1058
|
+
return {
|
1059
|
+
"id": f"mock_{hash(last_user_message)}",
|
1060
|
+
"content": response_content,
|
1061
|
+
"role": "assistant",
|
1062
|
+
"model": kwargs.get("model", "mock-model"),
|
1063
|
+
"created": 1701234567,
|
1064
|
+
"tool_calls": [],
|
1065
|
+
"finish_reason": "stop",
|
1066
|
+
"usage": {
|
1067
|
+
"prompt_tokens": len(
|
1068
|
+
" ".join(msg.get("content", "") for msg in messages)
|
1069
|
+
)
|
1070
|
+
// 4,
|
1071
|
+
"completion_tokens": len(response_content) // 4,
|
1072
|
+
"total_tokens": 0, # Will be calculated
|
1073
|
+
},
|
1074
|
+
"metadata": {},
|
1075
|
+
}
|
1076
|
+
|
1077
|
+
def embed(self, texts: List[str], **kwargs) -> List[List[float]]:
|
1078
|
+
"""Generate mock embeddings."""
|
1079
|
+
model = kwargs.get("model", "mock-embedding")
|
1080
|
+
dimensions = kwargs.get("dimensions", 1536)
|
1081
|
+
normalize = kwargs.get("normalize", True)
|
1082
|
+
|
1083
|
+
embeddings = []
|
1084
|
+
for text in texts:
|
1085
|
+
# Generate deterministic embedding based on text hash
|
1086
|
+
seed = int(hashlib.md5(f"{model}:{text}".encode()).hexdigest()[:8], 16)
|
1087
|
+
|
1088
|
+
import random
|
1089
|
+
|
1090
|
+
random.seed(seed)
|
1091
|
+
|
1092
|
+
# Generate random vector
|
1093
|
+
embedding = [random.gauss(0, 1) for _ in range(dimensions)]
|
1094
|
+
|
1095
|
+
# Normalize if requested
|
1096
|
+
if normalize:
|
1097
|
+
magnitude = sum(x * x for x in embedding) ** 0.5
|
1098
|
+
if magnitude > 0:
|
1099
|
+
embedding = [x / magnitude for x in embedding]
|
1100
|
+
|
1101
|
+
embeddings.append(embedding)
|
1102
|
+
|
1103
|
+
return embeddings
|
1104
|
+
|
1105
|
+
def get_model_info(self, model: str) -> Dict[str, Any]:
|
1106
|
+
"""Get information about a mock embedding model."""
|
1107
|
+
models = {
|
1108
|
+
"mock-embedding-small": {"dimensions": 384, "max_tokens": 512},
|
1109
|
+
"mock-embedding": {"dimensions": 1536, "max_tokens": 8192},
|
1110
|
+
"mock-embedding-large": {"dimensions": 3072, "max_tokens": 8192},
|
1111
|
+
}
|
1112
|
+
|
1113
|
+
return models.get(
|
1114
|
+
model,
|
1115
|
+
{
|
1116
|
+
"dimensions": 1536,
|
1117
|
+
"max_tokens": 8192,
|
1118
|
+
"description": f"Mock embedding model: {model}",
|
1119
|
+
"capabilities": {"all_features": True},
|
1120
|
+
},
|
1121
|
+
)
|
1122
|
+
|
1123
|
+
|
1124
|
+
# ============================================================================
|
1125
|
+
# Provider Registry and Factory
|
1126
|
+
# ============================================================================
|
1127
|
+
|
1128
|
+
# Provider registry mapping names to classes
|
1129
|
+
PROVIDERS = {
|
1130
|
+
"ollama": OllamaProvider,
|
1131
|
+
"openai": OpenAIProvider,
|
1132
|
+
"anthropic": AnthropicProvider,
|
1133
|
+
"cohere": CohereProvider,
|
1134
|
+
"huggingface": HuggingFaceProvider,
|
1135
|
+
"mock": MockProvider,
|
1136
|
+
}
|
1137
|
+
|
1138
|
+
|
1139
|
+
def get_provider(
|
1140
|
+
provider_name: str, provider_type: Optional[str] = None
|
1141
|
+
) -> Union[BaseAIProvider, LLMProvider, EmbeddingProvider]:
|
1142
|
+
"""
|
1143
|
+
Get an AI provider instance by name.
|
1144
|
+
|
1145
|
+
This factory function creates and returns the appropriate provider instance
|
1146
|
+
based on the provider name. It can optionally check for specific capabilities.
|
1147
|
+
|
1148
|
+
Args:
|
1149
|
+
provider_name (str): Name of the provider to instantiate.
|
1150
|
+
Valid options: "ollama", "openai", "anthropic", "cohere", "huggingface", "mock"
|
1151
|
+
Case-insensitive.
|
1152
|
+
provider_type (str, optional): Required capability - "chat", "embeddings", or None for any.
|
1153
|
+
If specified, will raise an error if the provider doesn't support it.
|
1154
|
+
|
1155
|
+
Returns:
|
1156
|
+
Provider instance with the requested capabilities.
|
1157
|
+
|
1158
|
+
Raises:
|
1159
|
+
ValueError: If the provider name is not recognized or doesn't support the requested type.
|
1160
|
+
|
1161
|
+
Examples:
|
1162
|
+
|
1163
|
+
Get any provider::
|
1164
|
+
|
1165
|
+
provider = get_provider("openai")
|
1166
|
+
if provider.supports_chat():
|
1167
|
+
# Use for chat
|
1168
|
+
if provider.supports_embeddings():
|
1169
|
+
# Use for embeddings
|
1170
|
+
|
1171
|
+
Get chat-only provider:
|
1172
|
+
|
1173
|
+
chat_provider = get_provider("anthropic", "chat")
|
1174
|
+
response = chat_provider.chat(messages, model="claude-3-sonnet")
|
1175
|
+
|
1176
|
+
Get embedding-only provider:
|
1177
|
+
|
1178
|
+
embed_provider = get_provider("cohere", "embeddings")
|
1179
|
+
embeddings = embed_provider.embed(texts, model="embed-english-v3.0")
|
1180
|
+
|
1181
|
+
Check provider capabilities:
|
1182
|
+
|
1183
|
+
provider = get_provider("ollama")
|
1184
|
+
capabilities = provider.get_capabilities()
|
1185
|
+
print(f"Chat: {capabilities['chat']}, Embeddings: {capabilities['embeddings']}")
|
1186
|
+
"""
|
1187
|
+
provider_class = PROVIDERS.get(provider_name.lower())
|
1188
|
+
if not provider_class:
|
1189
|
+
raise ValueError(
|
1190
|
+
f"Unknown provider: {provider_name}. Available: {list(PROVIDERS.keys())}"
|
1191
|
+
)
|
1192
|
+
|
1193
|
+
provider = provider_class()
|
1194
|
+
|
1195
|
+
# Check for required capability if specified
|
1196
|
+
if provider_type:
|
1197
|
+
if provider_type == "chat" and not provider.supports_chat():
|
1198
|
+
raise ValueError(
|
1199
|
+
f"Provider {provider_name} does not support chat operations"
|
1200
|
+
)
|
1201
|
+
elif provider_type == "embeddings" and not provider.supports_embeddings():
|
1202
|
+
raise ValueError(
|
1203
|
+
f"Provider {provider_name} does not support embedding operations"
|
1204
|
+
)
|
1205
|
+
elif provider_type not in ["chat", "embeddings"]:
|
1206
|
+
raise ValueError(
|
1207
|
+
f"Invalid provider_type: {provider_type}. Must be 'chat', 'embeddings', or None"
|
1208
|
+
)
|
1209
|
+
|
1210
|
+
return provider
|
1211
|
+
|
1212
|
+
|
1213
|
+
def get_available_providers(
|
1214
|
+
provider_type: Optional[str] = None,
|
1215
|
+
) -> Dict[str, Dict[str, Any]]:
|
1216
|
+
"""
|
1217
|
+
Get information about all available providers.
|
1218
|
+
|
1219
|
+
Args:
|
1220
|
+
provider_type (str, optional): Filter by capability - "chat", "embeddings", or None for all.
|
1221
|
+
|
1222
|
+
Returns:
|
1223
|
+
Dict mapping provider names to their availability and capabilities.
|
1224
|
+
|
1225
|
+
Examples:
|
1226
|
+
|
1227
|
+
Get all providers::
|
1228
|
+
|
1229
|
+
all_providers = get_available_providers()
|
1230
|
+
for name, info in all_providers.items():
|
1231
|
+
print(f"{name}: Available={info['available']}, Chat={info['chat']}, Embeddings={info['embeddings']}")
|
1232
|
+
|
1233
|
+
Get only chat providers:
|
1234
|
+
|
1235
|
+
chat_providers = get_available_providers("chat")
|
1236
|
+
|
1237
|
+
Get only embedding providers:
|
1238
|
+
|
1239
|
+
embed_providers = get_available_providers("embeddings")
|
1240
|
+
"""
|
1241
|
+
results = {}
|
1242
|
+
|
1243
|
+
for name in PROVIDERS:
|
1244
|
+
try:
|
1245
|
+
provider = get_provider(name)
|
1246
|
+
capabilities = provider.get_capabilities()
|
1247
|
+
|
1248
|
+
# Apply filter if specified
|
1249
|
+
if provider_type == "chat" and not capabilities.get("chat"):
|
1250
|
+
continue
|
1251
|
+
elif provider_type == "embeddings" and not capabilities.get("embeddings"):
|
1252
|
+
continue
|
1253
|
+
|
1254
|
+
results[name] = {
|
1255
|
+
"available": provider.is_available(),
|
1256
|
+
"chat": capabilities.get("chat", False),
|
1257
|
+
"embeddings": capabilities.get("embeddings", False),
|
1258
|
+
"description": (
|
1259
|
+
provider.__class__.__doc__.split("\n")[1].strip()
|
1260
|
+
if provider.__class__.__doc__
|
1261
|
+
else ""
|
1262
|
+
),
|
1263
|
+
}
|
1264
|
+
except Exception as e:
|
1265
|
+
results[name] = {
|
1266
|
+
"available": False,
|
1267
|
+
"error": str(e),
|
1268
|
+
"chat": False,
|
1269
|
+
"embeddings": False,
|
1270
|
+
}
|
1271
|
+
|
1272
|
+
return results
|