banko-ai-assistant 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- banko_ai/__init__.py +19 -0
- banko_ai/__main__.py +10 -0
- banko_ai/ai_providers/__init__.py +18 -0
- banko_ai/ai_providers/aws_provider.py +337 -0
- banko_ai/ai_providers/base.py +175 -0
- banko_ai/ai_providers/factory.py +84 -0
- banko_ai/ai_providers/gemini_provider.py +340 -0
- banko_ai/ai_providers/openai_provider.py +295 -0
- banko_ai/ai_providers/watsonx_provider.py +591 -0
- banko_ai/cli.py +374 -0
- banko_ai/config/__init__.py +5 -0
- banko_ai/config/settings.py +216 -0
- banko_ai/static/Anallytics.png +0 -0
- banko_ai/static/Graph.png +0 -0
- banko_ai/static/Graph2.png +0 -0
- banko_ai/static/ai-status.png +0 -0
- banko_ai/static/banko-ai-assistant-watsonx.gif +0 -0
- banko_ai/static/banko-db-ops.png +0 -0
- banko_ai/static/banko-response.png +0 -0
- banko_ai/static/cache-stats.png +0 -0
- banko_ai/static/creditcard.png +0 -0
- banko_ai/static/profilepic.jpeg +0 -0
- banko_ai/static/query_watcher.png +0 -0
- banko_ai/static/roach-logo.svg +54 -0
- banko_ai/static/watsonx-icon.svg +1 -0
- banko_ai/templates/base.html +59 -0
- banko_ai/templates/dashboard.html +569 -0
- banko_ai/templates/index.html +1499 -0
- banko_ai/templates/login.html +41 -0
- banko_ai/utils/__init__.py +8 -0
- banko_ai/utils/cache_manager.py +525 -0
- banko_ai/utils/database.py +202 -0
- banko_ai/utils/migration.py +123 -0
- banko_ai/vector_search/__init__.py +18 -0
- banko_ai/vector_search/enrichment.py +278 -0
- banko_ai/vector_search/generator.py +329 -0
- banko_ai/vector_search/search.py +463 -0
- banko_ai/web/__init__.py +13 -0
- banko_ai/web/app.py +668 -0
- banko_ai/web/auth.py +73 -0
- banko_ai_assistant-1.0.0.dist-info/METADATA +414 -0
- banko_ai_assistant-1.0.0.dist-info/RECORD +46 -0
- banko_ai_assistant-1.0.0.dist-info/WHEEL +5 -0
- banko_ai_assistant-1.0.0.dist-info/entry_points.txt +2 -0
- banko_ai_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
- banko_ai_assistant-1.0.0.dist-info/top_level.txt +1 -0
banko_ai/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
"""
|
2
|
+
Banko AI Assistant - AI-powered expense analysis and RAG system.
|
3
|
+
|
4
|
+
A modern Python package for AI-powered expense analysis using CockroachDB vector search
|
5
|
+
and multi-provider AI support (OpenAI, AWS Bedrock, IBM Watsonx, Google Gemini).
|
6
|
+
"""
|
7
|
+
|
8
|
+
__version__ = "1.0.0"
|
9
|
+
__author__ = "Virag Tripathi"
|
10
|
+
__email__ = "virag.tripathi@gmail.com"
|
11
|
+
|
12
|
+
from .config.settings import Config
|
13
|
+
|
14
|
+
def create_app():
|
15
|
+
"""Create Flask application (lazy import)."""
|
16
|
+
from .web.app import create_app as _create_app
|
17
|
+
return _create_app()
|
18
|
+
|
19
|
+
__all__ = ["Config", "create_app", "__version__"]
|
banko_ai/__main__.py
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
"""AI provider implementations for Banko AI Assistant."""
|
2
|
+
|
3
|
+
from .base import AIProvider, AIProviderError
|
4
|
+
from .openai_provider import OpenAIProvider
|
5
|
+
from .aws_provider import AWSProvider
|
6
|
+
from .watsonx_provider import WatsonxProvider
|
7
|
+
from .gemini_provider import GeminiProvider
|
8
|
+
from .factory import AIProviderFactory
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"AIProvider",
|
12
|
+
"AIProviderError",
|
13
|
+
"OpenAIProvider",
|
14
|
+
"AWSProvider",
|
15
|
+
"WatsonxProvider",
|
16
|
+
"GeminiProvider",
|
17
|
+
"AIProviderFactory"
|
18
|
+
]
|
@@ -0,0 +1,337 @@
|
|
1
|
+
"""
|
2
|
+
AWS Bedrock AI provider implementation.
|
3
|
+
|
4
|
+
This module provides AWS Bedrock integration for vector search and RAG responses.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
import json
|
9
|
+
from typing import List, Dict, Any, Optional
|
10
|
+
import boto3
|
11
|
+
from sentence_transformers import SentenceTransformer
|
12
|
+
from sqlalchemy import create_engine, text
|
13
|
+
|
14
|
+
from .base import AIProvider, SearchResult, RAGResponse, AIConnectionError, AIAuthenticationError
|
15
|
+
|
16
|
+
|
17
|
+
class AWSProvider(AIProvider):
|
18
|
+
"""AWS Bedrock AI provider implementation."""
|
19
|
+
|
20
|
+
def __init__(self, config: Dict[str, Any], cache_manager=None):
|
21
|
+
"""Initialize AWS provider."""
|
22
|
+
self.access_key_id = config.get("access_key_id")
|
23
|
+
self.secret_access_key = config.get("secret_access_key")
|
24
|
+
self.region = config.get("region", "us-east-1")
|
25
|
+
self.bedrock_client = None
|
26
|
+
self.embedding_model = None
|
27
|
+
self.db_engine = None
|
28
|
+
self.cache_manager = cache_manager
|
29
|
+
super().__init__(config)
|
30
|
+
|
31
|
+
def _validate_config(self) -> None:
|
32
|
+
"""Validate AWS configuration."""
|
33
|
+
if not self.access_key_id or not self.secret_access_key:
|
34
|
+
raise AIAuthenticationError("AWS access key ID and secret access key are required")
|
35
|
+
|
36
|
+
# Initialize Bedrock client
|
37
|
+
try:
|
38
|
+
self.bedrock_client = boto3.client(
|
39
|
+
'bedrock-runtime',
|
40
|
+
aws_access_key_id=self.access_key_id,
|
41
|
+
aws_secret_access_key=self.secret_access_key,
|
42
|
+
region_name=self.region
|
43
|
+
)
|
44
|
+
except Exception as e:
|
45
|
+
raise AIConnectionError(f"Failed to initialize AWS Bedrock client: {str(e)}")
|
46
|
+
|
47
|
+
def get_default_model(self) -> str:
|
48
|
+
"""Get the default AWS model."""
|
49
|
+
return "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
50
|
+
|
51
|
+
def get_available_models(self) -> List[str]:
|
52
|
+
"""Get available AWS models."""
|
53
|
+
return [
|
54
|
+
"us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
55
|
+
"us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
56
|
+
"us.anthropic.claude-3-opus-20240229-v1:0",
|
57
|
+
"us.anthropic.claude-3-sonnet-20240229-v1:0",
|
58
|
+
"us.anthropic.claude-3-haiku-20240307-v1:0"
|
59
|
+
]
|
60
|
+
|
61
|
+
def _get_embedding_model(self) -> SentenceTransformer:
|
62
|
+
"""Get or create the embedding model."""
|
63
|
+
if self.embedding_model is None:
|
64
|
+
try:
|
65
|
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
66
|
+
except Exception as e:
|
67
|
+
raise AIConnectionError(f"Failed to load embedding model: {str(e)}")
|
68
|
+
return self.embedding_model
|
69
|
+
|
70
|
+
def _get_db_engine(self):
|
71
|
+
"""Get database engine."""
|
72
|
+
if self.db_engine is None:
|
73
|
+
database_url = os.getenv("DATABASE_URL", "cockroachdb://root@localhost:26257/defaultdb?sslmode=disable")
|
74
|
+
try:
|
75
|
+
self.db_engine = create_engine(database_url)
|
76
|
+
except Exception as e:
|
77
|
+
raise AIConnectionError(f"Failed to connect to database: {str(e)}")
|
78
|
+
return self.db_engine
|
79
|
+
|
80
|
+
def search_expenses(
|
81
|
+
self,
|
82
|
+
query: str,
|
83
|
+
user_id: Optional[str] = None,
|
84
|
+
limit: int = 10,
|
85
|
+
threshold: float = 0.7
|
86
|
+
) -> List[SearchResult]:
|
87
|
+
"""Search for expenses using vector similarity."""
|
88
|
+
try:
|
89
|
+
# Generate query embedding
|
90
|
+
embedding_model = self._get_embedding_model()
|
91
|
+
query_embedding = embedding_model.encode([query])[0]
|
92
|
+
|
93
|
+
# Convert to PostgreSQL vector format
|
94
|
+
search_embedding = json.dumps(query_embedding.tolist())
|
95
|
+
|
96
|
+
# Build SQL query
|
97
|
+
sql = """
|
98
|
+
SELECT
|
99
|
+
expense_id,
|
100
|
+
user_id,
|
101
|
+
description,
|
102
|
+
merchant,
|
103
|
+
expense_amount,
|
104
|
+
expense_date,
|
105
|
+
1 - (embedding <-> %s) as similarity_score
|
106
|
+
FROM expenses
|
107
|
+
WHERE 1 - (embedding <-> %s) > %s
|
108
|
+
"""
|
109
|
+
|
110
|
+
params = [search_embedding, search_embedding, threshold]
|
111
|
+
|
112
|
+
if user_id:
|
113
|
+
sql += " AND user_id = %s"
|
114
|
+
params.append(user_id)
|
115
|
+
|
116
|
+
sql += " ORDER BY similarity_score DESC LIMIT %s"
|
117
|
+
params.append(limit)
|
118
|
+
|
119
|
+
# Execute query
|
120
|
+
engine = self._get_db_engine()
|
121
|
+
with engine.connect() as conn:
|
122
|
+
result = conn.execute(text(sql), params)
|
123
|
+
rows = result.fetchall()
|
124
|
+
|
125
|
+
# Convert to SearchResult objects
|
126
|
+
results = []
|
127
|
+
for row in rows:
|
128
|
+
results.append(SearchResult(
|
129
|
+
expense_id=str(row[0]),
|
130
|
+
user_id=str(row[1]),
|
131
|
+
description=row[2] or "",
|
132
|
+
merchant=row[3] or "",
|
133
|
+
amount=float(row[4]),
|
134
|
+
date=str(row[5]),
|
135
|
+
similarity_score=float(row[6]),
|
136
|
+
metadata={}
|
137
|
+
))
|
138
|
+
|
139
|
+
return results
|
140
|
+
|
141
|
+
except Exception as e:
|
142
|
+
raise AIConnectionError(f"Search failed: {str(e)}")
|
143
|
+
|
144
|
+
def generate_rag_response(
|
145
|
+
self,
|
146
|
+
query: str,
|
147
|
+
context: List[SearchResult],
|
148
|
+
user_id: Optional[str] = None,
|
149
|
+
language: str = "en"
|
150
|
+
) -> RAGResponse:
|
151
|
+
"""Generate RAG response using AWS Bedrock."""
|
152
|
+
try:
|
153
|
+
print(f"\n🤖 AWS BEDROCK RAG (with caching):")
|
154
|
+
print(f"1. Query: '{query[:60]}...'")
|
155
|
+
|
156
|
+
# Check for cached response first
|
157
|
+
if self.cache_manager:
|
158
|
+
# Convert SearchResult objects to dict format for cache lookup
|
159
|
+
search_results_dict = []
|
160
|
+
for result in context:
|
161
|
+
search_results_dict.append({
|
162
|
+
'expense_id': result.expense_id,
|
163
|
+
'user_id': result.user_id,
|
164
|
+
'description': result.description,
|
165
|
+
'merchant': result.merchant,
|
166
|
+
'expense_amount': result.amount,
|
167
|
+
'expense_date': result.date,
|
168
|
+
'similarity_score': result.similarity_score,
|
169
|
+
'shopping_type': result.metadata.get('shopping_type'),
|
170
|
+
'payment_method': result.metadata.get('payment_method'),
|
171
|
+
'recurring': result.metadata.get('recurring'),
|
172
|
+
'tags': result.metadata.get('tags')
|
173
|
+
})
|
174
|
+
|
175
|
+
cached_response = self.cache_manager.get_cached_response(
|
176
|
+
query, search_results_dict, "aws"
|
177
|
+
)
|
178
|
+
if cached_response:
|
179
|
+
print(f"2. ✅ Response cache HIT! Returning cached response")
|
180
|
+
return RAGResponse(
|
181
|
+
response=cached_response,
|
182
|
+
sources=context,
|
183
|
+
metadata={
|
184
|
+
'provider': 'aws',
|
185
|
+
'model': self.get_default_model(),
|
186
|
+
'user_id': user_id,
|
187
|
+
'language': language,
|
188
|
+
'cached': True
|
189
|
+
}
|
190
|
+
)
|
191
|
+
print(f"2. ❌ Response cache MISS, generating fresh response")
|
192
|
+
else:
|
193
|
+
print(f"2. No cache manager available, generating fresh response")
|
194
|
+
|
195
|
+
# Prepare context
|
196
|
+
context_text = self._prepare_context(context)
|
197
|
+
|
198
|
+
# Prepare the prompt
|
199
|
+
prompt = f"""You are Banko, a financial assistant. Answer based on this expense data:
|
200
|
+
|
201
|
+
Q: {query}
|
202
|
+
|
203
|
+
Data:
|
204
|
+
{context_text}
|
205
|
+
|
206
|
+
Provide helpful insights with numbers, markdown formatting, and actionable advice."""
|
207
|
+
|
208
|
+
# Define input parameters for Claude
|
209
|
+
payload = {
|
210
|
+
"anthropic_version": "bedrock-2023-05-31",
|
211
|
+
"max_tokens": 1000,
|
212
|
+
"top_k": 250,
|
213
|
+
"stop_sequences": [],
|
214
|
+
"temperature": 1,
|
215
|
+
"top_p": 0.999,
|
216
|
+
"messages": [
|
217
|
+
{
|
218
|
+
"role": "user",
|
219
|
+
"content": [
|
220
|
+
{
|
221
|
+
"type": "text",
|
222
|
+
"text": prompt
|
223
|
+
}
|
224
|
+
]
|
225
|
+
}
|
226
|
+
]
|
227
|
+
}
|
228
|
+
|
229
|
+
# Convert to JSON format
|
230
|
+
body = json.dumps(payload)
|
231
|
+
|
232
|
+
# Use current model
|
233
|
+
model_id = self.current_model
|
234
|
+
|
235
|
+
# Invoke model
|
236
|
+
response = self.bedrock_client.invoke_model(
|
237
|
+
modelId=model_id,
|
238
|
+
contentType="application/json",
|
239
|
+
accept="application/json",
|
240
|
+
body=body
|
241
|
+
)
|
242
|
+
|
243
|
+
# Parse response
|
244
|
+
response_body = json.loads(response['body'].read())
|
245
|
+
ai_response = response_body['content'][0]['text']
|
246
|
+
|
247
|
+
# Cache the response for future similar queries
|
248
|
+
if self.cache_manager and ai_response:
|
249
|
+
# Convert SearchResult objects to dict format for caching
|
250
|
+
search_results_dict = []
|
251
|
+
for result in context:
|
252
|
+
search_results_dict.append({
|
253
|
+
'expense_id': result.expense_id,
|
254
|
+
'user_id': result.user_id,
|
255
|
+
'description': result.description,
|
256
|
+
'merchant': result.merchant,
|
257
|
+
'expense_amount': result.amount,
|
258
|
+
'expense_date': result.date,
|
259
|
+
'similarity_score': result.similarity_score,
|
260
|
+
'shopping_type': result.metadata.get('shopping_type'),
|
261
|
+
'payment_method': result.metadata.get('payment_method'),
|
262
|
+
'recurring': result.metadata.get('recurring'),
|
263
|
+
'tags': result.metadata.get('tags')
|
264
|
+
})
|
265
|
+
|
266
|
+
# Estimate token usage (rough approximation for AWS)
|
267
|
+
prompt_tokens = len(query.split()) * 1.3 # ~1.3 tokens per word
|
268
|
+
response_tokens = len(ai_response.split()) * 1.3
|
269
|
+
|
270
|
+
self.cache_manager.cache_response(
|
271
|
+
query, ai_response, search_results_dict, "aws",
|
272
|
+
int(prompt_tokens), int(response_tokens)
|
273
|
+
)
|
274
|
+
print(f"3. ✅ Cached response (est. {int(prompt_tokens + response_tokens)} tokens)")
|
275
|
+
|
276
|
+
return RAGResponse(
|
277
|
+
response=ai_response,
|
278
|
+
sources=context,
|
279
|
+
metadata={
|
280
|
+
"model": "claude-3-5-sonnet",
|
281
|
+
"region": self.region,
|
282
|
+
"language": language
|
283
|
+
}
|
284
|
+
)
|
285
|
+
|
286
|
+
except Exception as e:
|
287
|
+
raise AIConnectionError(f"RAG response generation failed: {str(e)}")
|
288
|
+
|
289
|
+
def generate_embedding(self, text: str) -> List[float]:
|
290
|
+
"""Generate embedding for text."""
|
291
|
+
try:
|
292
|
+
embedding_model = self._get_embedding_model()
|
293
|
+
embedding = embedding_model.encode([text])[0]
|
294
|
+
return embedding.tolist()
|
295
|
+
except Exception as e:
|
296
|
+
raise AIConnectionError(f"Embedding generation failed: {str(e)}")
|
297
|
+
|
298
|
+
def test_connection(self) -> bool:
|
299
|
+
"""Test AWS Bedrock connection."""
|
300
|
+
try:
|
301
|
+
# Test with a simple completion
|
302
|
+
payload = {
|
303
|
+
"anthropic_version": "bedrock-2023-05-31",
|
304
|
+
"max_tokens": 5,
|
305
|
+
"messages": [
|
306
|
+
{
|
307
|
+
"role": "user",
|
308
|
+
"content": [{"type": "text", "text": "Hello"}]
|
309
|
+
}
|
310
|
+
]
|
311
|
+
}
|
312
|
+
|
313
|
+
response = self.bedrock_client.invoke_model(
|
314
|
+
modelId=self.current_model,
|
315
|
+
contentType="application/json",
|
316
|
+
accept="application/json",
|
317
|
+
body=json.dumps(payload)
|
318
|
+
)
|
319
|
+
|
320
|
+
response_body = json.loads(response['body'].read())
|
321
|
+
return response_body['content'][0]['text'] is not None
|
322
|
+
except Exception:
|
323
|
+
return False
|
324
|
+
|
325
|
+
def _prepare_context(self, context: List[SearchResult]) -> str:
|
326
|
+
"""Prepare context text from search results."""
|
327
|
+
if not context:
|
328
|
+
return "No relevant expense data found."
|
329
|
+
|
330
|
+
context_parts = []
|
331
|
+
for i, result in enumerate(context, 1):
|
332
|
+
context_parts.append(
|
333
|
+
f"• **{result.description}** at {result.merchant}: ${result.amount:.2f} "
|
334
|
+
f"({result.date}) - similarity: {result.similarity_score:.3f}"
|
335
|
+
)
|
336
|
+
|
337
|
+
return "\n".join(context_parts)
|
@@ -0,0 +1,175 @@
|
|
1
|
+
"""
|
2
|
+
Base AI provider interface and common functionality.
|
3
|
+
|
4
|
+
This module defines the abstract base class for AI providers and common error handling.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from typing import List, Dict, Any, Optional
|
9
|
+
from dataclasses import dataclass
|
10
|
+
|
11
|
+
|
12
|
+
class AIProviderError(Exception):
|
13
|
+
"""Base exception for AI provider errors."""
|
14
|
+
pass
|
15
|
+
|
16
|
+
|
17
|
+
class AIConnectionError(AIProviderError):
|
18
|
+
"""Exception raised when AI service connection fails."""
|
19
|
+
pass
|
20
|
+
|
21
|
+
|
22
|
+
class AIAuthenticationError(AIProviderError):
|
23
|
+
"""Exception raised when AI service authentication fails."""
|
24
|
+
pass
|
25
|
+
|
26
|
+
|
27
|
+
class AIQuotaExceededError(AIProviderError):
|
28
|
+
"""Exception raised when AI service quota is exceeded."""
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class SearchResult:
|
34
|
+
"""Result from vector similarity search."""
|
35
|
+
expense_id: str
|
36
|
+
user_id: str
|
37
|
+
description: str
|
38
|
+
merchant: str
|
39
|
+
amount: float
|
40
|
+
date: str
|
41
|
+
similarity_score: float
|
42
|
+
metadata: Dict[str, Any]
|
43
|
+
|
44
|
+
|
45
|
+
@dataclass
|
46
|
+
class RAGResponse:
|
47
|
+
"""Response from RAG (Retrieval-Augmented Generation) query."""
|
48
|
+
response: str
|
49
|
+
sources: List[SearchResult]
|
50
|
+
metadata: Dict[str, Any]
|
51
|
+
|
52
|
+
|
53
|
+
class AIProvider(ABC):
|
54
|
+
"""Abstract base class for AI providers."""
|
55
|
+
|
56
|
+
def __init__(self, config: Dict[str, Any]):
|
57
|
+
"""Initialize the AI provider with configuration."""
|
58
|
+
self.config = config
|
59
|
+
self.current_model = config.get("model", self.get_default_model())
|
60
|
+
self._validate_config()
|
61
|
+
|
62
|
+
@abstractmethod
|
63
|
+
def _validate_config(self) -> None:
|
64
|
+
"""Validate provider-specific configuration."""
|
65
|
+
pass
|
66
|
+
|
67
|
+
@abstractmethod
|
68
|
+
def search_expenses(
|
69
|
+
self,
|
70
|
+
query: str,
|
71
|
+
user_id: Optional[str] = None,
|
72
|
+
limit: int = 10,
|
73
|
+
threshold: float = 0.7
|
74
|
+
) -> List[SearchResult]:
|
75
|
+
"""
|
76
|
+
Search for expenses using vector similarity.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
query: Search query text
|
80
|
+
user_id: Optional user ID to filter results
|
81
|
+
limit: Maximum number of results to return
|
82
|
+
threshold: Minimum similarity score threshold
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
List of SearchResult objects
|
86
|
+
"""
|
87
|
+
pass
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def generate_rag_response(
|
91
|
+
self,
|
92
|
+
query: str,
|
93
|
+
context: List[SearchResult],
|
94
|
+
user_id: Optional[str] = None,
|
95
|
+
language: str = "en"
|
96
|
+
) -> RAGResponse:
|
97
|
+
"""
|
98
|
+
Generate a RAG response using the provided context.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
query: User query
|
102
|
+
context: List of relevant search results
|
103
|
+
user_id: Optional user ID
|
104
|
+
language: Response language code
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
RAGResponse object
|
108
|
+
"""
|
109
|
+
pass
|
110
|
+
|
111
|
+
@abstractmethod
|
112
|
+
def generate_embedding(self, text: str) -> List[float]:
|
113
|
+
"""
|
114
|
+
Generate embedding vector for the given text.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
text: Text to embed
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
List of float values representing the embedding
|
121
|
+
"""
|
122
|
+
pass
|
123
|
+
|
124
|
+
@abstractmethod
|
125
|
+
def test_connection(self) -> bool:
|
126
|
+
"""
|
127
|
+
Test the connection to the AI service.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
True if connection is successful, False otherwise
|
131
|
+
"""
|
132
|
+
pass
|
133
|
+
|
134
|
+
def get_provider_name(self) -> str:
|
135
|
+
"""Get the name of this AI provider."""
|
136
|
+
return self.__class__.__name__.replace("Provider", "").lower()
|
137
|
+
|
138
|
+
def get_provider_info(self) -> Dict[str, Any]:
|
139
|
+
"""Get provider information and status."""
|
140
|
+
return {
|
141
|
+
"name": self.get_provider_name(),
|
142
|
+
"connected": self.test_connection(),
|
143
|
+
"current_model": self.current_model,
|
144
|
+
"config_keys": list(self.config.keys())
|
145
|
+
}
|
146
|
+
|
147
|
+
@abstractmethod
|
148
|
+
def get_default_model(self) -> str:
|
149
|
+
"""Get the default model for this provider."""
|
150
|
+
pass
|
151
|
+
|
152
|
+
def get_available_models(self) -> List[str]:
|
153
|
+
"""Get available models for this provider."""
|
154
|
+
# Default implementation - can be overridden by providers
|
155
|
+
return [self.current_model]
|
156
|
+
|
157
|
+
def set_model(self, model: str) -> bool:
|
158
|
+
"""
|
159
|
+
Switch to a different model.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
model: Model name to switch to
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
True if model was switched successfully, False otherwise
|
166
|
+
"""
|
167
|
+
available_models = self.get_available_models()
|
168
|
+
if model in available_models:
|
169
|
+
self.current_model = model
|
170
|
+
return True
|
171
|
+
return False
|
172
|
+
|
173
|
+
def get_current_model(self) -> str:
|
174
|
+
"""Get the current model."""
|
175
|
+
return self.current_model
|
@@ -0,0 +1,84 @@
|
|
1
|
+
"""
|
2
|
+
AI Provider Factory for creating provider instances.
|
3
|
+
|
4
|
+
This module provides a factory pattern for creating AI provider instances
|
5
|
+
based on configuration.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Dict, Any, Type
|
9
|
+
from .base import AIProvider, AIProviderError
|
10
|
+
from .openai_provider import OpenAIProvider
|
11
|
+
from .aws_provider import AWSProvider
|
12
|
+
from .watsonx_provider import WatsonxProvider
|
13
|
+
from .gemini_provider import GeminiProvider
|
14
|
+
|
15
|
+
|
16
|
+
class AIProviderFactory:
|
17
|
+
"""Factory for creating AI provider instances."""
|
18
|
+
|
19
|
+
_providers: Dict[str, Type[AIProvider]] = {
|
20
|
+
"openai": OpenAIProvider,
|
21
|
+
"aws": AWSProvider,
|
22
|
+
"watsonx": WatsonxProvider,
|
23
|
+
"gemini": GeminiProvider,
|
24
|
+
}
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
def create_provider(cls, service_name: str, config: Dict[str, Any], cache_manager=None) -> AIProvider:
|
28
|
+
"""
|
29
|
+
Create an AI provider instance.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
service_name: Name of the AI service (openai, aws, watsonx, gemini)
|
33
|
+
config: Configuration dictionary for the provider
|
34
|
+
cache_manager: Optional cache manager instance
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
AIProvider instance
|
38
|
+
|
39
|
+
Raises:
|
40
|
+
AIProviderError: If the service is not supported
|
41
|
+
"""
|
42
|
+
service_name = service_name.lower()
|
43
|
+
|
44
|
+
if service_name not in cls._providers:
|
45
|
+
available = ", ".join(cls._providers.keys())
|
46
|
+
raise AIProviderError(
|
47
|
+
f"Unsupported AI service: {service_name}. "
|
48
|
+
f"Available services: {available}"
|
49
|
+
)
|
50
|
+
|
51
|
+
provider_class = cls._providers[service_name]
|
52
|
+
|
53
|
+
try:
|
54
|
+
# Pass cache_manager to all providers that support it
|
55
|
+
if hasattr(provider_class, '__init__'):
|
56
|
+
# Check if the provider's __init__ method accepts cache_manager parameter
|
57
|
+
import inspect
|
58
|
+
sig = inspect.signature(provider_class.__init__)
|
59
|
+
if 'cache_manager' in sig.parameters:
|
60
|
+
return provider_class(config, cache_manager)
|
61
|
+
else:
|
62
|
+
return provider_class(config)
|
63
|
+
else:
|
64
|
+
return provider_class(config)
|
65
|
+
except Exception as e:
|
66
|
+
raise AIProviderError(
|
67
|
+
f"Failed to create {service_name} provider: {str(e)}"
|
68
|
+
)
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def get_available_providers(cls) -> list[str]:
|
72
|
+
"""Get list of available AI providers."""
|
73
|
+
return list(cls._providers.keys())
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def register_provider(cls, name: str, provider_class: Type[AIProvider]) -> None:
|
77
|
+
"""
|
78
|
+
Register a new AI provider.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
name: Name of the provider
|
82
|
+
provider_class: Provider class that implements AIProvider
|
83
|
+
"""
|
84
|
+
cls._providers[name.lower()] = provider_class
|