api-service-handler 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,429 @@
1
+ """PostgreSQL storage backend using asyncpg."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from datetime import date, datetime, timezone
7
+ from typing import Optional
8
+
9
+ from ..enums import KeyStatus, Provider
10
+ from ..exceptions import DuplicateKeyError, KeyNotFoundError, StorageConnectionError
11
+ from ..models import APIKey, BulkOperationResult, KeyFilter, KeyUpdateRequest
12
+ from .base import StorageBackend
13
+
14
+ def _try_import_asyncpg():
15
+ """Lazy import asyncpg with helpful error."""
16
+ try:
17
+ import asyncpg
18
+ return asyncpg
19
+ except ImportError:
20
+ raise ImportError(
21
+ "PostgreSQL backend requires 'asyncpg'. "
22
+ "Install with: pip install api-service-handler[postgresql]"
23
+ )
24
+
25
+ class PostgreSQLStorageBackend(StorageBackend):
26
+ """PostgreSQL storage backend using asyncpg.
27
+
28
+ Enterprise grade relational storage.
29
+ Install with: pip install api-service-handler[postgresql]
30
+ """
31
+
32
+ def __init__(self, connection_string: str) -> None:
33
+ """Initialize PostgreSQL backend.
34
+
35
+ Args:
36
+ connection_string: PostgreSQL connection URI, e.g., 'postgresql://user:pass@localhost:5432/my_db'.
37
+ """
38
+ self._asyncpg = _try_import_asyncpg()
39
+ # Handle 'postgres://' vs 'postgresql://'
40
+ if connection_string.startswith('postgres://'):
41
+ self._connection_string = 'postgresql://' + connection_string[11:]
42
+ else:
43
+ self._connection_string = connection_string
44
+ self._pool = None
45
+ self._initialized = False
46
+
47
+ async def initialize(self) -> None:
48
+ """Create the PostgreSQL connection pool and tables."""
49
+ try:
50
+ self._pool = await self._asyncpg.create_pool(self._connection_string, min_size=2, max_size=10)
51
+
52
+ async with self._pool.acquire() as conn:
53
+ await conn.execute("""
54
+ CREATE TABLE IF NOT EXISTS api_keys (
55
+ id TEXT PRIMARY KEY,
56
+ provider TEXT NOT NULL,
57
+ key_value TEXT NOT NULL,
58
+ alias TEXT,
59
+ status TEXT NOT NULL DEFAULT 'active',
60
+ environment TEXT NOT NULL DEFAULT 'production',
61
+ daily_limit INTEGER,
62
+ monthly_limit INTEGER,
63
+ daily_usage_count INTEGER NOT NULL DEFAULT 0,
64
+ monthly_usage_count INTEGER NOT NULL DEFAULT 0,
65
+ total_usage_count INTEGER NOT NULL DEFAULT 0,
66
+ concurrent_usage INTEGER NOT NULL DEFAULT 0,
67
+ max_concurrent INTEGER,
68
+ weight INTEGER NOT NULL DEFAULT 1,
69
+ priority INTEGER NOT NULL DEFAULT 0,
70
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
71
+ tags JSONB NOT NULL DEFAULT '[]'::jsonb,
72
+ last_used_at TIMESTAMPTZ,
73
+ last_reset_daily DATE NOT NULL,
74
+ last_reset_monthly DATE NOT NULL,
75
+ expires_at TIMESTAMPTZ,
76
+ created_at TIMESTAMPTZ NOT NULL,
77
+ updated_at TIMESTAMPTZ NOT NULL
78
+ )
79
+ """)
80
+
81
+ # Create indexes
82
+ await conn.execute("CREATE INDEX IF NOT EXISTS idx_pg_api_keys_provider ON api_keys(provider)")
83
+ await conn.execute("CREATE INDEX IF NOT EXISTS idx_pg_api_keys_status ON api_keys(status)")
84
+ await conn.execute("CREATE INDEX IF NOT EXISTS idx_pg_api_keys_provider_status ON api_keys(provider, status)")
85
+
86
+ # Unique constraint for active keys
87
+ # PostgreSQL requires a partial unique index
88
+ await conn.execute("""
89
+ CREATE UNIQUE INDEX IF NOT EXISTS idx_pg_api_keys_unique_active
90
+ ON api_keys (key_value, provider)
91
+ WHERE status != 'revoked'
92
+ """)
93
+
94
+ self._initialized = True
95
+ except Exception as e:
96
+ raise StorageConnectionError(backend="postgresql", detail=str(e))
97
+
98
+ async def close(self) -> None:
99
+ """Close the PostgreSQL connection pool."""
100
+ if self._pool:
101
+ await self._pool.close()
102
+ self._pool = None
103
+ self._initialized = False
104
+
105
+ def _row_to_key(self, row) -> APIKey:
106
+ """Convert a database row to an APIKey model."""
107
+ data = dict(row)
108
+
109
+ # Parse JSON fields - asyncpg returns strings for JSONB if not configured, or parsed dicts if configured
110
+ if isinstance(data.get("metadata"), str):
111
+ data["metadata"] = json.loads(data["metadata"])
112
+ if isinstance(data.get("tags"), str):
113
+ data["tags"] = json.loads(data["tags"])
114
+
115
+ # Ensure UTC datetimes
116
+ for dt_field in ("last_used_at", "expires_at", "created_at", "updated_at"):
117
+ if data.get(dt_field):
118
+ dt = data[dt_field]
119
+ if dt.tzinfo is None:
120
+ data[dt_field] = dt.replace(tzinfo=timezone.utc)
121
+
122
+ return APIKey(**data)
123
+
124
+ # ── CRUD ───────────────────────────────────────────────────────────────
125
+
126
+ async def add_key(self, key: APIKey) -> APIKey:
127
+ provider_val = key.provider.value if hasattr(key.provider, "value") else key.provider
128
+ status_val = key.status.value if hasattr(key.status, "value") else key.status
129
+ env_val = key.environment.value if hasattr(key.environment, "value") else key.environment
130
+
131
+ async with self._pool.acquire() as conn:
132
+ try:
133
+ await conn.execute(
134
+ """
135
+ INSERT INTO api_keys (
136
+ id, provider, key_value, alias, status, environment,
137
+ daily_limit, monthly_limit, daily_usage_count, monthly_usage_count, total_usage_count,
138
+ concurrent_usage, max_concurrent, weight, priority,
139
+ metadata, tags, last_used_at, last_reset_daily, last_reset_monthly,
140
+ expires_at, created_at, updated_at
141
+ ) VALUES (
142
+ $1, $2, $3, $4, $5, $6,
143
+ $7, $8, $9, $10, $11,
144
+ $12, $13, $14, $15,
145
+ $16::jsonb, $17::jsonb, $18, $19, $20,
146
+ $21, $22, $23
147
+ )
148
+ """,
149
+ key.id, provider_val, key.key_value, key.alias, status_val, env_val,
150
+ key.daily_limit, key.monthly_limit, key.daily_usage_count, key.monthly_usage_count, key.total_usage_count,
151
+ key.concurrent_usage, key.max_concurrent, key.weight, key.priority,
152
+ json.dumps(key.metadata), json.dumps(key.tags), key.last_used_at, key.last_reset_daily, key.last_reset_monthly,
153
+ key.expires_at, key.created_at, key.updated_at
154
+ )
155
+ return key
156
+ except self._asyncpg.exceptions.UniqueViolationError:
157
+ raise DuplicateKeyError(provider=provider_val, key_value=key.key_value)
158
+
159
+ async def get_key(self, key_id: str) -> APIKey:
160
+ async with self._pool.acquire() as conn:
161
+ row = await conn.fetchrow("SELECT * FROM api_keys WHERE id = $1", key_id)
162
+ if not row:
163
+ raise KeyNotFoundError(key_id=key_id)
164
+ return self._row_to_key(row)
165
+
166
+ async def get_keys_by_provider(self, provider: Provider) -> list[APIKey]:
167
+ provider_val = provider.value if hasattr(provider, "value") else provider
168
+ async with self._pool.acquire() as conn:
169
+ rows = await conn.fetch("SELECT * FROM api_keys WHERE provider = $1", provider_val)
170
+ return [self._row_to_key(row) for row in rows]
171
+
172
+ async def get_all_keys(self, key_filter: Optional[KeyFilter] = None) -> list[APIKey]:
173
+ query = "SELECT * FROM api_keys"
174
+ conditions = []
175
+ args = []
176
+ idx = 1
177
+
178
+ if key_filter:
179
+ if key_filter.provider is not None:
180
+ provider_val = key_filter.provider.value if hasattr(key_filter.provider, "value") else key_filter.provider
181
+ conditions.append(f"provider = ${idx}")
182
+ args.append(provider_val)
183
+ idx += 1
184
+ if key_filter.status is not None:
185
+ status_val = key_filter.status.value if hasattr(key_filter.status, "value") else key_filter.status
186
+ conditions.append(f"status = ${idx}")
187
+ args.append(status_val)
188
+ idx += 1
189
+ if key_filter.environment is not None:
190
+ env_val = key_filter.environment.value if hasattr(key_filter.environment, "value") else key_filter.environment
191
+ conditions.append(f"environment = ${idx}")
192
+ args.append(env_val)
193
+ idx += 1
194
+ if key_filter.alias_contains:
195
+ conditions.append(f"alias ILIKE ${idx}")
196
+ args.append(f"%{key_filter.alias_contains}%")
197
+ idx += 1
198
+
199
+ if conditions:
200
+ query += " WHERE " + " AND ".join(conditions)
201
+
202
+ query += " ORDER BY priority ASC, created_at ASC"
203
+
204
+ async with self._pool.acquire() as conn:
205
+ rows = await conn.fetch(query, *args)
206
+ keys = [self._row_to_key(row) for row in rows]
207
+
208
+ # In-memory filtering for JSON/complex checks (could be moved to SQL later for optimization)
209
+ if key_filter:
210
+ if key_filter.tags:
211
+ keys = [k for k in keys if any(tag in k.tags for tag in key_filter.tags)]
212
+ if key_filter.metadata_filter:
213
+ keys = [k for k in keys if all(k.metadata.get(mk) == mv for mk, mv in key_filter.metadata_filter.items())]
214
+ if key_filter.has_capacity is not None:
215
+ keys = [k for k in keys if k.has_capacity == key_filter.has_capacity]
216
+
217
+ return keys
218
+
219
+ async def update_key(self, key_id: str, updates: KeyUpdateRequest) -> APIKey:
220
+ # Check exists
221
+ await self.get_key(key_id)
222
+
223
+ update_data = updates.model_dump(exclude_unset=True)
224
+ if not update_data:
225
+ return await self.get_key(key_id)
226
+
227
+ update_data["updated_at"] = datetime.now(timezone.utc)
228
+
229
+ # Format complex types
230
+ if "metadata" in update_data:
231
+ update_data["metadata"] = json.dumps(update_data["metadata"])
232
+ if "tags" in update_data:
233
+ update_data["tags"] = json.dumps(update_data["tags"])
234
+ if "status" in update_data:
235
+ val = update_data["status"]
236
+ update_data["status"] = val.value if hasattr(val, "value") else val
237
+ if "environment" in update_data:
238
+ val = update_data["environment"]
239
+ update_data["environment"] = val.value if hasattr(val, "value") else val
240
+
241
+ set_clauses = []
242
+ args = []
243
+ idx = 1
244
+
245
+ for k, v in update_data.items():
246
+ if k in ("metadata", "tags"):
247
+ set_clauses.append(f"{k} = ${idx}::jsonb")
248
+ else:
249
+ set_clauses.append(f"{k} = ${idx}")
250
+ args.append(v)
251
+ idx += 1
252
+
253
+ args.append(key_id)
254
+ query = f"UPDATE api_keys SET {', '.join(set_clauses)} WHERE id = ${idx}"
255
+
256
+ async with self._pool.acquire() as conn:
257
+ await conn.execute(query, *args)
258
+
259
+ return await self.get_key(key_id)
260
+
261
+ async def delete_key(self, key_id: str, soft: bool = True) -> bool:
262
+ await self.get_key(key_id)
263
+
264
+ async with self._pool.acquire() as conn:
265
+ if soft:
266
+ now = datetime.now(timezone.utc)
267
+ await conn.execute("UPDATE api_keys SET status = 'revoked', updated_at = $1 WHERE id = $2", now, key_id)
268
+ else:
269
+ await conn.execute("DELETE FROM api_keys WHERE id = $1", key_id)
270
+
271
+ return True
272
+
273
+ # ── Usage Tracking ─────────────────────────────────────────────────────
274
+
275
+ async def increment_usage(
276
+ self,
277
+ key_id: str,
278
+ daily: int = 1,
279
+ monthly: int = 1,
280
+ total: int = 1,
281
+ ) -> None:
282
+ now = datetime.now(timezone.utc)
283
+ async with self._pool.acquire() as conn:
284
+ await conn.execute(
285
+ """
286
+ UPDATE api_keys
287
+ SET daily_usage_count = daily_usage_count + $1,
288
+ monthly_usage_count = monthly_usage_count + $2,
289
+ total_usage_count = total_usage_count + $3,
290
+ last_used_at = $4,
291
+ updated_at = $5
292
+ WHERE id = $6
293
+ """,
294
+ daily, monthly, total, now, now, key_id
295
+ )
296
+
297
+ async def update_concurrent_usage(self, key_id: str, delta: int) -> int:
298
+ now = datetime.now(timezone.utc)
299
+ async with self._pool.acquire() as conn:
300
+ # Use RETURNING to get the updated value
301
+ row = await conn.fetchrow(
302
+ """
303
+ UPDATE api_keys
304
+ SET concurrent_usage = GREATEST(0, concurrent_usage + $1),
305
+ updated_at = $2
306
+ WHERE id = $3
307
+ RETURNING concurrent_usage
308
+ """,
309
+ delta, now, key_id
310
+ )
311
+ if not row:
312
+ raise KeyNotFoundError(key_id=key_id)
313
+ return row["concurrent_usage"]
314
+
315
+ async def reset_daily_counts(self, before_date: date) -> int:
316
+ now = datetime.now(timezone.utc)
317
+
318
+ async with self._pool.acquire() as conn:
319
+ result = await conn.execute(
320
+ """
321
+ UPDATE api_keys
322
+ SET daily_usage_count = 0,
323
+ last_reset_daily = $1,
324
+ updated_at = $2
325
+ WHERE last_reset_daily < $1
326
+ """,
327
+ before_date, now
328
+ )
329
+
330
+ # Re-activate rate-limited keys
331
+ await conn.execute(
332
+ """
333
+ UPDATE api_keys
334
+ SET status = 'active'
335
+ WHERE status = 'rate_limited'
336
+ AND (monthly_limit IS NULL OR monthly_usage_count < monthly_limit)
337
+ """
338
+ )
339
+
340
+ return int(result.split()[-1]) if result.startswith("UPDATE") else 0
341
+
342
+ async def reset_monthly_counts(self, before_date: date) -> int:
343
+ now = datetime.now(timezone.utc)
344
+
345
+ async with self._pool.acquire() as conn:
346
+ # PostgreSQL EXTRACT is useful here
347
+ result = await conn.execute(
348
+ """
349
+ UPDATE api_keys
350
+ SET monthly_usage_count = 0,
351
+ last_reset_monthly = $1,
352
+ status = CASE WHEN status = 'rate_limited' THEN 'active' ELSE status END,
353
+ updated_at = $2
354
+ WHERE EXTRACT(MONTH FROM last_reset_monthly) != EXTRACT(MONTH FROM $1::date)
355
+ OR EXTRACT(YEAR FROM last_reset_monthly) != EXTRACT(YEAR FROM $1::date)
356
+ """,
357
+ before_date, now
358
+ )
359
+
360
+ return int(result.split()[-1]) if result.startswith("UPDATE") else 0
361
+
362
+ async def update_last_used(self, key_id: str) -> None:
363
+ now = datetime.now(timezone.utc)
364
+ async with self._pool.acquire() as conn:
365
+ await conn.execute("UPDATE api_keys SET last_used_at = $1, updated_at = $1 WHERE id = $2", now, key_id)
366
+
367
+ # ── Bulk Operations ────────────────────────────────────────────────────
368
+
369
+ async def bulk_add_keys(self, keys: list[APIKey]) -> BulkOperationResult:
370
+ result = BulkOperationResult(total=len(keys))
371
+ for key in keys:
372
+ try:
373
+ await self.add_key(key)
374
+ result.successful += 1
375
+ result.created_ids.append(key.id)
376
+ except Exception as e:
377
+ result.failed += 1
378
+ result.errors.append(f"Key {key.alias or key.id}: {e}")
379
+ return result
380
+
381
+ async def bulk_delete_keys(self, key_ids: list[str], soft: bool = True) -> BulkOperationResult:
382
+ result = BulkOperationResult(total=len(key_ids))
383
+ for key_id in key_ids:
384
+ try:
385
+ await self.delete_key(key_id, soft=soft)
386
+ result.successful += 1
387
+ except Exception as e:
388
+ result.failed += 1
389
+ result.errors.append(f"Key {key_id}: {e}")
390
+ return result
391
+
392
+ # ── Health ─────────────────────────────────────────────────────────────
393
+
394
+ async def health_check(self) -> bool:
395
+ if not self._pool or not self._initialized:
396
+ return False
397
+ try:
398
+ async with self._pool.acquire() as conn:
399
+ await conn.execute("SELECT 1")
400
+ return True
401
+ except Exception:
402
+ return False
403
+
404
+ async def count_keys(
405
+ self,
406
+ provider: Optional[Provider] = None,
407
+ status: Optional[KeyStatus] = None,
408
+ ) -> int:
409
+ query = "SELECT COUNT(*) FROM api_keys"
410
+ conditions = []
411
+ args = []
412
+ idx = 1
413
+
414
+ if provider is not None:
415
+ provider_val = provider.value if hasattr(provider, "value") else provider
416
+ conditions.append(f"provider = ${idx}")
417
+ args.append(provider_val)
418
+ idx += 1
419
+ if status is not None:
420
+ status_val = status.value if hasattr(status, "value") else status
421
+ conditions.append(f"status = ${idx}")
422
+ args.append(status_val)
423
+ idx += 1
424
+
425
+ if conditions:
426
+ query += " WHERE " + " AND ".join(conditions)
427
+
428
+ async with self._pool.acquire() as conn:
429
+ return await conn.fetchval(query, *args)