cendor-core 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,22 @@
1
+ """cendor.core — the shared foundation. Keep this public surface small and stable."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from . import bus, otel, prices, protocols, tokens
6
+ from .instrument import Reroute, instrument, instrument_tool
7
+ from .types import LLMCall, Money, ToolCall, Usage
8
+
9
+ __all__ = [
10
+ "LLMCall",
11
+ "ToolCall",
12
+ "Usage",
13
+ "Money",
14
+ "bus",
15
+ "tokens",
16
+ "prices",
17
+ "otel",
18
+ "protocols",
19
+ "instrument",
20
+ "instrument_tool",
21
+ "Reroute",
22
+ ]
cendor/core/bus.py ADDED
@@ -0,0 +1,66 @@
1
+ """In-process pub/sub event bus: one instrument() emits, many tools subscribe. docs/core.md §6.
2
+
3
+ Thread-safe within a process: the subscriber list is guarded by a lock for registration changes,
4
+ and :func:`emit` fans out over a snapshot taken under that lock — so subscribers may (un)subscribe
5
+ from other threads (or from inside a callback) without corrupting the list or deadlocking.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import threading
11
+ from collections.abc import Callable
12
+ from typing import Any
13
+
14
+ _subscribers: list[Callable[[Any], None]] = []
15
+ _lock = threading.Lock()
16
+
17
+
18
+ def subscribe(fn: Callable[[Any], None]) -> Callable[[Any], None]:
19
+ """Register a subscriber. Usable as a decorator. Idempotent: re-registering the
20
+ same callable is a no-op, so a sibling tool can safely ensure its subscription."""
21
+ with _lock:
22
+ if fn not in _subscribers:
23
+ _subscribers.append(fn)
24
+ return fn
25
+
26
+
27
+ def unsubscribe(fn: Callable[[Any], None]) -> None:
28
+ """Remove a subscriber (no error if absent) — the inverse of :func:`subscribe`.
29
+
30
+ Lets a tool register a *temporary* subscriber (e.g. cassette's recorder) and tear it down
31
+ cleanly, without reaching into the internal subscriber list."""
32
+ with _lock:
33
+ if fn in _subscribers:
34
+ _subscribers.remove(fn)
35
+
36
+
37
+ def emit(event: Any) -> None:
38
+ """Publish an event to every subscriber (synchronous).
39
+
40
+ Every subscriber runs even if an earlier one raises, so one tool's failure can't starve
41
+ another (a logging subscriber's bug must not skip ``tokenguard``'s enforcement, or vice versa).
42
+ The first ``Exception`` raised is re-raised after all subscribers have run, so intentional
43
+ control flow (e.g. ``tokenguard``'s post-flight ``BudgetExceeded``) still reaches the caller.
44
+ ``BaseException`` (``KeyboardInterrupt``/``SystemExit``) is not caught — it propagates at once.
45
+
46
+ The fan-out runs over a snapshot taken under the lock, then *releases* it before invoking any
47
+ subscriber — so a subscriber is free to (un)subscribe without deadlocking, and a slow one never
48
+ blocks registration on another thread.
49
+ """
50
+ with _lock:
51
+ subscribers = list(_subscribers)
52
+ first_exc: Exception | None = None
53
+ for fn in subscribers:
54
+ try:
55
+ fn(event)
56
+ except Exception as exc: # noqa: BLE001 - isolate subscribers, re-raise first after all run
57
+ if first_exc is None:
58
+ first_exc = exc
59
+ if first_exc is not None:
60
+ raise first_exc
61
+
62
+
63
+ def _reset() -> None:
64
+ """Test helper: clear all subscribers."""
65
+ with _lock:
66
+ _subscribers.clear()
@@ -0,0 +1,578 @@
1
+ """Single interception point: wrap a provider client (or tool) once; emit normalized events.
2
+
3
+ docs/core.md §6. Idempotent (re-wrapping is a no-op) and additive (coexists with other
4
+ instrumentation like OpenLLMetry). Supports sync and async, and **streaming** responses
5
+ (``stream=True``): the chunk iterator is passed through unchanged while usage is accumulated, so
6
+ the ``LLMCall`` is emitted once with usage/cost/latency when the stream completes — not the
7
+ unconsumed iterator. Uses duck typing — the provider SDKs are never imported here, so they stay
8
+ optional.
9
+
10
+ Two cooperation hooks (used by ``cassette``; harmless otherwise):
11
+ * **record** — the raw provider response is attached at ``call.metadata["response"]`` before
12
+ the event is emitted, so a subscriber can persist it.
13
+ * **replay** — registered *interceptors* run *before* the real call; one may return a response
14
+ to short-circuit it (returning :data:`MISS` to decline). This is how record/replay avoids a
15
+ second instrumentation point: tools cooperate through ``core``, they never patch the client.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import functools
21
+ import inspect
22
+ import threading
23
+ import time
24
+ import uuid
25
+ from collections.abc import Callable
26
+ from datetime import UTC, datetime
27
+ from decimal import Decimal
28
+ from typing import Any, TypeVar
29
+
30
+ from . import bus, prices, tokens
31
+ from .types import LLMCall, Money, ToolCall, Usage
32
+
33
+ T = TypeVar("T")
34
+
35
+ _WRAPPED = "_cendor_wrapped"
36
+
37
+ #: Sentinel an interceptor returns to decline a call (let it proceed normally). A recorded
38
+ #: response may legitimately be ``None``, so "no replay" needs its own distinct value.
39
+ MISS: Any = object()
40
+
41
+
42
+ class Reroute:
43
+ """Returned by an interceptor to modify the outgoing request, then run the real call.
44
+
45
+ Used by ``tokenguard`` for ``on_exceed="downgrade"`` (e.g. ``Reroute(model="gpt-4o-mini")``).
46
+ Any keyword updates are applied to the call's kwargs before it executes.
47
+ """
48
+
49
+ def __init__(self, **updates: Any) -> None:
50
+ self.updates = updates
51
+
52
+
53
+ _interceptors: list[Callable[[Any], Any]] = []
54
+ _interceptors_lock = threading.Lock()
55
+
56
+
57
+ def add_interceptor(fn: Callable[[Any], Any]) -> Callable[[Any], Any]:
58
+ """Register a pre-call interceptor. It receives the event (``LLMCall``/``ToolCall``) and
59
+ returns a response to short-circuit the real call, or :data:`MISS` to proceed. Idempotent.
60
+
61
+ Thread-safe: registration is guarded by a lock and :func:`_intercept` runs over a snapshot.
62
+ """
63
+ with _interceptors_lock:
64
+ if fn not in _interceptors:
65
+ _interceptors.append(fn)
66
+ return fn
67
+
68
+
69
+ def remove_interceptor(fn: Callable[[Any], Any]) -> None:
70
+ """Unregister a previously added interceptor (no error if absent)."""
71
+ with _interceptors_lock:
72
+ if fn in _interceptors:
73
+ _interceptors.remove(fn)
74
+
75
+
76
+ def _intercept(event: Any) -> Any:
77
+ with _interceptors_lock:
78
+ interceptors = list(_interceptors)
79
+ for fn in interceptors:
80
+ result = fn(event)
81
+ if result is not MISS:
82
+ return result
83
+ return MISS
84
+
85
+
86
+ # --------------------------------------------------------------------------- model clients
87
+
88
+
89
+ def instrument(client: T) -> T:
90
+ """Wrap an OpenAI- or Anthropic-shaped client so each call emits an ``LLMCall`` on the bus.
91
+
92
+ Detection is structural: an object exposing ``chat.completions.create`` is treated as
93
+ OpenAI-style; one exposing ``messages.create`` as Anthropic-style. Unknown clients are
94
+ returned untouched. Wrapping is idempotent and returns the same client object.
95
+ """
96
+ target = _find_target(client)
97
+ if target is None:
98
+ return client
99
+ owner, attr, provider = target
100
+ fn = getattr(owner, attr)
101
+ if getattr(fn, _WRAPPED, False):
102
+ return client # already instrumented — no double-wrap
103
+ model_default = ""
104
+ if provider == "google":
105
+ # Gemini binds the model to the GenerativeModel object (not the call kwargs); read it here
106
+ # so the LLMCall carries a real model id (and can be priced). Strip the "models/" prefix.
107
+ name = getattr(client, "model_name", None) or getattr(client, "_model_name", "")
108
+ model_default = str(name).removeprefix("models/")
109
+ setattr(owner, attr, _wrap(fn, provider, model_default))
110
+ return client
111
+
112
+
113
+ def _find_target(client: Any) -> tuple[Any, str, str] | None:
114
+ chat = getattr(client, "chat", None)
115
+ completions = getattr(chat, "completions", None) if chat is not None else None
116
+ if completions is not None and callable(getattr(completions, "create", None)):
117
+ return completions, "create", "openai"
118
+ messages = getattr(client, "messages", None)
119
+ if messages is not None and callable(getattr(messages, "create", None)):
120
+ return messages, "create", "anthropic"
121
+ if callable(getattr(client, "converse", None)): # AWS Bedrock Converse API
122
+ return client, "converse", "bedrock"
123
+ if callable(getattr(client, "generate_content", None)): # Google Gemini
124
+ return client, "generate_content", "google"
125
+ if callable(chat): # Ollama: client.chat(...) is itself callable (vs OpenAI's chat namespace)
126
+ return client, "chat", "ollama"
127
+ return None
128
+
129
+
130
+ def _wrap(fn: Any, provider: str, model_default: str = "") -> Any:
131
+ if inspect.iscoroutinefunction(fn):
132
+
133
+ @functools.wraps(fn)
134
+ async def awrapper(*args: Any, **kwargs: Any) -> Any:
135
+ call, start = _pre(provider, args, kwargs, model_default)
136
+ _ensure_stream_usage_options(provider, kwargs)
137
+ directive = _intercept(call)
138
+ if isinstance(directive, Reroute):
139
+ _apply_reroute(call, kwargs, directive)
140
+ response = await fn(*args, **kwargs)
141
+ elif directive is not MISS:
142
+ call.metadata["replayed"] = True
143
+ if kwargs.get("stream"):
144
+ return _areplay_stream(call, directive, provider, start)
145
+ _post(call, directive, provider, start)
146
+ return directive
147
+ else:
148
+ response = await fn(*args, **kwargs)
149
+ if kwargs.get("stream"):
150
+ return _aproxy_stream(call, response, provider, start)
151
+ _post(call, response, provider, start)
152
+ return response
153
+
154
+ setattr(awrapper, _WRAPPED, True)
155
+ return awrapper
156
+
157
+ @functools.wraps(fn)
158
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
159
+ call, start = _pre(provider, args, kwargs, model_default)
160
+ _ensure_stream_usage_options(provider, kwargs)
161
+ directive = _intercept(call)
162
+ if isinstance(directive, Reroute):
163
+ _apply_reroute(call, kwargs, directive)
164
+ response = fn(*args, **kwargs)
165
+ elif directive is not MISS:
166
+ call.metadata["replayed"] = True
167
+ if kwargs.get("stream"):
168
+ return _replay_stream(call, directive, provider, start)
169
+ _post(call, directive, provider, start)
170
+ return directive
171
+ else:
172
+ response = fn(*args, **kwargs)
173
+ if kwargs.get("stream"):
174
+ return _proxy_stream(call, response, provider, start)
175
+ _post(call, response, provider, start)
176
+ return response
177
+
178
+ setattr(wrapper, _WRAPPED, True)
179
+ return wrapper
180
+
181
+
182
+ def _apply_reroute(call: LLMCall, kwargs: dict, directive: Reroute) -> None:
183
+ kwargs.update(directive.updates)
184
+ if "model" in directive.updates:
185
+ call.model = directive.updates["model"]
186
+ call.metadata["rerouted"] = True
187
+
188
+
189
+ def _pre(
190
+ provider: str, args: tuple, kwargs: dict, model_default: str = ""
191
+ ) -> tuple[LLMCall, float]:
192
+ model, messages = _extract_request(provider, args, kwargs, model_default)
193
+ call = LLMCall(
194
+ id=uuid.uuid4().hex,
195
+ provider=provider,
196
+ model=model,
197
+ messages=messages,
198
+ ts=datetime.now(UTC),
199
+ )
200
+ call.metadata["request_kwargs"] = kwargs # so pre-flight interceptors can read e.g. max_tokens
201
+ return call, time.perf_counter()
202
+
203
+
204
+ def _extract_request(
205
+ provider: str, args: tuple, kwargs: dict, model_default: str = ""
206
+ ) -> tuple[str, list[dict]]:
207
+ """Normalize (model, messages) out of a provider's call signature."""
208
+ if provider == "bedrock":
209
+ return kwargs.get("modelId", ""), list(kwargs.get("messages") or [])
210
+ if provider == "google":
211
+ contents = kwargs.get("contents")
212
+ if contents is None and args:
213
+ contents = args[0]
214
+ if isinstance(contents, list):
215
+ messages = contents
216
+ elif contents:
217
+ messages = [{"role": "user", "content": str(contents)}]
218
+ else:
219
+ messages = []
220
+ # model id is bound to the GenerativeModel object (model_default), not the call kwargs
221
+ return kwargs.get("model") or model_default, messages
222
+ # openai / anthropic / ollama all take model= + messages=
223
+ return kwargs.get("model", ""), list(kwargs.get("messages") or [])
224
+
225
+
226
+ def _post(call: LLMCall, response: Any, provider: str, start: float) -> None:
227
+ call.latency_ms = (time.perf_counter() - start) * 1000.0
228
+ usage = _extract_usage(response, provider)
229
+ call.usage = usage
230
+ _set_cost(call, usage, _extract_reported_cost(response))
231
+ call.metadata["response"] = response # for recorders (cassette); a reference, not a copy
232
+ bus.emit(call)
233
+
234
+
235
+ def _set_cost(call: LLMCall, usage: Usage | None, reported: Money | None) -> None:
236
+ """Set ``call.cost`` and label its provenance: a provider-/gateway-reported figure is preferred
237
+ over an offline estimate from the price table.
238
+
239
+ Tags ``metadata["cost_reported"]`` when the provider billed us a real amount (e.g. OpenRouter's
240
+ ``usage.cost``), or ``metadata["cost_estimated"]`` when we priced it from the snapshot — so
241
+ downstream tools and audits can tell a billed cost apart from an estimate (mirrors the existing
242
+ ``usage_estimated`` flag). Unknown model + no reported cost leaves ``cost = None``.
243
+ """
244
+ if reported is not None:
245
+ call.cost = reported
246
+ call.metadata["cost_reported"] = True
247
+ return
248
+ if usage is not None:
249
+ try:
250
+ call.cost = prices.estimate(
251
+ call.model, usage.input_tokens, usage.output_tokens, usage.cached_tokens
252
+ )
253
+ call.metadata["cost_estimated"] = True
254
+ except KeyError:
255
+ call.cost = None
256
+
257
+
258
+ def _extract_reported_cost(response: Any) -> Money | None:
259
+ """Read a provider-/gateway-reported cost off a response (or stream chunk), if present.
260
+
261
+ Gateways like OpenRouter attach the real billed cost at ``usage.cost``; standard OpenAI/
262
+ Anthropic SDK responses carry no cost, so this returns ``None`` and the caller falls back to the
263
+ offline estimate. Best-effort and exception-safe — never breaks the call.
264
+ """
265
+ u = _get(response, "usage")
266
+ candidates = []
267
+ if u is not None:
268
+ candidates += [_get(u, "cost"), _get(u, "total_cost")]
269
+ candidates += [_get(response, "cost"), _get(response, "total_cost")]
270
+ for c in candidates:
271
+ if c is None:
272
+ continue
273
+ try:
274
+ amount = Decimal(str(c))
275
+ except (ArithmeticError, ValueError, TypeError):
276
+ continue
277
+ if amount >= 0:
278
+ return Money(amount)
279
+ return None
280
+
281
+
282
+ def _ensure_stream_usage_options(provider: str, kwargs: dict) -> None:
283
+ """For an OpenAI stream, ask the provider to emit a final usage chunk so streamed usage is the
284
+ real billed count, not an offline estimate.
285
+
286
+ Injects ``stream_options={"include_usage": True}`` only when ``stream=True`` and the caller
287
+ hasn't set ``stream_options`` themselves (their value is always left intact). No-op for other
288
+ providers. docs/core.md §6.
289
+ """
290
+ if provider == "openai" and kwargs.get("stream") and "stream_options" not in kwargs:
291
+ kwargs["stream_options"] = {"include_usage": True}
292
+
293
+
294
+ # --------------------------------------------------------------------------- streaming
295
+
296
+
297
+ def _proxy_stream(call: LLMCall, stream: Any, provider: str, start: float) -> Any:
298
+ """Pass a sync streaming response through unchanged, collecting chunks; emit once on completion.
299
+
300
+ The caller iterates exactly the provider's chunks. When the stream is exhausted (or closed
301
+ early), the ``LLMCall`` — now with usage, cost, and true end-to-end latency — is emitted once.
302
+ Without this, a streamed call returns an unconsumed iterator and emits no usable usage.
303
+ """
304
+
305
+ def gen() -> Any:
306
+ chunks: list[Any] = []
307
+ try:
308
+ for chunk in stream:
309
+ chunks.append(chunk)
310
+ yield chunk
311
+ finally:
312
+ _finalize_stream(call, chunks, provider, start)
313
+
314
+ return gen()
315
+
316
+
317
+ def _aproxy_stream(call: LLMCall, stream: Any, provider: str, start: float) -> Any:
318
+ """Async counterpart of :func:`_proxy_stream` for ``async for`` streaming responses.
319
+
320
+ A plain function that *returns* an async generator (not ``async def``), so the wrapper can hand
321
+ the async iterator straight back to the caller's ``async for`` — not an un-awaited coroutine.
322
+ """
323
+
324
+ async def agen() -> Any:
325
+ chunks: list[Any] = []
326
+ try:
327
+ async for chunk in stream:
328
+ chunks.append(chunk)
329
+ yield chunk
330
+ finally:
331
+ _finalize_stream(call, chunks, provider, start)
332
+
333
+ return agen()
334
+
335
+
336
+ def _replay_stream(call: LLMCall, recorded: Any, provider: str, start: float) -> Any:
337
+ """Re-yield a recorded stream (a chunk sequence) so a replayed streamed call still iterates."""
338
+ chunks = list(recorded) if recorded is not None else []
339
+
340
+ def gen() -> Any:
341
+ try:
342
+ yield from chunks
343
+ finally:
344
+ _finalize_stream(call, chunks, provider, start)
345
+
346
+ return gen()
347
+
348
+
349
+ def _areplay_stream(call: LLMCall, recorded: Any, provider: str, start: float) -> Any:
350
+ """Async counterpart of :func:`_replay_stream` (yields the recorded chunks for ``async for``).
351
+
352
+ Like :func:`_aproxy_stream`, a plain function returning an async generator.
353
+ """
354
+ chunks = list(recorded) if recorded is not None else []
355
+
356
+ async def agen() -> Any:
357
+ try:
358
+ for chunk in chunks:
359
+ yield chunk
360
+ finally:
361
+ _finalize_stream(call, chunks, provider, start)
362
+
363
+ return agen()
364
+
365
+
366
+ def _finalize_stream(call: LLMCall, chunks: list, provider: str, start: float) -> None:
367
+ """Close out a streamed call: recover (or estimate) usage, price it, emit on the bus once."""
368
+ call.latency_ms = (time.perf_counter() - start) * 1000.0
369
+ usage = _stream_usage(chunks, provider)
370
+ if usage is None:
371
+ usage = _estimate_stream_usage(call, chunks, provider)
372
+ call.usage = usage
373
+ reported = None
374
+ for ch in chunks: # a gateway may report cost on the final usage chunk
375
+ reported = _extract_reported_cost(ch)
376
+ if reported is not None:
377
+ break
378
+ _set_cost(call, usage, reported)
379
+ call.metadata["streamed"] = True
380
+ call.metadata["response"] = chunks # the collected chunks, so a recorder (cassette) can persist
381
+ bus.emit(call)
382
+
383
+
384
+ def _stream_usage(chunks: list, provider: str) -> Usage | None:
385
+ """Recover real usage from streamed chunks where the provider reports it.
386
+
387
+ OpenAI/Ollama/Gemini carry usage on a single (final) chunk shaped like a full response, so
388
+ :func:`_extract_usage` reads it directly; Anthropic splits it across ``message_start`` (input)
389
+ and ``message_delta`` (output) events; Bedrock puts it on a ``metadata`` event. Returns ``None``
390
+ when no chunk reports usage (e.g. OpenAI without ``stream_options={"include_usage": True}``).
391
+ """
392
+ if provider == "anthropic":
393
+ inp = out = None
394
+ cached = 0
395
+ for ch in chunks:
396
+ etype = _get(ch, "type")
397
+ if etype == "message_start":
398
+ u = _get(_get(ch, "message"), "usage")
399
+ inp = _get(u, "input_tokens", inp)
400
+ cached = _get(u, "cache_read_input_tokens", 0) or 0
401
+ elif etype == "message_delta":
402
+ u = _get(ch, "usage")
403
+ if u is not None:
404
+ out = _get(u, "output_tokens", out)
405
+ if inp is None:
406
+ return None
407
+ return Usage(int(inp), int(out or 0), int(cached or 0))
408
+ if provider == "bedrock":
409
+ for ch in chunks:
410
+ u = _get(_get(ch, "metadata"), "usage")
411
+ if u is not None:
412
+ return Usage(
413
+ int(_get(u, "inputTokens", 0) or 0), int(_get(u, "outputTokens", 0) or 0)
414
+ )
415
+ return None
416
+ for ch in chunks: # openai / ollama / google: usage rides one chunk, full-response shaped
417
+ u = _extract_usage(ch, provider)
418
+ if u is not None:
419
+ return u
420
+ return None
421
+
422
+
423
+ def _estimate_stream_usage(call: LLMCall, chunks: list, provider: str) -> Usage | None:
424
+ """Offline fallback when a stream reports no usage: count input messages + streamed output text.
425
+
426
+ Marks ``call.metadata["usage_estimated"] = True`` so downstream tools (and audits) can tell the
427
+ figure is an offline estimate, not the provider's billed count. Exact with the ``[tiktoken]``
428
+ extra for OpenAI; a heuristic otherwise (see :mod:`cendor.core.tokens`).
429
+ """
430
+ text = "".join(_stream_text(ch, provider) for ch in chunks)
431
+ if not text and not call.messages:
432
+ return None
433
+ inp = tokens.count(call.messages, call.model) if call.messages else 0
434
+ out = tokens.count(text, call.model) if text else 0
435
+ call.metadata["usage_estimated"] = True
436
+ return Usage(int(inp), int(out))
437
+
438
+
439
+ def _stream_text(chunk: Any, provider: str) -> str:
440
+ """Best-effort text of one streamed chunk, per provider (only for the offline estimate)."""
441
+ try:
442
+ if provider == "openai":
443
+ choices = _get(chunk, "choices") or []
444
+ return "".join(str(_get(_get(c, "delta"), "content", "") or "") for c in choices)
445
+ if provider == "anthropic":
446
+ if _get(chunk, "type") == "content_block_delta":
447
+ return str(_get(_get(chunk, "delta"), "text", "") or "")
448
+ return ""
449
+ if provider == "ollama":
450
+ return str(_get(_get(chunk, "message"), "content", "") or "")
451
+ if provider == "google":
452
+ return str(_get(chunk, "text", "") or "")
453
+ if provider == "bedrock":
454
+ return str(_get(_get(_get(chunk, "contentBlockDelta"), "delta"), "text", "") or "")
455
+ except Exception: # noqa: BLE001 - estimation must never break the passthrough
456
+ return ""
457
+ return ""
458
+
459
+
460
+ def _get(obj: Any, name: str, default: Any = None) -> Any:
461
+ if isinstance(obj, dict):
462
+ return obj.get(name, default)
463
+ return getattr(obj, name, default)
464
+
465
+
466
+ def _extract_usage(response: Any, provider: str) -> Usage | None:
467
+ cached = 0
468
+ reasoning = 0 # tokens the model spent reasoning/thinking; a subset of output (see Usage)
469
+ if provider == "google": # usage lives under usage_metadata
470
+ meta = _get(response, "usage_metadata")
471
+ inp = _get(meta, "prompt_token_count")
472
+ # Gemini reports thinking-model reasoning under thoughts_token_count, *separate* from
473
+ # candidates_token_count. Both are billed as output, so fold thoughts into the output total
474
+ # (otherwise reasoning models are under-counted) and also surface it as reasoning_tokens.
475
+ reasoning = _get(meta, "thoughts_token_count", 0) or 0
476
+ out = (_get(meta, "candidates_token_count", 0) or 0) + reasoning
477
+ elif provider == "ollama": # token counts are top-level on the response
478
+ inp = _get(response, "prompt_eval_count")
479
+ out = _get(response, "eval_count", 0) or 0
480
+ else:
481
+ u = _get(response, "usage")
482
+ if u is None:
483
+ return None
484
+ if provider == "openai":
485
+ inp = _get(u, "prompt_tokens")
486
+ out = _get(u, "completion_tokens", 0) or 0
487
+ details = _get(u, "prompt_tokens_details")
488
+ cached = _get(details, "cached_tokens", 0) or 0 if details is not None else 0
489
+ # o-series/GPT-5 reasoning tokens are a subset of completion_tokens (already in `out`).
490
+ cdetails = _get(u, "completion_tokens_details")
491
+ reasoning = _get(cdetails, "reasoning_tokens", 0) or 0 if cdetails is not None else 0
492
+ elif provider == "bedrock": # Converse usage uses camelCase token keys
493
+ inp = _get(u, "inputTokens")
494
+ out = _get(u, "outputTokens", 0) or 0
495
+ else: # anthropic — thinking tokens are folded into output_tokens with no separate count
496
+ inp = _get(u, "input_tokens")
497
+ out = _get(u, "output_tokens", 0) or 0
498
+ cached = _get(u, "cache_read_input_tokens", 0) or 0
499
+ if inp is None:
500
+ return None
501
+ return Usage(
502
+ input_tokens=int(inp),
503
+ output_tokens=int(out),
504
+ cached_tokens=int(cached),
505
+ reasoning_tokens=int(reasoning),
506
+ )
507
+
508
+
509
+ # --------------------------------------------------------------------------- tools
510
+
511
+
512
+ def instrument_tool(name: str | Callable | None = None) -> Callable:
513
+ """Wrap a tool/function so each invocation emits a ``ToolCall`` on the bus.
514
+
515
+ Usable as ``@instrument_tool`` or ``@instrument_tool("search")``. Mirrors :func:`instrument`:
516
+ idempotent, sync + async, replay-aware. The return value is stored on ``ToolCall.result`` so
517
+ ``cassette`` can record/replay tool side effects.
518
+ """
519
+ if callable(name): # bare @instrument_tool
520
+ return _wrap_tool(name, str(getattr(name, "__name__", "tool")))
521
+
522
+ def decorator(fn: Callable) -> Callable:
523
+ return _wrap_tool(fn, name or str(getattr(fn, "__name__", "tool")))
524
+
525
+ return decorator
526
+
527
+
528
+ def _wrap_tool(fn: Callable, tool_name: str) -> Callable:
529
+ if getattr(fn, _WRAPPED, False):
530
+ return fn
531
+
532
+ if inspect.iscoroutinefunction(fn):
533
+
534
+ @functools.wraps(fn)
535
+ async def awrapper(*args: Any, **kwargs: Any) -> Any:
536
+ tc, start = _pre_tool(tool_name, args, kwargs)
537
+ replayed = _intercept(tc)
538
+ if replayed is not MISS:
539
+ tc.metadata["replayed"] = True
540
+ result = replayed
541
+ else:
542
+ result = await fn(*args, **kwargs)
543
+ _post_tool(tc, result, start)
544
+ return result
545
+
546
+ setattr(awrapper, _WRAPPED, True)
547
+ return awrapper
548
+
549
+ @functools.wraps(fn)
550
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
551
+ tc, start = _pre_tool(tool_name, args, kwargs)
552
+ replayed = _intercept(tc)
553
+ if replayed is not MISS:
554
+ tc.metadata["replayed"] = True
555
+ result = replayed
556
+ else:
557
+ result = fn(*args, **kwargs)
558
+ _post_tool(tc, result, start)
559
+ return result
560
+
561
+ setattr(wrapper, _WRAPPED, True)
562
+ return wrapper
563
+
564
+
565
+ def _pre_tool(name: str, args: tuple, kwargs: dict) -> tuple[ToolCall, float]:
566
+ tc = ToolCall(
567
+ id=uuid.uuid4().hex,
568
+ name=name,
569
+ arguments={"args": list(args), "kwargs": dict(kwargs)},
570
+ ts=datetime.now(UTC),
571
+ )
572
+ return tc, time.perf_counter()
573
+
574
+
575
+ def _post_tool(tc: ToolCall, result: Any, start: float) -> None:
576
+ tc.latency_ms = (time.perf_counter() - start) * 1000.0
577
+ tc.result = result
578
+ bus.emit(tc)