marqeta-diva-mcp 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ """Marqeta DiVA API MCP Server."""
2
+
3
+ __version__ = "0.1.0"
@@ -0,0 +1,6 @@
1
+ """Main entry point for running the package as a module."""
2
+
3
+ from .server import main
4
+
5
+ if __name__ == "__main__":
6
+ main()
@@ -0,0 +1,471 @@
1
+ """Marqeta DiVA API Client."""
2
+
3
+ import csv
4
+ import difflib
5
+ import json
6
+ import sys
7
+ import time
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from urllib.parse import urlencode
11
+
12
+ import httpx
13
+
14
+
15
+ class RateLimiter:
16
+ """Simple rate limiter to prevent exceeding API limits (300 requests per 5 minutes)."""
17
+
18
+ def __init__(self, max_requests: int = 300, time_window: int = 300):
19
+ self.max_requests = max_requests
20
+ self.time_window = time_window
21
+ self.requests: List[float] = []
22
+
23
+ def wait_if_needed(self) -> None:
24
+ """Wait if necessary to comply with rate limits."""
25
+ now = time.time()
26
+ # Remove requests older than the time window
27
+ self.requests = [req_time for req_time in self.requests if now - req_time < self.time_window]
28
+
29
+ if len(self.requests) >= self.max_requests:
30
+ # Need to wait
31
+ oldest_request = self.requests[0]
32
+ wait_time = self.time_window - (now - oldest_request)
33
+ if wait_time > 0:
34
+ time.sleep(wait_time + 0.1) # Add small buffer
35
+ now = time.time()
36
+ self.requests = [req_time for req_time in self.requests if now - req_time < self.time_window]
37
+
38
+ self.requests.append(now)
39
+
40
+
41
+ class DiVAAPIError(Exception):
42
+ """Base exception for DiVA API errors."""
43
+
44
+ def __init__(self, status_code: int, message: str, response: Optional[Dict[str, Any]] = None):
45
+ self.status_code = status_code
46
+ self.message = message
47
+ self.response = response
48
+ super().__init__(f"DiVA API Error {status_code}: {message}")
49
+
50
+
51
+ class DiVAClient:
52
+ """Client for interacting with the Marqeta DiVA API."""
53
+
54
+ BASE_URL = "https://diva-api.marqeta.com/data/v2"
55
+
56
+ def __init__(self, app_token: str, access_token: str, program: str):
57
+ """
58
+ Initialize the DiVA API client.
59
+
60
+ Args:
61
+ app_token: Application token for authentication
62
+ access_token: Access token for authentication
63
+ program: Default program name(s) for API requests
64
+ """
65
+ self.app_token = app_token
66
+ self.access_token = access_token
67
+ self.program = program
68
+ self.rate_limiter = RateLimiter()
69
+ self.client = httpx.Client(
70
+ auth=(app_token, access_token),
71
+ headers={"Content-Type": "application/json"},
72
+ timeout=30.0,
73
+ )
74
+ # Cache for schemas to avoid repeated API calls
75
+ self._schema_cache: Dict[str, List[Dict[str, Any]]] = {}
76
+ # Views that support date range parameters
77
+ self._date_range_views = {
78
+ "authorizations", "settlements", "clearings", "declines", "loads",
79
+ "programbalances", "programbalancessettlement", "activitybalances",
80
+ "chargebacks"
81
+ }
82
+
83
+ def _find_similar_fields(self, invalid_field: str, valid_fields: List[str], cutoff: float = 0.6) -> List[str]:
84
+ """
85
+ Find similar field names using fuzzy matching.
86
+
87
+ Args:
88
+ invalid_field: The invalid field name
89
+ valid_fields: List of valid field names
90
+ cutoff: Similarity threshold (0-1)
91
+
92
+ Returns:
93
+ List of similar field names
94
+ """
95
+ matches = difflib.get_close_matches(invalid_field, valid_fields, n=3, cutoff=cutoff)
96
+ return matches
97
+
98
+ def _validate_filters(
99
+ self,
100
+ view_name: str,
101
+ aggregation: str,
102
+ filters: Optional[Dict[str, Any]]
103
+ ) -> None:
104
+ """
105
+ Validate filter fields against the view schema.
106
+
107
+ Args:
108
+ view_name: Name of the view
109
+ aggregation: Aggregation level
110
+ filters: Filter dictionary to validate
111
+
112
+ Raises:
113
+ DiVAAPIError: If invalid filter fields are detected
114
+ """
115
+ if not filters:
116
+ return
117
+
118
+ # Check cache first
119
+ cache_key = f"{view_name}:{aggregation}"
120
+ if cache_key not in self._schema_cache:
121
+ try:
122
+ schema = self.get_view_schema(view_name, aggregation)
123
+ self._schema_cache[cache_key] = schema
124
+ except Exception:
125
+ # If schema fetch fails, skip validation
126
+ return
127
+
128
+ schema = self._schema_cache[cache_key]
129
+ valid_fields = [field["field"] for field in schema]
130
+
131
+ invalid_fields = [field for field in filters.keys() if field not in valid_fields]
132
+
133
+ if invalid_fields:
134
+ suggestions = {}
135
+ for invalid_field in invalid_fields:
136
+ similar = self._find_similar_fields(invalid_field, valid_fields)
137
+ if similar:
138
+ suggestions[invalid_field] = similar
139
+
140
+ error_msg = f"/{view_name}/{aggregation} does not have the following column(s): {', '.join(repr(f) for f in invalid_fields)}"
141
+
142
+ if suggestions:
143
+ suggestion_text = "; ".join(
144
+ f"{repr(invalid)} -> did you mean {' or '.join(repr(s) for s in similar)}?"
145
+ for invalid, similar in suggestions.items()
146
+ )
147
+ error_msg += f"\nSuggestions: {suggestion_text}"
148
+
149
+ error_msg += f"\nUse get_view_schema('{view_name}', '{aggregation}') to see all valid fields."
150
+
151
+ raise DiVAAPIError(400, "Invalid filter fields", {"message": error_msg})
152
+
153
+
154
+ def _estimate_response_size(
155
+ self,
156
+ view_name: str,
157
+ count: int,
158
+ fields: Optional[List[str]]
159
+ ) -> tuple[int, str]:
160
+ """
161
+ Estimate response size and return warning if it might be large.
162
+
163
+ Args:
164
+ view_name: Name of the view
165
+ count: Number of records requested
166
+ fields: Specific fields requested
167
+
168
+ Returns:
169
+ Tuple of (estimated_tokens, warning_message)
170
+ """
171
+ # Rough token estimates per record for different views
172
+ tokens_per_record = {
173
+ "authorizations": 100,
174
+ "settlements": 120,
175
+ "clearings": 150,
176
+ "declines": 140,
177
+ "cards": 80,
178
+ "users": 70,
179
+ "chargebacks": 200,
180
+ }
181
+
182
+ base_tokens = tokens_per_record.get(view_name, 100)
183
+
184
+ # If specific fields requested, reduce estimate by ~60%
185
+ if fields:
186
+ base_tokens = int(base_tokens * 0.4)
187
+
188
+ estimated = base_tokens * count
189
+ warning = ""
190
+
191
+ if estimated > 20000:
192
+ warning = (
193
+ f"Warning: Requesting {count} records may return ~{estimated:,} tokens "
194
+ f"(limit is 25,000). Consider reducing 'count' or using more specific 'fields'."
195
+ )
196
+
197
+ return estimated, warning
198
+
199
+ def _build_query_params(
200
+ self,
201
+ program: Optional[str] = None,
202
+ fields: Optional[List[str]] = None,
203
+ filters: Optional[Dict[str, Any]] = None,
204
+ sort_by: Optional[str] = None,
205
+ count: Optional[int] = None,
206
+ group_by: Optional[str] = None,
207
+ expand: Optional[str] = None,
208
+ ) -> Dict[str, str]:
209
+ """Build query parameters for API requests.
210
+
211
+ Note: Date filtering should be done through the filters parameter using the
212
+ actual date field names (e.g., transaction_timestamp, post_date) with operators.
213
+ Example: filters={"transaction_timestamp": ">=2023-10-20"}
214
+ """
215
+ params: Dict[str, str] = {}
216
+
217
+ # Program is required for most endpoints
218
+ params["program"] = program or self.program
219
+
220
+ # Field filtering
221
+ if fields:
222
+ params["fields"] = ",".join(fields)
223
+
224
+ # Sorting
225
+ if sort_by:
226
+ params["sort_by"] = sort_by
227
+
228
+ # Count/limit - apply max based on DiVA API limits
229
+ if count is not None:
230
+ # DiVA API limit: 10,000 for JSON responses
231
+ params["count"] = str(min(count, 10000))
232
+ else:
233
+ # Default to 10,000 records (DiVA API default)
234
+ params["count"] = "10000"
235
+
236
+ # Grouping
237
+ if group_by:
238
+ params["group_by"] = group_by
239
+
240
+ # Field expansion
241
+ if expand:
242
+ params["expand"] = expand
243
+
244
+ # Custom filters
245
+ if filters:
246
+ for key, value in filters.items():
247
+ if isinstance(value, list):
248
+ params[key] = ",".join(str(v) for v in value)
249
+ else:
250
+ params[key] = str(value)
251
+
252
+ return params
253
+
254
+ def _make_request(self, endpoint: str, params: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
255
+ """
256
+ Make a request to the DiVA API.
257
+
258
+ Args:
259
+ endpoint: API endpoint path (e.g., '/views/authorizations/detail')
260
+ params: Query parameters
261
+
262
+ Returns:
263
+ API response as dictionary
264
+
265
+ Raises:
266
+ DiVAAPIError: If the API returns an error
267
+ """
268
+ # Apply rate limiting
269
+ self.rate_limiter.wait_if_needed()
270
+
271
+ url = f"{self.BASE_URL}{endpoint}"
272
+
273
+ try:
274
+ response = self.client.get(url, params=params)
275
+
276
+ # Handle different error codes
277
+ if response.status_code == 200:
278
+ return response.json()
279
+ elif response.status_code == 400:
280
+ error_data = response.json() if response.text else {}
281
+ error_message = error_data.get("message", "")
282
+
283
+ # Provide helpful context for common errors
284
+ enhanced_message = "Bad Request - Malformed query or filter parameters"
285
+ if "does not have a column" in error_message:
286
+ enhanced_message += (
287
+ "\n\nNote: This error usually means you're using an invalid field name in 'filters'. "
288
+ "Parameters like 'count', 'sort_by', 'program' should NOT be in 'filters'. "
289
+ "Only actual data field names belong in 'filters'. "
290
+ "For date filtering, use the actual date field name (e.g., 'transaction_timestamp') with operators like '>=2023-10-20'."
291
+ )
292
+
293
+ raise DiVAAPIError(400, enhanced_message, error_data)
294
+ elif response.status_code == 403:
295
+ error_data = response.json() if response.text else {}
296
+ error_code = error_data.get("error_code", "")
297
+ if error_code == "403001":
298
+ msg = "Forbidden - Field access denied"
299
+ elif error_code == "403002":
300
+ msg = "Forbidden - Filter not allowed"
301
+ elif error_code == "403003":
302
+ msg = "Forbidden - Program access denied"
303
+ else:
304
+ msg = "Forbidden - Unauthorized access"
305
+ raise DiVAAPIError(403, msg, error_data)
306
+ elif response.status_code == 404:
307
+ raise DiVAAPIError(404, "Not Found - Malformed URL or endpoint does not exist")
308
+ elif response.status_code == 429:
309
+ raise DiVAAPIError(
310
+ 429,
311
+ "Rate limit exceeded - Maximum 300 requests per 5 minutes",
312
+ )
313
+ else:
314
+ raise DiVAAPIError(
315
+ response.status_code,
316
+ f"Unexpected error: {response.text}",
317
+ response.json() if response.text else None,
318
+ )
319
+
320
+ except httpx.RequestError as e:
321
+ raise DiVAAPIError(0, f"Network error: {str(e)}")
322
+
323
+ def export_to_file(
324
+ self,
325
+ view_name: str,
326
+ aggregation: str,
327
+ output_path: str,
328
+ format: str = "json",
329
+ max_records: Optional[int] = None,
330
+ **kwargs: Any
331
+ ) -> Dict[str, Any]:
332
+ """
333
+ Export large datasets to a file with automatic pagination.
334
+
335
+ Args:
336
+ view_name: Name of the view
337
+ aggregation: Aggregation level
338
+ output_path: Path where file will be written
339
+ format: Output format ('json' or 'csv')
340
+ max_records: Maximum total records to export (None = all)
341
+ **kwargs: Additional query parameters
342
+
343
+ Returns:
344
+ Dictionary with export metadata
345
+ """
346
+ output_file = Path(output_path)
347
+ output_file.parent.mkdir(parents=True, exist_ok=True)
348
+
349
+ all_records = []
350
+ batch_size = 10000 # DiVA API limit: 10,000 for JSON responses
351
+ total_fetched = 0
352
+
353
+ print(f"[DiVA Export] Starting export to {output_path}...", file=sys.stderr)
354
+
355
+ # DiVA API does not support offset-based pagination
356
+ # Fetch one batch with the specified count limit
357
+ if max_records:
358
+ kwargs['count'] = min(max_records, batch_size)
359
+ else:
360
+ kwargs['count'] = batch_size
361
+
362
+ endpoint = f"/views/{view_name}/{aggregation}"
363
+ params = self._build_query_params(**kwargs)
364
+ response = self._make_request(endpoint, params)
365
+
366
+ records = response.get('records', [])
367
+ if records:
368
+ all_records.extend(records)
369
+ total_fetched = len(records)
370
+
371
+ # Truncate if we got more than max_records
372
+ if max_records and total_fetched > max_records:
373
+ all_records = all_records[:max_records]
374
+ total_fetched = len(all_records)
375
+
376
+ print(f"[DiVA Export] Fetched {total_fetched} records", file=sys.stderr)
377
+
378
+ # Warn if there are more records available
379
+ if response.get('is_more', False):
380
+ print(f"[DiVA Export] Warning: More records available but DiVA API does not support offset pagination.", file=sys.stderr)
381
+ print(f"[DiVA Export] To get more data, use narrower date ranges or more specific filters.", file=sys.stderr)
382
+
383
+ # Write to file
384
+ if format == "csv":
385
+ if all_records:
386
+ with open(output_file, 'w', newline='') as f:
387
+ writer = csv.DictWriter(f, fieldnames=all_records[0].keys())
388
+ writer.writeheader()
389
+ writer.writerows(all_records)
390
+ else: # json
391
+ with open(output_file, 'w') as f:
392
+ json.dump(all_records, f, indent=2)
393
+
394
+ file_size = output_file.stat().st_size
395
+ print(f"[DiVA Export] Complete! Wrote {total_fetched} records to {output_path}", file=sys.stderr)
396
+
397
+ return {
398
+ "success": True,
399
+ "file_path": str(output_file.absolute()),
400
+ "format": format,
401
+ "records_exported": total_fetched,
402
+ "file_size_bytes": file_size
403
+ }
404
+
405
+ def get_view(
406
+ self,
407
+ view_name: str,
408
+ aggregation: str = "detail",
409
+ **kwargs: Any,
410
+ ) -> Dict[str, Any]:
411
+ """
412
+ Get data from a specific view.
413
+
414
+ Args:
415
+ view_name: Name of the view (e.g., 'authorizations', 'settlements')
416
+ aggregation: Aggregation level ('detail', 'day', 'week', 'month')
417
+ **kwargs: Additional query parameters (program, filters, count, sort_by, etc.)
418
+ Note: For date filtering, use filters with the actual date field name.
419
+ Example: filters={"transaction_timestamp": ">=2023-10-20"}
420
+
421
+ Returns:
422
+ API response containing records and metadata
423
+ """
424
+ # Validate filter fields before making the request
425
+ # Note: Validation is disabled to allow API to determine valid fields
426
+ # filters = kwargs.get("filters")
427
+ # if filters:
428
+ # self._validate_filters(view_name, aggregation, filters)
429
+
430
+ # Check response size estimate
431
+ count = kwargs.get("count", 10000) # Default is 10,000 per DiVA API
432
+ fields = kwargs.get("fields")
433
+ estimated_tokens, warning = self._estimate_response_size(view_name, count, fields)
434
+
435
+ if warning:
436
+ # Log warning but don't block the request
437
+ print(f"[DiVA Client Warning] {warning}", file=sys.stderr)
438
+
439
+ endpoint = f"/views/{view_name}/{aggregation}"
440
+ params = self._build_query_params(**kwargs)
441
+ return self._make_request(endpoint, params)
442
+
443
+ def get_views_list(self) -> Dict[str, Any]:
444
+ """Get list of all available views with metadata."""
445
+ return self._make_request("/views")
446
+
447
+ def get_view_schema(self, view_name: str, aggregation: str = "detail") -> Dict[str, Any]:
448
+ """
449
+ Get the schema for a specific view.
450
+
451
+ Args:
452
+ view_name: Name of the view
453
+ aggregation: Aggregation level
454
+
455
+ Returns:
456
+ Schema definition with field types and descriptions
457
+ """
458
+ endpoint = f"/views/{view_name}/{aggregation}/schema"
459
+ return self._make_request(endpoint)
460
+
461
+ def close(self) -> None:
462
+ """Close the HTTP client."""
463
+ self.client.close()
464
+
465
+ def __enter__(self):
466
+ """Context manager entry."""
467
+ return self
468
+
469
+ def __exit__(self, exc_type, exc_val, exc_tb):
470
+ """Context manager exit."""
471
+ self.close()
@@ -0,0 +1,131 @@
1
+ """Transaction embedding generation for RAG capabilities."""
2
+
3
+ import sys
4
+ from typing import Any, Dict, List
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+
8
+ class TransactionEmbedder:
9
+ """Generates embeddings for financial transactions."""
10
+
11
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
12
+ """
13
+ Initialize the transaction embedder.
14
+
15
+ Args:
16
+ model_name: Name of the sentence-transformers model to use.
17
+ Default: all-MiniLM-L6-v2 (fast, lightweight, 384 dimensions)
18
+ """
19
+ print(f"[Embeddings] Loading model: {model_name}...", file=sys.stderr)
20
+ self.model = SentenceTransformer(model_name)
21
+ self.model_name = model_name
22
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
23
+ print(f"[Embeddings] Model loaded. Embedding dimension: {self.embedding_dim}", file=sys.stderr)
24
+
25
+ def format_transaction_text(self, transaction: Dict[str, Any]) -> str:
26
+ """
27
+ Format a transaction into text for embedding.
28
+
29
+ Args:
30
+ transaction: Transaction dictionary from DiVA API
31
+
32
+ Returns:
33
+ Formatted text string combining key transaction attributes
34
+ """
35
+ parts = []
36
+
37
+ # Merchant name (most important)
38
+ merchant = transaction.get("merchant_name", transaction.get("acquirer_merchant_name", ""))
39
+ if merchant:
40
+ parts.append(f"Merchant: {merchant}")
41
+
42
+ # Transaction amount
43
+ amount = transaction.get("transaction_amount")
44
+ if amount is not None:
45
+ currency_code = transaction.get("currency_code", "USD")
46
+ parts.append(f"Amount: {amount} {currency_code}")
47
+
48
+ # Transaction type/status
49
+ txn_type = transaction.get("transaction_type")
50
+ if txn_type:
51
+ parts.append(f"Type: {txn_type}")
52
+
53
+ state = transaction.get("state", transaction.get("transaction_status"))
54
+ if state:
55
+ parts.append(f"Status: {state}")
56
+
57
+ # Merchant category
58
+ mcc = transaction.get("merchant_category_code")
59
+ if mcc:
60
+ parts.append(f"MCC: {mcc}")
61
+
62
+ # Cardholder presence
63
+ card_presence = transaction.get("card_presence_indicator")
64
+ if card_presence:
65
+ parts.append(f"Card Presence: {card_presence}")
66
+
67
+ # Network
68
+ network = transaction.get("network")
69
+ if network:
70
+ parts.append(f"Network: {network}")
71
+
72
+ # If we have very little info, at least include tokens
73
+ if len(parts) < 2:
74
+ txn_token = transaction.get("transaction_token")
75
+ if txn_token:
76
+ parts.append(f"Transaction: {txn_token}")
77
+
78
+ return " | ".join(parts)
79
+
80
+ def embed_transaction(self, transaction: Dict[str, Any]) -> List[float]:
81
+ """
82
+ Generate embedding for a single transaction.
83
+
84
+ Args:
85
+ transaction: Transaction dictionary
86
+
87
+ Returns:
88
+ Embedding vector as list of floats
89
+ """
90
+ text = self.format_transaction_text(transaction)
91
+ embedding = self.model.encode(text, convert_to_numpy=True)
92
+ return embedding.tolist()
93
+
94
+ def embed_transactions_batch(self, transactions: List[Dict[str, Any]]) -> List[List[float]]:
95
+ """
96
+ Generate embeddings for multiple transactions efficiently.
97
+
98
+ Args:
99
+ transactions: List of transaction dictionaries
100
+
101
+ Returns:
102
+ List of embedding vectors
103
+ """
104
+ texts = [self.format_transaction_text(txn) for txn in transactions]
105
+ embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
106
+ return embeddings.tolist()
107
+
108
+ def embed_query(self, query: str) -> List[float]:
109
+ """
110
+ Generate embedding for a search query.
111
+
112
+ Args:
113
+ query: Natural language search query
114
+
115
+ Returns:
116
+ Embedding vector as list of floats
117
+ """
118
+ embedding = self.model.encode(query, convert_to_numpy=True)
119
+ return embedding.tolist()
120
+
121
+
122
+ # Global embedder instance (lazy-loaded)
123
+ _embedder: TransactionEmbedder | None = None
124
+
125
+
126
+ def get_embedder() -> TransactionEmbedder:
127
+ """Get or create the global transaction embedder instance."""
128
+ global _embedder
129
+ if _embedder is None:
130
+ _embedder = TransactionEmbedder()
131
+ return _embedder