marqeta-diva-mcp 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- marqeta_diva_mcp/__init__.py +3 -0
- marqeta_diva_mcp/__main__.py +6 -0
- marqeta_diva_mcp/client.py +471 -0
- marqeta_diva_mcp/embeddings.py +131 -0
- marqeta_diva_mcp/local_storage.py +348 -0
- marqeta_diva_mcp/rag_tools.py +366 -0
- marqeta_diva_mcp/server.py +940 -0
- marqeta_diva_mcp/vector_store.py +274 -0
- marqeta_diva_mcp-0.2.0.dist-info/METADATA +515 -0
- marqeta_diva_mcp-0.2.0.dist-info/RECORD +13 -0
- marqeta_diva_mcp-0.2.0.dist-info/WHEEL +4 -0
- marqeta_diva_mcp-0.2.0.dist-info/entry_points.txt +2 -0
- marqeta_diva_mcp-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
"""Marqeta DiVA API Client."""
|
|
2
|
+
|
|
3
|
+
import csv
|
|
4
|
+
import difflib
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
7
|
+
import time
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, List, Optional, Union
|
|
10
|
+
from urllib.parse import urlencode
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RateLimiter:
|
|
16
|
+
"""Simple rate limiter to prevent exceeding API limits (300 requests per 5 minutes)."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, max_requests: int = 300, time_window: int = 300):
|
|
19
|
+
self.max_requests = max_requests
|
|
20
|
+
self.time_window = time_window
|
|
21
|
+
self.requests: List[float] = []
|
|
22
|
+
|
|
23
|
+
def wait_if_needed(self) -> None:
|
|
24
|
+
"""Wait if necessary to comply with rate limits."""
|
|
25
|
+
now = time.time()
|
|
26
|
+
# Remove requests older than the time window
|
|
27
|
+
self.requests = [req_time for req_time in self.requests if now - req_time < self.time_window]
|
|
28
|
+
|
|
29
|
+
if len(self.requests) >= self.max_requests:
|
|
30
|
+
# Need to wait
|
|
31
|
+
oldest_request = self.requests[0]
|
|
32
|
+
wait_time = self.time_window - (now - oldest_request)
|
|
33
|
+
if wait_time > 0:
|
|
34
|
+
time.sleep(wait_time + 0.1) # Add small buffer
|
|
35
|
+
now = time.time()
|
|
36
|
+
self.requests = [req_time for req_time in self.requests if now - req_time < self.time_window]
|
|
37
|
+
|
|
38
|
+
self.requests.append(now)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DiVAAPIError(Exception):
|
|
42
|
+
"""Base exception for DiVA API errors."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, status_code: int, message: str, response: Optional[Dict[str, Any]] = None):
|
|
45
|
+
self.status_code = status_code
|
|
46
|
+
self.message = message
|
|
47
|
+
self.response = response
|
|
48
|
+
super().__init__(f"DiVA API Error {status_code}: {message}")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DiVAClient:
|
|
52
|
+
"""Client for interacting with the Marqeta DiVA API."""
|
|
53
|
+
|
|
54
|
+
BASE_URL = "https://diva-api.marqeta.com/data/v2"
|
|
55
|
+
|
|
56
|
+
def __init__(self, app_token: str, access_token: str, program: str):
|
|
57
|
+
"""
|
|
58
|
+
Initialize the DiVA API client.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
app_token: Application token for authentication
|
|
62
|
+
access_token: Access token for authentication
|
|
63
|
+
program: Default program name(s) for API requests
|
|
64
|
+
"""
|
|
65
|
+
self.app_token = app_token
|
|
66
|
+
self.access_token = access_token
|
|
67
|
+
self.program = program
|
|
68
|
+
self.rate_limiter = RateLimiter()
|
|
69
|
+
self.client = httpx.Client(
|
|
70
|
+
auth=(app_token, access_token),
|
|
71
|
+
headers={"Content-Type": "application/json"},
|
|
72
|
+
timeout=30.0,
|
|
73
|
+
)
|
|
74
|
+
# Cache for schemas to avoid repeated API calls
|
|
75
|
+
self._schema_cache: Dict[str, List[Dict[str, Any]]] = {}
|
|
76
|
+
# Views that support date range parameters
|
|
77
|
+
self._date_range_views = {
|
|
78
|
+
"authorizations", "settlements", "clearings", "declines", "loads",
|
|
79
|
+
"programbalances", "programbalancessettlement", "activitybalances",
|
|
80
|
+
"chargebacks"
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
def _find_similar_fields(self, invalid_field: str, valid_fields: List[str], cutoff: float = 0.6) -> List[str]:
|
|
84
|
+
"""
|
|
85
|
+
Find similar field names using fuzzy matching.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
invalid_field: The invalid field name
|
|
89
|
+
valid_fields: List of valid field names
|
|
90
|
+
cutoff: Similarity threshold (0-1)
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
List of similar field names
|
|
94
|
+
"""
|
|
95
|
+
matches = difflib.get_close_matches(invalid_field, valid_fields, n=3, cutoff=cutoff)
|
|
96
|
+
return matches
|
|
97
|
+
|
|
98
|
+
def _validate_filters(
|
|
99
|
+
self,
|
|
100
|
+
view_name: str,
|
|
101
|
+
aggregation: str,
|
|
102
|
+
filters: Optional[Dict[str, Any]]
|
|
103
|
+
) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Validate filter fields against the view schema.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
view_name: Name of the view
|
|
109
|
+
aggregation: Aggregation level
|
|
110
|
+
filters: Filter dictionary to validate
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
DiVAAPIError: If invalid filter fields are detected
|
|
114
|
+
"""
|
|
115
|
+
if not filters:
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
# Check cache first
|
|
119
|
+
cache_key = f"{view_name}:{aggregation}"
|
|
120
|
+
if cache_key not in self._schema_cache:
|
|
121
|
+
try:
|
|
122
|
+
schema = self.get_view_schema(view_name, aggregation)
|
|
123
|
+
self._schema_cache[cache_key] = schema
|
|
124
|
+
except Exception:
|
|
125
|
+
# If schema fetch fails, skip validation
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
schema = self._schema_cache[cache_key]
|
|
129
|
+
valid_fields = [field["field"] for field in schema]
|
|
130
|
+
|
|
131
|
+
invalid_fields = [field for field in filters.keys() if field not in valid_fields]
|
|
132
|
+
|
|
133
|
+
if invalid_fields:
|
|
134
|
+
suggestions = {}
|
|
135
|
+
for invalid_field in invalid_fields:
|
|
136
|
+
similar = self._find_similar_fields(invalid_field, valid_fields)
|
|
137
|
+
if similar:
|
|
138
|
+
suggestions[invalid_field] = similar
|
|
139
|
+
|
|
140
|
+
error_msg = f"/{view_name}/{aggregation} does not have the following column(s): {', '.join(repr(f) for f in invalid_fields)}"
|
|
141
|
+
|
|
142
|
+
if suggestions:
|
|
143
|
+
suggestion_text = "; ".join(
|
|
144
|
+
f"{repr(invalid)} -> did you mean {' or '.join(repr(s) for s in similar)}?"
|
|
145
|
+
for invalid, similar in suggestions.items()
|
|
146
|
+
)
|
|
147
|
+
error_msg += f"\nSuggestions: {suggestion_text}"
|
|
148
|
+
|
|
149
|
+
error_msg += f"\nUse get_view_schema('{view_name}', '{aggregation}') to see all valid fields."
|
|
150
|
+
|
|
151
|
+
raise DiVAAPIError(400, "Invalid filter fields", {"message": error_msg})
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _estimate_response_size(
|
|
155
|
+
self,
|
|
156
|
+
view_name: str,
|
|
157
|
+
count: int,
|
|
158
|
+
fields: Optional[List[str]]
|
|
159
|
+
) -> tuple[int, str]:
|
|
160
|
+
"""
|
|
161
|
+
Estimate response size and return warning if it might be large.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
view_name: Name of the view
|
|
165
|
+
count: Number of records requested
|
|
166
|
+
fields: Specific fields requested
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Tuple of (estimated_tokens, warning_message)
|
|
170
|
+
"""
|
|
171
|
+
# Rough token estimates per record for different views
|
|
172
|
+
tokens_per_record = {
|
|
173
|
+
"authorizations": 100,
|
|
174
|
+
"settlements": 120,
|
|
175
|
+
"clearings": 150,
|
|
176
|
+
"declines": 140,
|
|
177
|
+
"cards": 80,
|
|
178
|
+
"users": 70,
|
|
179
|
+
"chargebacks": 200,
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
base_tokens = tokens_per_record.get(view_name, 100)
|
|
183
|
+
|
|
184
|
+
# If specific fields requested, reduce estimate by ~60%
|
|
185
|
+
if fields:
|
|
186
|
+
base_tokens = int(base_tokens * 0.4)
|
|
187
|
+
|
|
188
|
+
estimated = base_tokens * count
|
|
189
|
+
warning = ""
|
|
190
|
+
|
|
191
|
+
if estimated > 20000:
|
|
192
|
+
warning = (
|
|
193
|
+
f"Warning: Requesting {count} records may return ~{estimated:,} tokens "
|
|
194
|
+
f"(limit is 25,000). Consider reducing 'count' or using more specific 'fields'."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return estimated, warning
|
|
198
|
+
|
|
199
|
+
def _build_query_params(
|
|
200
|
+
self,
|
|
201
|
+
program: Optional[str] = None,
|
|
202
|
+
fields: Optional[List[str]] = None,
|
|
203
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
204
|
+
sort_by: Optional[str] = None,
|
|
205
|
+
count: Optional[int] = None,
|
|
206
|
+
group_by: Optional[str] = None,
|
|
207
|
+
expand: Optional[str] = None,
|
|
208
|
+
) -> Dict[str, str]:
|
|
209
|
+
"""Build query parameters for API requests.
|
|
210
|
+
|
|
211
|
+
Note: Date filtering should be done through the filters parameter using the
|
|
212
|
+
actual date field names (e.g., transaction_timestamp, post_date) with operators.
|
|
213
|
+
Example: filters={"transaction_timestamp": ">=2023-10-20"}
|
|
214
|
+
"""
|
|
215
|
+
params: Dict[str, str] = {}
|
|
216
|
+
|
|
217
|
+
# Program is required for most endpoints
|
|
218
|
+
params["program"] = program or self.program
|
|
219
|
+
|
|
220
|
+
# Field filtering
|
|
221
|
+
if fields:
|
|
222
|
+
params["fields"] = ",".join(fields)
|
|
223
|
+
|
|
224
|
+
# Sorting
|
|
225
|
+
if sort_by:
|
|
226
|
+
params["sort_by"] = sort_by
|
|
227
|
+
|
|
228
|
+
# Count/limit - apply max based on DiVA API limits
|
|
229
|
+
if count is not None:
|
|
230
|
+
# DiVA API limit: 10,000 for JSON responses
|
|
231
|
+
params["count"] = str(min(count, 10000))
|
|
232
|
+
else:
|
|
233
|
+
# Default to 10,000 records (DiVA API default)
|
|
234
|
+
params["count"] = "10000"
|
|
235
|
+
|
|
236
|
+
# Grouping
|
|
237
|
+
if group_by:
|
|
238
|
+
params["group_by"] = group_by
|
|
239
|
+
|
|
240
|
+
# Field expansion
|
|
241
|
+
if expand:
|
|
242
|
+
params["expand"] = expand
|
|
243
|
+
|
|
244
|
+
# Custom filters
|
|
245
|
+
if filters:
|
|
246
|
+
for key, value in filters.items():
|
|
247
|
+
if isinstance(value, list):
|
|
248
|
+
params[key] = ",".join(str(v) for v in value)
|
|
249
|
+
else:
|
|
250
|
+
params[key] = str(value)
|
|
251
|
+
|
|
252
|
+
return params
|
|
253
|
+
|
|
254
|
+
def _make_request(self, endpoint: str, params: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
|
|
255
|
+
"""
|
|
256
|
+
Make a request to the DiVA API.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
endpoint: API endpoint path (e.g., '/views/authorizations/detail')
|
|
260
|
+
params: Query parameters
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
API response as dictionary
|
|
264
|
+
|
|
265
|
+
Raises:
|
|
266
|
+
DiVAAPIError: If the API returns an error
|
|
267
|
+
"""
|
|
268
|
+
# Apply rate limiting
|
|
269
|
+
self.rate_limiter.wait_if_needed()
|
|
270
|
+
|
|
271
|
+
url = f"{self.BASE_URL}{endpoint}"
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
response = self.client.get(url, params=params)
|
|
275
|
+
|
|
276
|
+
# Handle different error codes
|
|
277
|
+
if response.status_code == 200:
|
|
278
|
+
return response.json()
|
|
279
|
+
elif response.status_code == 400:
|
|
280
|
+
error_data = response.json() if response.text else {}
|
|
281
|
+
error_message = error_data.get("message", "")
|
|
282
|
+
|
|
283
|
+
# Provide helpful context for common errors
|
|
284
|
+
enhanced_message = "Bad Request - Malformed query or filter parameters"
|
|
285
|
+
if "does not have a column" in error_message:
|
|
286
|
+
enhanced_message += (
|
|
287
|
+
"\n\nNote: This error usually means you're using an invalid field name in 'filters'. "
|
|
288
|
+
"Parameters like 'count', 'sort_by', 'program' should NOT be in 'filters'. "
|
|
289
|
+
"Only actual data field names belong in 'filters'. "
|
|
290
|
+
"For date filtering, use the actual date field name (e.g., 'transaction_timestamp') with operators like '>=2023-10-20'."
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
raise DiVAAPIError(400, enhanced_message, error_data)
|
|
294
|
+
elif response.status_code == 403:
|
|
295
|
+
error_data = response.json() if response.text else {}
|
|
296
|
+
error_code = error_data.get("error_code", "")
|
|
297
|
+
if error_code == "403001":
|
|
298
|
+
msg = "Forbidden - Field access denied"
|
|
299
|
+
elif error_code == "403002":
|
|
300
|
+
msg = "Forbidden - Filter not allowed"
|
|
301
|
+
elif error_code == "403003":
|
|
302
|
+
msg = "Forbidden - Program access denied"
|
|
303
|
+
else:
|
|
304
|
+
msg = "Forbidden - Unauthorized access"
|
|
305
|
+
raise DiVAAPIError(403, msg, error_data)
|
|
306
|
+
elif response.status_code == 404:
|
|
307
|
+
raise DiVAAPIError(404, "Not Found - Malformed URL or endpoint does not exist")
|
|
308
|
+
elif response.status_code == 429:
|
|
309
|
+
raise DiVAAPIError(
|
|
310
|
+
429,
|
|
311
|
+
"Rate limit exceeded - Maximum 300 requests per 5 minutes",
|
|
312
|
+
)
|
|
313
|
+
else:
|
|
314
|
+
raise DiVAAPIError(
|
|
315
|
+
response.status_code,
|
|
316
|
+
f"Unexpected error: {response.text}",
|
|
317
|
+
response.json() if response.text else None,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
except httpx.RequestError as e:
|
|
321
|
+
raise DiVAAPIError(0, f"Network error: {str(e)}")
|
|
322
|
+
|
|
323
|
+
def export_to_file(
|
|
324
|
+
self,
|
|
325
|
+
view_name: str,
|
|
326
|
+
aggregation: str,
|
|
327
|
+
output_path: str,
|
|
328
|
+
format: str = "json",
|
|
329
|
+
max_records: Optional[int] = None,
|
|
330
|
+
**kwargs: Any
|
|
331
|
+
) -> Dict[str, Any]:
|
|
332
|
+
"""
|
|
333
|
+
Export large datasets to a file with automatic pagination.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
view_name: Name of the view
|
|
337
|
+
aggregation: Aggregation level
|
|
338
|
+
output_path: Path where file will be written
|
|
339
|
+
format: Output format ('json' or 'csv')
|
|
340
|
+
max_records: Maximum total records to export (None = all)
|
|
341
|
+
**kwargs: Additional query parameters
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Dictionary with export metadata
|
|
345
|
+
"""
|
|
346
|
+
output_file = Path(output_path)
|
|
347
|
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
348
|
+
|
|
349
|
+
all_records = []
|
|
350
|
+
batch_size = 10000 # DiVA API limit: 10,000 for JSON responses
|
|
351
|
+
total_fetched = 0
|
|
352
|
+
|
|
353
|
+
print(f"[DiVA Export] Starting export to {output_path}...", file=sys.stderr)
|
|
354
|
+
|
|
355
|
+
# DiVA API does not support offset-based pagination
|
|
356
|
+
# Fetch one batch with the specified count limit
|
|
357
|
+
if max_records:
|
|
358
|
+
kwargs['count'] = min(max_records, batch_size)
|
|
359
|
+
else:
|
|
360
|
+
kwargs['count'] = batch_size
|
|
361
|
+
|
|
362
|
+
endpoint = f"/views/{view_name}/{aggregation}"
|
|
363
|
+
params = self._build_query_params(**kwargs)
|
|
364
|
+
response = self._make_request(endpoint, params)
|
|
365
|
+
|
|
366
|
+
records = response.get('records', [])
|
|
367
|
+
if records:
|
|
368
|
+
all_records.extend(records)
|
|
369
|
+
total_fetched = len(records)
|
|
370
|
+
|
|
371
|
+
# Truncate if we got more than max_records
|
|
372
|
+
if max_records and total_fetched > max_records:
|
|
373
|
+
all_records = all_records[:max_records]
|
|
374
|
+
total_fetched = len(all_records)
|
|
375
|
+
|
|
376
|
+
print(f"[DiVA Export] Fetched {total_fetched} records", file=sys.stderr)
|
|
377
|
+
|
|
378
|
+
# Warn if there are more records available
|
|
379
|
+
if response.get('is_more', False):
|
|
380
|
+
print(f"[DiVA Export] Warning: More records available but DiVA API does not support offset pagination.", file=sys.stderr)
|
|
381
|
+
print(f"[DiVA Export] To get more data, use narrower date ranges or more specific filters.", file=sys.stderr)
|
|
382
|
+
|
|
383
|
+
# Write to file
|
|
384
|
+
if format == "csv":
|
|
385
|
+
if all_records:
|
|
386
|
+
with open(output_file, 'w', newline='') as f:
|
|
387
|
+
writer = csv.DictWriter(f, fieldnames=all_records[0].keys())
|
|
388
|
+
writer.writeheader()
|
|
389
|
+
writer.writerows(all_records)
|
|
390
|
+
else: # json
|
|
391
|
+
with open(output_file, 'w') as f:
|
|
392
|
+
json.dump(all_records, f, indent=2)
|
|
393
|
+
|
|
394
|
+
file_size = output_file.stat().st_size
|
|
395
|
+
print(f"[DiVA Export] Complete! Wrote {total_fetched} records to {output_path}", file=sys.stderr)
|
|
396
|
+
|
|
397
|
+
return {
|
|
398
|
+
"success": True,
|
|
399
|
+
"file_path": str(output_file.absolute()),
|
|
400
|
+
"format": format,
|
|
401
|
+
"records_exported": total_fetched,
|
|
402
|
+
"file_size_bytes": file_size
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
def get_view(
|
|
406
|
+
self,
|
|
407
|
+
view_name: str,
|
|
408
|
+
aggregation: str = "detail",
|
|
409
|
+
**kwargs: Any,
|
|
410
|
+
) -> Dict[str, Any]:
|
|
411
|
+
"""
|
|
412
|
+
Get data from a specific view.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
view_name: Name of the view (e.g., 'authorizations', 'settlements')
|
|
416
|
+
aggregation: Aggregation level ('detail', 'day', 'week', 'month')
|
|
417
|
+
**kwargs: Additional query parameters (program, filters, count, sort_by, etc.)
|
|
418
|
+
Note: For date filtering, use filters with the actual date field name.
|
|
419
|
+
Example: filters={"transaction_timestamp": ">=2023-10-20"}
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
API response containing records and metadata
|
|
423
|
+
"""
|
|
424
|
+
# Validate filter fields before making the request
|
|
425
|
+
# Note: Validation is disabled to allow API to determine valid fields
|
|
426
|
+
# filters = kwargs.get("filters")
|
|
427
|
+
# if filters:
|
|
428
|
+
# self._validate_filters(view_name, aggregation, filters)
|
|
429
|
+
|
|
430
|
+
# Check response size estimate
|
|
431
|
+
count = kwargs.get("count", 10000) # Default is 10,000 per DiVA API
|
|
432
|
+
fields = kwargs.get("fields")
|
|
433
|
+
estimated_tokens, warning = self._estimate_response_size(view_name, count, fields)
|
|
434
|
+
|
|
435
|
+
if warning:
|
|
436
|
+
# Log warning but don't block the request
|
|
437
|
+
print(f"[DiVA Client Warning] {warning}", file=sys.stderr)
|
|
438
|
+
|
|
439
|
+
endpoint = f"/views/{view_name}/{aggregation}"
|
|
440
|
+
params = self._build_query_params(**kwargs)
|
|
441
|
+
return self._make_request(endpoint, params)
|
|
442
|
+
|
|
443
|
+
def get_views_list(self) -> Dict[str, Any]:
|
|
444
|
+
"""Get list of all available views with metadata."""
|
|
445
|
+
return self._make_request("/views")
|
|
446
|
+
|
|
447
|
+
def get_view_schema(self, view_name: str, aggregation: str = "detail") -> Dict[str, Any]:
|
|
448
|
+
"""
|
|
449
|
+
Get the schema for a specific view.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
view_name: Name of the view
|
|
453
|
+
aggregation: Aggregation level
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
Schema definition with field types and descriptions
|
|
457
|
+
"""
|
|
458
|
+
endpoint = f"/views/{view_name}/{aggregation}/schema"
|
|
459
|
+
return self._make_request(endpoint)
|
|
460
|
+
|
|
461
|
+
def close(self) -> None:
|
|
462
|
+
"""Close the HTTP client."""
|
|
463
|
+
self.client.close()
|
|
464
|
+
|
|
465
|
+
def __enter__(self):
|
|
466
|
+
"""Context manager entry."""
|
|
467
|
+
return self
|
|
468
|
+
|
|
469
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
470
|
+
"""Context manager exit."""
|
|
471
|
+
self.close()
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Transaction embedding generation for RAG capabilities."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
from sentence_transformers import SentenceTransformer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TransactionEmbedder:
|
|
9
|
+
"""Generates embeddings for financial transactions."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
|
12
|
+
"""
|
|
13
|
+
Initialize the transaction embedder.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
model_name: Name of the sentence-transformers model to use.
|
|
17
|
+
Default: all-MiniLM-L6-v2 (fast, lightweight, 384 dimensions)
|
|
18
|
+
"""
|
|
19
|
+
print(f"[Embeddings] Loading model: {model_name}...", file=sys.stderr)
|
|
20
|
+
self.model = SentenceTransformer(model_name)
|
|
21
|
+
self.model_name = model_name
|
|
22
|
+
self.embedding_dim = self.model.get_sentence_embedding_dimension()
|
|
23
|
+
print(f"[Embeddings] Model loaded. Embedding dimension: {self.embedding_dim}", file=sys.stderr)
|
|
24
|
+
|
|
25
|
+
def format_transaction_text(self, transaction: Dict[str, Any]) -> str:
|
|
26
|
+
"""
|
|
27
|
+
Format a transaction into text for embedding.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
transaction: Transaction dictionary from DiVA API
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Formatted text string combining key transaction attributes
|
|
34
|
+
"""
|
|
35
|
+
parts = []
|
|
36
|
+
|
|
37
|
+
# Merchant name (most important)
|
|
38
|
+
merchant = transaction.get("merchant_name", transaction.get("acquirer_merchant_name", ""))
|
|
39
|
+
if merchant:
|
|
40
|
+
parts.append(f"Merchant: {merchant}")
|
|
41
|
+
|
|
42
|
+
# Transaction amount
|
|
43
|
+
amount = transaction.get("transaction_amount")
|
|
44
|
+
if amount is not None:
|
|
45
|
+
currency_code = transaction.get("currency_code", "USD")
|
|
46
|
+
parts.append(f"Amount: {amount} {currency_code}")
|
|
47
|
+
|
|
48
|
+
# Transaction type/status
|
|
49
|
+
txn_type = transaction.get("transaction_type")
|
|
50
|
+
if txn_type:
|
|
51
|
+
parts.append(f"Type: {txn_type}")
|
|
52
|
+
|
|
53
|
+
state = transaction.get("state", transaction.get("transaction_status"))
|
|
54
|
+
if state:
|
|
55
|
+
parts.append(f"Status: {state}")
|
|
56
|
+
|
|
57
|
+
# Merchant category
|
|
58
|
+
mcc = transaction.get("merchant_category_code")
|
|
59
|
+
if mcc:
|
|
60
|
+
parts.append(f"MCC: {mcc}")
|
|
61
|
+
|
|
62
|
+
# Cardholder presence
|
|
63
|
+
card_presence = transaction.get("card_presence_indicator")
|
|
64
|
+
if card_presence:
|
|
65
|
+
parts.append(f"Card Presence: {card_presence}")
|
|
66
|
+
|
|
67
|
+
# Network
|
|
68
|
+
network = transaction.get("network")
|
|
69
|
+
if network:
|
|
70
|
+
parts.append(f"Network: {network}")
|
|
71
|
+
|
|
72
|
+
# If we have very little info, at least include tokens
|
|
73
|
+
if len(parts) < 2:
|
|
74
|
+
txn_token = transaction.get("transaction_token")
|
|
75
|
+
if txn_token:
|
|
76
|
+
parts.append(f"Transaction: {txn_token}")
|
|
77
|
+
|
|
78
|
+
return " | ".join(parts)
|
|
79
|
+
|
|
80
|
+
def embed_transaction(self, transaction: Dict[str, Any]) -> List[float]:
|
|
81
|
+
"""
|
|
82
|
+
Generate embedding for a single transaction.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
transaction: Transaction dictionary
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Embedding vector as list of floats
|
|
89
|
+
"""
|
|
90
|
+
text = self.format_transaction_text(transaction)
|
|
91
|
+
embedding = self.model.encode(text, convert_to_numpy=True)
|
|
92
|
+
return embedding.tolist()
|
|
93
|
+
|
|
94
|
+
def embed_transactions_batch(self, transactions: List[Dict[str, Any]]) -> List[List[float]]:
|
|
95
|
+
"""
|
|
96
|
+
Generate embeddings for multiple transactions efficiently.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
transactions: List of transaction dictionaries
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of embedding vectors
|
|
103
|
+
"""
|
|
104
|
+
texts = [self.format_transaction_text(txn) for txn in transactions]
|
|
105
|
+
embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
|
|
106
|
+
return embeddings.tolist()
|
|
107
|
+
|
|
108
|
+
def embed_query(self, query: str) -> List[float]:
|
|
109
|
+
"""
|
|
110
|
+
Generate embedding for a search query.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
query: Natural language search query
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Embedding vector as list of floats
|
|
117
|
+
"""
|
|
118
|
+
embedding = self.model.encode(query, convert_to_numpy=True)
|
|
119
|
+
return embedding.tolist()
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# Global embedder instance (lazy-loaded)
|
|
123
|
+
_embedder: TransactionEmbedder | None = None
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_embedder() -> TransactionEmbedder:
|
|
127
|
+
"""Get or create the global transaction embedder instance."""
|
|
128
|
+
global _embedder
|
|
129
|
+
if _embedder is None:
|
|
130
|
+
_embedder = TransactionEmbedder()
|
|
131
|
+
return _embedder
|