memra 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memra/cli.py +322 -51
- {memra-0.2.13.dist-info → memra-0.2.15.dist-info}/METADATA +1 -1
- {memra-0.2.13.dist-info → memra-0.2.15.dist-info}/RECORD +7 -61
- memra-0.2.15.dist-info/top_level.txt +1 -0
- memra-0.2.13.dist-info/top_level.txt +0 -4
- memra-ops/app.py +0 -808
- memra-ops/config/config.py +0 -25
- memra-ops/config.py +0 -34
- memra-ops/logic/__init__.py +0 -1
- memra-ops/logic/file_tools.py +0 -43
- memra-ops/logic/invoice_tools.py +0 -668
- memra-ops/logic/invoice_tools_fix.py +0 -66
- memra-ops/mcp_bridge_server.py +0 -1178
- memra-ops/scripts/check_database.py +0 -37
- memra-ops/scripts/clear_database.py +0 -48
- memra-ops/scripts/monitor_database.py +0 -67
- memra-ops/scripts/release.py +0 -133
- memra-ops/scripts/reset_database.py +0 -65
- memra-ops/scripts/start_memra.py +0 -334
- memra-ops/scripts/stop_memra.py +0 -132
- memra-ops/server_tool_registry.py +0 -190
- memra-ops/tests/test_llm_text_to_sql.py +0 -115
- memra-ops/tests/test_llm_vs_pattern.py +0 -130
- memra-ops/tests/test_mcp_schema_aware.py +0 -124
- memra-ops/tests/test_schema_aware_sql.py +0 -139
- memra-ops/tests/test_schema_aware_sql_simple.py +0 -66
- memra-ops/tests/test_text_to_sql_demo.py +0 -140
- memra-ops/tools/mcp_bridge_server.py +0 -851
- memra-sdk/examples/accounts_payable.py +0 -215
- memra-sdk/examples/accounts_payable_client.py +0 -217
- memra-sdk/examples/accounts_payable_mcp.py +0 -200
- memra-sdk/examples/ask_questions.py +0 -123
- memra-sdk/examples/invoice_processing.py +0 -116
- memra-sdk/examples/propane_delivery.py +0 -87
- memra-sdk/examples/simple_text_to_sql.py +0 -158
- memra-sdk/memra/__init__.py +0 -31
- memra-sdk/memra/discovery.py +0 -15
- memra-sdk/memra/discovery_client.py +0 -49
- memra-sdk/memra/execution.py +0 -481
- memra-sdk/memra/models.py +0 -99
- memra-sdk/memra/tool_registry.py +0 -343
- memra-sdk/memra/tool_registry_client.py +0 -106
- memra-sdk/scripts/release.py +0 -133
- memra-sdk/setup.py +0 -52
- memra-workflows/accounts_payable/accounts_payable.py +0 -215
- memra-workflows/accounts_payable/accounts_payable_client.py +0 -216
- memra-workflows/accounts_payable/accounts_payable_mcp.py +0 -200
- memra-workflows/accounts_payable/accounts_payable_smart.py +0 -221
- memra-workflows/invoice_processing/invoice_processing.py +0 -116
- memra-workflows/invoice_processing/smart_invoice_processor.py +0 -220
- memra-workflows/logic/__init__.py +0 -1
- memra-workflows/logic/file_tools.py +0 -50
- memra-workflows/logic/invoice_tools.py +0 -501
- memra-workflows/logic/propane_agents.py +0 -52
- memra-workflows/mcp_bridge_server.py +0 -230
- memra-workflows/propane_delivery/propane_delivery.py +0 -87
- memra-workflows/text_to_sql/complete_invoice_workflow_with_queries.py +0 -208
- memra-workflows/text_to_sql/complete_text_to_sql_system.py +0 -266
- memra-workflows/text_to_sql/file_discovery_demo.py +0 -156
- {memra-0.2.13.dist-info → memra-0.2.15.dist-info}/LICENSE +0 -0
- {memra-0.2.13.dist-info → memra-0.2.15.dist-info}/WHEEL +0 -0
- {memra-0.2.13.dist-info → memra-0.2.15.dist-info}/entry_points.txt +0 -0
@@ -1,851 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
"""
|
3
|
-
Simple MCP Bridge Server for local tool execution
|
4
|
-
"""
|
5
|
-
|
6
|
-
import os
|
7
|
-
import json
|
8
|
-
import hmac
|
9
|
-
import hashlib
|
10
|
-
import logging
|
11
|
-
import asyncio
|
12
|
-
import psycopg2
|
13
|
-
import re
|
14
|
-
from decimal import Decimal
|
15
|
-
from aiohttp import web, web_request
|
16
|
-
from typing import Dict, Any, Optional
|
17
|
-
|
18
|
-
# Add Hugging Face imports
|
19
|
-
try:
|
20
|
-
from huggingface_hub import InferenceClient
|
21
|
-
HF_AVAILABLE = True
|
22
|
-
except ImportError:
|
23
|
-
HF_AVAILABLE = False
|
24
|
-
print("Warning: huggingface_hub not available. Install with: pip install huggingface_hub")
|
25
|
-
|
26
|
-
logging.basicConfig(level=logging.INFO)
|
27
|
-
logger = logging.getLogger(__name__)
|
28
|
-
|
29
|
-
class MCPBridgeServer:
|
30
|
-
def __init__(self, postgres_url: str, bridge_secret: str):
|
31
|
-
self.postgres_url = postgres_url
|
32
|
-
self.bridge_secret = bridge_secret
|
33
|
-
|
34
|
-
# Hugging Face configuration
|
35
|
-
self.hf_api_key = os.getenv("HUGGINGFACE_API_KEY", "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")
|
36
|
-
self.hf_model = os.getenv("HUGGINGFACE_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
|
37
|
-
self.hf_client = None
|
38
|
-
|
39
|
-
# Initialize Hugging Face client if available
|
40
|
-
if HF_AVAILABLE and self.hf_api_key:
|
41
|
-
try:
|
42
|
-
self.hf_client = InferenceClient(
|
43
|
-
model=self.hf_model,
|
44
|
-
token=self.hf_api_key
|
45
|
-
)
|
46
|
-
logger.info(f"Initialized Hugging Face client with model: {self.hf_model}")
|
47
|
-
except Exception as e:
|
48
|
-
logger.warning(f"Failed to initialize Hugging Face client: {e}")
|
49
|
-
self.hf_client = None
|
50
|
-
else:
|
51
|
-
logger.warning("Hugging Face client not available - using fallback pattern matching")
|
52
|
-
|
53
|
-
def verify_signature(self, request_body: str, signature: str) -> bool:
|
54
|
-
"""Verify HMAC signature"""
|
55
|
-
expected = hmac.new(
|
56
|
-
self.bridge_secret.encode(),
|
57
|
-
request_body.encode(),
|
58
|
-
hashlib.sha256
|
59
|
-
).hexdigest()
|
60
|
-
return hmac.compare_digest(expected, signature)
|
61
|
-
|
62
|
-
async def execute_tool(self, request: web_request.Request) -> web.Response:
|
63
|
-
"""Execute MCP tool endpoint"""
|
64
|
-
try:
|
65
|
-
# Get request body
|
66
|
-
body = await request.text()
|
67
|
-
data = json.loads(body)
|
68
|
-
|
69
|
-
# Verify signature
|
70
|
-
signature = request.headers.get('X-Bridge-Secret')
|
71
|
-
if not signature or signature != self.bridge_secret:
|
72
|
-
logger.warning("Invalid or missing bridge secret")
|
73
|
-
return web.json_response({
|
74
|
-
"success": False,
|
75
|
-
"error": "Invalid authentication"
|
76
|
-
}, status=401)
|
77
|
-
|
78
|
-
tool_name = data.get('tool_name')
|
79
|
-
input_data = data.get('input_data', {})
|
80
|
-
|
81
|
-
logger.info(f"Executing MCP tool: {tool_name}")
|
82
|
-
|
83
|
-
if tool_name == "DataValidator":
|
84
|
-
result = await self.data_validator(input_data)
|
85
|
-
elif tool_name == "PostgresInsert":
|
86
|
-
result = await self.postgres_insert(input_data)
|
87
|
-
elif tool_name == "SQLExecutor":
|
88
|
-
result = await self.sql_executor(input_data)
|
89
|
-
elif tool_name == "TextToSQLGenerator":
|
90
|
-
result = await self.text_to_sql_generator(input_data)
|
91
|
-
else:
|
92
|
-
return web.json_response({
|
93
|
-
"success": False,
|
94
|
-
"error": f"Unknown tool: {tool_name}"
|
95
|
-
}, status=400)
|
96
|
-
|
97
|
-
return web.json_response(result)
|
98
|
-
|
99
|
-
except Exception as e:
|
100
|
-
logger.error(f"Tool execution failed: {str(e)}")
|
101
|
-
return web.json_response({
|
102
|
-
"success": False,
|
103
|
-
"error": str(e)
|
104
|
-
}, status=500)
|
105
|
-
|
106
|
-
async def data_validator(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
107
|
-
"""Validate data against schema"""
|
108
|
-
try:
|
109
|
-
invoice_data = input_data.get('invoice_data', {})
|
110
|
-
|
111
|
-
# Perform basic validation
|
112
|
-
validation_errors = []
|
113
|
-
|
114
|
-
# Check required fields
|
115
|
-
required_fields = ['headerSection', 'billingDetails', 'chargesSummary']
|
116
|
-
for field in required_fields:
|
117
|
-
if field not in invoice_data:
|
118
|
-
validation_errors.append(f"Missing required field: {field}")
|
119
|
-
|
120
|
-
# Validate header section
|
121
|
-
if 'headerSection' in invoice_data:
|
122
|
-
header = invoice_data['headerSection']
|
123
|
-
if not header.get('vendorName'):
|
124
|
-
validation_errors.append("Missing vendor name in header")
|
125
|
-
if not header.get('subtotal'):
|
126
|
-
validation_errors.append("Missing subtotal in header")
|
127
|
-
|
128
|
-
# Validate billing details
|
129
|
-
if 'billingDetails' in invoice_data:
|
130
|
-
billing = invoice_data['billingDetails']
|
131
|
-
if not billing.get('invoiceNumber'):
|
132
|
-
validation_errors.append("Missing invoice number")
|
133
|
-
if not billing.get('invoiceDate'):
|
134
|
-
validation_errors.append("Missing invoice date")
|
135
|
-
|
136
|
-
is_valid = len(validation_errors) == 0
|
137
|
-
|
138
|
-
logger.info(f"Data validation completed: {'valid' if is_valid else 'invalid'}")
|
139
|
-
|
140
|
-
return {
|
141
|
-
"success": True,
|
142
|
-
"data": {
|
143
|
-
"is_valid": is_valid,
|
144
|
-
"validation_errors": validation_errors,
|
145
|
-
"validated_data": invoice_data
|
146
|
-
}
|
147
|
-
}
|
148
|
-
|
149
|
-
except Exception as e:
|
150
|
-
logger.error(f"Data validation failed: {str(e)}")
|
151
|
-
return {
|
152
|
-
"success": False,
|
153
|
-
"error": str(e)
|
154
|
-
}
|
155
|
-
|
156
|
-
async def postgres_insert(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
157
|
-
"""Insert data into PostgreSQL"""
|
158
|
-
try:
|
159
|
-
invoice_data = input_data.get('invoice_data', {})
|
160
|
-
table_name = input_data.get('table_name', 'invoices')
|
161
|
-
|
162
|
-
# Extract key fields from invoice data
|
163
|
-
header = invoice_data.get('headerSection', {})
|
164
|
-
billing = invoice_data.get('billingDetails', {})
|
165
|
-
charges = invoice_data.get('chargesSummary', {})
|
166
|
-
|
167
|
-
# Prepare insert data
|
168
|
-
insert_data = {
|
169
|
-
'invoice_number': billing.get('invoiceNumber', ''),
|
170
|
-
'vendor_name': header.get('vendorName', ''),
|
171
|
-
'invoice_date': billing.get('invoiceDate', ''),
|
172
|
-
'total_amount': charges.get('document_total', 0),
|
173
|
-
'tax_amount': charges.get('secondary_tax', 0),
|
174
|
-
'line_items': json.dumps(charges.get('lineItemsBreakdown', [])),
|
175
|
-
'status': 'processed'
|
176
|
-
}
|
177
|
-
|
178
|
-
# Connect to database and insert
|
179
|
-
conn = psycopg2.connect(self.postgres_url)
|
180
|
-
cursor = conn.cursor()
|
181
|
-
|
182
|
-
# Build insert query
|
183
|
-
columns = ', '.join(insert_data.keys())
|
184
|
-
placeholders = ', '.join(['%s'] * len(insert_data))
|
185
|
-
query = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders}) RETURNING id"
|
186
|
-
|
187
|
-
cursor.execute(query, list(insert_data.values()))
|
188
|
-
record_id = cursor.fetchone()[0]
|
189
|
-
|
190
|
-
conn.commit()
|
191
|
-
cursor.close()
|
192
|
-
conn.close()
|
193
|
-
|
194
|
-
logger.info(f"Successfully inserted record with ID: {record_id}")
|
195
|
-
|
196
|
-
return {
|
197
|
-
"success": True,
|
198
|
-
"data": {
|
199
|
-
"success": True,
|
200
|
-
"record_id": record_id,
|
201
|
-
"database_table": table_name,
|
202
|
-
"inserted_data": insert_data
|
203
|
-
}
|
204
|
-
}
|
205
|
-
|
206
|
-
except Exception as e:
|
207
|
-
logger.error(f"Database insert failed: {str(e)}")
|
208
|
-
return {
|
209
|
-
"success": False,
|
210
|
-
"error": str(e)
|
211
|
-
}
|
212
|
-
|
213
|
-
async def sql_executor(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
214
|
-
"""Execute SQL query against PostgreSQL"""
|
215
|
-
try:
|
216
|
-
sql_query = input_data.get('sql_query', '')
|
217
|
-
|
218
|
-
if not sql_query:
|
219
|
-
return {
|
220
|
-
"success": False,
|
221
|
-
"error": "No SQL query provided"
|
222
|
-
}
|
223
|
-
|
224
|
-
# Connect to database and execute query
|
225
|
-
conn = psycopg2.connect(self.postgres_url)
|
226
|
-
cursor = conn.cursor()
|
227
|
-
|
228
|
-
# Execute the query
|
229
|
-
cursor.execute(sql_query)
|
230
|
-
|
231
|
-
# Fetch results if it's a SELECT query
|
232
|
-
if sql_query.strip().upper().startswith('SELECT'):
|
233
|
-
results = cursor.fetchall()
|
234
|
-
column_names = [desc[0] for desc in cursor.description]
|
235
|
-
|
236
|
-
# Convert to list of dictionaries
|
237
|
-
formatted_results = []
|
238
|
-
for row in results:
|
239
|
-
row_dict = dict(zip(column_names, row))
|
240
|
-
# Convert date/datetime objects to strings for JSON serialization
|
241
|
-
for key, value in row_dict.items():
|
242
|
-
if hasattr(value, 'isoformat'): # datetime, date objects
|
243
|
-
row_dict[key] = value.isoformat()
|
244
|
-
elif isinstance(value, Decimal): # Decimal objects
|
245
|
-
row_dict[key] = float(value)
|
246
|
-
formatted_results.append(row_dict)
|
247
|
-
|
248
|
-
logger.info(f"SQL query executed successfully, returned {len(results)} rows")
|
249
|
-
|
250
|
-
return {
|
251
|
-
"success": True,
|
252
|
-
"data": {
|
253
|
-
"query": sql_query,
|
254
|
-
"results": formatted_results,
|
255
|
-
"row_count": len(results),
|
256
|
-
"columns": column_names
|
257
|
-
}
|
258
|
-
}
|
259
|
-
else:
|
260
|
-
# For non-SELECT queries (INSERT, UPDATE, DELETE)
|
261
|
-
conn.commit()
|
262
|
-
affected_rows = cursor.rowcount
|
263
|
-
|
264
|
-
logger.info(f"SQL query executed successfully, affected {affected_rows} rows")
|
265
|
-
|
266
|
-
return {
|
267
|
-
"success": True,
|
268
|
-
"data": {
|
269
|
-
"query": sql_query,
|
270
|
-
"affected_rows": affected_rows,
|
271
|
-
"message": "Query executed successfully"
|
272
|
-
}
|
273
|
-
}
|
274
|
-
|
275
|
-
except Exception as e:
|
276
|
-
logger.error(f"SQL execution failed: {str(e)}")
|
277
|
-
return {
|
278
|
-
"success": False,
|
279
|
-
"error": str(e)
|
280
|
-
}
|
281
|
-
finally:
|
282
|
-
if 'cursor' in locals():
|
283
|
-
cursor.close()
|
284
|
-
if 'conn' in locals():
|
285
|
-
conn.close()
|
286
|
-
|
287
|
-
async def text_to_sql_generator(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
288
|
-
"""Generate SQL from natural language using LLM or fallback to pattern matching"""
|
289
|
-
try:
|
290
|
-
question = input_data.get('question', '')
|
291
|
-
schema_info = input_data.get('schema_info', {})
|
292
|
-
|
293
|
-
if not question:
|
294
|
-
return {
|
295
|
-
"success": False,
|
296
|
-
"error": "No question provided"
|
297
|
-
}
|
298
|
-
|
299
|
-
# If no schema provided or incomplete, fetch it dynamically
|
300
|
-
if not schema_info or not schema_info.get('schema', {}).get('invoices', {}).get('columns'):
|
301
|
-
logger.info("No schema provided, fetching dynamically from database")
|
302
|
-
schema_info = await self.get_table_schema("invoices")
|
303
|
-
|
304
|
-
# Try LLM-based generation first
|
305
|
-
if self.hf_client:
|
306
|
-
try:
|
307
|
-
return await self._llm_text_to_sql(question, schema_info)
|
308
|
-
except Exception as e:
|
309
|
-
logger.warning(f"LLM text-to-SQL failed, falling back to pattern matching: {e}")
|
310
|
-
|
311
|
-
# Fallback to pattern matching
|
312
|
-
return await self._pattern_text_to_sql(question, schema_info)
|
313
|
-
|
314
|
-
except Exception as e:
|
315
|
-
logger.error(f"Text-to-SQL generation failed: {str(e)}")
|
316
|
-
return {
|
317
|
-
"success": False,
|
318
|
-
"error": str(e)
|
319
|
-
}
|
320
|
-
|
321
|
-
async def _llm_text_to_sql(self, question: str, schema_info: Dict[str, Any]) -> Dict[str, Any]:
|
322
|
-
"""Generate SQL using Hugging Face LLM"""
|
323
|
-
|
324
|
-
# Extract schema information
|
325
|
-
tables = schema_info.get('schema', {})
|
326
|
-
table_name = 'invoices' # Default table
|
327
|
-
columns = []
|
328
|
-
|
329
|
-
# Get column information from schema
|
330
|
-
if table_name in tables:
|
331
|
-
table_info = tables[table_name]
|
332
|
-
if 'columns' in table_info:
|
333
|
-
columns = [f"{col['name']} ({col['type']})" for col in table_info['columns']]
|
334
|
-
|
335
|
-
# If no schema info, use default columns
|
336
|
-
if not columns:
|
337
|
-
columns = [
|
338
|
-
'id (integer)',
|
339
|
-
'vendor_name (text)',
|
340
|
-
'invoice_number (text)',
|
341
|
-
'invoice_date (date)',
|
342
|
-
'total_amount (numeric)',
|
343
|
-
'tax_amount (numeric)',
|
344
|
-
'line_items (jsonb)',
|
345
|
-
'status (text)'
|
346
|
-
]
|
347
|
-
|
348
|
-
# Create the prompt for the LLM
|
349
|
-
schema_text = f"Table: {table_name}\nColumns: {', '.join(columns)}"
|
350
|
-
|
351
|
-
# Comprehensive prompt with detailed instructions and examples
|
352
|
-
prompt = f"""You are a PostgreSQL SQL query generator. Convert natural language questions into valid PostgreSQL queries.
|
353
|
-
|
354
|
-
IMPORTANT RULES:
|
355
|
-
1. ALWAYS return a complete, valid SQL query
|
356
|
-
2. Use ONLY the table and columns provided in the schema
|
357
|
-
3. Use PostgreSQL syntax (ILIKE for case-insensitive matching)
|
358
|
-
4. For aggregations with GROUP BY, don't include non-aggregated columns in ORDER BY unless they're in GROUP BY
|
359
|
-
5. Use appropriate aliases for calculated columns (as count, as total, as average, etc.)
|
360
|
-
6. For date queries, use proper date functions and comparisons
|
361
|
-
|
362
|
-
TABLE SCHEMA:
|
363
|
-
Table: invoices
|
364
|
-
Columns: {', '.join(columns)}
|
365
|
-
|
366
|
-
QUERY PATTERNS AND EXAMPLES:
|
367
|
-
|
368
|
-
1. COUNT QUERIES:
|
369
|
-
Q: How many invoices are there?
|
370
|
-
A: SELECT COUNT(*) as count FROM invoices
|
371
|
-
|
372
|
-
Q: How many invoices from Air Liquide?
|
373
|
-
A: SELECT COUNT(*) as count FROM invoices WHERE vendor_name ILIKE '%air liquide%'
|
374
|
-
|
375
|
-
2. VENDOR FILTERING:
|
376
|
-
Q: Show me all invoices from Air Liquide
|
377
|
-
A: SELECT * FROM invoices WHERE vendor_name ILIKE '%air liquide%'
|
378
|
-
|
379
|
-
Q: Find invoices from Microsoft
|
380
|
-
A: SELECT * FROM invoices WHERE vendor_name ILIKE '%microsoft%'
|
381
|
-
|
382
|
-
3. AGGREGATION QUERIES:
|
383
|
-
Q: What is the total amount of all invoices?
|
384
|
-
A: SELECT SUM(total_amount) as total FROM invoices
|
385
|
-
|
386
|
-
Q: What is the average invoice amount?
|
387
|
-
A: SELECT AVG(total_amount) as average FROM invoices
|
388
|
-
|
389
|
-
Q: What is the highest invoice amount?
|
390
|
-
A: SELECT MAX(total_amount) as max_amount FROM invoices
|
391
|
-
|
392
|
-
4. GROUPING QUERIES:
|
393
|
-
Q: Show me invoices grouped by date
|
394
|
-
A: SELECT invoice_date, COUNT(*) as count FROM invoices GROUP BY invoice_date ORDER BY invoice_date
|
395
|
-
|
396
|
-
Q: Show me invoice counts by vendor
|
397
|
-
A: SELECT vendor_name, COUNT(*) as count FROM invoices GROUP BY vendor_name ORDER BY count DESC
|
398
|
-
|
399
|
-
Q: Who is the primary vendor?
|
400
|
-
A: SELECT vendor_name, COUNT(*) as count FROM invoices GROUP BY vendor_name ORDER BY count DESC LIMIT 1
|
401
|
-
|
402
|
-
5. SORTING AND LIMITING:
|
403
|
-
Q: Show me the 3 most recent invoices
|
404
|
-
A: SELECT * FROM invoices ORDER BY invoice_date DESC LIMIT 3
|
405
|
-
|
406
|
-
Q: Show me the oldest invoice
|
407
|
-
A: SELECT * FROM invoices ORDER BY invoice_date ASC LIMIT 1
|
408
|
-
|
409
|
-
6. AMOUNT FILTERING:
|
410
|
-
Q: Find invoices with amounts greater than 1000
|
411
|
-
A: SELECT * FROM invoices WHERE total_amount > 1000
|
412
|
-
|
413
|
-
Q: Show me invoices under 500
|
414
|
-
A: SELECT * FROM invoices WHERE total_amount < 500
|
415
|
-
|
416
|
-
7. DATE QUERIES:
|
417
|
-
Q: What is the most recent invoice date?
|
418
|
-
A: SELECT MAX(invoice_date) as latest_date FROM invoices
|
419
|
-
|
420
|
-
Q: Show me invoices from this year
|
421
|
-
A: SELECT * FROM invoices WHERE EXTRACT(YEAR FROM invoice_date) = EXTRACT(YEAR FROM CURRENT_DATE)
|
422
|
-
|
423
|
-
Q: What are the invoices created this month?
|
424
|
-
A: SELECT * FROM invoices WHERE EXTRACT(YEAR FROM created_at) = EXTRACT(YEAR FROM CURRENT_DATE) AND EXTRACT(MONTH FROM created_at) = EXTRACT(MONTH FROM CURRENT_DATE)
|
425
|
-
|
426
|
-
Q: Show me invoices from last month
|
427
|
-
A: SELECT * FROM invoices WHERE EXTRACT(YEAR FROM invoice_date) = EXTRACT(YEAR FROM CURRENT_DATE - INTERVAL '1 month') AND EXTRACT(MONTH FROM invoice_date) = EXTRACT(MONTH FROM CURRENT_DATE - INTERVAL '1 month')
|
428
|
-
|
429
|
-
8. DISTINCT QUERIES:
|
430
|
-
Q: Show me all the vendors
|
431
|
-
A: SELECT DISTINCT vendor_name FROM invoices ORDER BY vendor_name
|
432
|
-
|
433
|
-
Q: What are all the different invoice dates?
|
434
|
-
A: SELECT DISTINCT invoice_date FROM invoices ORDER BY invoice_date
|
435
|
-
|
436
|
-
9. COMPLEX VENDOR ANALYSIS:
|
437
|
-
Q: Which vendor has the highest total invoice amount?
|
438
|
-
A: SELECT vendor_name, SUM(total_amount) as total FROM invoices GROUP BY vendor_name ORDER BY total DESC LIMIT 1
|
439
|
-
|
440
|
-
Q: Show me vendor totals
|
441
|
-
A: SELECT vendor_name, SUM(total_amount) as total, COUNT(*) as count FROM invoices GROUP BY vendor_name ORDER BY total DESC
|
442
|
-
|
443
|
-
10. LINE ITEMS (JSONB):
|
444
|
-
Q: Show me all the line item costs
|
445
|
-
A: SELECT vendor_name, invoice_number, line_items FROM invoices WHERE line_items IS NOT NULL
|
446
|
-
|
447
|
-
Q: What are the line item details?
|
448
|
-
A: SELECT id, vendor_name, line_items FROM invoices WHERE line_items IS NOT NULL AND line_items != '[]'
|
449
|
-
|
450
|
-
Q: Which invoice contains a line item for 'Electricity'?
|
451
|
-
A: SELECT * FROM invoices WHERE line_items::text ILIKE '%electricity%'
|
452
|
-
|
453
|
-
Q: Find invoices with line items containing 'PROPANE'
|
454
|
-
A: SELECT * FROM invoices WHERE line_items::text ILIKE '%propane%'
|
455
|
-
|
456
|
-
IMPORTANT:
|
457
|
-
- Always return a complete SQL query starting with SELECT
|
458
|
-
- Never return partial queries or just "SELECT"
|
459
|
-
- Use proper PostgreSQL syntax
|
460
|
-
- Include appropriate WHERE, GROUP BY, ORDER BY, and LIMIT clauses as needed
|
461
|
-
- For vendor searches, use ILIKE with % wildcards for partial matching
|
462
|
-
|
463
|
-
Question: {question}
|
464
|
-
SQL Query:
|
465
|
-
"""
|
466
|
-
|
467
|
-
try:
|
468
|
-
# Call Hugging Face API with improved parameters
|
469
|
-
response = self.hf_client.text_generation(
|
470
|
-
prompt,
|
471
|
-
max_new_tokens=150, # Increased for complex queries
|
472
|
-
temperature=0.05, # Lower temperature for more deterministic output
|
473
|
-
do_sample=True,
|
474
|
-
stop_sequences=["\n\n", "Q:", "Question:", "Examples:"], # Reduced to 4 sequences
|
475
|
-
return_full_text=False # Only return the generated part
|
476
|
-
)
|
477
|
-
|
478
|
-
# Extract SQL from response
|
479
|
-
sql_query = response.strip()
|
480
|
-
|
481
|
-
# Clean up the response - remove any extra text and extract SQL
|
482
|
-
if "SELECT" in sql_query.upper():
|
483
|
-
# Find the SQL query part
|
484
|
-
lines = sql_query.split('\n')
|
485
|
-
for line in lines:
|
486
|
-
line = line.strip()
|
487
|
-
if line.upper().startswith('SELECT'):
|
488
|
-
sql_query = line.rstrip(';')
|
489
|
-
break
|
490
|
-
else:
|
491
|
-
# If no line starts with SELECT, try to extract from the whole response
|
492
|
-
sql_match = re.search(r'(SELECT[^;]+)', sql_query, re.IGNORECASE | re.DOTALL)
|
493
|
-
if sql_match:
|
494
|
-
sql_query = sql_match.group(1).strip()
|
495
|
-
|
496
|
-
# Final cleanup
|
497
|
-
sql_query = sql_query.replace('\n', ' ').strip()
|
498
|
-
|
499
|
-
# Validate the SQL contains basic components
|
500
|
-
if not sql_query.upper().strip().startswith('SELECT'):
|
501
|
-
raise ValueError(f"Generated response is not a valid SQL query: '{sql_query}'")
|
502
|
-
|
503
|
-
# Check for incomplete queries
|
504
|
-
if len(sql_query.strip()) < 15 or sql_query.upper().strip() == 'SELECT':
|
505
|
-
raise ValueError(f"Generated incomplete SQL query: '{sql_query}'")
|
506
|
-
|
507
|
-
logger.info(f"LLM generated SQL for question: '{question}' -> {sql_query}")
|
508
|
-
|
509
|
-
return {
|
510
|
-
"success": True,
|
511
|
-
"data": {
|
512
|
-
"question": question,
|
513
|
-
"generated_sql": sql_query,
|
514
|
-
"explanation": f"Generated using {self.hf_model} LLM with schema context",
|
515
|
-
"confidence": "high",
|
516
|
-
"method": "llm",
|
517
|
-
"schema_used": {
|
518
|
-
"table": table_name,
|
519
|
-
"columns": [col.split(' (')[0] for col in columns]
|
520
|
-
}
|
521
|
-
}
|
522
|
-
}
|
523
|
-
|
524
|
-
except Exception as e:
|
525
|
-
logger.error(f"LLM text-to-SQL generation failed: {str(e)}")
|
526
|
-
raise e
|
527
|
-
|
528
|
-
async def _pattern_text_to_sql(self, question: str, schema_info: Dict[str, Any]) -> Dict[str, Any]:
|
529
|
-
"""Generate SQL using pattern matching (fallback method)"""
|
530
|
-
|
531
|
-
# Extract schema information for better SQL generation
|
532
|
-
tables = schema_info.get('schema', {})
|
533
|
-
table_name = 'invoices' # Default table
|
534
|
-
columns = []
|
535
|
-
|
536
|
-
# Get column information from schema
|
537
|
-
if table_name in tables:
|
538
|
-
table_info = tables[table_name]
|
539
|
-
if 'columns' in table_info:
|
540
|
-
columns = [col['name'] for col in table_info['columns']]
|
541
|
-
column_types = {col['name']: col['type'] for col in table_info['columns']}
|
542
|
-
|
543
|
-
# If no schema info, use default columns
|
544
|
-
if not columns:
|
545
|
-
columns = ['id', 'vendor_name', 'invoice_number', 'invoice_date', 'total_amount', 'line_items']
|
546
|
-
column_types = {
|
547
|
-
'id': 'integer',
|
548
|
-
'vendor_name': 'text',
|
549
|
-
'invoice_number': 'text',
|
550
|
-
'invoice_date': 'date',
|
551
|
-
'total_amount': 'numeric',
|
552
|
-
'line_items': 'jsonb'
|
553
|
-
}
|
554
|
-
|
555
|
-
# Generate SQL based on question and schema context
|
556
|
-
question_lower = question.lower()
|
557
|
-
|
558
|
-
# Determine what columns to select based on question
|
559
|
-
select_clause = "*" # default
|
560
|
-
|
561
|
-
if 'total amount' in question_lower or 'sum' in question_lower:
|
562
|
-
if 'total_amount' in columns:
|
563
|
-
select_clause = "SUM(total_amount) as total"
|
564
|
-
else:
|
565
|
-
select_clause = "SUM(amount) as total" # fallback
|
566
|
-
elif 'count' in question_lower or 'how many' in question_lower:
|
567
|
-
select_clause = "COUNT(*) as count"
|
568
|
-
elif 'average' in question_lower or 'avg' in question_lower:
|
569
|
-
if 'total_amount' in columns:
|
570
|
-
select_clause = "AVG(total_amount) as average"
|
571
|
-
else:
|
572
|
-
select_clause = "AVG(amount) as average" # fallback
|
573
|
-
elif 'vendors' in question_lower or 'companies' in question_lower:
|
574
|
-
if 'who are' in question_lower or 'all the' in question_lower or 'list' in question_lower:
|
575
|
-
select_clause = "DISTINCT vendor_name"
|
576
|
-
elif 'last' in question_lower or 'latest' in question_lower or 'most recent' in question_lower:
|
577
|
-
if 'date' in question_lower:
|
578
|
-
select_clause = "MAX(invoice_date) as latest_date"
|
579
|
-
elif 'invoice' in question_lower:
|
580
|
-
# For "last invoice" or "latest invoice", show the most recent one
|
581
|
-
select_clause = "*"
|
582
|
-
# Will be handled in ORDER BY section
|
583
|
-
elif 'first' in question_lower or 'earliest' in question_lower:
|
584
|
-
if 'date' in question_lower:
|
585
|
-
select_clause = "MIN(invoice_date) as earliest_date"
|
586
|
-
elif 'invoice' in question_lower:
|
587
|
-
select_clause = "*"
|
588
|
-
# Will be handled in ORDER BY section
|
589
|
-
elif 'max' in question_lower or 'maximum' in question_lower or 'highest' in question_lower:
|
590
|
-
if 'amount' in question_lower:
|
591
|
-
select_clause = "MAX(total_amount) as max_amount"
|
592
|
-
elif 'date' in question_lower:
|
593
|
-
select_clause = "MAX(invoice_date) as max_date"
|
594
|
-
elif 'min' in question_lower or 'minimum' in question_lower or 'lowest' in question_lower:
|
595
|
-
if 'amount' in question_lower:
|
596
|
-
select_clause = "MIN(total_amount) as min_amount"
|
597
|
-
elif 'date' in question_lower:
|
598
|
-
select_clause = "MIN(invoice_date) as min_date"
|
599
|
-
|
600
|
-
# Build WHERE clause based on question
|
601
|
-
where_clause = ""
|
602
|
-
|
603
|
-
# Look for vendor filtering patterns
|
604
|
-
vendor_patterns = [
|
605
|
-
('from', 'from'), # "invoices from Air Liquide"
|
606
|
-
('by', 'by'), # "invoices by Microsoft"
|
607
|
-
('for', 'for'), # "invoices for Apple"
|
608
|
-
]
|
609
|
-
|
610
|
-
vendor_name = None
|
611
|
-
for pattern, keyword in vendor_patterns:
|
612
|
-
if keyword in question_lower:
|
613
|
-
parts = question_lower.split(keyword)
|
614
|
-
if len(parts) > 1:
|
615
|
-
# Extract vendor name after the keyword
|
616
|
-
vendor_part = parts[1].strip()
|
617
|
-
# Remove common trailing words
|
618
|
-
vendor_part = vendor_part.replace(' invoices', '').replace(' invoice', '').strip()
|
619
|
-
# Take first few words as vendor name
|
620
|
-
vendor_words = vendor_part.split()[:3] # Max 3 words for vendor name
|
621
|
-
if vendor_words:
|
622
|
-
vendor_name = ' '.join(vendor_words).strip('"\'.,?!')
|
623
|
-
break
|
624
|
-
|
625
|
-
# Also check for direct company name patterns
|
626
|
-
if not vendor_name:
|
627
|
-
# Look for patterns like "Air Liquide invoices" or "Microsoft invoices"
|
628
|
-
# Match capitalized words that might be company names
|
629
|
-
company_pattern = r'\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)\s+invoices?'
|
630
|
-
match = re.search(company_pattern, question)
|
631
|
-
if match:
|
632
|
-
vendor_name = match.group(1)
|
633
|
-
|
634
|
-
if vendor_name:
|
635
|
-
if 'vendor_name' in columns:
|
636
|
-
where_clause = f"WHERE vendor_name ILIKE '%{vendor_name}%'"
|
637
|
-
elif 'vendor' in columns:
|
638
|
-
where_clause = f"WHERE vendor ILIKE '%{vendor_name}%'"
|
639
|
-
|
640
|
-
# Build ORDER BY clause
|
641
|
-
order_clause = ""
|
642
|
-
limit_clause = ""
|
643
|
-
|
644
|
-
if 'recent' in question_lower or 'latest' in question_lower or 'last' in question_lower:
|
645
|
-
if 'invoice_date' in columns:
|
646
|
-
order_clause = "ORDER BY invoice_date DESC"
|
647
|
-
elif 'date' in columns:
|
648
|
-
order_clause = "ORDER BY date DESC"
|
649
|
-
elif 'created_at' in columns:
|
650
|
-
order_clause = "ORDER BY created_at DESC"
|
651
|
-
|
652
|
-
# Add LIMIT for recent queries if it's asking for specific invoices
|
653
|
-
if 'invoice' in question_lower and select_clause == "*":
|
654
|
-
# Extract number if specified
|
655
|
-
numbers = re.findall(r'\d+', question)
|
656
|
-
limit = numbers[0] if numbers else "1" # Default to 1 for "last invoice"
|
657
|
-
limit_clause = f"LIMIT {limit}"
|
658
|
-
|
659
|
-
elif 'first' in question_lower or 'earliest' in question_lower:
|
660
|
-
if 'invoice_date' in columns:
|
661
|
-
order_clause = "ORDER BY invoice_date ASC"
|
662
|
-
elif 'date' in columns:
|
663
|
-
order_clause = "ORDER BY date ASC"
|
664
|
-
elif 'created_at' in columns:
|
665
|
-
order_clause = "ORDER BY created_at ASC"
|
666
|
-
|
667
|
-
# Add LIMIT for earliest queries if it's asking for specific invoices
|
668
|
-
if 'invoice' in question_lower and select_clause == "*":
|
669
|
-
numbers = re.findall(r'\d+', question)
|
670
|
-
limit = numbers[0] if numbers else "1"
|
671
|
-
limit_clause = f"LIMIT {limit}"
|
672
|
-
|
673
|
-
elif re.search(r'\d+', question) and ('recent' in question_lower or 'latest' in question_lower or 'last' in question_lower):
|
674
|
-
# Handle "Show me the 5 most recent invoices"
|
675
|
-
if 'invoice_date' in columns:
|
676
|
-
order_clause = "ORDER BY invoice_date DESC"
|
677
|
-
numbers = re.findall(r'\d+', question)
|
678
|
-
if numbers:
|
679
|
-
limit_clause = f"LIMIT {numbers[0]}"
|
680
|
-
|
681
|
-
# Construct the final SQL query
|
682
|
-
sql_parts = [f"SELECT {select_clause}", f"FROM {table_name}"]
|
683
|
-
|
684
|
-
if where_clause:
|
685
|
-
sql_parts.append(where_clause)
|
686
|
-
|
687
|
-
if order_clause:
|
688
|
-
sql_parts.append(order_clause)
|
689
|
-
|
690
|
-
if limit_clause:
|
691
|
-
sql_parts.append(limit_clause)
|
692
|
-
|
693
|
-
sql_query = " ".join(sql_parts)
|
694
|
-
|
695
|
-
# Generate explanation based on schema context
|
696
|
-
explanation_parts = [f"Generated SQL query for table '{table_name}' using pattern matching"]
|
697
|
-
if columns:
|
698
|
-
explanation_parts.append(f"Available columns: {', '.join(columns)}")
|
699
|
-
if where_clause:
|
700
|
-
explanation_parts.append("Applied filtering based on question context")
|
701
|
-
if order_clause:
|
702
|
-
explanation_parts.append("Added sorting and/or limiting based on question")
|
703
|
-
|
704
|
-
explanation = ". ".join(explanation_parts)
|
705
|
-
|
706
|
-
logger.info(f"Pattern matching generated SQL for question: '{question}' -> {sql_query}")
|
707
|
-
logger.info(f"Used schema with columns: {columns}")
|
708
|
-
|
709
|
-
return {
|
710
|
-
"success": True,
|
711
|
-
"data": {
|
712
|
-
"question": question,
|
713
|
-
"generated_sql": sql_query,
|
714
|
-
"explanation": explanation,
|
715
|
-
"confidence": "medium", # Lower confidence for pattern matching
|
716
|
-
"method": "pattern_matching",
|
717
|
-
"schema_used": {
|
718
|
-
"table": table_name,
|
719
|
-
"columns": columns,
|
720
|
-
"column_types": column_types if 'column_types' in locals() else {}
|
721
|
-
}
|
722
|
-
}
|
723
|
-
}
|
724
|
-
|
725
|
-
async def health_check(self, request: web_request.Request) -> web.Response:
|
726
|
-
"""Health check endpoint"""
|
727
|
-
return web.json_response({"status": "healthy", "service": "mcp-bridge"})
|
728
|
-
|
729
|
-
async def get_schema(self, request: web_request.Request) -> web.Response:
|
730
|
-
"""Get database schema endpoint"""
|
731
|
-
try:
|
732
|
-
table_name = request.query.get('table', 'invoices')
|
733
|
-
schema = await self.get_table_schema(table_name)
|
734
|
-
return web.json_response({
|
735
|
-
"success": True,
|
736
|
-
"data": schema
|
737
|
-
})
|
738
|
-
except Exception as e:
|
739
|
-
logger.error(f"Schema fetch failed: {str(e)}")
|
740
|
-
return web.json_response({
|
741
|
-
"success": False,
|
742
|
-
"error": str(e)
|
743
|
-
}, status=500)
|
744
|
-
|
745
|
-
def create_app(self) -> web.Application:
|
746
|
-
"""Create aiohttp application"""
|
747
|
-
app = web.Application()
|
748
|
-
|
749
|
-
# Add routes
|
750
|
-
app.router.add_post('/execute_tool', self.execute_tool)
|
751
|
-
app.router.add_get('/health', self.health_check)
|
752
|
-
app.router.add_get('/get_schema', self.get_schema)
|
753
|
-
|
754
|
-
return app
|
755
|
-
|
756
|
-
async def start(self, port: int = 8081):
|
757
|
-
"""Start the server"""
|
758
|
-
app = self.create_app()
|
759
|
-
runner = web.AppRunner(app)
|
760
|
-
await runner.setup()
|
761
|
-
|
762
|
-
site = web.TCPSite(runner, 'localhost', port)
|
763
|
-
await site.start()
|
764
|
-
|
765
|
-
logger.info(f"MCP Bridge Server started on http://localhost:{port}")
|
766
|
-
logger.info(f"Available endpoints:")
|
767
|
-
logger.info(f" POST /execute_tool - Execute MCP tools")
|
768
|
-
logger.info(f" GET /health - Health check")
|
769
|
-
logger.info(f" GET /get_schema - Get database schema")
|
770
|
-
|
771
|
-
# Keep running
|
772
|
-
try:
|
773
|
-
await asyncio.Future() # Run forever
|
774
|
-
except KeyboardInterrupt:
|
775
|
-
logger.info("Shutting down server...")
|
776
|
-
finally:
|
777
|
-
await runner.cleanup()
|
778
|
-
|
779
|
-
async def get_table_schema(self, table_name: str = "invoices") -> Dict[str, Any]:
|
780
|
-
"""Dynamically fetch table schema from database"""
|
781
|
-
try:
|
782
|
-
conn = psycopg2.connect(self.postgres_url)
|
783
|
-
cursor = conn.cursor()
|
784
|
-
|
785
|
-
# Get column information
|
786
|
-
query = """
|
787
|
-
SELECT column_name, data_type, is_nullable, column_default
|
788
|
-
FROM information_schema.columns
|
789
|
-
WHERE table_name = %s
|
790
|
-
ORDER BY ordinal_position
|
791
|
-
"""
|
792
|
-
cursor.execute(query, (table_name,))
|
793
|
-
columns = cursor.fetchall()
|
794
|
-
|
795
|
-
schema = {
|
796
|
-
"schema": {
|
797
|
-
table_name: {
|
798
|
-
"columns": [
|
799
|
-
{
|
800
|
-
"name": col[0],
|
801
|
-
"type": col[1],
|
802
|
-
"nullable": col[2] == "YES",
|
803
|
-
"default": col[3]
|
804
|
-
}
|
805
|
-
for col in columns
|
806
|
-
]
|
807
|
-
}
|
808
|
-
}
|
809
|
-
}
|
810
|
-
|
811
|
-
cursor.close()
|
812
|
-
conn.close()
|
813
|
-
|
814
|
-
logger.info(f"Dynamically fetched schema for table '{table_name}' with {len(columns)} columns")
|
815
|
-
return schema
|
816
|
-
|
817
|
-
except Exception as e:
|
818
|
-
logger.error(f"Failed to fetch schema for table '{table_name}': {str(e)}")
|
819
|
-
# Fallback to basic schema
|
820
|
-
return {
|
821
|
-
"schema": {
|
822
|
-
table_name: {
|
823
|
-
"columns": [
|
824
|
-
{"name": "id", "type": "integer"},
|
825
|
-
{"name": "vendor_name", "type": "character varying"},
|
826
|
-
{"name": "invoice_date", "type": "date"},
|
827
|
-
{"name": "total_amount", "type": "numeric"}
|
828
|
-
]
|
829
|
-
}
|
830
|
-
}
|
831
|
-
}
|
832
|
-
|
833
|
-
def main():
|
834
|
-
# Get configuration from environment
|
835
|
-
postgres_url = os.getenv('MCP_POSTGRES_URL', 'postgresql://memra:memra123@localhost:5432/memra_invoice_db')
|
836
|
-
bridge_secret = os.getenv('MCP_BRIDGE_SECRET', 'test-secret-for-development')
|
837
|
-
hf_api_key = os.getenv('HUGGINGFACE_API_KEY', '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')
|
838
|
-
hf_model = os.getenv('HUGGINGFACE_MODEL', 'meta-llama/Llama-3.1-8B-Instruct')
|
839
|
-
|
840
|
-
logger.info(f"Starting MCP Bridge Server...")
|
841
|
-
logger.info(f"PostgreSQL URL: {postgres_url}")
|
842
|
-
logger.info(f"Bridge Secret: {'*' * len(bridge_secret)}")
|
843
|
-
logger.info(f"Hugging Face Model: {hf_model}")
|
844
|
-
logger.info(f"Hugging Face API Key: {'*' * (len(hf_api_key) - 8) + hf_api_key[-8:] if hf_api_key else 'Not set'}")
|
845
|
-
|
846
|
-
# Create and start server
|
847
|
-
server = MCPBridgeServer(postgres_url, bridge_secret)
|
848
|
-
asyncio.run(server.start())
|
849
|
-
|
850
|
-
if __name__ == '__main__':
|
851
|
-
main()
|