api-service-handler 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,219 @@
1
+ """Usage tracking and concurrent usage management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from collections import defaultdict
7
+ from contextlib import asynccontextmanager
8
+ from datetime import datetime, timezone
9
+ from typing import AsyncIterator, Optional
10
+
11
+ from .enums import Provider, KeyStatus
12
+ from .exceptions import MaxConcurrentExceededError, KeyNotFoundError
13
+ from .models import APIKey, UsageStats
14
+ from .storage.base import StorageBackend
15
+
16
+
17
+ class UsageTracker:
18
+ """Tracks API key usage including counts and concurrent usage.
19
+
20
+ Provides:
21
+ - Usage recording (daily, monthly, total increments)
22
+ - Concurrent usage acquire/release with max enforcement
23
+ - Context manager for safe concurrent usage
24
+ - Usage statistics aggregation
25
+ """
26
+
27
+ def __init__(self, storage: StorageBackend) -> None:
28
+ """Initialize the usage tracker.
29
+
30
+ Args:
31
+ storage: The storage backend for persisting usage data.
32
+ """
33
+ self._storage = storage
34
+ # Per-key locks make the check-and-increment in acquire() atomic.
35
+ self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
36
+
37
+ async def record_usage(self, key_id: str, count: int = 1) -> None:
38
+ """Record usage for a key.
39
+
40
+ Increments daily, monthly, and total counters.
41
+
42
+ Args:
43
+ key_id: The key to record usage for.
44
+ count: Number of uses to record (default: 1).
45
+ """
46
+ await self._storage.increment_usage(
47
+ key_id, daily=count, monthly=count, total=count
48
+ )
49
+
50
+ async def acquire(self, key_id: str) -> int:
51
+ """Acquire concurrent usage slot for a key.
52
+
53
+ Increments the concurrent_usage counter by 1.
54
+ Checks max_concurrent limit before acquiring.
55
+
56
+ Args:
57
+ key_id: The key to acquire a slot for.
58
+
59
+ Returns:
60
+ The new concurrent_usage value.
61
+
62
+ Raises:
63
+ MaxConcurrentExceededError: If max_concurrent would be exceeded.
64
+ KeyNotFoundError: If the key doesn't exist.
65
+ """
66
+ async with self._locks[key_id]:
67
+ key = await self._storage.get_key(key_id)
68
+
69
+ if (
70
+ key.max_concurrent is not None
71
+ and key.concurrent_usage >= key.max_concurrent
72
+ ):
73
+ raise MaxConcurrentExceededError(
74
+ key_id=key_id,
75
+ max_concurrent=key.max_concurrent,
76
+ )
77
+
78
+ return await self._storage.update_concurrent_usage(key_id, delta=1)
79
+
80
+ async def release(self, key_id: str) -> int:
81
+ """Release a concurrent usage slot for a key.
82
+
83
+ Decrements the concurrent_usage counter by 1 (min 0).
84
+
85
+ Args:
86
+ key_id: The key to release a slot for.
87
+
88
+ Returns:
89
+ The new concurrent_usage value.
90
+ """
91
+ return await self._storage.update_concurrent_usage(key_id, delta=-1)
92
+
93
+ @asynccontextmanager
94
+ async def use(self, key_id: str, record: bool = True) -> AsyncIterator[APIKey]:
95
+ """Context manager for using a key with automatic acquire/release.
96
+
97
+ Acquires a concurrent slot on entry, releases on exit.
98
+ Optionally records usage on successful exit.
99
+
100
+ Args:
101
+ key_id: The key to use.
102
+ record: If True, record a usage event on successful exit.
103
+
104
+ Yields:
105
+ The APIKey being used.
106
+
107
+ Example:
108
+ async with tracker.use(key.id) as key:
109
+ # concurrent_usage is incremented
110
+ response = await call_api(key.key_value)
111
+ # concurrent_usage is decremented, usage recorded
112
+ """
113
+ await self.acquire(key_id)
114
+ try:
115
+ key = await self._storage.get_key(key_id)
116
+ yield key
117
+ # Only record usage if the block completed successfully
118
+ if record:
119
+ await self.record_usage(key_id)
120
+ finally:
121
+ await self.release(key_id)
122
+
123
+ async def get_concurrent_usage(self, key_id: str) -> int:
124
+ """Get the current concurrent usage count for a key.
125
+
126
+ Args:
127
+ key_id: The key to check.
128
+
129
+ Returns:
130
+ Current concurrent_usage value.
131
+ """
132
+ key = await self._storage.get_key(key_id)
133
+ return key.concurrent_usage
134
+
135
+ async def get_usage_stats(self, key_id: str) -> UsageStats:
136
+ """Get aggregated usage statistics for a key.
137
+
138
+ Args:
139
+ key_id: The key to get stats for.
140
+
141
+ Returns:
142
+ UsageStats with all usage information.
143
+ """
144
+ key = await self._storage.get_key(key_id)
145
+
146
+ daily_remaining = None
147
+ if key.daily_limit is not None:
148
+ daily_remaining = max(0, key.daily_limit - key.daily_usage_count)
149
+
150
+ monthly_remaining = None
151
+ if key.monthly_limit is not None:
152
+ monthly_remaining = max(0, key.monthly_limit - key.monthly_usage_count)
153
+
154
+ return UsageStats(
155
+ key_id=key.id,
156
+ provider=key.provider,
157
+ alias=key.alias,
158
+ daily_usage_count=key.daily_usage_count,
159
+ monthly_usage_count=key.monthly_usage_count,
160
+ total_usage_count=key.total_usage_count,
161
+ concurrent_usage=key.concurrent_usage,
162
+ daily_limit=key.daily_limit,
163
+ monthly_limit=key.monthly_limit,
164
+ max_concurrent=key.max_concurrent,
165
+ daily_remaining=daily_remaining,
166
+ monthly_remaining=monthly_remaining,
167
+ last_used_at=key.last_used_at,
168
+ status=key.status,
169
+ )
170
+
171
+ async def get_provider_stats(self, provider: Provider | str) -> list[UsageStats]:
172
+ """Get usage stats for all keys of a provider.
173
+
174
+ Args:
175
+ provider: The provider to get stats for.
176
+
177
+ Returns:
178
+ List of UsageStats for all keys of the provider.
179
+ """
180
+ if isinstance(provider, str):
181
+ provider = Provider(provider)
182
+
183
+ keys = await self._storage.get_keys_by_provider(provider)
184
+ stats = []
185
+ for key in keys:
186
+ stats.append(await self.get_usage_stats(key.id))
187
+ return stats
188
+
189
+ async def reset_concurrent(self, key_id: str) -> None:
190
+ """Reset concurrent usage to 0 for a key.
191
+
192
+ Useful for recovery after crashes where release wasn't called.
193
+
194
+ Args:
195
+ key_id: The key to reset.
196
+ """
197
+ key = await self._storage.get_key(key_id)
198
+ if key.concurrent_usage > 0:
199
+ # Set to 0 by decrementing current value
200
+ await self._storage.update_concurrent_usage(key_id, delta=-key.concurrent_usage)
201
+
202
+ async def reset_all_concurrent(self) -> int:
203
+ """Reset concurrent usage to 0 for all keys.
204
+
205
+ Useful for application startup recovery.
206
+
207
+ Returns:
208
+ Number of keys reset.
209
+ """
210
+ from .models import KeyFilter
211
+ all_keys = await self._storage.get_all_keys()
212
+ count = 0
213
+ for key in all_keys:
214
+ if key.concurrent_usage > 0:
215
+ await self._storage.update_concurrent_usage(
216
+ key.id, delta=-key.concurrent_usage
217
+ )
218
+ count += 1
219
+ return count
@@ -0,0 +1,322 @@
1
+ """Utility helpers for the api-service-handler library.
2
+
3
+ Provides common helper functions used across the library including
4
+ ID generation, date/time utilities, string masking, connection string
5
+ validation, metadata filtering, and more.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import re
11
+ import uuid
12
+ from collections.abc import Iterator
13
+ from datetime import date, datetime, timedelta, timezone
14
+ from typing import Any
15
+
16
+
17
+ def generate_id() -> str:
18
+ """Generate a new UUID4 string.
19
+
20
+ Returns:
21
+ A lowercase hex UUID4 string (e.g. ``'a1b2c3d4-...'``).
22
+ """
23
+ return str(uuid.uuid4())
24
+
25
+
26
+ def now_utc() -> datetime:
27
+ """Return the current UTC datetime (timezone-aware).
28
+
29
+ Returns:
30
+ A :class:`datetime.datetime` with ``tzinfo`` set to :data:`datetime.timezone.utc`.
31
+ """
32
+ return datetime.now(tz=timezone.utc)
33
+
34
+
35
+ def today_utc() -> date:
36
+ """Return the current UTC date.
37
+
38
+ Returns:
39
+ A :class:`datetime.date` representing today in UTC.
40
+ """
41
+ return now_utc().date()
42
+
43
+
44
+ def is_same_day(d1: date, d2: date) -> bool:
45
+ """Check whether two dates represent the same calendar day.
46
+
47
+ Args:
48
+ d1: First date.
49
+ d2: Second date.
50
+
51
+ Returns:
52
+ ``True`` if *d1* and *d2* have the same year, month, and day.
53
+ """
54
+ return d1.year == d2.year and d1.month == d2.month and d1.day == d2.day
55
+
56
+
57
+ def is_same_month(d1: date, d2: date) -> bool:
58
+ """Check whether two dates fall in the same month and year.
59
+
60
+ Args:
61
+ d1: First date.
62
+ d2: Second date.
63
+
64
+ Returns:
65
+ ``True`` if *d1* and *d2* share the same year and month.
66
+ """
67
+ return d1.year == d2.year and d1.month == d2.month
68
+
69
+
70
+ def mask_key_value(key: str, visible_chars: int = 8) -> str:
71
+ """Mask a sensitive string, showing only leading characters.
72
+
73
+ If *key* is shorter than *visible_chars*, the first 2 characters are
74
+ shown followed by ``'***'``. Otherwise the first *visible_chars*
75
+ characters are shown followed by ``'***'``.
76
+
77
+ Args:
78
+ key: The sensitive string to mask.
79
+ visible_chars: Number of leading characters to keep visible.
80
+ Defaults to ``8``.
81
+
82
+ Returns:
83
+ The masked string.
84
+
85
+ Examples:
86
+ >>> mask_key_value("sk-abc123456789")
87
+ 'sk-abc12***'
88
+ >>> mask_key_value("short")
89
+ 'sh***'
90
+ """
91
+ if not key:
92
+ return "***"
93
+
94
+ if len(key) < visible_chars:
95
+ return key[:2] + "***"
96
+
97
+ return key[:visible_chars] + "***"
98
+
99
+
100
+ # ---------------------------------------------------------------------------
101
+ # Connection-string validation
102
+ # ---------------------------------------------------------------------------
103
+
104
+ _BACKEND_PREFIXES: dict[str, list[str]] = {
105
+ "sqlite": ["sqlite:///", "sqlite://"],
106
+ "mongo": ["mongodb://", "mongodb+srv://"],
107
+ "mongodb": ["mongodb://", "mongodb+srv://"],
108
+ "pg": ["postgresql://", "postgresql+asyncpg://", "postgres://"],
109
+ "postgres": ["postgresql://", "postgresql+asyncpg://", "postgres://"],
110
+ "postgresql": ["postgresql://", "postgresql+asyncpg://", "postgres://"],
111
+ }
112
+
113
+
114
+ def validate_connection_string(backend: str, connection_string: str) -> bool:
115
+ """Validate that a connection string format matches the given backend.
116
+
117
+ Performs a basic prefix check — it does **not** attempt to connect or
118
+ parse every component of the URI.
119
+
120
+ Supported backends (case-insensitive):
121
+
122
+ * ``sqlite`` — expects ``sqlite://``
123
+ * ``mongo`` / ``mongodb`` — expects ``mongodb://`` or ``mongodb+srv://``
124
+ * ``pg`` / ``postgres`` / ``postgresql`` — expects ``postgresql://``,
125
+ ``postgresql+asyncpg://``, or ``postgres://``
126
+
127
+ Args:
128
+ backend: The database backend name.
129
+ connection_string: The connection URI to validate.
130
+
131
+ Returns:
132
+ ``True`` if the connection string starts with an expected prefix
133
+ for the given backend, ``False`` otherwise (including for unknown
134
+ backends).
135
+ """
136
+ if not backend or not connection_string:
137
+ return False
138
+
139
+ normalised = backend.strip().lower()
140
+ prefixes = _BACKEND_PREFIXES.get(normalised)
141
+
142
+ if prefixes is None:
143
+ return False
144
+
145
+ conn_lower = connection_string.strip().lower()
146
+ return any(conn_lower.startswith(prefix) for prefix in prefixes)
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # Metadata filter parsing
151
+ # ---------------------------------------------------------------------------
152
+
153
+
154
+ def _resolve_dotted_key(data: dict[str, Any], dotted_key: str) -> tuple[bool, Any]:
155
+ """Resolve a dot-notation key against a nested dictionary.
156
+
157
+ Args:
158
+ data: The dictionary to traverse.
159
+ dotted_key: A key that may contain dots to indicate nesting
160
+ (e.g. ``'team.name'``).
161
+
162
+ Returns:
163
+ A ``(found, value)`` tuple. *found* is ``False`` when the key
164
+ path does not exist in *data*.
165
+ """
166
+ parts = dotted_key.split(".")
167
+ current: Any = data
168
+ for part in parts:
169
+ if not isinstance(current, dict) or part not in current:
170
+ return False, None
171
+ current = current[part]
172
+ return True, current
173
+
174
+
175
+ def parse_metadata_filter(metadata_filter: dict[str, Any], metadata: dict[str, Any]) -> bool:
176
+ """Check if all key-value pairs in *metadata_filter* exist in *metadata*.
177
+
178
+ Keys in *metadata_filter* support **dot notation** to reach into nested
179
+ dictionaries. For example, the filter ``{"team.name": "backend"}``
180
+ matches metadata ``{"team": {"name": "backend"}}``.
181
+
182
+ Args:
183
+ metadata_filter: The filter criteria — a flat dict whose keys may
184
+ use dot notation for nested lookups.
185
+ metadata: The metadata dict to match against.
186
+
187
+ Returns:
188
+ ``True`` if every key-value pair in *metadata_filter* is found
189
+ (and equal) in *metadata*.
190
+
191
+ Examples:
192
+ >>> parse_metadata_filter({"env": "prod"}, {"env": "prod", "version": 2})
193
+ True
194
+ >>> parse_metadata_filter({"team.name": "api"}, {"team": {"name": "api"}})
195
+ True
196
+ >>> parse_metadata_filter({"env": "prod"}, {"env": "staging"})
197
+ False
198
+ """
199
+ if not metadata_filter:
200
+ return True
201
+
202
+ for key, expected_value in metadata_filter.items():
203
+ found, actual_value = _resolve_dotted_key(metadata, key)
204
+ if not found or actual_value != expected_value:
205
+ return False
206
+
207
+ return True
208
+
209
+
210
+ # ---------------------------------------------------------------------------
211
+ # Timedelta formatting
212
+ # ---------------------------------------------------------------------------
213
+
214
+
215
+ def format_timedelta(td: timedelta) -> str:
216
+ """Format a :class:`~datetime.timedelta` as a human-readable string.
217
+
218
+ The output uses the largest applicable units from days, hours, and
219
+ minutes (e.g. ``'2h 30m'``, ``'5d 12h'``, ``'45m'``).
220
+
221
+ Negative timedeltas are prefixed with ``'-'``.
222
+
223
+ Args:
224
+ td: The timedelta to format.
225
+
226
+ Returns:
227
+ A compact human-readable duration string. Returns ``'0m'`` for
228
+ a zero-length timedelta.
229
+
230
+ Examples:
231
+ >>> from datetime import timedelta
232
+ >>> format_timedelta(timedelta(hours=2, minutes=30))
233
+ '2h 30m'
234
+ >>> format_timedelta(timedelta(days=5, hours=12))
235
+ '5d 12h'
236
+ >>> format_timedelta(timedelta(seconds=90))
237
+ '1m'
238
+ >>> format_timedelta(timedelta(0))
239
+ '0m'
240
+ """
241
+ total_seconds = int(td.total_seconds())
242
+
243
+ negative = total_seconds < 0
244
+ total_seconds = abs(total_seconds)
245
+
246
+ days, remainder = divmod(total_seconds, 86400)
247
+ hours, remainder = divmod(remainder, 3600)
248
+ minutes = remainder // 60
249
+
250
+ parts: list[str] = []
251
+ if days:
252
+ parts.append(f"{days}d")
253
+ if hours:
254
+ parts.append(f"{hours}h")
255
+ if minutes:
256
+ parts.append(f"{minutes}m")
257
+
258
+ if not parts:
259
+ return "0m"
260
+
261
+ result = " ".join(parts)
262
+ return f"-{result}" if negative else result
263
+
264
+
265
+ # ---------------------------------------------------------------------------
266
+ # List chunking
267
+ # ---------------------------------------------------------------------------
268
+
269
+
270
+ def chunks(lst: list[Any], n: int) -> Iterator[list[Any]]:
271
+ """Yield successive *n*-sized chunks from *lst*.
272
+
273
+ The last chunk may contain fewer than *n* elements if the list length
274
+ is not evenly divisible.
275
+
276
+ Args:
277
+ lst: The list to split.
278
+ n: Maximum chunk size. Must be at least ``1``.
279
+
280
+ Yields:
281
+ Sub-lists of at most *n* elements.
282
+
283
+ Raises:
284
+ ValueError: If *n* is less than ``1``.
285
+
286
+ Examples:
287
+ >>> list(chunks([1, 2, 3, 4, 5], 2))
288
+ [[1, 2], [3, 4], [5]]
289
+ """
290
+ if n < 1:
291
+ raise ValueError(f"Chunk size must be >= 1, got {n}")
292
+
293
+ for i in range(0, len(lst), n):
294
+ yield lst[i : i + n]
295
+
296
+
297
+ # ---------------------------------------------------------------------------
298
+ # Provider name sanitization
299
+ # ---------------------------------------------------------------------------
300
+
301
+ _SANITIZE_PATTERN = re.compile(r"[\s\-]+")
302
+
303
+
304
+ def sanitize_provider_name(name: str) -> str:
305
+ """Normalize a provider name to a canonical form.
306
+
307
+ The name is lowercased, stripped of leading/trailing whitespace, and
308
+ any internal spaces or hyphens are replaced with underscores.
309
+
310
+ Args:
311
+ name: The raw provider name.
312
+
313
+ Returns:
314
+ The sanitized provider name.
315
+
316
+ Examples:
317
+ >>> sanitize_provider_name(" Open AI ")
318
+ 'open_ai'
319
+ >>> sanitize_provider_name("azure-openai")
320
+ 'azure_openai'
321
+ """
322
+ return _SANITIZE_PATTERN.sub("_", name.strip().lower())