banko-ai-assistant 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- banko_ai/__init__.py +19 -0
- banko_ai/__main__.py +10 -0
- banko_ai/ai_providers/__init__.py +18 -0
- banko_ai/ai_providers/aws_provider.py +337 -0
- banko_ai/ai_providers/base.py +175 -0
- banko_ai/ai_providers/factory.py +84 -0
- banko_ai/ai_providers/gemini_provider.py +340 -0
- banko_ai/ai_providers/openai_provider.py +295 -0
- banko_ai/ai_providers/watsonx_provider.py +591 -0
- banko_ai/cli.py +374 -0
- banko_ai/config/__init__.py +5 -0
- banko_ai/config/settings.py +216 -0
- banko_ai/static/Anallytics.png +0 -0
- banko_ai/static/Graph.png +0 -0
- banko_ai/static/Graph2.png +0 -0
- banko_ai/static/ai-status.png +0 -0
- banko_ai/static/banko-ai-assistant-watsonx.gif +0 -0
- banko_ai/static/banko-db-ops.png +0 -0
- banko_ai/static/banko-response.png +0 -0
- banko_ai/static/cache-stats.png +0 -0
- banko_ai/static/creditcard.png +0 -0
- banko_ai/static/profilepic.jpeg +0 -0
- banko_ai/static/query_watcher.png +0 -0
- banko_ai/static/roach-logo.svg +54 -0
- banko_ai/static/watsonx-icon.svg +1 -0
- banko_ai/templates/base.html +59 -0
- banko_ai/templates/dashboard.html +569 -0
- banko_ai/templates/index.html +1499 -0
- banko_ai/templates/login.html +41 -0
- banko_ai/utils/__init__.py +8 -0
- banko_ai/utils/cache_manager.py +525 -0
- banko_ai/utils/database.py +202 -0
- banko_ai/utils/migration.py +123 -0
- banko_ai/vector_search/__init__.py +18 -0
- banko_ai/vector_search/enrichment.py +278 -0
- banko_ai/vector_search/generator.py +329 -0
- banko_ai/vector_search/search.py +463 -0
- banko_ai/web/__init__.py +13 -0
- banko_ai/web/app.py +668 -0
- banko_ai/web/auth.py +73 -0
- banko_ai_assistant-1.0.0.dist-info/METADATA +414 -0
- banko_ai_assistant-1.0.0.dist-info/RECORD +46 -0
- banko_ai_assistant-1.0.0.dist-info/WHEEL +5 -0
- banko_ai_assistant-1.0.0.dist-info/entry_points.txt +2 -0
- banko_ai_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
- banko_ai_assistant-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,202 @@
|
|
1
|
+
"""
|
2
|
+
Database management utilities.
|
3
|
+
|
4
|
+
This module provides database schema creation and management functionality.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
from typing import Optional
|
9
|
+
|
10
|
+
|
11
|
+
class DatabaseManager:
|
12
|
+
"""Database management utilities."""
|
13
|
+
|
14
|
+
def __init__(self, database_url: Optional[str] = None):
|
15
|
+
"""Initialize database manager."""
|
16
|
+
self.database_url = database_url or os.getenv('DATABASE_URL', "cockroachdb://root@localhost:26257/banko_ai?sslmode=disable")
|
17
|
+
self._engine = None
|
18
|
+
|
19
|
+
@property
|
20
|
+
def engine(self):
|
21
|
+
"""Get SQLAlchemy engine (lazy import)."""
|
22
|
+
if self._engine is None:
|
23
|
+
from sqlalchemy import create_engine
|
24
|
+
from sqlalchemy.dialects.postgresql.base import PGDialect
|
25
|
+
|
26
|
+
# Monkey patch version parsing to handle CockroachDB
|
27
|
+
original_get_server_version_info = PGDialect._get_server_version_info
|
28
|
+
|
29
|
+
def patched_get_server_version_info(self, connection):
|
30
|
+
try:
|
31
|
+
return original_get_server_version_info(self, connection)
|
32
|
+
except Exception:
|
33
|
+
return (25, 3, 0) # Return compatible version tuple
|
34
|
+
|
35
|
+
PGDialect._get_server_version_info = patched_get_server_version_info
|
36
|
+
|
37
|
+
# Convert cockroachdb:// to postgresql:// for SQLAlchemy compatibility
|
38
|
+
database_url = self.database_url.replace("cockroachdb://", "postgresql://")
|
39
|
+
|
40
|
+
self._engine = create_engine(
|
41
|
+
database_url,
|
42
|
+
connect_args={
|
43
|
+
"options": "-c default_transaction_isolation=serializable"
|
44
|
+
},
|
45
|
+
pool_pre_ping=True,
|
46
|
+
pool_recycle=300
|
47
|
+
)
|
48
|
+
return self._engine
|
49
|
+
|
50
|
+
def create_tables(self) -> bool:
|
51
|
+
"""Create all required tables."""
|
52
|
+
try:
|
53
|
+
from sqlalchemy import text
|
54
|
+
with self.engine.connect() as conn:
|
55
|
+
# Create expenses table with vector support
|
56
|
+
conn.execute(text("""
|
57
|
+
CREATE TABLE IF NOT EXISTS expenses (
|
58
|
+
expense_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
59
|
+
user_id UUID NOT NULL,
|
60
|
+
expense_date DATE NOT NULL,
|
61
|
+
expense_amount DECIMAL(10,2) NOT NULL,
|
62
|
+
shopping_type STRING NOT NULL,
|
63
|
+
description STRING,
|
64
|
+
merchant STRING,
|
65
|
+
payment_method STRING NOT NULL,
|
66
|
+
recurring BOOL DEFAULT false,
|
67
|
+
tags STRING[],
|
68
|
+
embedding VECTOR(384),
|
69
|
+
created_at TIMESTAMP DEFAULT now()
|
70
|
+
)
|
71
|
+
"""))
|
72
|
+
|
73
|
+
# Create vector index for general search
|
74
|
+
conn.execute(text("""
|
75
|
+
CREATE VECTOR INDEX IF NOT EXISTS idx_expenses_embedding
|
76
|
+
ON expenses (embedding)
|
77
|
+
"""))
|
78
|
+
|
79
|
+
# Create user-specific vector index
|
80
|
+
conn.execute(text("""
|
81
|
+
CREATE VECTOR INDEX IF NOT EXISTS idx_expenses_user_embedding
|
82
|
+
ON expenses (user_id, embedding)
|
83
|
+
"""))
|
84
|
+
|
85
|
+
# Create additional indexes for common queries
|
86
|
+
conn.execute(text("""
|
87
|
+
CREATE INDEX IF NOT EXISTS idx_expenses_user_date
|
88
|
+
ON expenses (user_id, expense_date DESC)
|
89
|
+
"""))
|
90
|
+
|
91
|
+
conn.execute(text("""
|
92
|
+
CREATE INDEX IF NOT EXISTS idx_expenses_merchant
|
93
|
+
ON expenses (merchant)
|
94
|
+
"""))
|
95
|
+
|
96
|
+
conn.execute(text("""
|
97
|
+
CREATE INDEX IF NOT EXISTS idx_expenses_shopping_type
|
98
|
+
ON expenses (shopping_type)
|
99
|
+
"""))
|
100
|
+
|
101
|
+
conn.commit()
|
102
|
+
return True
|
103
|
+
|
104
|
+
except Exception as e:
|
105
|
+
print(f"Error creating tables: {e}")
|
106
|
+
return False
|
107
|
+
|
108
|
+
def drop_tables(self) -> bool:
|
109
|
+
"""Drop all tables."""
|
110
|
+
try:
|
111
|
+
from sqlalchemy import text
|
112
|
+
with self.engine.connect() as conn:
|
113
|
+
conn.execute(text("DROP TABLE IF EXISTS expenses CASCADE"))
|
114
|
+
conn.commit()
|
115
|
+
return True
|
116
|
+
except Exception as e:
|
117
|
+
print(f"Error dropping tables: {e}")
|
118
|
+
return False
|
119
|
+
|
120
|
+
def table_exists(self, table_name: str) -> bool:
|
121
|
+
"""Check if a table exists."""
|
122
|
+
try:
|
123
|
+
from sqlalchemy import text
|
124
|
+
with self.engine.connect() as conn:
|
125
|
+
result = conn.execute(text("""
|
126
|
+
SELECT EXISTS (
|
127
|
+
SELECT FROM information_schema.tables
|
128
|
+
WHERE table_name = :table_name
|
129
|
+
)
|
130
|
+
"""), {"table_name": table_name})
|
131
|
+
return result.scalar()
|
132
|
+
except Exception as e:
|
133
|
+
print(f"Error checking table existence: {e}")
|
134
|
+
return False
|
135
|
+
|
136
|
+
def get_table_info(self, table_name: str) -> dict:
|
137
|
+
"""Get table information."""
|
138
|
+
try:
|
139
|
+
from sqlalchemy import text
|
140
|
+
with self.engine.connect() as conn:
|
141
|
+
# Get column information
|
142
|
+
columns_result = conn.execute(text("""
|
143
|
+
SELECT column_name, data_type, is_nullable, column_default
|
144
|
+
FROM information_schema.columns
|
145
|
+
WHERE table_name = :table_name
|
146
|
+
ORDER BY ordinal_position
|
147
|
+
"""), {"table_name": table_name})
|
148
|
+
|
149
|
+
columns = []
|
150
|
+
for row in columns_result:
|
151
|
+
columns.append({
|
152
|
+
"name": row[0],
|
153
|
+
"type": row[1],
|
154
|
+
"nullable": row[2] == "YES",
|
155
|
+
"default": row[3]
|
156
|
+
})
|
157
|
+
|
158
|
+
# Get index information
|
159
|
+
indexes_result = conn.execute(text("""
|
160
|
+
SELECT indexname, indexdef
|
161
|
+
FROM pg_indexes
|
162
|
+
WHERE tablename = :table_name
|
163
|
+
"""), {"table_name": table_name})
|
164
|
+
|
165
|
+
indexes = []
|
166
|
+
for row in indexes_result:
|
167
|
+
indexes.append({
|
168
|
+
"name": row[0],
|
169
|
+
"definition": row[1]
|
170
|
+
})
|
171
|
+
|
172
|
+
return {
|
173
|
+
"table_name": table_name,
|
174
|
+
"columns": columns,
|
175
|
+
"indexes": indexes
|
176
|
+
}
|
177
|
+
|
178
|
+
except Exception as e:
|
179
|
+
print(f"Error getting table info: {e}")
|
180
|
+
return {}
|
181
|
+
|
182
|
+
def get_record_count(self, table_name: str) -> int:
|
183
|
+
"""Get record count for a table."""
|
184
|
+
try:
|
185
|
+
from sqlalchemy import text
|
186
|
+
with self.engine.connect() as conn:
|
187
|
+
result = conn.execute(text(f"SELECT COUNT(*) FROM {table_name}"))
|
188
|
+
return result.scalar()
|
189
|
+
except Exception as e:
|
190
|
+
print(f"Error getting record count: {e}")
|
191
|
+
return 0
|
192
|
+
|
193
|
+
def test_connection(self) -> bool:
|
194
|
+
"""Test database connection."""
|
195
|
+
try:
|
196
|
+
from sqlalchemy import text
|
197
|
+
with self.engine.connect() as conn:
|
198
|
+
conn.execute(text("SELECT 1"))
|
199
|
+
return True
|
200
|
+
except Exception as e:
|
201
|
+
print(f"Database connection failed: {e}")
|
202
|
+
return False
|
@@ -0,0 +1,123 @@
|
|
1
|
+
"""
|
2
|
+
Database migration utilities.
|
3
|
+
|
4
|
+
This module provides migration scripts to update the database schema
|
5
|
+
for user-specific vector indexing and other enhancements.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
from typing import Optional
|
10
|
+
|
11
|
+
|
12
|
+
class DatabaseMigration:
|
13
|
+
"""Database migration utilities."""
|
14
|
+
|
15
|
+
def __init__(self, database_url: Optional[str] = None):
|
16
|
+
"""Initialize migration manager."""
|
17
|
+
self.database_url = database_url or os.getenv('DATABASE_URL', "cockroachdb://root@localhost:26257/defaultdb?sslmode=disable")
|
18
|
+
self._engine = None
|
19
|
+
|
20
|
+
@property
|
21
|
+
def engine(self):
|
22
|
+
"""Get SQLAlchemy engine (lazy import)."""
|
23
|
+
if self._engine is None:
|
24
|
+
from sqlalchemy import create_engine
|
25
|
+
self._engine = create_engine(self.database_url)
|
26
|
+
return self._engine
|
27
|
+
|
28
|
+
def migrate_to_user_specific_indexing(self) -> bool:
|
29
|
+
"""Migrate database to support user-specific vector indexing."""
|
30
|
+
try:
|
31
|
+
from sqlalchemy import text
|
32
|
+
with self.engine.connect() as conn:
|
33
|
+
# Check if user_id column exists
|
34
|
+
result = conn.execute(text("""
|
35
|
+
SELECT column_name
|
36
|
+
FROM information_schema.columns
|
37
|
+
WHERE table_name = 'expenses' AND column_name = 'user_id'
|
38
|
+
"""))
|
39
|
+
|
40
|
+
if not result.fetchone():
|
41
|
+
# Add user_id column if it doesn't exist
|
42
|
+
conn.execute(text("""
|
43
|
+
ALTER TABLE expenses
|
44
|
+
ADD COLUMN user_id UUID DEFAULT gen_random_uuid()
|
45
|
+
"""))
|
46
|
+
print("Added user_id column to expenses table")
|
47
|
+
|
48
|
+
# Create user-specific vector index
|
49
|
+
conn.execute(text("""
|
50
|
+
CREATE INDEX IF NOT EXISTS idx_expenses_user_embedding
|
51
|
+
ON expenses (user_id, embedding)
|
52
|
+
USING ivfflat (embedding vector_cosine_ops)
|
53
|
+
WITH (lists = 100)
|
54
|
+
"""))
|
55
|
+
print("Created user-specific vector index")
|
56
|
+
|
57
|
+
# Create regional index if supported
|
58
|
+
try:
|
59
|
+
conn.execute(text("""
|
60
|
+
CREATE INDEX IF NOT EXISTS idx_expenses_user_embedding_regional
|
61
|
+
ON expenses (user_id, embedding)
|
62
|
+
LOCALITY REGIONAL BY ROW AS region
|
63
|
+
"""))
|
64
|
+
print("Created regional user-specific vector index")
|
65
|
+
except Exception as e:
|
66
|
+
print(f"Regional indexing not supported: {e}")
|
67
|
+
|
68
|
+
# Create additional indexes for user queries
|
69
|
+
conn.execute(text("""
|
70
|
+
CREATE INDEX IF NOT EXISTS idx_expenses_user_date
|
71
|
+
ON expenses (user_id, expense_date DESC)
|
72
|
+
"""))
|
73
|
+
print("Created user date index")
|
74
|
+
|
75
|
+
conn.commit()
|
76
|
+
return True
|
77
|
+
|
78
|
+
except Exception as e:
|
79
|
+
print(f"Migration failed: {e}")
|
80
|
+
return False
|
81
|
+
|
82
|
+
def add_created_at_column(self) -> bool:
|
83
|
+
"""Add created_at timestamp column."""
|
84
|
+
try:
|
85
|
+
from sqlalchemy import text
|
86
|
+
with self.engine.connect() as conn:
|
87
|
+
# Check if created_at column exists
|
88
|
+
result = conn.execute(text("""
|
89
|
+
SELECT column_name
|
90
|
+
FROM information_schema.columns
|
91
|
+
WHERE table_name = 'expenses' AND column_name = 'created_at'
|
92
|
+
"""))
|
93
|
+
|
94
|
+
if not result.fetchone():
|
95
|
+
conn.execute(text("""
|
96
|
+
ALTER TABLE expenses
|
97
|
+
ADD COLUMN created_at TIMESTAMP DEFAULT now()
|
98
|
+
"""))
|
99
|
+
print("Added created_at column to expenses table")
|
100
|
+
conn.commit()
|
101
|
+
return True
|
102
|
+
else:
|
103
|
+
print("created_at column already exists")
|
104
|
+
return True
|
105
|
+
|
106
|
+
except Exception as e:
|
107
|
+
print(f"Failed to add created_at column: {e}")
|
108
|
+
return False
|
109
|
+
|
110
|
+
def run_all_migrations(self) -> bool:
|
111
|
+
"""Run all pending migrations."""
|
112
|
+
print("Running database migrations...")
|
113
|
+
|
114
|
+
success = True
|
115
|
+
success &= self.add_created_at_column()
|
116
|
+
success &= self.migrate_to_user_specific_indexing()
|
117
|
+
|
118
|
+
if success:
|
119
|
+
print("All migrations completed successfully")
|
120
|
+
else:
|
121
|
+
print("Some migrations failed")
|
122
|
+
|
123
|
+
return success
|
@@ -0,0 +1,18 @@
|
|
1
|
+
"""Vector search functionality for Banko AI Assistant."""
|
2
|
+
|
3
|
+
def get_data_enricher():
|
4
|
+
"""Get DataEnricher (lazy import)."""
|
5
|
+
from .enrichment import DataEnricher
|
6
|
+
return DataEnricher
|
7
|
+
|
8
|
+
def get_vector_search_engine():
|
9
|
+
"""Get VectorSearchEngine (lazy import)."""
|
10
|
+
from .search import VectorSearchEngine
|
11
|
+
return VectorSearchEngine
|
12
|
+
|
13
|
+
def get_enhanced_expense_generator():
|
14
|
+
"""Get EnhancedExpenseGenerator (lazy import)."""
|
15
|
+
from .generator import EnhancedExpenseGenerator
|
16
|
+
return EnhancedExpenseGenerator
|
17
|
+
|
18
|
+
__all__ = ["get_data_enricher", "get_vector_search_engine", "get_enhanced_expense_generator"]
|
@@ -0,0 +1,278 @@
|
|
1
|
+
"""
|
2
|
+
Data enrichment module for improving vector search accuracy.
|
3
|
+
|
4
|
+
This module enriches expense descriptions with merchant context and other relevant
|
5
|
+
information to improve vector search accuracy and relevance.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Dict, Any, Optional
|
9
|
+
from datetime import datetime
|
10
|
+
import re
|
11
|
+
|
12
|
+
|
13
|
+
class DataEnricher:
|
14
|
+
"""Enriches expense data with contextual information for better vector search."""
|
15
|
+
|
16
|
+
def __init__(self):
|
17
|
+
"""Initialize the data enricher."""
|
18
|
+
self.merchant_categories = {
|
19
|
+
"grocery": ["Whole Foods Market", "Trader Joe's", "Kroger", "Safeway", "Publix", "Walmart", "Target"],
|
20
|
+
"retail": ["Amazon", "Best Buy", "Apple Store", "Home Depot", "Costco", "Target", "Walmart"],
|
21
|
+
"dining": ["Starbucks", "McDonald's", "Chipotle", "Subway", "Pizza Hut", "Domino's"],
|
22
|
+
"transportation": ["Shell Gas Station", "Exxon", "Uber", "Lyft", "Metro", "Parking"],
|
23
|
+
"healthcare": ["CVS Pharmacy", "Walgreens", "Rite Aid", "Hospital", "Clinic"],
|
24
|
+
"entertainment": ["Netflix", "Spotify", "Movie Theater", "Concert", "Gaming"],
|
25
|
+
"utilities": ["Electric Company", "Internet Provider", "Phone Company", "Water Company"]
|
26
|
+
}
|
27
|
+
|
28
|
+
self.amount_ranges = {
|
29
|
+
"low": (0, 25),
|
30
|
+
"medium": (25, 100),
|
31
|
+
"high": (100, 500),
|
32
|
+
"very_high": (500, float('inf'))
|
33
|
+
}
|
34
|
+
|
35
|
+
def enrich_expense_description(
|
36
|
+
self,
|
37
|
+
description: str,
|
38
|
+
merchant: str,
|
39
|
+
amount: float,
|
40
|
+
category: str,
|
41
|
+
payment_method: str,
|
42
|
+
date: datetime,
|
43
|
+
**kwargs
|
44
|
+
) -> str:
|
45
|
+
"""
|
46
|
+
Enrich expense description with contextual information.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
description: Original expense description
|
50
|
+
merchant: Merchant name
|
51
|
+
amount: Expense amount
|
52
|
+
category: Expense category
|
53
|
+
payment_method: Payment method used
|
54
|
+
date: Expense date
|
55
|
+
**kwargs: Additional metadata
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Enriched description string
|
59
|
+
"""
|
60
|
+
# Start with the original description
|
61
|
+
enriched_parts = [description]
|
62
|
+
|
63
|
+
# Add merchant name and amount prominently
|
64
|
+
enriched_parts.append(f"at {merchant} for ${amount:.2f}")
|
65
|
+
|
66
|
+
# Add merchant context
|
67
|
+
merchant_context = self._get_merchant_context(merchant, amount)
|
68
|
+
if merchant_context:
|
69
|
+
enriched_parts.append(merchant_context)
|
70
|
+
|
71
|
+
# Add amount context
|
72
|
+
amount_context = self._get_amount_context(amount)
|
73
|
+
if amount_context:
|
74
|
+
enriched_parts.append(amount_context)
|
75
|
+
|
76
|
+
# Add category context
|
77
|
+
category_context = self._get_category_context(category, merchant)
|
78
|
+
if category_context:
|
79
|
+
enriched_parts.append(category_context)
|
80
|
+
|
81
|
+
# Add payment method context
|
82
|
+
payment_context = self._get_payment_context(payment_method)
|
83
|
+
if payment_context:
|
84
|
+
enriched_parts.append(payment_context)
|
85
|
+
|
86
|
+
# Add temporal context
|
87
|
+
temporal_context = self._get_temporal_context(date)
|
88
|
+
if temporal_context:
|
89
|
+
enriched_parts.append(temporal_context)
|
90
|
+
|
91
|
+
# Add merchant category context
|
92
|
+
merchant_category = self._get_merchant_category(merchant)
|
93
|
+
if merchant_category:
|
94
|
+
enriched_parts.append(f"at {merchant_category} store")
|
95
|
+
|
96
|
+
# Combine all parts
|
97
|
+
enriched_description = " ".join(enriched_parts)
|
98
|
+
|
99
|
+
# Clean up and format
|
100
|
+
enriched_description = self._clean_description(enriched_description)
|
101
|
+
|
102
|
+
return enriched_description
|
103
|
+
|
104
|
+
def _get_merchant_context(self, merchant: str, amount: float) -> Optional[str]:
|
105
|
+
"""Get merchant-specific context."""
|
106
|
+
merchant_lower = merchant.lower()
|
107
|
+
|
108
|
+
# Gas stations
|
109
|
+
if any(gas in merchant_lower for gas in ["shell", "exxon", "chevron", "bp", "gas"]):
|
110
|
+
return f"fuel purchase at {merchant}"
|
111
|
+
|
112
|
+
# Grocery stores
|
113
|
+
if any(grocery in merchant_lower for grocery in ["whole foods", "trader joe", "kroger", "safeway"]):
|
114
|
+
return f"grocery shopping at {merchant}"
|
115
|
+
|
116
|
+
# Online retailers
|
117
|
+
if merchant_lower == "amazon":
|
118
|
+
return f"online purchase from {merchant}"
|
119
|
+
|
120
|
+
# Coffee shops
|
121
|
+
if any(coffee in merchant_lower for coffee in ["starbucks", "dunkin", "peet", "coffee"]):
|
122
|
+
return f"coffee and food at {merchant}"
|
123
|
+
|
124
|
+
# Fast food
|
125
|
+
if any(fast in merchant_lower for fast in ["mcdonald", "burger", "pizza", "chipotle", "subway"]):
|
126
|
+
return f"fast food at {merchant}"
|
127
|
+
|
128
|
+
# Pharmacies
|
129
|
+
if any(pharmacy in merchant_lower for pharmacy in ["cvs", "walgreens", "rite aid", "pharmacy"]):
|
130
|
+
return f"pharmacy visit at {merchant}"
|
131
|
+
|
132
|
+
# Home improvement
|
133
|
+
if any(home in merchant_lower for home in ["home depot", "lowes", "ace hardware"]):
|
134
|
+
return f"home improvement at {merchant}"
|
135
|
+
|
136
|
+
return None
|
137
|
+
|
138
|
+
def _get_amount_context(self, amount: float) -> Optional[str]:
|
139
|
+
"""Get amount-based context."""
|
140
|
+
if amount < 10:
|
141
|
+
return "small purchase"
|
142
|
+
elif amount < 50:
|
143
|
+
return "moderate expense"
|
144
|
+
elif amount < 200:
|
145
|
+
return "significant purchase"
|
146
|
+
elif amount < 500:
|
147
|
+
return "major expense"
|
148
|
+
else:
|
149
|
+
return "large transaction"
|
150
|
+
|
151
|
+
def _get_category_context(self, category: str, merchant: str) -> Optional[str]:
|
152
|
+
"""Get category-specific context."""
|
153
|
+
category_lower = category.lower()
|
154
|
+
|
155
|
+
if category_lower == "groceries":
|
156
|
+
return "food and household items"
|
157
|
+
elif category_lower == "transportation":
|
158
|
+
return "travel and commuting"
|
159
|
+
elif category_lower == "dining":
|
160
|
+
return "restaurant and food service"
|
161
|
+
elif category_lower == "entertainment":
|
162
|
+
return "leisure and recreation"
|
163
|
+
elif category_lower == "healthcare":
|
164
|
+
return "medical and wellness"
|
165
|
+
elif category_lower == "shopping":
|
166
|
+
return "retail and consumer goods"
|
167
|
+
elif category_lower == "utilities":
|
168
|
+
return "essential services"
|
169
|
+
|
170
|
+
return None
|
171
|
+
|
172
|
+
def _get_payment_context(self, payment_method: str) -> Optional[str]:
|
173
|
+
"""Get payment method context."""
|
174
|
+
payment_lower = payment_method.lower()
|
175
|
+
|
176
|
+
if "credit" in payment_lower:
|
177
|
+
return "paid with credit card"
|
178
|
+
elif "debit" in payment_lower:
|
179
|
+
return "paid with debit card"
|
180
|
+
elif "cash" in payment_lower:
|
181
|
+
return "paid with cash"
|
182
|
+
elif "mobile" in payment_lower:
|
183
|
+
return "paid with mobile payment"
|
184
|
+
elif "bank" in payment_lower:
|
185
|
+
return "paid via bank transfer"
|
186
|
+
|
187
|
+
return None
|
188
|
+
|
189
|
+
def _get_temporal_context(self, date: datetime) -> Optional[str]:
|
190
|
+
"""Get temporal context based on date."""
|
191
|
+
from datetime import date as date_type
|
192
|
+
# Convert both to datetime for consistent comparison
|
193
|
+
if isinstance(date, date_type):
|
194
|
+
date = datetime.combine(date, datetime.min.time())
|
195
|
+
now = datetime.now()
|
196
|
+
days_ago = (now - date).days
|
197
|
+
|
198
|
+
if days_ago == 0:
|
199
|
+
return "today"
|
200
|
+
elif days_ago == 1:
|
201
|
+
return "yesterday"
|
202
|
+
elif days_ago <= 7:
|
203
|
+
return "this week"
|
204
|
+
elif days_ago <= 30:
|
205
|
+
return "this month"
|
206
|
+
elif days_ago <= 90:
|
207
|
+
return "recently"
|
208
|
+
else:
|
209
|
+
return "in the past"
|
210
|
+
|
211
|
+
def _get_merchant_category(self, merchant: str) -> Optional[str]:
|
212
|
+
"""Get merchant category for additional context."""
|
213
|
+
merchant_lower = merchant.lower()
|
214
|
+
|
215
|
+
for category, merchants in self.merchant_categories.items():
|
216
|
+
if any(m in merchant_lower for m in merchants):
|
217
|
+
return category
|
218
|
+
|
219
|
+
return None
|
220
|
+
|
221
|
+
def _clean_description(self, description: str) -> str:
|
222
|
+
"""Clean and format the enriched description."""
|
223
|
+
# Remove extra spaces
|
224
|
+
description = re.sub(r'\s+', ' ', description)
|
225
|
+
|
226
|
+
# Remove duplicate words
|
227
|
+
words = description.split()
|
228
|
+
seen = set()
|
229
|
+
unique_words = []
|
230
|
+
for word in words:
|
231
|
+
if word.lower() not in seen:
|
232
|
+
unique_words.append(word)
|
233
|
+
seen.add(word.lower())
|
234
|
+
|
235
|
+
return ' '.join(unique_words).strip()
|
236
|
+
|
237
|
+
def create_searchable_text(
|
238
|
+
self,
|
239
|
+
description: str,
|
240
|
+
merchant: str,
|
241
|
+
amount: float,
|
242
|
+
category: str,
|
243
|
+
**kwargs
|
244
|
+
) -> str:
|
245
|
+
"""
|
246
|
+
Create a comprehensive searchable text block for embedding.
|
247
|
+
|
248
|
+
This creates a rich text representation that includes all relevant
|
249
|
+
information for vector search.
|
250
|
+
"""
|
251
|
+
# Extract required parameters from kwargs
|
252
|
+
payment_method = kwargs.get('payment_method', '')
|
253
|
+
date = kwargs.get('date', datetime.now())
|
254
|
+
|
255
|
+
# Remove these from kwargs to avoid conflicts
|
256
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ['payment_method', 'date']}
|
257
|
+
|
258
|
+
enriched_description = self.enrich_expense_description(
|
259
|
+
description, merchant, amount, category, payment_method, date, **filtered_kwargs
|
260
|
+
)
|
261
|
+
|
262
|
+
# Create a comprehensive searchable text
|
263
|
+
searchable_parts = [
|
264
|
+
f"Spent ${amount:.2f} on {enriched_description}",
|
265
|
+
f"Merchant: {merchant}",
|
266
|
+
f"Category: {category}",
|
267
|
+
f"Amount: ${amount:.2f}"
|
268
|
+
]
|
269
|
+
|
270
|
+
# Add any additional context
|
271
|
+
if kwargs.get('payment_method'):
|
272
|
+
searchable_parts.append(f"Payment: {kwargs['payment_method']}")
|
273
|
+
|
274
|
+
if kwargs.get('tags'):
|
275
|
+
tags = ', '.join(kwargs['tags'])
|
276
|
+
searchable_parts.append(f"Tags: {tags}")
|
277
|
+
|
278
|
+
return " | ".join(searchable_parts)
|