token-limit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- token_limit/__init__.py +8 -0
- token_limit/config.py +35 -0
- token_limit/exceptions.py +5 -0
- token_limit/meter.py +479 -0
- token_limit/patches/__init__.py +14 -0
- token_limit/patches/_base_patch.py +397 -0
- token_limit/patches/anthropic_patch.py +627 -0
- token_limit/patches/deepseek_patch.py +707 -0
- token_limit/patches/google_patch.py +677 -0
- token_limit/patches/openai_patch.py +1199 -0
- token_limit/patches/openrouter_patch.py +400 -0
- token_limit/transport/http_client.py +311 -0
- token_limit/transport/queue.py +95 -0
- token_limit/types.py +92 -0
- token_limit-0.1.0.dist-info/METADATA +532 -0
- token_limit-0.1.0.dist-info/RECORD +18 -0
- token_limit-0.1.0.dist-info/WHEEL +4 -0
- token_limit-0.1.0.dist-info/licenses/LICENSE +35 -0
token_limit/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
__version__ = '0.1.0'
|
|
2
|
+
|
|
3
|
+
from token_limit.meter import Meter, PER_MONTH, PER_DAY
|
|
4
|
+
from token_limit.exceptions import LimitExceededException
|
|
5
|
+
|
|
6
|
+
from token_limit.meter import MeterConfig
|
|
7
|
+
|
|
8
|
+
__all__ = ["Meter", "PER_MONTH", "PER_DAY", "LimitExceededException", "MeterConfig"]
|
token_limit/config.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Callable, List, Optional
|
|
3
|
+
|
|
4
|
+
from .transport.http_client import INGEST_URL
|
|
5
|
+
from .types import LLMEvent
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class MeterConfig:
|
|
10
|
+
# ── required ──────────────────────────────────────────────────
|
|
11
|
+
api_key: str # your SaaS API key
|
|
12
|
+
url: str = INGEST_URL # POST endpoint on your backend
|
|
13
|
+
|
|
14
|
+
# ── batching ──────────────────────────────────────────────────
|
|
15
|
+
flush_interval: float = 5.0 # seconds between auto-flushes
|
|
16
|
+
max_batch_size: int = 50 # flush early if queue hits this
|
|
17
|
+
max_queue_size: int = 1000 # drop oldest if queue overflows
|
|
18
|
+
|
|
19
|
+
# ── limit checks ─────────────────────────────────────────────
|
|
20
|
+
limit_check_cache_ttl: float = (
|
|
21
|
+
5.0 # seconds a check_limit() result is cached per tenant
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# ── behaviour ─────────────────────────────────────────────────
|
|
25
|
+
raise_on_error: bool = False # if True, re-raise patch exceptions
|
|
26
|
+
debug: bool = False # print captured events to stdout
|
|
27
|
+
|
|
28
|
+
# ── hooks (optional) ─────────────────────────────────────────
|
|
29
|
+
on_event: Optional[Callable[[LLMEvent], None]] = None # called after capture
|
|
30
|
+
on_flush_error: Optional[Callable[[Exception], None]] = None # transport errors
|
|
31
|
+
|
|
32
|
+
# ── patches to install (default = all available) ──────────────
|
|
33
|
+
patches: List[str] = field(
|
|
34
|
+
default_factory=lambda: ["openai", "anthropic", "deepseek", "google","openrouter"]
|
|
35
|
+
)
|
token_limit/meter.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
"""
|
|
2
|
+
token_limit.meter
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~
|
|
4
|
+
The single entry point developers import and initialise.
|
|
5
|
+
|
|
6
|
+
Usage
|
|
7
|
+
-----
|
|
8
|
+
from token_limit import Meter, MeterConfig
|
|
9
|
+
|
|
10
|
+
meter = Meter(MeterConfig(
|
|
11
|
+
api_key="sk-...",
|
|
12
|
+
url="https://api.yoursaas.com/v1/ingest",
|
|
13
|
+
))
|
|
14
|
+
meter.patch_all() # ← one line, all providers auto-instrumented
|
|
15
|
+
|
|
16
|
+
# Per-request tenant override:
|
|
17
|
+
with meter.for_tenant("stripe-inc"):
|
|
18
|
+
response = openai_client.chat.completions.create(...)
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import atexit
|
|
24
|
+
from datetime import datetime, timezone
|
|
25
|
+
import logging
|
|
26
|
+
import threading
|
|
27
|
+
import time
|
|
28
|
+
from contextlib import contextmanager
|
|
29
|
+
from contextvars import ContextVar
|
|
30
|
+
from typing import Dict, Generator, List, Optional, Tuple
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
import aiohttp as _aiohttp # optional — only needed for async limit checks
|
|
34
|
+
except ImportError: # pragma: no cover
|
|
35
|
+
_aiohttp = None # type: ignore[assignment]
|
|
36
|
+
|
|
37
|
+
from .exceptions import LimitExceededException
|
|
38
|
+
|
|
39
|
+
from .config import MeterConfig
|
|
40
|
+
from .patches import PATCH_REGISTRY
|
|
41
|
+
from .patches._base_patch import BasePatch
|
|
42
|
+
from .transport.http_client import CHECK_LIMIT_URL, HttpClient, SET_LIMIT_URL, _TIMEOUT
|
|
43
|
+
from .transport.queue import EventQueue
|
|
44
|
+
from .types import LLMEvent
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger("token_limit")
|
|
47
|
+
|
|
48
|
+
# Thread-/async-safe tenant_id override
|
|
49
|
+
_tenant_id_var: ContextVar[Optional[str]] = ContextVar("tenant_id", default=None)
|
|
50
|
+
|
|
51
|
+
PER_DAY = "per_day"
|
|
52
|
+
PER_MONTH = "per_month"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Meter:
|
|
56
|
+
"""
|
|
57
|
+
Central SDK object. One instance per application.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
config : MeterConfig
|
|
62
|
+
All configuration in one place.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, config: MeterConfig) -> None:
|
|
66
|
+
self.config = config # meter config
|
|
67
|
+
self._patches: Dict[str, BasePatch] = {}
|
|
68
|
+
|
|
69
|
+
self._http = HttpClient(
|
|
70
|
+
url=self.config.url,
|
|
71
|
+
api_key=config.api_key,
|
|
72
|
+
on_error=config.on_flush_error,
|
|
73
|
+
)
|
|
74
|
+
# Dedicated, reused clients for the limit-check / set-limit
|
|
75
|
+
# endpoints. HttpClient itself pools its underlying connections
|
|
76
|
+
# at the class level (see transport.http_client), but we still
|
|
77
|
+
# avoid constructing a fresh HttpClient object on every single
|
|
78
|
+
# patched LLM call by building these once, here.
|
|
79
|
+
self._limit_check_http = HttpClient(
|
|
80
|
+
url=CHECK_LIMIT_URL,
|
|
81
|
+
api_key=config.api_key,
|
|
82
|
+
)
|
|
83
|
+
self._set_limit_http = HttpClient(
|
|
84
|
+
url=SET_LIMIT_URL,
|
|
85
|
+
api_key=config.api_key,
|
|
86
|
+
)
|
|
87
|
+
# Lazily-created, reused aiohttp session for async_check_limit.
|
|
88
|
+
# Created on first use rather than here, since aiohttp sessions
|
|
89
|
+
# must be created inside a running event loop.
|
|
90
|
+
self._aiohttp_session: Optional["_aiohttp.ClientSession"] = None
|
|
91
|
+
|
|
92
|
+
# In-process cache for limit checks, keyed by tenant_id:
|
|
93
|
+
# {tenant_id: (checked_at_monotonic, limit_exceeded: bool, remaining_tokens: Optional[int])}
|
|
94
|
+
# check_limit() / async_check_limit() are called on *every*
|
|
95
|
+
# patched SDK call (every chat completion, every Gemini call,
|
|
96
|
+
# etc.), so without a cache each call pays a full network
|
|
97
|
+
# round-trip to the billing backend purely to enforce a number
|
|
98
|
+
# that does not change meaningfully within a few seconds. The
|
|
99
|
+
# TTL is intentionally short — this trades a small amount of
|
|
100
|
+
# enforcement staleness for avoiding network latency on every
|
|
101
|
+
# single LLM call.
|
|
102
|
+
self._limit_cache: Dict[str, Tuple[float, bool, Optional[int]]] = {}
|
|
103
|
+
self._limit_cache_lock = threading.Lock()
|
|
104
|
+
self._limit_cache_ttl: float = getattr(config, "limit_check_cache_ttl", 5.0)
|
|
105
|
+
|
|
106
|
+
# Local-only token trip-wire, keyed by tenant_id:
|
|
107
|
+
# {tenant_id: tokens captured since the last real network sync}
|
|
108
|
+
# Lets the SDK catch a runaway burst *within* the TTL window
|
|
109
|
+
# (e.g. a retry loop firing dozens of calls in a few seconds)
|
|
110
|
+
# without waiting for the next backend round-trip. Reset to 0
|
|
111
|
+
# every time _store_limit_result() records a fresh sync, since
|
|
112
|
+
# the remaining_tokens value from that sync already reflects
|
|
113
|
+
# all spend up to that point. Guarded by _limit_cache_lock since
|
|
114
|
+
# the two dicts are always read/written together.
|
|
115
|
+
self._local_token_spend: Dict[str, int] = {}
|
|
116
|
+
|
|
117
|
+
self._queue = EventQueue(
|
|
118
|
+
flush_fn=self._http.send_batch,
|
|
119
|
+
flush_interval=config.flush_interval,
|
|
120
|
+
max_batch_size=config.max_batch_size,
|
|
121
|
+
max_queue_size=config.max_queue_size,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
atexit.register(self._queue.shutdown)
|
|
125
|
+
|
|
126
|
+
# ── patching ─────────────────────────────────────────────────
|
|
127
|
+
|
|
128
|
+
def patch_all(self) -> "Meter":
|
|
129
|
+
"""Install all configured provider patches. Chainable."""
|
|
130
|
+
for name in self.config.patches:
|
|
131
|
+
self.patch(name)
|
|
132
|
+
return self
|
|
133
|
+
|
|
134
|
+
def patch(self, provider: str) -> "Meter":
|
|
135
|
+
"""Install a single provider patch by name."""
|
|
136
|
+
cls = PATCH_REGISTRY.get(provider)
|
|
137
|
+
if cls is None:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Unknown provider '{provider}'. " f"Available: {list(PATCH_REGISTRY)}"
|
|
140
|
+
)
|
|
141
|
+
instance = cls(self)
|
|
142
|
+
instance.patch()
|
|
143
|
+
self._patches[provider] = instance
|
|
144
|
+
if self.config.debug:
|
|
145
|
+
status = "active" if instance.is_active else "skipped (not installed)"
|
|
146
|
+
logger.info("token_limit: patch '%s' %s", provider, status)
|
|
147
|
+
return self
|
|
148
|
+
|
|
149
|
+
def unpatch_all(self) -> "Meter":
|
|
150
|
+
"""Restore all original SDK methods."""
|
|
151
|
+
for patch in self._patches.values():
|
|
152
|
+
patch.unpatch()
|
|
153
|
+
self._patches.clear()
|
|
154
|
+
return self
|
|
155
|
+
|
|
156
|
+
def unpatch(self, provider: str) -> "Meter":
|
|
157
|
+
patch = self._patches.pop(provider, None)
|
|
158
|
+
if patch:
|
|
159
|
+
patch.unpatch()
|
|
160
|
+
return self
|
|
161
|
+
|
|
162
|
+
# ── tenant context ─────────────────────────────────────────
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def current_tenant_id(self) -> Optional[str]:
|
|
166
|
+
return _tenant_id_var.get()
|
|
167
|
+
|
|
168
|
+
@contextmanager
|
|
169
|
+
def for_tenant(self, tenant_id: str) -> Generator[None, None, None]:
|
|
170
|
+
"""
|
|
171
|
+
Context manager that tags every event inside the block with tenant_id.
|
|
172
|
+
Thread-safe and async-safe via contextvars.
|
|
173
|
+
|
|
174
|
+
with meter.for_tenant("acme-corp"):
|
|
175
|
+
client.chat.completions.create(...)
|
|
176
|
+
"""
|
|
177
|
+
token = _tenant_id_var.set(tenant_id)
|
|
178
|
+
try:
|
|
179
|
+
yield
|
|
180
|
+
finally:
|
|
181
|
+
_tenant_id_var.reset(token)
|
|
182
|
+
|
|
183
|
+
def set_tenant(self, tenant_id: str) -> None:
|
|
184
|
+
"""
|
|
185
|
+
Imperative alternative — useful when you can't use a context manager.
|
|
186
|
+
Sets the tenant for the current thread/task until changed.
|
|
187
|
+
"""
|
|
188
|
+
_tenant_id_var.set(tenant_id)
|
|
189
|
+
|
|
190
|
+
def _cached_limit_entry(
|
|
191
|
+
self, tenant_id: str
|
|
192
|
+
) -> Optional[Tuple[bool, Optional[int]]]:
|
|
193
|
+
"""
|
|
194
|
+
Return the cached (limit_exceeded, remaining_tokens) pair for
|
|
195
|
+
tenant_id if the cache entry is still within TTL, else None
|
|
196
|
+
(cache miss/expired).
|
|
197
|
+
"""
|
|
198
|
+
with self._limit_cache_lock:
|
|
199
|
+
entry = self._limit_cache.get(tenant_id)
|
|
200
|
+
if entry is None:
|
|
201
|
+
return None
|
|
202
|
+
checked_at, limit_exceeded, remaining_tokens = entry
|
|
203
|
+
if time.monotonic() - checked_at > self._limit_cache_ttl:
|
|
204
|
+
return None
|
|
205
|
+
return limit_exceeded, remaining_tokens
|
|
206
|
+
|
|
207
|
+
def _store_limit_result(
|
|
208
|
+
self,
|
|
209
|
+
tenant_id: str,
|
|
210
|
+
limit_exceeded: bool,
|
|
211
|
+
remaining_tokens: Optional[int] = None,
|
|
212
|
+
) -> None:
|
|
213
|
+
with self._limit_cache_lock:
|
|
214
|
+
self._limit_cache[tenant_id] = (
|
|
215
|
+
time.monotonic(),
|
|
216
|
+
limit_exceeded,
|
|
217
|
+
remaining_tokens,
|
|
218
|
+
)
|
|
219
|
+
# Fresh sync point — remaining_tokens above already accounts
|
|
220
|
+
# for every event up to now, so local accounting restarts.
|
|
221
|
+
self._local_token_spend[tenant_id] = 0
|
|
222
|
+
|
|
223
|
+
def _accumulate_local_tokens(self, data: dict) -> None:
|
|
224
|
+
"""
|
|
225
|
+
Local trip-wire, step 1: add this event's token usage to the
|
|
226
|
+
tenant's running total since the last network sync.
|
|
227
|
+
|
|
228
|
+
local_tokens = local_tokens + tokens_this_call
|
|
229
|
+
|
|
230
|
+
Called from _capture() for every event, regardless of provider —
|
|
231
|
+
every patch already funnels through _capture() with total_tokens
|
|
232
|
+
(or input_tokens/output_tokens) in `data`.
|
|
233
|
+
"""
|
|
234
|
+
tenant_id = data.get("tenant_id")
|
|
235
|
+
if not tenant_id:
|
|
236
|
+
return
|
|
237
|
+
total_tokens = data.get("total_tokens") or (
|
|
238
|
+
(data.get("input_tokens") or 0) + (data.get("output_tokens") or 0)
|
|
239
|
+
)
|
|
240
|
+
if not total_tokens:
|
|
241
|
+
return
|
|
242
|
+
with self._limit_cache_lock:
|
|
243
|
+
self._local_token_spend[tenant_id] = self._local_token_spend.get(
|
|
244
|
+
tenant_id, 0
|
|
245
|
+
) + int(total_tokens)
|
|
246
|
+
|
|
247
|
+
def check_limit(self) -> None:
|
|
248
|
+
"""
|
|
249
|
+
Check current usage vs threshold for a tenant (sync).
|
|
250
|
+
Raises LimitExceededException if the tenant's usage limit is reached.
|
|
251
|
+
No-ops silently when there is no active tenant context.
|
|
252
|
+
|
|
253
|
+
Results are cached per-tenant for `limit_check_cache_ttl` seconds
|
|
254
|
+
(default 5s) so high-throughput callers don't pay a network
|
|
255
|
+
round-trip to the billing backend on every single LLM call.
|
|
256
|
+
"""
|
|
257
|
+
tenant_id = self.current_tenant_id
|
|
258
|
+
if not tenant_id:
|
|
259
|
+
return
|
|
260
|
+
# cache sync -> first
|
|
261
|
+
cached = self._cached_limit_entry(tenant_id)
|
|
262
|
+
if cached is not None:
|
|
263
|
+
limit_exceeded, remaining_tokens = cached
|
|
264
|
+
if limit_exceeded:
|
|
265
|
+
raise LimitExceededException()
|
|
266
|
+
return
|
|
267
|
+
# realy sync
|
|
268
|
+
response = self._limit_check_http._get(params={"tenant_id": tenant_id})
|
|
269
|
+
limit_exceeded = bool(response.get("limit_exceeded"))
|
|
270
|
+
remaining_tokens = response.get("remaining_tokens")
|
|
271
|
+
self._store_limit_result(tenant_id, limit_exceeded, remaining_tokens)
|
|
272
|
+
if limit_exceeded:
|
|
273
|
+
raise LimitExceededException()
|
|
274
|
+
|
|
275
|
+
async def async_check_limit(self) -> None:
|
|
276
|
+
"""
|
|
277
|
+
Async variant of check_limit — used by _make_async_wrapper so the
|
|
278
|
+
event loop is never blocked by a synchronous HTTP call.
|
|
279
|
+
Raises LimitExceededException if the tenant's usage limit is reached.
|
|
280
|
+
No-ops silently when there is no active tenant context.
|
|
281
|
+
|
|
282
|
+
Shares the same per-tenant TTL cache as check_limit(), so a mix
|
|
283
|
+
of sync and async patched calls for the same tenant still only
|
|
284
|
+
hits the network once per TTL window.
|
|
285
|
+
|
|
286
|
+
Requires ``aiohttp`` to be installed. If it is absent the method
|
|
287
|
+
falls back to the synchronous check (acceptable for low-concurrency
|
|
288
|
+
use-cases, but will block the event loop on each call).
|
|
289
|
+
"""
|
|
290
|
+
tenant_id = self.current_tenant_id
|
|
291
|
+
if not tenant_id:
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
cached = self._cached_limit_entry(tenant_id)
|
|
295
|
+
if cached is not None:
|
|
296
|
+
limit_exceeded, remaining_tokens = cached
|
|
297
|
+
if limit_exceeded:
|
|
298
|
+
raise LimitExceededException()
|
|
299
|
+
return
|
|
300
|
+
|
|
301
|
+
if _aiohttp is None:
|
|
302
|
+
# Graceful degradation: fall back to sync check with a warning.
|
|
303
|
+
logger.warning(
|
|
304
|
+
"token_limit: aiohttp is not installed; falling back to "
|
|
305
|
+
"synchronous limit check inside async wrapper. "
|
|
306
|
+
"Install aiohttp to avoid blocking the event loop."
|
|
307
|
+
)
|
|
308
|
+
self.check_limit()
|
|
309
|
+
return
|
|
310
|
+
|
|
311
|
+
headers = {
|
|
312
|
+
"Authorization": f"Bearer {self.config.api_key}",
|
|
313
|
+
"Content-Type": "application/json",
|
|
314
|
+
}
|
|
315
|
+
params = {"tenant_id": tenant_id}
|
|
316
|
+
timeout = _aiohttp.ClientTimeout(total=_TIMEOUT)
|
|
317
|
+
session = self._get_aiohttp_session()
|
|
318
|
+
async with session.get(
|
|
319
|
+
CHECK_LIMIT_URL, headers=headers, params=params, timeout=timeout
|
|
320
|
+
) as resp:
|
|
321
|
+
resp.raise_for_status()
|
|
322
|
+
data = await resp.json()
|
|
323
|
+
limit_exceeded = bool(data.get("limit_exceeded"))
|
|
324
|
+
remaining_tokens = data.get("remaining_tokens")
|
|
325
|
+
self._store_limit_result(tenant_id, limit_exceeded, remaining_tokens)
|
|
326
|
+
if limit_exceeded:
|
|
327
|
+
raise LimitExceededException()
|
|
328
|
+
|
|
329
|
+
def _get_aiohttp_session(self) -> "_aiohttp.ClientSession":
|
|
330
|
+
"""
|
|
331
|
+
Lazily create and reuse a single aiohttp.ClientSession across all
|
|
332
|
+
async_check_limit() calls, instead of opening and tearing down a
|
|
333
|
+
brand-new session on every call. Sessions must be created from
|
|
334
|
+
inside a running event loop, so this happens on first use rather
|
|
335
|
+
than in __init__.
|
|
336
|
+
"""
|
|
337
|
+
if self._aiohttp_session is None or self._aiohttp_session.closed:
|
|
338
|
+
self._aiohttp_session = _aiohttp.ClientSession()
|
|
339
|
+
return self._aiohttp_session
|
|
340
|
+
|
|
341
|
+
def set_limit(
|
|
342
|
+
self,
|
|
343
|
+
tenant_id: str,
|
|
344
|
+
limit_usd: float,
|
|
345
|
+
frequency: str = PER_MONTH,
|
|
346
|
+
effective_date=datetime.now(tz=timezone.utc).isoformat(),
|
|
347
|
+
) -> None:
|
|
348
|
+
"""
|
|
349
|
+
Set (or update) the monthly spend limit for a tenant.
|
|
350
|
+
Calls the backend to upsert the threshold record.
|
|
351
|
+
"""
|
|
352
|
+
self._set_limit_http._post(
|
|
353
|
+
payload={
|
|
354
|
+
"tenant_id": tenant_id,
|
|
355
|
+
"limit_usd": limit_usd,
|
|
356
|
+
"frequency": frequency,
|
|
357
|
+
"effective_date": effective_date,
|
|
358
|
+
},
|
|
359
|
+
)
|
|
360
|
+
# Invalidate any cached limit-check result so the new threshold
|
|
361
|
+
# is honored immediately rather than after the TTL expires.
|
|
362
|
+
with self._limit_cache_lock:
|
|
363
|
+
self._limit_cache.pop(tenant_id, None)
|
|
364
|
+
self._local_token_spend.pop(tenant_id, None)
|
|
365
|
+
|
|
366
|
+
# ── openrouter ──────────────────────────────────────────
|
|
367
|
+
def register_openrouter_client(self, client) -> None:
|
|
368
|
+
"""Instrument an existing openai.OpenAI / AsyncOpenAI instance for OpenRouter."""
|
|
369
|
+
if "openrouter" not in self._patches:
|
|
370
|
+
self.patch("openrouter") # creates & stores in self._patches
|
|
371
|
+
self._patches["openrouter"].patch_instance(client)
|
|
372
|
+
|
|
373
|
+
def openrouter_client(self, api_key: str, **kwargs):
|
|
374
|
+
"""Factory: create an openai.OpenAI client for OpenRouter + register it."""
|
|
375
|
+
import openai
|
|
376
|
+
client = openai.OpenAI(
|
|
377
|
+
base_url="https://openrouter.ai/api/v1",
|
|
378
|
+
api_key=api_key,
|
|
379
|
+
**kwargs,
|
|
380
|
+
)
|
|
381
|
+
self.register_openrouter_client(client)
|
|
382
|
+
return client
|
|
383
|
+
|
|
384
|
+
def async_openrouter_client(self, api_key: str, **kwargs):
|
|
385
|
+
"""Factory: create an openai.AsyncOpenAI client for OpenRouter + register it."""
|
|
386
|
+
import openai
|
|
387
|
+
client = openai.AsyncOpenAI(
|
|
388
|
+
base_url="https://openrouter.ai/api/v1",
|
|
389
|
+
api_key=api_key,
|
|
390
|
+
**kwargs,
|
|
391
|
+
)
|
|
392
|
+
self.register_openrouter_client(client)
|
|
393
|
+
return client
|
|
394
|
+
# ── manual tracking ──────────────────────────────────────────
|
|
395
|
+
|
|
396
|
+
def track_manually(
|
|
397
|
+
self,
|
|
398
|
+
*,
|
|
399
|
+
tenant_id: Optional[str] = None,
|
|
400
|
+
provider: str,
|
|
401
|
+
model: str,
|
|
402
|
+
input_tokens: int = 0,
|
|
403
|
+
output_tokens: int = 0,
|
|
404
|
+
**extra,
|
|
405
|
+
) -> LLMEvent:
|
|
406
|
+
"""
|
|
407
|
+
Manually record an event (for providers not yet patched, or custom logic).
|
|
408
|
+
meter.track_manually(
|
|
409
|
+
# necessary fields
|
|
410
|
+
provider="custom",
|
|
411
|
+
model="my-model-v1",
|
|
412
|
+
input_tokens=512,
|
|
413
|
+
output_tokens=128,
|
|
414
|
+
tenant_id="user_track_manually",
|
|
415
|
+
)
|
|
416
|
+
Keep in mind: these fields being sending through 'track_manually',
|
|
417
|
+
should have relevant values of 'Model Pricing' in dashboard for calculating cost events;
|
|
418
|
+
otherwise, cost won't be calculated but event is recorded.
|
|
419
|
+
"""
|
|
420
|
+
event_data = {
|
|
421
|
+
"provider": provider,
|
|
422
|
+
"model": model,
|
|
423
|
+
"input_tokens": input_tokens,
|
|
424
|
+
"output_tokens": output_tokens,
|
|
425
|
+
"total_tokens": input_tokens + output_tokens,
|
|
426
|
+
"tenant_id": tenant_id or self.current_tenant_id,
|
|
427
|
+
**extra,
|
|
428
|
+
}
|
|
429
|
+
return self._capture(event_data)
|
|
430
|
+
|
|
431
|
+
# ── context manager support ──────────────────────────────────
|
|
432
|
+
|
|
433
|
+
def __enter__(self) -> "Meter":
|
|
434
|
+
return self
|
|
435
|
+
|
|
436
|
+
def __exit__(self, *_) -> None:
|
|
437
|
+
self.unpatch_all()
|
|
438
|
+
self._queue.shutdown()
|
|
439
|
+
if self._aiohttp_session is not None and not self._aiohttp_session.closed:
|
|
440
|
+
# __exit__ is sync; aiohttp sessions close async, but we're
|
|
441
|
+
# best-effort here since some loops may already be closed
|
|
442
|
+
# by the time __exit__ runs. Swallow errors rather than
|
|
443
|
+
# raise during cleanup.
|
|
444
|
+
try:
|
|
445
|
+
import asyncio
|
|
446
|
+
|
|
447
|
+
loop = asyncio.get_event_loop()
|
|
448
|
+
if loop.is_running():
|
|
449
|
+
loop.create_task(self._aiohttp_session.close())
|
|
450
|
+
else:
|
|
451
|
+
loop.run_until_complete(self._aiohttp_session.close())
|
|
452
|
+
except Exception:
|
|
453
|
+
pass
|
|
454
|
+
|
|
455
|
+
# ── internal ─────────────────────────────────────────────────
|
|
456
|
+
|
|
457
|
+
def _capture(self, data: dict) -> LLMEvent:
|
|
458
|
+
"""Build an LLMEvent and enqueue it. Called by every patch."""
|
|
459
|
+
self._accumulate_local_tokens(data)
|
|
460
|
+
known_fields = LLMEvent.__dataclass_fields__
|
|
461
|
+
if self.config.debug:
|
|
462
|
+
dropped = [k for k in data if k not in known_fields]
|
|
463
|
+
if dropped:
|
|
464
|
+
logger.debug(
|
|
465
|
+
"token_limit: dropping unknown event field(s) %s "
|
|
466
|
+
"(not present on LLMEvent) — check for typos in "
|
|
467
|
+
"track_manually()/extractor kwargs.",
|
|
468
|
+
dropped,
|
|
469
|
+
)
|
|
470
|
+
event = LLMEvent(**{k: v for k, v in data.items() if k in known_fields})
|
|
471
|
+
if self.config.debug:
|
|
472
|
+
logger.debug("token_limit captured: %s", event)
|
|
473
|
+
if self.config.on_event:
|
|
474
|
+
try:
|
|
475
|
+
self.config.on_event(event)
|
|
476
|
+
except Exception:
|
|
477
|
+
pass
|
|
478
|
+
self._queue.push(event)
|
|
479
|
+
return event
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from token_limit.patches.deepseek_patch import DeepSeekPatch
|
|
2
|
+
|
|
3
|
+
from .anthropic_patch import AnthropicPatch
|
|
4
|
+
from .openai_patch import OpenAIPatch
|
|
5
|
+
from .google_patch import GooglePatch
|
|
6
|
+
from .openrouter_patch import OpenRouterPatch
|
|
7
|
+
|
|
8
|
+
PATCH_REGISTRY = {
|
|
9
|
+
"anthropic": AnthropicPatch,
|
|
10
|
+
"openai": OpenAIPatch,
|
|
11
|
+
"google": GooglePatch,
|
|
12
|
+
"openrouter": OpenRouterPatch,
|
|
13
|
+
"deepseek": DeepSeekPatch
|
|
14
|
+
}
|