banko-ai-assistant 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- banko_ai/__init__.py +19 -0
- banko_ai/__main__.py +10 -0
- banko_ai/ai_providers/__init__.py +18 -0
- banko_ai/ai_providers/aws_provider.py +337 -0
- banko_ai/ai_providers/base.py +175 -0
- banko_ai/ai_providers/factory.py +84 -0
- banko_ai/ai_providers/gemini_provider.py +340 -0
- banko_ai/ai_providers/openai_provider.py +295 -0
- banko_ai/ai_providers/watsonx_provider.py +591 -0
- banko_ai/cli.py +374 -0
- banko_ai/config/__init__.py +5 -0
- banko_ai/config/settings.py +216 -0
- banko_ai/static/Anallytics.png +0 -0
- banko_ai/static/Graph.png +0 -0
- banko_ai/static/Graph2.png +0 -0
- banko_ai/static/ai-status.png +0 -0
- banko_ai/static/banko-ai-assistant-watsonx.gif +0 -0
- banko_ai/static/banko-db-ops.png +0 -0
- banko_ai/static/banko-response.png +0 -0
- banko_ai/static/cache-stats.png +0 -0
- banko_ai/static/creditcard.png +0 -0
- banko_ai/static/profilepic.jpeg +0 -0
- banko_ai/static/query_watcher.png +0 -0
- banko_ai/static/roach-logo.svg +54 -0
- banko_ai/static/watsonx-icon.svg +1 -0
- banko_ai/templates/base.html +59 -0
- banko_ai/templates/dashboard.html +569 -0
- banko_ai/templates/index.html +1499 -0
- banko_ai/templates/login.html +41 -0
- banko_ai/utils/__init__.py +8 -0
- banko_ai/utils/cache_manager.py +525 -0
- banko_ai/utils/database.py +202 -0
- banko_ai/utils/migration.py +123 -0
- banko_ai/vector_search/__init__.py +18 -0
- banko_ai/vector_search/enrichment.py +278 -0
- banko_ai/vector_search/generator.py +329 -0
- banko_ai/vector_search/search.py +463 -0
- banko_ai/web/__init__.py +13 -0
- banko_ai/web/app.py +668 -0
- banko_ai/web/auth.py +73 -0
- banko_ai_assistant-1.0.0.dist-info/METADATA +414 -0
- banko_ai_assistant-1.0.0.dist-info/RECORD +46 -0
- banko_ai_assistant-1.0.0.dist-info/WHEEL +5 -0
- banko_ai_assistant-1.0.0.dist-info/entry_points.txt +2 -0
- banko_ai_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
- banko_ai_assistant-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,340 @@
|
|
1
|
+
"""
|
2
|
+
Google Gemini AI provider implementation.
|
3
|
+
|
4
|
+
This module provides Google Vertex AI/Gemini integration for vector search and RAG responses.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
import json
|
9
|
+
from typing import List, Dict, Any, Optional
|
10
|
+
from sentence_transformers import SentenceTransformer
|
11
|
+
from sqlalchemy import create_engine, text
|
12
|
+
|
13
|
+
try:
|
14
|
+
from google.cloud import aiplatform
|
15
|
+
from google.oauth2 import service_account
|
16
|
+
GEMINI_AVAILABLE = True
|
17
|
+
except ImportError:
|
18
|
+
GEMINI_AVAILABLE = False
|
19
|
+
|
20
|
+
from .base import AIProvider, SearchResult, RAGResponse, AIConnectionError, AIAuthenticationError
|
21
|
+
|
22
|
+
|
23
|
+
class GeminiProvider(AIProvider):
|
24
|
+
"""Google Gemini AI provider implementation."""
|
25
|
+
|
26
|
+
def __init__(self, config: Dict[str, Any], cache_manager=None):
|
27
|
+
"""Initialize Gemini provider."""
|
28
|
+
if not GEMINI_AVAILABLE:
|
29
|
+
raise AIConnectionError("Google Cloud AI Platform not available. Install with: pip install google-cloud-aiplatform")
|
30
|
+
|
31
|
+
self.cache_manager = cache_manager
|
32
|
+
|
33
|
+
self.project_id = config.get("project_id")
|
34
|
+
self.location = config.get("location", "us-central1")
|
35
|
+
self.model_name = "gemini-1.5-pro"
|
36
|
+
self.embedding_model = None
|
37
|
+
self.db_engine = None
|
38
|
+
self.vertex_client = None
|
39
|
+
super().__init__(config)
|
40
|
+
|
41
|
+
def _validate_config(self) -> None:
|
42
|
+
"""Validate Gemini configuration."""
|
43
|
+
if not self.project_id:
|
44
|
+
raise AIAuthenticationError("Google project ID is required")
|
45
|
+
|
46
|
+
# Initialize Vertex AI
|
47
|
+
try:
|
48
|
+
aiplatform.init(project=self.project_id, location=self.location)
|
49
|
+
self.vertex_client = aiplatform.gapic.PredictionServiceClient()
|
50
|
+
except Exception as e:
|
51
|
+
raise AIConnectionError(f"Failed to initialize Vertex AI: {str(e)}")
|
52
|
+
|
53
|
+
def get_default_model(self) -> str:
|
54
|
+
"""Get the default Gemini model."""
|
55
|
+
return "gemini-1.5-pro"
|
56
|
+
|
57
|
+
def get_available_models(self) -> List[str]:
|
58
|
+
"""Get available Gemini models."""
|
59
|
+
return [
|
60
|
+
"gemini-1.5-pro",
|
61
|
+
"gemini-1.5-flash",
|
62
|
+
"gemini-1.0-pro"
|
63
|
+
]
|
64
|
+
|
65
|
+
def _get_embedding_model(self) -> SentenceTransformer:
|
66
|
+
"""Get or create the embedding model."""
|
67
|
+
if self.embedding_model is None:
|
68
|
+
try:
|
69
|
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
70
|
+
except Exception as e:
|
71
|
+
raise AIConnectionError(f"Failed to load embedding model: {str(e)}")
|
72
|
+
return self.embedding_model
|
73
|
+
|
74
|
+
def _get_db_engine(self):
|
75
|
+
"""Get database engine."""
|
76
|
+
if self.db_engine is None:
|
77
|
+
database_url = os.getenv("DATABASE_URL", "cockroachdb://root@localhost:26257/defaultdb?sslmode=disable")
|
78
|
+
try:
|
79
|
+
self.db_engine = create_engine(database_url)
|
80
|
+
except Exception as e:
|
81
|
+
raise AIConnectionError(f"Failed to connect to database: {str(e)}")
|
82
|
+
return self.db_engine
|
83
|
+
|
84
|
+
def search_expenses(
|
85
|
+
self,
|
86
|
+
query: str,
|
87
|
+
user_id: Optional[str] = None,
|
88
|
+
limit: int = 10,
|
89
|
+
threshold: float = 0.7
|
90
|
+
) -> List[SearchResult]:
|
91
|
+
"""Search for expenses using vector similarity."""
|
92
|
+
try:
|
93
|
+
# Generate query embedding
|
94
|
+
embedding_model = self._get_embedding_model()
|
95
|
+
query_embedding = embedding_model.encode([query])[0]
|
96
|
+
|
97
|
+
# Convert to PostgreSQL vector format
|
98
|
+
search_embedding = json.dumps(query_embedding.tolist())
|
99
|
+
|
100
|
+
# Build SQL query
|
101
|
+
sql = """
|
102
|
+
SELECT
|
103
|
+
expense_id,
|
104
|
+
user_id,
|
105
|
+
description,
|
106
|
+
merchant,
|
107
|
+
expense_amount,
|
108
|
+
expense_date,
|
109
|
+
1 - (embedding <-> %s) as similarity_score
|
110
|
+
FROM expenses
|
111
|
+
WHERE 1 - (embedding <-> %s) > %s
|
112
|
+
"""
|
113
|
+
|
114
|
+
params = [search_embedding, search_embedding, threshold]
|
115
|
+
|
116
|
+
if user_id:
|
117
|
+
sql += " AND user_id = %s"
|
118
|
+
params.append(user_id)
|
119
|
+
|
120
|
+
sql += " ORDER BY similarity_score DESC LIMIT %s"
|
121
|
+
params.append(limit)
|
122
|
+
|
123
|
+
# Execute query
|
124
|
+
engine = self._get_db_engine()
|
125
|
+
with engine.connect() as conn:
|
126
|
+
result = conn.execute(text(sql), params)
|
127
|
+
rows = result.fetchall()
|
128
|
+
|
129
|
+
# Convert to SearchResult objects
|
130
|
+
results = []
|
131
|
+
for row in rows:
|
132
|
+
results.append(SearchResult(
|
133
|
+
expense_id=str(row[0]),
|
134
|
+
user_id=str(row[1]),
|
135
|
+
description=row[2] or "",
|
136
|
+
merchant=row[3] or "",
|
137
|
+
amount=float(row[4]),
|
138
|
+
date=str(row[5]),
|
139
|
+
similarity_score=float(row[6]),
|
140
|
+
metadata={}
|
141
|
+
))
|
142
|
+
|
143
|
+
return results
|
144
|
+
|
145
|
+
except Exception as e:
|
146
|
+
raise AIConnectionError(f"Search failed: {str(e)}")
|
147
|
+
|
148
|
+
def generate_rag_response(
|
149
|
+
self,
|
150
|
+
query: str,
|
151
|
+
context: List[SearchResult],
|
152
|
+
user_id: Optional[str] = None,
|
153
|
+
language: str = "en"
|
154
|
+
) -> RAGResponse:
|
155
|
+
"""Generate RAG response using Google Gemini."""
|
156
|
+
try:
|
157
|
+
print(f"\n🤖 GOOGLE GEMINI RAG (with caching):")
|
158
|
+
print(f"1. Query: '{query[:60]}...'")
|
159
|
+
|
160
|
+
# Check for cached response first
|
161
|
+
if self.cache_manager:
|
162
|
+
# Convert SearchResult objects to dict format for cache lookup
|
163
|
+
search_results_dict = []
|
164
|
+
for result in context:
|
165
|
+
search_results_dict.append({
|
166
|
+
'expense_id': result.expense_id,
|
167
|
+
'user_id': result.user_id,
|
168
|
+
'description': result.description,
|
169
|
+
'merchant': result.merchant,
|
170
|
+
'expense_amount': result.amount,
|
171
|
+
'expense_date': result.date,
|
172
|
+
'similarity_score': result.similarity_score,
|
173
|
+
'shopping_type': result.metadata.get('shopping_type'),
|
174
|
+
'payment_method': result.metadata.get('payment_method'),
|
175
|
+
'recurring': result.metadata.get('recurring'),
|
176
|
+
'tags': result.metadata.get('tags')
|
177
|
+
})
|
178
|
+
|
179
|
+
cached_response = self.cache_manager.get_cached_response(
|
180
|
+
query, search_results_dict, "gemini"
|
181
|
+
)
|
182
|
+
if cached_response:
|
183
|
+
print(f"2. ✅ Response cache HIT! Returning cached response")
|
184
|
+
return RAGResponse(
|
185
|
+
response=cached_response,
|
186
|
+
sources=context,
|
187
|
+
metadata={
|
188
|
+
'provider': 'gemini',
|
189
|
+
'model': self.get_default_model(),
|
190
|
+
'user_id': user_id,
|
191
|
+
'language': language,
|
192
|
+
'cached': True
|
193
|
+
}
|
194
|
+
)
|
195
|
+
print(f"2. ❌ Response cache MISS, generating fresh response")
|
196
|
+
else:
|
197
|
+
print(f"2. No cache manager available, generating fresh response")
|
198
|
+
|
199
|
+
# Prepare context
|
200
|
+
context_text = self._prepare_context(context)
|
201
|
+
|
202
|
+
# Prepare the prompt
|
203
|
+
prompt = f"""You are Banko, a financial assistant. Answer based on this expense data:
|
204
|
+
|
205
|
+
Q: {query}
|
206
|
+
|
207
|
+
Data:
|
208
|
+
{context_text}
|
209
|
+
|
210
|
+
Provide helpful insights with numbers, markdown formatting, and actionable advice."""
|
211
|
+
|
212
|
+
# Prepare the request
|
213
|
+
endpoint = f"projects/{self.project_id}/locations/{self.location}/publishers/google/models/{self.current_model}"
|
214
|
+
|
215
|
+
instances = [{
|
216
|
+
"messages": [
|
217
|
+
{
|
218
|
+
"role": "user",
|
219
|
+
"content": prompt
|
220
|
+
}
|
221
|
+
]
|
222
|
+
}]
|
223
|
+
|
224
|
+
parameters = {
|
225
|
+
"temperature": 0.7,
|
226
|
+
"maxOutputTokens": 1000,
|
227
|
+
"topP": 0.9,
|
228
|
+
"topK": 40
|
229
|
+
}
|
230
|
+
|
231
|
+
# Make prediction request
|
232
|
+
response = self.vertex_client.predict(
|
233
|
+
endpoint=endpoint,
|
234
|
+
instances=instances,
|
235
|
+
parameters=parameters
|
236
|
+
)
|
237
|
+
|
238
|
+
# Extract response
|
239
|
+
predictions = response.predictions
|
240
|
+
if predictions and len(predictions) > 0:
|
241
|
+
ai_response = predictions[0].get("candidates", [{}])[0].get("content", "")
|
242
|
+
else:
|
243
|
+
ai_response = "I apologize, but I couldn't generate a response at this time."
|
244
|
+
|
245
|
+
# Cache the response for future similar queries
|
246
|
+
if self.cache_manager and ai_response:
|
247
|
+
# Convert SearchResult objects to dict format for caching
|
248
|
+
search_results_dict = []
|
249
|
+
for result in context:
|
250
|
+
search_results_dict.append({
|
251
|
+
'expense_id': result.expense_id,
|
252
|
+
'user_id': result.user_id,
|
253
|
+
'description': result.description,
|
254
|
+
'merchant': result.merchant,
|
255
|
+
'expense_amount': result.amount,
|
256
|
+
'expense_date': result.date,
|
257
|
+
'similarity_score': result.similarity_score,
|
258
|
+
'shopping_type': result.metadata.get('shopping_type'),
|
259
|
+
'payment_method': result.metadata.get('payment_method'),
|
260
|
+
'recurring': result.metadata.get('recurring'),
|
261
|
+
'tags': result.metadata.get('tags')
|
262
|
+
})
|
263
|
+
|
264
|
+
# Estimate token usage (rough approximation for Gemini)
|
265
|
+
prompt_tokens = len(query.split()) * 1.3 # ~1.3 tokens per word
|
266
|
+
response_tokens = len(ai_response.split()) * 1.3
|
267
|
+
|
268
|
+
self.cache_manager.cache_response(
|
269
|
+
query, ai_response, search_results_dict, "gemini",
|
270
|
+
int(prompt_tokens), int(response_tokens)
|
271
|
+
)
|
272
|
+
print(f"3. ✅ Cached response (est. {int(prompt_tokens + response_tokens)} tokens)")
|
273
|
+
|
274
|
+
return RAGResponse(
|
275
|
+
response=ai_response,
|
276
|
+
sources=context,
|
277
|
+
metadata={
|
278
|
+
"model": self.model_name,
|
279
|
+
"project_id": self.project_id,
|
280
|
+
"location": self.location,
|
281
|
+
"language": language
|
282
|
+
}
|
283
|
+
)
|
284
|
+
|
285
|
+
except Exception as e:
|
286
|
+
raise AIConnectionError(f"RAG response generation failed: {str(e)}")
|
287
|
+
|
288
|
+
def generate_embedding(self, text: str) -> List[float]:
|
289
|
+
"""Generate embedding for text."""
|
290
|
+
try:
|
291
|
+
embedding_model = self._get_embedding_model()
|
292
|
+
embedding = embedding_model.encode([text])[0]
|
293
|
+
return embedding.tolist()
|
294
|
+
except Exception as e:
|
295
|
+
raise AIConnectionError(f"Embedding generation failed: {str(e)}")
|
296
|
+
|
297
|
+
def test_connection(self) -> bool:
|
298
|
+
"""Test Gemini connection."""
|
299
|
+
try:
|
300
|
+
# Test with a simple completion
|
301
|
+
endpoint = f"projects/{self.project_id}/locations/{self.location}/publishers/google/models/{self.current_model}"
|
302
|
+
|
303
|
+
instances = [{
|
304
|
+
"messages": [
|
305
|
+
{
|
306
|
+
"role": "user",
|
307
|
+
"content": "Hello"
|
308
|
+
}
|
309
|
+
]
|
310
|
+
}]
|
311
|
+
|
312
|
+
parameters = {
|
313
|
+
"temperature": 0.7,
|
314
|
+
"maxOutputTokens": 5
|
315
|
+
}
|
316
|
+
|
317
|
+
response = self.vertex_client.predict(
|
318
|
+
endpoint=endpoint,
|
319
|
+
instances=instances,
|
320
|
+
parameters=parameters
|
321
|
+
)
|
322
|
+
|
323
|
+
predictions = response.predictions
|
324
|
+
return predictions and len(predictions) > 0 and predictions[0].get("candidates")
|
325
|
+
except Exception:
|
326
|
+
return False
|
327
|
+
|
328
|
+
def _prepare_context(self, context: List[SearchResult]) -> str:
|
329
|
+
"""Prepare context text from search results."""
|
330
|
+
if not context:
|
331
|
+
return "No relevant expense data found."
|
332
|
+
|
333
|
+
context_parts = []
|
334
|
+
for i, result in enumerate(context, 1):
|
335
|
+
context_parts.append(
|
336
|
+
f"• **{result.description}** at {result.merchant}: ${result.amount:.2f} "
|
337
|
+
f"({result.date}) - similarity: {result.similarity_score:.3f}"
|
338
|
+
)
|
339
|
+
|
340
|
+
return "\n".join(context_parts)
|
@@ -0,0 +1,295 @@
|
|
1
|
+
"""
|
2
|
+
OpenAI AI provider implementation.
|
3
|
+
|
4
|
+
This module provides OpenAI integration for vector search and RAG responses.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
from typing import List, Dict, Any, Optional
|
9
|
+
from openai import OpenAI
|
10
|
+
from sentence_transformers import SentenceTransformer
|
11
|
+
import numpy as np
|
12
|
+
from sqlalchemy import create_engine, text
|
13
|
+
|
14
|
+
from .base import AIProvider, SearchResult, RAGResponse, AIConnectionError, AIAuthenticationError
|
15
|
+
|
16
|
+
|
17
|
+
class OpenAIProvider(AIProvider):
|
18
|
+
"""OpenAI AI provider implementation."""
|
19
|
+
|
20
|
+
def __init__(self, config: Dict[str, Any], cache_manager=None):
|
21
|
+
"""Initialize OpenAI provider."""
|
22
|
+
self.api_key = config.get("api_key")
|
23
|
+
self.client = None
|
24
|
+
self.embedding_model = None
|
25
|
+
self.db_engine = None
|
26
|
+
self.cache_manager = cache_manager
|
27
|
+
super().__init__(config)
|
28
|
+
|
29
|
+
def _validate_config(self) -> None:
|
30
|
+
"""Validate OpenAI configuration."""
|
31
|
+
if not self.api_key:
|
32
|
+
raise AIAuthenticationError("OpenAI API key is required")
|
33
|
+
|
34
|
+
# Initialize OpenAI client
|
35
|
+
try:
|
36
|
+
self.client = OpenAI(api_key=self.api_key)
|
37
|
+
except Exception as e:
|
38
|
+
raise AIConnectionError(f"Failed to initialize OpenAI client: {str(e)}")
|
39
|
+
|
40
|
+
def get_default_model(self) -> str:
|
41
|
+
"""Get the default OpenAI model."""
|
42
|
+
return "gpt-3.5-turbo"
|
43
|
+
|
44
|
+
def get_available_models(self) -> List[str]:
|
45
|
+
"""Get available OpenAI models."""
|
46
|
+
return [
|
47
|
+
"gpt-3.5-turbo",
|
48
|
+
"gpt-3.5-turbo-16k",
|
49
|
+
"gpt-4",
|
50
|
+
"gpt-4-turbo",
|
51
|
+
"gpt-4o",
|
52
|
+
"gpt-4o-mini"
|
53
|
+
]
|
54
|
+
|
55
|
+
def _get_embedding_model(self) -> SentenceTransformer:
|
56
|
+
"""Get or create the embedding model."""
|
57
|
+
if self.embedding_model is None:
|
58
|
+
try:
|
59
|
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
60
|
+
except Exception as e:
|
61
|
+
raise AIConnectionError(f"Failed to load embedding model: {str(e)}")
|
62
|
+
return self.embedding_model
|
63
|
+
|
64
|
+
def _get_db_engine(self):
|
65
|
+
"""Get database engine."""
|
66
|
+
if self.db_engine is None:
|
67
|
+
database_url = os.getenv("DATABASE_URL", "cockroachdb://root@localhost:26257/defaultdb?sslmode=disable")
|
68
|
+
try:
|
69
|
+
self.db_engine = create_engine(database_url)
|
70
|
+
except Exception as e:
|
71
|
+
raise AIConnectionError(f"Failed to connect to database: {str(e)}")
|
72
|
+
return self.db_engine
|
73
|
+
|
74
|
+
def search_expenses(
|
75
|
+
self,
|
76
|
+
query: str,
|
77
|
+
user_id: Optional[str] = None,
|
78
|
+
limit: int = 10,
|
79
|
+
threshold: float = 0.7
|
80
|
+
) -> List[SearchResult]:
|
81
|
+
"""Search for expenses using vector similarity."""
|
82
|
+
try:
|
83
|
+
# Generate query embedding
|
84
|
+
embedding_model = self._get_embedding_model()
|
85
|
+
query_embedding = embedding_model.encode([query])[0]
|
86
|
+
|
87
|
+
# Build SQL query
|
88
|
+
sql = """
|
89
|
+
SELECT
|
90
|
+
expense_id,
|
91
|
+
user_id,
|
92
|
+
description,
|
93
|
+
merchant,
|
94
|
+
expense_amount,
|
95
|
+
expense_date,
|
96
|
+
1 - (embedding <=> %s) as similarity_score
|
97
|
+
FROM expenses
|
98
|
+
WHERE 1 - (embedding <=> %s) > %s
|
99
|
+
"""
|
100
|
+
|
101
|
+
params = [query_embedding.tolist(), query_embedding.tolist(), threshold]
|
102
|
+
|
103
|
+
if user_id:
|
104
|
+
sql += " AND user_id = %s"
|
105
|
+
params.append(user_id)
|
106
|
+
|
107
|
+
sql += " ORDER BY similarity_score DESC LIMIT %s"
|
108
|
+
params.append(limit)
|
109
|
+
|
110
|
+
# Execute query
|
111
|
+
engine = self._get_db_engine()
|
112
|
+
with engine.connect() as conn:
|
113
|
+
result = conn.execute(text(sql), params)
|
114
|
+
rows = result.fetchall()
|
115
|
+
|
116
|
+
# Convert to SearchResult objects
|
117
|
+
results = []
|
118
|
+
for row in rows:
|
119
|
+
results.append(SearchResult(
|
120
|
+
expense_id=str(row[0]),
|
121
|
+
user_id=str(row[1]),
|
122
|
+
description=row[2] or "",
|
123
|
+
merchant=row[3] or "",
|
124
|
+
amount=float(row[4]),
|
125
|
+
date=str(row[5]),
|
126
|
+
similarity_score=float(row[6]),
|
127
|
+
metadata={}
|
128
|
+
))
|
129
|
+
|
130
|
+
return results
|
131
|
+
|
132
|
+
except Exception as e:
|
133
|
+
raise AIConnectionError(f"Search failed: {str(e)}")
|
134
|
+
|
135
|
+
def generate_rag_response(
|
136
|
+
self,
|
137
|
+
query: str,
|
138
|
+
context: List[SearchResult],
|
139
|
+
user_id: Optional[str] = None,
|
140
|
+
language: str = "en"
|
141
|
+
) -> RAGResponse:
|
142
|
+
"""Generate RAG response using OpenAI."""
|
143
|
+
try:
|
144
|
+
print(f"\n🤖 OPENAI RAG (with caching):")
|
145
|
+
print(f"1. Query: '{query[:60]}...'")
|
146
|
+
|
147
|
+
# Check for cached response first
|
148
|
+
if self.cache_manager:
|
149
|
+
# Convert SearchResult objects to dict format for cache lookup
|
150
|
+
search_results_dict = []
|
151
|
+
for result in context:
|
152
|
+
search_results_dict.append({
|
153
|
+
'expense_id': result.expense_id,
|
154
|
+
'user_id': result.user_id,
|
155
|
+
'description': result.description,
|
156
|
+
'merchant': result.merchant,
|
157
|
+
'expense_amount': result.amount,
|
158
|
+
'expense_date': result.date,
|
159
|
+
'similarity_score': result.similarity_score,
|
160
|
+
'shopping_type': result.metadata.get('shopping_type'),
|
161
|
+
'payment_method': result.metadata.get('payment_method'),
|
162
|
+
'recurring': result.metadata.get('recurring'),
|
163
|
+
'tags': result.metadata.get('tags')
|
164
|
+
})
|
165
|
+
|
166
|
+
cached_response = self.cache_manager.get_cached_response(
|
167
|
+
query, search_results_dict, "openai"
|
168
|
+
)
|
169
|
+
if cached_response:
|
170
|
+
print(f"2. ✅ Response cache HIT! Returning cached response")
|
171
|
+
return RAGResponse(
|
172
|
+
response=cached_response,
|
173
|
+
sources=context,
|
174
|
+
metadata={
|
175
|
+
'provider': 'openai',
|
176
|
+
'model': self.get_default_model(),
|
177
|
+
'user_id': user_id,
|
178
|
+
'language': language,
|
179
|
+
'cached': True
|
180
|
+
}
|
181
|
+
)
|
182
|
+
print(f"2. ❌ Response cache MISS, generating fresh response")
|
183
|
+
else:
|
184
|
+
print(f"2. No cache manager available, generating fresh response")
|
185
|
+
|
186
|
+
# Prepare context
|
187
|
+
context_text = self._prepare_context(context)
|
188
|
+
|
189
|
+
# Prepare system message
|
190
|
+
system_message = f"""You are a helpful AI assistant for expense analysis.
|
191
|
+
You have access to the user's expense data and can help answer questions about their spending patterns.
|
192
|
+
|
193
|
+
Please respond in {language} if requested, otherwise use English.
|
194
|
+
|
195
|
+
Use the provided expense data to answer questions accurately and helpfully."""
|
196
|
+
|
197
|
+
# Prepare user message
|
198
|
+
user_message = f"""Query: {query}
|
199
|
+
|
200
|
+
Relevant expense data:
|
201
|
+
{context_text}
|
202
|
+
|
203
|
+
Please provide a helpful response based on the expense data above."""
|
204
|
+
|
205
|
+
# Generate response
|
206
|
+
response = self.client.chat.completions.create(
|
207
|
+
model=self.current_model,
|
208
|
+
messages=[
|
209
|
+
{"role": "system", "content": system_message},
|
210
|
+
{"role": "user", "content": user_message}
|
211
|
+
],
|
212
|
+
max_tokens=500,
|
213
|
+
temperature=0.7
|
214
|
+
)
|
215
|
+
|
216
|
+
response_text = response.choices[0].message.content
|
217
|
+
|
218
|
+
# Cache the response for future similar queries
|
219
|
+
if self.cache_manager and response_text:
|
220
|
+
# Convert SearchResult objects to dict format for caching
|
221
|
+
search_results_dict = []
|
222
|
+
for result in context:
|
223
|
+
search_results_dict.append({
|
224
|
+
'expense_id': result.expense_id,
|
225
|
+
'user_id': result.user_id,
|
226
|
+
'description': result.description,
|
227
|
+
'merchant': result.merchant,
|
228
|
+
'expense_amount': result.amount,
|
229
|
+
'expense_date': result.date,
|
230
|
+
'similarity_score': result.similarity_score,
|
231
|
+
'shopping_type': result.metadata.get('shopping_type'),
|
232
|
+
'payment_method': result.metadata.get('payment_method'),
|
233
|
+
'recurring': result.metadata.get('recurring'),
|
234
|
+
'tags': result.metadata.get('tags')
|
235
|
+
})
|
236
|
+
|
237
|
+
# Use actual token counts from OpenAI response
|
238
|
+
prompt_tokens = response.usage.prompt_tokens if response.usage else 0
|
239
|
+
response_tokens = response.usage.completion_tokens if response.usage else 0
|
240
|
+
|
241
|
+
self.cache_manager.cache_response(
|
242
|
+
query, response_text, search_results_dict, "openai",
|
243
|
+
prompt_tokens, response_tokens
|
244
|
+
)
|
245
|
+
print(f"3. ✅ Cached response ({prompt_tokens + response_tokens} tokens)")
|
246
|
+
|
247
|
+
return RAGResponse(
|
248
|
+
response=response_text,
|
249
|
+
sources=context,
|
250
|
+
metadata={
|
251
|
+
"model": "gpt-3.5-turbo",
|
252
|
+
"tokens_used": response.usage.total_tokens if response.usage else 0,
|
253
|
+
"language": language
|
254
|
+
}
|
255
|
+
)
|
256
|
+
|
257
|
+
except Exception as e:
|
258
|
+
raise AIConnectionError(f"RAG response generation failed: {str(e)}")
|
259
|
+
|
260
|
+
def generate_embedding(self, text: str) -> List[float]:
|
261
|
+
"""Generate embedding for text."""
|
262
|
+
try:
|
263
|
+
embedding_model = self._get_embedding_model()
|
264
|
+
embedding = embedding_model.encode([text])[0]
|
265
|
+
return embedding.tolist()
|
266
|
+
except Exception as e:
|
267
|
+
raise AIConnectionError(f"Embedding generation failed: {str(e)}")
|
268
|
+
|
269
|
+
def test_connection(self) -> bool:
|
270
|
+
"""Test OpenAI connection."""
|
271
|
+
try:
|
272
|
+
# Test with a simple completion
|
273
|
+
response = self.client.chat.completions.create(
|
274
|
+
model=self.current_model,
|
275
|
+
messages=[{"role": "user", "content": "Hello"}],
|
276
|
+
max_tokens=5
|
277
|
+
)
|
278
|
+
return response.choices[0].message.content is not None
|
279
|
+
except Exception:
|
280
|
+
return False
|
281
|
+
|
282
|
+
def _prepare_context(self, context: List[SearchResult]) -> str:
|
283
|
+
"""Prepare context text from search results."""
|
284
|
+
if not context:
|
285
|
+
return "No relevant expense data found."
|
286
|
+
|
287
|
+
context_parts = []
|
288
|
+
for i, result in enumerate(context, 1):
|
289
|
+
context_parts.append(
|
290
|
+
f"{i}. {result.description} at {result.merchant} - "
|
291
|
+
f"${result.amount:.2f} on {result.date} "
|
292
|
+
f"(similarity: {result.similarity_score:.3f})"
|
293
|
+
)
|
294
|
+
|
295
|
+
return "\n".join(context_parts)
|