token-limit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,8 @@
1
+ __version__ = '0.1.0'
2
+
3
+ from token_limit.meter import Meter, PER_MONTH, PER_DAY
4
+ from token_limit.exceptions import LimitExceededException
5
+
6
+ from token_limit.meter import MeterConfig
7
+
8
+ __all__ = ["Meter", "PER_MONTH", "PER_DAY", "LimitExceededException", "MeterConfig"]
token_limit/config.py ADDED
@@ -0,0 +1,35 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable, List, Optional
3
+
4
+ from .transport.http_client import INGEST_URL
5
+ from .types import LLMEvent
6
+
7
+
8
+ @dataclass
9
+ class MeterConfig:
10
+ # ── required ──────────────────────────────────────────────────
11
+ api_key: str # your SaaS API key
12
+ url: str = INGEST_URL # POST endpoint on your backend
13
+
14
+ # ── batching ──────────────────────────────────────────────────
15
+ flush_interval: float = 5.0 # seconds between auto-flushes
16
+ max_batch_size: int = 50 # flush early if queue hits this
17
+ max_queue_size: int = 1000 # drop oldest if queue overflows
18
+
19
+ # ── limit checks ─────────────────────────────────────────────
20
+ limit_check_cache_ttl: float = (
21
+ 5.0 # seconds a check_limit() result is cached per tenant
22
+ )
23
+
24
+ # ── behaviour ─────────────────────────────────────────────────
25
+ raise_on_error: bool = False # if True, re-raise patch exceptions
26
+ debug: bool = False # print captured events to stdout
27
+
28
+ # ── hooks (optional) ─────────────────────────────────────────
29
+ on_event: Optional[Callable[[LLMEvent], None]] = None # called after capture
30
+ on_flush_error: Optional[Callable[[Exception], None]] = None # transport errors
31
+
32
+ # ── patches to install (default = all available) ──────────────
33
+ patches: List[str] = field(
34
+ default_factory=lambda: ["openai", "anthropic", "deepseek", "google","openrouter"]
35
+ )
@@ -0,0 +1,5 @@
1
+ class LimitExceededException(Exception):
2
+ """Raised when a tenant's usage exceeds their configured threshold."""
3
+
4
+ def __init__(self, message: str = "tenant_usage_limit_exceeded"):
5
+ super().__init__(message)
token_limit/meter.py ADDED
@@ -0,0 +1,479 @@
1
+ """
2
+ token_limit.meter
3
+ ~~~~~~~~~~~~~~~~~~~~~
4
+ The single entry point developers import and initialise.
5
+
6
+ Usage
7
+ -----
8
+ from token_limit import Meter, MeterConfig
9
+
10
+ meter = Meter(MeterConfig(
11
+ api_key="sk-...",
12
+ url="https://api.yoursaas.com/v1/ingest",
13
+ ))
14
+ meter.patch_all() # ← one line, all providers auto-instrumented
15
+
16
+ # Per-request tenant override:
17
+ with meter.for_tenant("stripe-inc"):
18
+ response = openai_client.chat.completions.create(...)
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import atexit
24
+ from datetime import datetime, timezone
25
+ import logging
26
+ import threading
27
+ import time
28
+ from contextlib import contextmanager
29
+ from contextvars import ContextVar
30
+ from typing import Dict, Generator, List, Optional, Tuple
31
+
32
+ try:
33
+ import aiohttp as _aiohttp # optional — only needed for async limit checks
34
+ except ImportError: # pragma: no cover
35
+ _aiohttp = None # type: ignore[assignment]
36
+
37
+ from .exceptions import LimitExceededException
38
+
39
+ from .config import MeterConfig
40
+ from .patches import PATCH_REGISTRY
41
+ from .patches._base_patch import BasePatch
42
+ from .transport.http_client import CHECK_LIMIT_URL, HttpClient, SET_LIMIT_URL, _TIMEOUT
43
+ from .transport.queue import EventQueue
44
+ from .types import LLMEvent
45
+
46
+ logger = logging.getLogger("token_limit")
47
+
48
+ # Thread-/async-safe tenant_id override
49
+ _tenant_id_var: ContextVar[Optional[str]] = ContextVar("tenant_id", default=None)
50
+
51
+ PER_DAY = "per_day"
52
+ PER_MONTH = "per_month"
53
+
54
+
55
+ class Meter:
56
+ """
57
+ Central SDK object. One instance per application.
58
+
59
+ Parameters
60
+ ----------
61
+ config : MeterConfig
62
+ All configuration in one place.
63
+ """
64
+
65
+ def __init__(self, config: MeterConfig) -> None:
66
+ self.config = config # meter config
67
+ self._patches: Dict[str, BasePatch] = {}
68
+
69
+ self._http = HttpClient(
70
+ url=self.config.url,
71
+ api_key=config.api_key,
72
+ on_error=config.on_flush_error,
73
+ )
74
+ # Dedicated, reused clients for the limit-check / set-limit
75
+ # endpoints. HttpClient itself pools its underlying connections
76
+ # at the class level (see transport.http_client), but we still
77
+ # avoid constructing a fresh HttpClient object on every single
78
+ # patched LLM call by building these once, here.
79
+ self._limit_check_http = HttpClient(
80
+ url=CHECK_LIMIT_URL,
81
+ api_key=config.api_key,
82
+ )
83
+ self._set_limit_http = HttpClient(
84
+ url=SET_LIMIT_URL,
85
+ api_key=config.api_key,
86
+ )
87
+ # Lazily-created, reused aiohttp session for async_check_limit.
88
+ # Created on first use rather than here, since aiohttp sessions
89
+ # must be created inside a running event loop.
90
+ self._aiohttp_session: Optional["_aiohttp.ClientSession"] = None
91
+
92
+ # In-process cache for limit checks, keyed by tenant_id:
93
+ # {tenant_id: (checked_at_monotonic, limit_exceeded: bool, remaining_tokens: Optional[int])}
94
+ # check_limit() / async_check_limit() are called on *every*
95
+ # patched SDK call (every chat completion, every Gemini call,
96
+ # etc.), so without a cache each call pays a full network
97
+ # round-trip to the billing backend purely to enforce a number
98
+ # that does not change meaningfully within a few seconds. The
99
+ # TTL is intentionally short — this trades a small amount of
100
+ # enforcement staleness for avoiding network latency on every
101
+ # single LLM call.
102
+ self._limit_cache: Dict[str, Tuple[float, bool, Optional[int]]] = {}
103
+ self._limit_cache_lock = threading.Lock()
104
+ self._limit_cache_ttl: float = getattr(config, "limit_check_cache_ttl", 5.0)
105
+
106
+ # Local-only token trip-wire, keyed by tenant_id:
107
+ # {tenant_id: tokens captured since the last real network sync}
108
+ # Lets the SDK catch a runaway burst *within* the TTL window
109
+ # (e.g. a retry loop firing dozens of calls in a few seconds)
110
+ # without waiting for the next backend round-trip. Reset to 0
111
+ # every time _store_limit_result() records a fresh sync, since
112
+ # the remaining_tokens value from that sync already reflects
113
+ # all spend up to that point. Guarded by _limit_cache_lock since
114
+ # the two dicts are always read/written together.
115
+ self._local_token_spend: Dict[str, int] = {}
116
+
117
+ self._queue = EventQueue(
118
+ flush_fn=self._http.send_batch,
119
+ flush_interval=config.flush_interval,
120
+ max_batch_size=config.max_batch_size,
121
+ max_queue_size=config.max_queue_size,
122
+ )
123
+
124
+ atexit.register(self._queue.shutdown)
125
+
126
+ # ── patching ─────────────────────────────────────────────────
127
+
128
+ def patch_all(self) -> "Meter":
129
+ """Install all configured provider patches. Chainable."""
130
+ for name in self.config.patches:
131
+ self.patch(name)
132
+ return self
133
+
134
+ def patch(self, provider: str) -> "Meter":
135
+ """Install a single provider patch by name."""
136
+ cls = PATCH_REGISTRY.get(provider)
137
+ if cls is None:
138
+ raise ValueError(
139
+ f"Unknown provider '{provider}'. " f"Available: {list(PATCH_REGISTRY)}"
140
+ )
141
+ instance = cls(self)
142
+ instance.patch()
143
+ self._patches[provider] = instance
144
+ if self.config.debug:
145
+ status = "active" if instance.is_active else "skipped (not installed)"
146
+ logger.info("token_limit: patch '%s' %s", provider, status)
147
+ return self
148
+
149
+ def unpatch_all(self) -> "Meter":
150
+ """Restore all original SDK methods."""
151
+ for patch in self._patches.values():
152
+ patch.unpatch()
153
+ self._patches.clear()
154
+ return self
155
+
156
+ def unpatch(self, provider: str) -> "Meter":
157
+ patch = self._patches.pop(provider, None)
158
+ if patch:
159
+ patch.unpatch()
160
+ return self
161
+
162
+ # ── tenant context ─────────────────────────────────────────
163
+
164
+ @property
165
+ def current_tenant_id(self) -> Optional[str]:
166
+ return _tenant_id_var.get()
167
+
168
+ @contextmanager
169
+ def for_tenant(self, tenant_id: str) -> Generator[None, None, None]:
170
+ """
171
+ Context manager that tags every event inside the block with tenant_id.
172
+ Thread-safe and async-safe via contextvars.
173
+
174
+ with meter.for_tenant("acme-corp"):
175
+ client.chat.completions.create(...)
176
+ """
177
+ token = _tenant_id_var.set(tenant_id)
178
+ try:
179
+ yield
180
+ finally:
181
+ _tenant_id_var.reset(token)
182
+
183
+ def set_tenant(self, tenant_id: str) -> None:
184
+ """
185
+ Imperative alternative — useful when you can't use a context manager.
186
+ Sets the tenant for the current thread/task until changed.
187
+ """
188
+ _tenant_id_var.set(tenant_id)
189
+
190
+ def _cached_limit_entry(
191
+ self, tenant_id: str
192
+ ) -> Optional[Tuple[bool, Optional[int]]]:
193
+ """
194
+ Return the cached (limit_exceeded, remaining_tokens) pair for
195
+ tenant_id if the cache entry is still within TTL, else None
196
+ (cache miss/expired).
197
+ """
198
+ with self._limit_cache_lock:
199
+ entry = self._limit_cache.get(tenant_id)
200
+ if entry is None:
201
+ return None
202
+ checked_at, limit_exceeded, remaining_tokens = entry
203
+ if time.monotonic() - checked_at > self._limit_cache_ttl:
204
+ return None
205
+ return limit_exceeded, remaining_tokens
206
+
207
+ def _store_limit_result(
208
+ self,
209
+ tenant_id: str,
210
+ limit_exceeded: bool,
211
+ remaining_tokens: Optional[int] = None,
212
+ ) -> None:
213
+ with self._limit_cache_lock:
214
+ self._limit_cache[tenant_id] = (
215
+ time.monotonic(),
216
+ limit_exceeded,
217
+ remaining_tokens,
218
+ )
219
+ # Fresh sync point — remaining_tokens above already accounts
220
+ # for every event up to now, so local accounting restarts.
221
+ self._local_token_spend[tenant_id] = 0
222
+
223
+ def _accumulate_local_tokens(self, data: dict) -> None:
224
+ """
225
+ Local trip-wire, step 1: add this event's token usage to the
226
+ tenant's running total since the last network sync.
227
+
228
+ local_tokens = local_tokens + tokens_this_call
229
+
230
+ Called from _capture() for every event, regardless of provider —
231
+ every patch already funnels through _capture() with total_tokens
232
+ (or input_tokens/output_tokens) in `data`.
233
+ """
234
+ tenant_id = data.get("tenant_id")
235
+ if not tenant_id:
236
+ return
237
+ total_tokens = data.get("total_tokens") or (
238
+ (data.get("input_tokens") or 0) + (data.get("output_tokens") or 0)
239
+ )
240
+ if not total_tokens:
241
+ return
242
+ with self._limit_cache_lock:
243
+ self._local_token_spend[tenant_id] = self._local_token_spend.get(
244
+ tenant_id, 0
245
+ ) + int(total_tokens)
246
+
247
+ def check_limit(self) -> None:
248
+ """
249
+ Check current usage vs threshold for a tenant (sync).
250
+ Raises LimitExceededException if the tenant's usage limit is reached.
251
+ No-ops silently when there is no active tenant context.
252
+
253
+ Results are cached per-tenant for `limit_check_cache_ttl` seconds
254
+ (default 5s) so high-throughput callers don't pay a network
255
+ round-trip to the billing backend on every single LLM call.
256
+ """
257
+ tenant_id = self.current_tenant_id
258
+ if not tenant_id:
259
+ return
260
+ # cache sync -> first
261
+ cached = self._cached_limit_entry(tenant_id)
262
+ if cached is not None:
263
+ limit_exceeded, remaining_tokens = cached
264
+ if limit_exceeded:
265
+ raise LimitExceededException()
266
+ return
267
+ # realy sync
268
+ response = self._limit_check_http._get(params={"tenant_id": tenant_id})
269
+ limit_exceeded = bool(response.get("limit_exceeded"))
270
+ remaining_tokens = response.get("remaining_tokens")
271
+ self._store_limit_result(tenant_id, limit_exceeded, remaining_tokens)
272
+ if limit_exceeded:
273
+ raise LimitExceededException()
274
+
275
+ async def async_check_limit(self) -> None:
276
+ """
277
+ Async variant of check_limit — used by _make_async_wrapper so the
278
+ event loop is never blocked by a synchronous HTTP call.
279
+ Raises LimitExceededException if the tenant's usage limit is reached.
280
+ No-ops silently when there is no active tenant context.
281
+
282
+ Shares the same per-tenant TTL cache as check_limit(), so a mix
283
+ of sync and async patched calls for the same tenant still only
284
+ hits the network once per TTL window.
285
+
286
+ Requires ``aiohttp`` to be installed. If it is absent the method
287
+ falls back to the synchronous check (acceptable for low-concurrency
288
+ use-cases, but will block the event loop on each call).
289
+ """
290
+ tenant_id = self.current_tenant_id
291
+ if not tenant_id:
292
+ return
293
+
294
+ cached = self._cached_limit_entry(tenant_id)
295
+ if cached is not None:
296
+ limit_exceeded, remaining_tokens = cached
297
+ if limit_exceeded:
298
+ raise LimitExceededException()
299
+ return
300
+
301
+ if _aiohttp is None:
302
+ # Graceful degradation: fall back to sync check with a warning.
303
+ logger.warning(
304
+ "token_limit: aiohttp is not installed; falling back to "
305
+ "synchronous limit check inside async wrapper. "
306
+ "Install aiohttp to avoid blocking the event loop."
307
+ )
308
+ self.check_limit()
309
+ return
310
+
311
+ headers = {
312
+ "Authorization": f"Bearer {self.config.api_key}",
313
+ "Content-Type": "application/json",
314
+ }
315
+ params = {"tenant_id": tenant_id}
316
+ timeout = _aiohttp.ClientTimeout(total=_TIMEOUT)
317
+ session = self._get_aiohttp_session()
318
+ async with session.get(
319
+ CHECK_LIMIT_URL, headers=headers, params=params, timeout=timeout
320
+ ) as resp:
321
+ resp.raise_for_status()
322
+ data = await resp.json()
323
+ limit_exceeded = bool(data.get("limit_exceeded"))
324
+ remaining_tokens = data.get("remaining_tokens")
325
+ self._store_limit_result(tenant_id, limit_exceeded, remaining_tokens)
326
+ if limit_exceeded:
327
+ raise LimitExceededException()
328
+
329
+ def _get_aiohttp_session(self) -> "_aiohttp.ClientSession":
330
+ """
331
+ Lazily create and reuse a single aiohttp.ClientSession across all
332
+ async_check_limit() calls, instead of opening and tearing down a
333
+ brand-new session on every call. Sessions must be created from
334
+ inside a running event loop, so this happens on first use rather
335
+ than in __init__.
336
+ """
337
+ if self._aiohttp_session is None or self._aiohttp_session.closed:
338
+ self._aiohttp_session = _aiohttp.ClientSession()
339
+ return self._aiohttp_session
340
+
341
+ def set_limit(
342
+ self,
343
+ tenant_id: str,
344
+ limit_usd: float,
345
+ frequency: str = PER_MONTH,
346
+ effective_date=datetime.now(tz=timezone.utc).isoformat(),
347
+ ) -> None:
348
+ """
349
+ Set (or update) the monthly spend limit for a tenant.
350
+ Calls the backend to upsert the threshold record.
351
+ """
352
+ self._set_limit_http._post(
353
+ payload={
354
+ "tenant_id": tenant_id,
355
+ "limit_usd": limit_usd,
356
+ "frequency": frequency,
357
+ "effective_date": effective_date,
358
+ },
359
+ )
360
+ # Invalidate any cached limit-check result so the new threshold
361
+ # is honored immediately rather than after the TTL expires.
362
+ with self._limit_cache_lock:
363
+ self._limit_cache.pop(tenant_id, None)
364
+ self._local_token_spend.pop(tenant_id, None)
365
+
366
+ # ── openrouter ──────────────────────────────────────────
367
+ def register_openrouter_client(self, client) -> None:
368
+ """Instrument an existing openai.OpenAI / AsyncOpenAI instance for OpenRouter."""
369
+ if "openrouter" not in self._patches:
370
+ self.patch("openrouter") # creates & stores in self._patches
371
+ self._patches["openrouter"].patch_instance(client)
372
+
373
+ def openrouter_client(self, api_key: str, **kwargs):
374
+ """Factory: create an openai.OpenAI client for OpenRouter + register it."""
375
+ import openai
376
+ client = openai.OpenAI(
377
+ base_url="https://openrouter.ai/api/v1",
378
+ api_key=api_key,
379
+ **kwargs,
380
+ )
381
+ self.register_openrouter_client(client)
382
+ return client
383
+
384
+ def async_openrouter_client(self, api_key: str, **kwargs):
385
+ """Factory: create an openai.AsyncOpenAI client for OpenRouter + register it."""
386
+ import openai
387
+ client = openai.AsyncOpenAI(
388
+ base_url="https://openrouter.ai/api/v1",
389
+ api_key=api_key,
390
+ **kwargs,
391
+ )
392
+ self.register_openrouter_client(client)
393
+ return client
394
+ # ── manual tracking ──────────────────────────────────────────
395
+
396
+ def track_manually(
397
+ self,
398
+ *,
399
+ tenant_id: Optional[str] = None,
400
+ provider: str,
401
+ model: str,
402
+ input_tokens: int = 0,
403
+ output_tokens: int = 0,
404
+ **extra,
405
+ ) -> LLMEvent:
406
+ """
407
+ Manually record an event (for providers not yet patched, or custom logic).
408
+ meter.track_manually(
409
+ # necessary fields
410
+ provider="custom",
411
+ model="my-model-v1",
412
+ input_tokens=512,
413
+ output_tokens=128,
414
+ tenant_id="user_track_manually",
415
+ )
416
+ Keep in mind: these fields being sending through 'track_manually',
417
+ should have relevant values of 'Model Pricing' in dashboard for calculating cost events;
418
+ otherwise, cost won't be calculated but event is recorded.
419
+ """
420
+ event_data = {
421
+ "provider": provider,
422
+ "model": model,
423
+ "input_tokens": input_tokens,
424
+ "output_tokens": output_tokens,
425
+ "total_tokens": input_tokens + output_tokens,
426
+ "tenant_id": tenant_id or self.current_tenant_id,
427
+ **extra,
428
+ }
429
+ return self._capture(event_data)
430
+
431
+ # ── context manager support ──────────────────────────────────
432
+
433
+ def __enter__(self) -> "Meter":
434
+ return self
435
+
436
+ def __exit__(self, *_) -> None:
437
+ self.unpatch_all()
438
+ self._queue.shutdown()
439
+ if self._aiohttp_session is not None and not self._aiohttp_session.closed:
440
+ # __exit__ is sync; aiohttp sessions close async, but we're
441
+ # best-effort here since some loops may already be closed
442
+ # by the time __exit__ runs. Swallow errors rather than
443
+ # raise during cleanup.
444
+ try:
445
+ import asyncio
446
+
447
+ loop = asyncio.get_event_loop()
448
+ if loop.is_running():
449
+ loop.create_task(self._aiohttp_session.close())
450
+ else:
451
+ loop.run_until_complete(self._aiohttp_session.close())
452
+ except Exception:
453
+ pass
454
+
455
+ # ── internal ─────────────────────────────────────────────────
456
+
457
+ def _capture(self, data: dict) -> LLMEvent:
458
+ """Build an LLMEvent and enqueue it. Called by every patch."""
459
+ self._accumulate_local_tokens(data)
460
+ known_fields = LLMEvent.__dataclass_fields__
461
+ if self.config.debug:
462
+ dropped = [k for k in data if k not in known_fields]
463
+ if dropped:
464
+ logger.debug(
465
+ "token_limit: dropping unknown event field(s) %s "
466
+ "(not present on LLMEvent) — check for typos in "
467
+ "track_manually()/extractor kwargs.",
468
+ dropped,
469
+ )
470
+ event = LLMEvent(**{k: v for k, v in data.items() if k in known_fields})
471
+ if self.config.debug:
472
+ logger.debug("token_limit captured: %s", event)
473
+ if self.config.on_event:
474
+ try:
475
+ self.config.on_event(event)
476
+ except Exception:
477
+ pass
478
+ self._queue.push(event)
479
+ return event
@@ -0,0 +1,14 @@
1
+ from token_limit.patches.deepseek_patch import DeepSeekPatch
2
+
3
+ from .anthropic_patch import AnthropicPatch
4
+ from .openai_patch import OpenAIPatch
5
+ from .google_patch import GooglePatch
6
+ from .openrouter_patch import OpenRouterPatch
7
+
8
+ PATCH_REGISTRY = {
9
+ "anthropic": AnthropicPatch,
10
+ "openai": OpenAIPatch,
11
+ "google": GooglePatch,
12
+ "openrouter": OpenRouterPatch,
13
+ "deepseek": DeepSeekPatch
14
+ }