debugerai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
debugai/sdk.py ADDED
@@ -0,0 +1,1271 @@
1
+ """Level 2 integration — the one-line SDK wrapper (Architecture §3.2).
2
+
3
+ from debugai.sdk import wrap_llm
4
+ client = wrap_llm(OpenAI()) # or wrap_llm(Anthropic())
5
+ resp = client.chat.completions.create(...) # unchanged call site
6
+
7
+ ``wrap_llm`` returns a transparent proxy: every attribute access forwards to the
8
+ real client untouched *except* the terminal ``create`` call, which is
9
+ instrumented. After the real call returns, a CaptureRecord is built and handed
10
+ to a background worker for diagnosis — so the user's request is never blocked
11
+ (the only added latency is cheap dict-building, well under the 10ms budget).
12
+
13
+ Retrieval context (chunks + similarity scores) isn't visible from the LLM call
14
+ itself, so attach it either with the ``retrieval_context`` context manager or
15
+ by passing ``debugai_chunks`` / ``debugai_similarity_scores`` kwargs to
16
+ ``create`` (popped before the call is forwarded).
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import atexit
22
+ import contextlib
23
+ import contextvars
24
+ import logging
25
+ import queue
26
+ import threading
27
+ import time
28
+ from dataclasses import dataclass, field
29
+ from typing import Any, Callable
30
+
31
+ import hashlib
32
+ import json as _json_mod
33
+ import uuid
34
+ from concurrent.futures import ThreadPoolExecutor, as_completed
35
+
36
+ from debugai.analyze import analyze
37
+ from debugai.config import DebugAIConfig
38
+ from debugai.metrics import metrics as _global_metrics
39
+ from debugai.thresholds import DEFAULT_THRESHOLDS, Thresholds
40
+ from debugai.tracing import Span, Trace, scores_from_diagnosis, status_from_diagnosis
41
+
42
+ log = logging.getLogger("debugai.sdk")
43
+
44
+ # Per-context retrieval payload set by retrieval_context() (thread/async safe).
45
+ _retrieval: contextvars.ContextVar[dict | None] = contextvars.ContextVar(
46
+ "debugai_retrieval", default=None
47
+ )
48
+ # Per-context session id set by session() (groups traces into a conversation).
49
+ _session: contextvars.ContextVar[str | None] = contextvars.ContextVar(
50
+ "debugai_session", default=None
51
+ )
52
+
53
+
54
+ @contextlib.contextmanager
55
+ def session(session_id: str):
56
+ """Group all wrapped LLM calls made inside the block into one session."""
57
+ token = _session.set(session_id)
58
+ try:
59
+ yield
60
+ finally:
61
+ _session.reset(token)
62
+
63
+
64
+ @contextlib.contextmanager
65
+ def retrieval_context(chunks: list[str], similarity_scores: list[float] | None = None,
66
+ retrieval_query: str | None = None):
67
+ """Attach RAG context to any wrapped LLM calls made inside the block."""
68
+ token = _retrieval.set(
69
+ {
70
+ "retrieved_chunks": list(chunks or []),
71
+ "similarity_scores": list(similarity_scores or []),
72
+ "retrieval_query": retrieval_query,
73
+ }
74
+ )
75
+ try:
76
+ yield
77
+ finally:
78
+ _retrieval.reset(token)
79
+
80
+
81
+ # --------------------------------------------------------------------------- #
82
+ # Provider adapters (duck-typed so fakes work without the real SDKs installed)
83
+ # --------------------------------------------------------------------------- #
84
+ @dataclass
85
+ class _Captured:
86
+ system_prompt: str = ""
87
+ user_prompt: str = ""
88
+ model_name: str | None = None
89
+ temperature: float | None = None
90
+ max_tokens: int | None = None
91
+
92
+
93
+ def _msg_text(content) -> str:
94
+ """Coerce a chat message's `content` (str, None, or a list of content parts)
95
+ to plain text — the modern SDKs allow list/None, which would otherwise break
96
+ the wrapper before the real LLM call runs."""
97
+ if content is None:
98
+ return ""
99
+ if isinstance(content, str):
100
+ return content
101
+ if isinstance(content, list):
102
+ out = []
103
+ for part in content:
104
+ if isinstance(part, dict):
105
+ out.append(part.get("text") or part.get("content") or "")
106
+ elif isinstance(part, str):
107
+ out.append(part)
108
+ return " ".join(p for p in out if p)
109
+ return str(content)
110
+
111
+
112
+ class _OpenAIAdapter:
113
+ create_path = ("chat", "completions", "create")
114
+
115
+ @staticmethod
116
+ def matches(client: Any) -> bool:
117
+ chat = getattr(client, "chat", None)
118
+ comp = getattr(chat, "completions", None)
119
+ return callable(getattr(comp, "create", None))
120
+
121
+ @staticmethod
122
+ def from_request(kwargs: dict) -> _Captured:
123
+ msgs = kwargs.get("messages", []) or []
124
+ system = " ".join(_msg_text(m.get("content")) for m in msgs if m.get("role") == "system")
125
+ user = " ".join(_msg_text(m.get("content")) for m in msgs if m.get("role") == "user")
126
+ return _Captured(
127
+ system_prompt=system,
128
+ user_prompt=user,
129
+ model_name=kwargs.get("model"),
130
+ temperature=kwargs.get("temperature"),
131
+ max_tokens=kwargs.get("max_tokens"),
132
+ )
133
+
134
+ @staticmethod
135
+ def from_response(resp: Any) -> tuple[str, dict]:
136
+ try:
137
+ text = resp.choices[0].message.content or ""
138
+ except Exception:
139
+ text = ""
140
+ usage = {}
141
+ u = getattr(resp, "usage", None)
142
+ if u is not None:
143
+ usage = {
144
+ "prompt": getattr(u, "prompt_tokens", 0),
145
+ "completion": getattr(u, "completion_tokens", 0),
146
+ "total": getattr(u, "total_tokens", 0),
147
+ }
148
+ return text, usage
149
+
150
+ @staticmethod
151
+ def extract_tool_calls(resp: Any) -> list[dict]:
152
+ """B3 — extract function/tool calls from the response."""
153
+ try:
154
+ calls = resp.choices[0].message.tool_calls or []
155
+ return [{"name": tc.function.name,
156
+ "input": tc.function.arguments,
157
+ "id": getattr(tc, "id", "")} for tc in calls]
158
+ except Exception:
159
+ return []
160
+
161
+
162
+ class _AnthropicAdapter:
163
+ create_path = ("messages", "create")
164
+
165
+ @staticmethod
166
+ def matches(client: Any) -> bool:
167
+ msgs = getattr(client, "messages", None)
168
+ # Distinguish from OpenAI (which has .chat); Anthropic has no .chat.
169
+ return callable(getattr(msgs, "create", None)) and not hasattr(client, "chat")
170
+
171
+ @staticmethod
172
+ def from_request(kwargs: dict) -> _Captured:
173
+ system = kwargs.get("system", "") or ""
174
+ msgs = kwargs.get("messages", []) or []
175
+ user = " ".join(_msg_text(m.get("content")) for m in msgs if m.get("role") == "user")
176
+ return _Captured(
177
+ system_prompt=system if isinstance(system, str) else _msg_text(system),
178
+ user_prompt=user,
179
+ model_name=kwargs.get("model"),
180
+ temperature=kwargs.get("temperature"),
181
+ max_tokens=kwargs.get("max_tokens"),
182
+ )
183
+
184
+ @staticmethod
185
+ def from_response(resp: Any) -> tuple[str, dict]:
186
+ text = ""
187
+ try:
188
+ text = "".join(
189
+ getattr(b, "text", "") for b in resp.content
190
+ if getattr(b, "type", "") == "text"
191
+ )
192
+ except Exception:
193
+ text = ""
194
+ usage = {}
195
+ u = getattr(resp, "usage", None)
196
+ if u is not None:
197
+ inp = getattr(u, "input_tokens", 0)
198
+ out = getattr(u, "output_tokens", 0)
199
+ usage = {"prompt": inp, "completion": out, "total": inp + out}
200
+ return text, usage
201
+
202
+ @staticmethod
203
+ def extract_tool_calls(resp: Any) -> list[dict]:
204
+ """B3 — extract tool_use blocks from an Anthropic response."""
205
+ try:
206
+ return [{"name": b.name, "input": b.input, "id": getattr(b, "id", "")}
207
+ for b in resp.content if getattr(b, "type", "") == "tool_use"]
208
+ except Exception:
209
+ return []
210
+
211
+
212
+ class _OpenAICompatAdapter(_OpenAIAdapter):
213
+ """Matches any OpenAI-API-compatible client: Azure, Groq, Together AI, Mistral,
214
+ Ollama (Qwen, Llama, Phi, DeepSeek…), OpenRouter, LM Studio, vLLM.
215
+
216
+ Identical create_path/from_request/from_response to _OpenAIAdapter.
217
+ The difference is only in the base_url the client was constructed with."""
218
+ create_path = ("chat", "completions", "create")
219
+
220
+ @staticmethod
221
+ def matches(client: Any) -> bool:
222
+ return _OpenAIAdapter.matches(client) and not _AnthropicAdapter.matches(client)
223
+
224
+
225
+ # Backward-compat alias.
226
+ _GenericOpenAICompatAdapter = _OpenAICompatAdapter
227
+
228
+
229
+ class _CohereAdapter:
230
+ """Native Cohere SDK adapter (ClientV2). Requires: pip install cohere"""
231
+
232
+ create_path = ("chat",)
233
+
234
+ @staticmethod
235
+ def matches(client: Any) -> bool:
236
+ return (callable(getattr(client, "chat", None)) and
237
+ hasattr(client, "embed") and
238
+ not hasattr(client, "messages"))
239
+
240
+ @staticmethod
241
+ def from_request(kwargs: dict) -> "_Captured":
242
+ msgs = kwargs.get("messages") or []
243
+ system = " ".join(m.get("message", m.get("content", "")) for m in msgs
244
+ if m.get("role", "").upper() in ("SYSTEM", "system"))
245
+ user = " ".join(m.get("message", m.get("content", "")) for m in msgs
246
+ if m.get("role", "").upper() in ("USER", "user"))
247
+ return _Captured(
248
+ system_prompt=system,
249
+ user_prompt=user,
250
+ model_name=kwargs.get("model"),
251
+ temperature=kwargs.get("temperature"),
252
+ max_tokens=kwargs.get("max_tokens"),
253
+ )
254
+
255
+ @staticmethod
256
+ def from_response(resp: Any) -> tuple[str, dict]:
257
+ text = ""
258
+ try:
259
+ text = resp.message.content[0].text or ""
260
+ except Exception:
261
+ try:
262
+ text = resp.text or ""
263
+ except Exception:
264
+ pass
265
+ usage = {}
266
+ try:
267
+ u = resp.meta.tokens
268
+ inp = getattr(u, "input_tokens", 0) or 0
269
+ out = getattr(u, "output_tokens", 0) or 0
270
+ usage = {"prompt": inp, "completion": out, "total": inp + out}
271
+ except Exception:
272
+ pass
273
+ return text, usage
274
+
275
+
276
+ _ADAPTERS = [_AnthropicAdapter, _OpenAICompatAdapter, _OpenAIAdapter, _CohereAdapter]
277
+ _EXTRA_ADAPTERS: list = []
278
+
279
+
280
+ def register_adapter(adapter_class) -> None:
281
+ """Register a custom adapter class for use with ``wrap_llm()``."""
282
+ _EXTRA_ADAPTERS.insert(0, adapter_class)
283
+
284
+
285
+ def _detect_adapter(client: Any):
286
+ for adapter in _EXTRA_ADAPTERS + _ADAPTERS:
287
+ if adapter.matches(client):
288
+ return adapter
289
+ raise TypeError(
290
+ "wrap_llm: unrecognised client. Supported: OpenAI-compatible clients "
291
+ "(.chat.completions.create), Anthropic (.messages.create), Cohere (.chat + .embed). "
292
+ "For custom providers use register_adapter() or register_provider()."
293
+ )
294
+
295
+
296
+ # --------------------------------------------------------------------------- #
297
+ # B4 – Budget manager
298
+ # --------------------------------------------------------------------------- #
299
+ class BudgetExceededError(Exception):
300
+ """Raised when a completion() call would exceed DebugAIConfig.budget_usd."""
301
+
302
+
303
+ # --------------------------------------------------------------------------- #
304
+ # B5 – TTL response cache
305
+ # --------------------------------------------------------------------------- #
306
+ class _TTLCache:
307
+ """Thread-safe in-memory cache with per-entry TTL."""
308
+
309
+ def __init__(self) -> None:
310
+ self._store: dict[str, tuple[Any, float]] = {}
311
+ self._lock = threading.Lock()
312
+
313
+ def get(self, key: str) -> Any:
314
+ with self._lock:
315
+ entry = self._store.get(key)
316
+ if entry is None:
317
+ return None
318
+ value, exp = entry
319
+ if time.time() < exp:
320
+ return value
321
+ del self._store[key]
322
+ return None
323
+
324
+ def set(self, key: str, value: Any, ttl: float) -> None:
325
+ with self._lock:
326
+ self._store[key] = (value, time.time() + ttl)
327
+
328
+ def clear(self) -> None:
329
+ with self._lock:
330
+ self._store.clear()
331
+
332
+
333
+ _response_cache = _TTLCache()
334
+
335
+
336
+ def _cache_key(model: str, messages: list) -> str:
337
+ payload = model + _json_mod.dumps(messages, sort_keys=True, ensure_ascii=False)
338
+ return hashlib.sha256(payload.encode()).hexdigest()
339
+
340
+
341
+ # --------------------------------------------------------------------------- #
342
+ # B9 – Model comparison
343
+ # --------------------------------------------------------------------------- #
344
+ @dataclass
345
+ class ComparisonResult:
346
+ """One model's result from ``debugai.compare()``."""
347
+ model: str
348
+ text: str
349
+ cost_usd: float
350
+ latency_ms: int
351
+ diagnosis: dict | None = None
352
+ error: str | None = None
353
+
354
+
355
+ def compare(
356
+ prompt: str,
357
+ models: list[str],
358
+ *,
359
+ system: str = "",
360
+ config: "DebugAIConfig | None" = None,
361
+ max_workers: int = 4,
362
+ **kwargs,
363
+ ) -> list[ComparisonResult]:
364
+ """Run the same prompt against multiple models in parallel and compare results.
365
+
366
+ results = debugai.compare(
367
+ prompt="Explain refraction.",
368
+ models=["gpt-4o", "claude-haiku-4-5", "ollama/qwen2.5"],
369
+ )
370
+ for r in results:
371
+ print(r.model, r.latency_ms, r.cost_usd, r.text[:80])
372
+ """
373
+ messages: list[dict] = []
374
+ if system:
375
+ messages.append({"role": "system", "content": system})
376
+ messages.append({"role": "user", "content": prompt})
377
+
378
+ def _one(model: str) -> ComparisonResult:
379
+ try:
380
+ resp = completion(model, messages, config=config, **kwargs)
381
+ return ComparisonResult(
382
+ model=model, text=resp.text, cost_usd=resp.cost_usd,
383
+ latency_ms=resp.latency_ms,
384
+ )
385
+ except Exception as e:
386
+ return ComparisonResult(model=model, text="", cost_usd=0.0,
387
+ latency_ms=0, error=str(e))
388
+
389
+ with ThreadPoolExecutor(max_workers=min(max_workers, len(models))) as ex:
390
+ futures = {ex.submit(_one, m): m for m in models}
391
+ results = [f.result() for f in as_completed(futures)]
392
+
393
+ results.sort(key=lambda r: (r.error is not None, r.latency_ms))
394
+ return results
395
+
396
+
397
+ def _prompt_hash(system_prompt: str) -> str:
398
+ """B10 — stable 12-char SHA256 prefix of the system prompt."""
399
+ if not system_prompt:
400
+ return ""
401
+ return hashlib.sha256(system_prompt.encode()).hexdigest()[:12]
402
+
403
+
404
+ def _validate_json_schema(output: str, schema: dict) -> list[str]:
405
+ """Validate a JSON response against a JSON Schema dict. Returns a list of
406
+ violation strings, or an empty list if valid. Stdlib only (no jsonschema pkg)."""
407
+ if not schema:
408
+ return []
409
+ # Step 1: is it valid JSON?
410
+ try:
411
+ data = _json_mod.loads(output.strip())
412
+ except _json_mod.JSONDecodeError as e:
413
+ return [f"Output is not valid JSON: {e}"]
414
+ violations = []
415
+ # Step 2: basic type checking against the schema (no external dependency).
416
+ schema_type = schema.get("type")
417
+ if schema_type:
418
+ type_map = {"object": dict, "array": list, "string": str,
419
+ "number": (int, float), "integer": int, "boolean": bool}
420
+ expected = type_map.get(schema_type)
421
+ if expected and not isinstance(data, expected):
422
+ violations.append(
423
+ f"Expected JSON {schema_type}, got {type(data).__name__}")
424
+ # Step 3: check required properties.
425
+ if isinstance(data, dict):
426
+ for req in schema.get("required", []):
427
+ if req not in data:
428
+ violations.append(f"Missing required property: '{req}'")
429
+ # Step 4: check property types.
430
+ for prop, prop_schema in schema.get("properties", {}).items():
431
+ if prop in data and isinstance(prop_schema, dict):
432
+ ptype = prop_schema.get("type")
433
+ type_map = {"string": str, "number": (int, float),
434
+ "integer": int, "boolean": bool, "array": list, "object": dict}
435
+ expected = type_map.get(ptype)
436
+ if expected and not isinstance(data[prop], expected):
437
+ violations.append(
438
+ f"Property '{prop}' should be {ptype}, got {type(data[prop]).__name__}")
439
+ return violations
440
+
441
+
442
+ # --------------------------------------------------------------------------- #
443
+ # Background diagnosis worker (async + batching, §5 step 'Async + batching')
444
+ # --------------------------------------------------------------------------- #
445
+ @dataclass
446
+ class _Job:
447
+ captured: _Captured
448
+ output: str
449
+ usage: dict
450
+ latency_ms: int
451
+ retrieval: dict | None
452
+ context_window: int | None
453
+ session_id: str | None = None
454
+ tool_calls: list = field(default_factory=list) # B3
455
+ correlation_id: str | None = None # B7
456
+ retry_count: int = 0 # B6
457
+ from_cache: bool = False # B5
458
+
459
+
460
+ class _Diagnoser:
461
+ """Single daemon worker that drains a queue and runs diagnosis off the
462
+ request path. Configuration is read from a DebugAIConfig so the same
463
+ worker respects enable_* flags, sampling, and sinks."""
464
+
465
+ def __init__(self, config: DebugAIConfig,
466
+ # Legacy positional compat (on_diagnosis, explain_with_llm, thresholds)
467
+ on_diagnosis: Callable | None = None,
468
+ explain_with_llm: bool = False,
469
+ thresholds: Thresholds | None = None,
470
+ batch_size: int = 16,
471
+ on_trace: Callable | None = None):
472
+ self._cfg = config
473
+ # Legacy kwargs override the config so existing call-sites still work.
474
+ if on_diagnosis is not None:
475
+ self._cfg = DebugAIConfig(**{
476
+ **self._cfg.__dict__,
477
+ "on_diagnosis": on_diagnosis,
478
+ })
479
+ if on_trace is not None:
480
+ self._cfg = DebugAIConfig(**{
481
+ **self._cfg.__dict__,
482
+ "on_trace": on_trace,
483
+ })
484
+ if thresholds is not None:
485
+ self._cfg = DebugAIConfig(**{
486
+ **self._cfg.__dict__,
487
+ "thresholds": thresholds,
488
+ })
489
+ if explain_with_llm:
490
+ self._cfg = DebugAIConfig(**{
491
+ **self._cfg.__dict__,
492
+ "enable_explain": True,
493
+ })
494
+ self._q: queue.Queue = queue.Queue(maxsize=config.max_queue_depth)
495
+ self.recent: list[dict] = []
496
+ self.recent_traces: list[dict] = []
497
+ self._lock = threading.Lock()
498
+ self._thread = threading.Thread(target=self._run, daemon=True)
499
+ self._thread.start()
500
+ atexit.register(self.flush)
501
+
502
+ def submit(self, job: _Job) -> None:
503
+ try:
504
+ self._q.put_nowait(job)
505
+ except queue.Full:
506
+ log.debug("diagnosis queue full (depth %d); dropping job", self._cfg.max_queue_depth)
507
+
508
+ def _run(self) -> None:
509
+ while True:
510
+ job = self._q.get()
511
+ if job is None: # shutdown sentinel
512
+ self._q.task_done()
513
+ break
514
+ try:
515
+ self._process(job)
516
+ except Exception as e: # never let the worker die
517
+ log.warning("diagnosis failed: %s", e)
518
+ finally:
519
+ self._q.task_done()
520
+
521
+ def _process(self, job: _Job) -> None:
522
+ cfg = self._cfg
523
+ r = job.retrieval or {}
524
+
525
+ result = None
526
+ if cfg.enable_diagnosis:
527
+ result = analyze(
528
+ prompt=job.captured.user_prompt,
529
+ output=job.output,
530
+ system_prompt=job.captured.system_prompt,
531
+ chunks=r.get("retrieved_chunks"),
532
+ similarity_scores=r.get("similarity_scores"),
533
+ retrieval_query=r.get("retrieval_query"),
534
+ model_name=job.captured.model_name,
535
+ temperature=job.captured.temperature,
536
+ max_tokens=job.captured.max_tokens,
537
+ context_window=job.context_window,
538
+ latency_ms=job.latency_ms,
539
+ token_usage=job.usage,
540
+ thresholds=cfg.thresholds,
541
+ explain_with_llm=cfg.enable_explain,
542
+ lazy=cfg.lazy,
543
+ )
544
+ with self._lock:
545
+ self.recent.append(result)
546
+ del self.recent[:-200]
547
+ if cfg.on_diagnosis is not None:
548
+ cfg.on_diagnosis(result)
549
+
550
+ if cfg.enable_traces:
551
+ trace = self._build_trace(job, result or {})
552
+ with self._lock:
553
+ self.recent_traces.append(trace.to_dict())
554
+ del self.recent_traces[:-200]
555
+ if cfg.on_trace is not None:
556
+ cfg.on_trace(trace)
557
+ if cfg.sink_url:
558
+ _http_post_trace(trace.to_dict(), cfg.sink_url, cfg.sink_token)
559
+
560
+ # Update global MetricsLedger after each request.
561
+ if cfg.track_tokens or cfg.track_cost or cfg.track_latency:
562
+ from debugai.tracing import estimate_cost
563
+ usage = job.usage or {}
564
+ prompt_t = usage.get("prompt", 0)
565
+ compl_t = usage.get("completion", 0)
566
+ cost = estimate_cost(job.captured.model_name, prompt_t, compl_t) if cfg.track_cost else 0.0
567
+ failed = bool(result and not result.get("healthy"))
568
+ _global_metrics.record(
569
+ model=job.captured.model_name or "unknown",
570
+ prompt_tokens=prompt_t if cfg.track_tokens else 0,
571
+ completion_tokens=compl_t if cfg.track_tokens else 0,
572
+ cost_usd=cost,
573
+ latency_ms=float(job.latency_ms or 0),
574
+ failed=failed,
575
+ from_cache=job.from_cache,
576
+ )
577
+ if cfg.on_metrics is not None:
578
+ cfg.on_metrics({
579
+ "model": job.captured.model_name,
580
+ "prompt_tokens": prompt_t,
581
+ "completion_tokens": compl_t,
582
+ "cost_usd": cost,
583
+ "latency_ms": job.latency_ms,
584
+ "failed": failed,
585
+ })
586
+
587
+ # B8 — latency SLA alert.
588
+ if cfg.latency_sla_ms and job.latency_ms > cfg.latency_sla_ms:
589
+ if cfg.on_sla_breach:
590
+ try:
591
+ cfg.on_sla_breach({"model": job.captured.model_name,
592
+ "latency_ms": job.latency_ms,
593
+ "threshold_ms": cfg.latency_sla_ms,
594
+ "correlation_id": job.correlation_id})
595
+ except Exception as e:
596
+ log.warning("on_sla_breach callback failed: %s", e)
597
+
598
+ # B2: JSON schema validation (runs regardless of other diagnosis).
599
+ if cfg.response_schema and job.output:
600
+ violations = _validate_json_schema(job.output, cfg.response_schema)
601
+ if violations:
602
+ if cfg.on_schema_violation:
603
+ try:
604
+ cfg.on_schema_violation(job.output, violations)
605
+ except Exception as e:
606
+ log.warning("on_schema_violation callback failed: %s", e)
607
+
608
+ def _build_trace(self, job: _Job, result: dict) -> Trace:
609
+ """Turn a captured call + its diagnosis into an observability trace."""
610
+ t = Trace(name="llm.call", session_id=job.session_id, model=job.captured.model_name)
611
+ # B7 — correlation ID.
612
+ if job.correlation_id:
613
+ t.metadata["correlation_id"] = job.correlation_id
614
+ # B6 — retry count.
615
+ if job.retry_count:
616
+ t.metadata["retry_count"] = job.retry_count
617
+ # B10 — prompt version hash.
618
+ if job.captured.system_prompt:
619
+ t.metadata["prompt_hash"] = _prompt_hash(job.captured.system_prompt)
620
+
621
+ r = job.retrieval or {}
622
+ if r.get("retrieved_chunks"):
623
+ sp = Span(name="retrieval", kind="retrieval")
624
+ sp.input = r.get("retrieval_query")
625
+ sp.output = r.get("retrieved_chunks")
626
+ sp.metadata = {"similarity_scores": r.get("similarity_scores")}
627
+ sp.end_ms = sp.start_ms
628
+ t.add_span(sp)
629
+ gen = Span(name="generation", kind="generation", model=job.captured.model_name)
630
+ gen.input = job.captured.user_prompt
631
+ gen.output = job.output
632
+ gen.set_usage(prompt=(job.usage or {}).get("prompt", 0),
633
+ completion=(job.usage or {}).get("completion", 0))
634
+ gen.end_ms = gen.start_ms + float(job.latency_ms or 0)
635
+ t.add_span(gen)
636
+ # B3 — tool call child spans.
637
+ for tc in (job.tool_calls or []):
638
+ tool_span = Span(name=tc.get("name", "tool"), kind="tool")
639
+ tool_span.input = tc.get("input")
640
+ tool_span.metadata = {"id": tc.get("id", "")}
641
+ tool_span.end_ms = tool_span.start_ms
642
+ t.add_span(tool_span)
643
+ t.diagnosis = result
644
+ t.scores = scores_from_diagnosis(result)
645
+ t.status = status_from_diagnosis(result)
646
+ t.end()
647
+ return t
648
+
649
+ def flush(self) -> None:
650
+ """Block until all queued jobs are processed (used in tests / shutdown)."""
651
+ self._q.join()
652
+
653
+
654
+ # --------------------------------------------------------------------------- #
655
+ # Transparent proxy
656
+ # --------------------------------------------------------------------------- #
657
+ class _PathProxy:
658
+ """Forwards attribute access to ``target`` until the configured create path
659
+ is reached, where it returns the instrumented callable instead."""
660
+
661
+ def __init__(self, target: Any, path: tuple[str, ...], instrumented: Callable):
662
+ object.__setattr__(self, "_t", target)
663
+ object.__setattr__(self, "_path", path)
664
+ object.__setattr__(self, "_instrumented", instrumented)
665
+
666
+ def __getattr__(self, name: str) -> Any:
667
+ path = object.__getattribute__(self, "_path")
668
+ target = object.__getattribute__(self, "_t")
669
+ attr = getattr(target, name)
670
+ if path and name == path[0]:
671
+ if len(path) == 1: # this is the terminal create() method
672
+ return object.__getattribute__(self, "_instrumented")
673
+ return _PathProxy(attr, path[1:], object.__getattribute__(self, "_instrumented"))
674
+ return attr # forward everything else untouched
675
+
676
+ def __setattr__(self, name: str, value: Any) -> None:
677
+ setattr(object.__getattribute__(self, "_t"), name, value)
678
+
679
+
680
+ def _http_post_trace(trace_dict: dict, url: str, token: str | None) -> None:
681
+ """Fire-and-forget HTTP POST for the sink_url option (stdlib only)."""
682
+ import json as _json
683
+ import urllib.request
684
+ headers = {"Content-Type": "application/json"}
685
+ if token:
686
+ headers["X-API-Key"] = token
687
+ req = urllib.request.Request(url, data=_json.dumps(trace_dict).encode(),
688
+ headers=headers, method="POST")
689
+ try:
690
+ urllib.request.urlopen(req, timeout=5.0).read()
691
+ except Exception as e: # pragma: no cover - network dependent
692
+ log.debug("sink_url POST failed (%s)", e)
693
+
694
+
695
+ def wrap_llm(
696
+ client: Any,
697
+ *,
698
+ config: "DebugAIConfig | None" = None,
699
+ # Legacy individual kwargs — still work for backward compatibility.
700
+ on_diagnosis: Callable | None = None,
701
+ on_trace: Callable | None = None,
702
+ session_id: str | None = None,
703
+ explain_with_llm: bool = False,
704
+ context_window: int | None = None,
705
+ thresholds: Thresholds = DEFAULT_THRESHOLDS,
706
+ sample_rate: float = 1.0,
707
+ ) -> Any:
708
+ """Wrap an OpenAI/Anthropic client so every ``create`` call is auto-diagnosed,
709
+ auto-traced, and contributes to the metrics ledger.
710
+
711
+ Drop-in replacement: call sites don't change. Pass a ``DebugAIConfig`` for full
712
+ control, or use the individual legacy kwargs for backward compatibility.
713
+ Work runs in a background thread — the wrapped call adds only microseconds.
714
+ """
715
+ # Build effective config: start from the provided config (or default),
716
+ # then layer any explicit legacy kwargs on top.
717
+ effective = config or DebugAIConfig(
718
+ on_diagnosis=on_diagnosis,
719
+ on_trace=on_trace,
720
+ session_id=session_id,
721
+ enable_explain=explain_with_llm,
722
+ thresholds=thresholds,
723
+ sample_rate=sample_rate,
724
+ )
725
+ # If individual kwargs provided alongside a config, they take precedence.
726
+ if config is not None:
727
+ overrides = {}
728
+ if on_diagnosis is not None: overrides["on_diagnosis"] = on_diagnosis
729
+ if on_trace is not None: overrides["on_trace"] = on_trace
730
+ if session_id is not None: overrides["session_id"] = session_id
731
+ if explain_with_llm: overrides["enable_explain"] = True
732
+ if thresholds is not DEFAULT_THRESHOLDS: overrides["thresholds"] = thresholds
733
+ if sample_rate != 1.0: overrides["sample_rate"] = sample_rate
734
+ if overrides:
735
+ import dataclasses
736
+ effective = dataclasses.replace(effective, **overrides)
737
+
738
+ adapter = _detect_adapter(client)
739
+ diagnoser = _Diagnoser(effective)
740
+ real_create = _resolve(client, adapter.create_path)
741
+ _rate = effective.sample_rate
742
+ counter = {"n": 0}
743
+
744
+ def instrumented(*args, **kwargs):
745
+ # Pop DebugAI-only kwargs so they never reach the real SDK.
746
+ chunks = kwargs.pop("debugai_chunks", None)
747
+ scores = kwargs.pop("debugai_similarity_scores", None)
748
+ rquery = kwargs.pop("debugai_retrieval_query", None)
749
+
750
+ captured = adapter.from_request(kwargs)
751
+ start = time.perf_counter()
752
+ resp = real_create(*args, **kwargs)
753
+ latency_ms = int((time.perf_counter() - start) * 1000)
754
+
755
+ counter["n"] += 1
756
+ sampled = _rate >= 1.0 or (counter["n"] * _rate) % 1 < _rate
757
+ if sampled:
758
+ output, usage = adapter.from_response(resp)
759
+ tool_calls = (getattr(adapter, "extract_tool_calls", lambda r: [])(resp))
760
+ retrieval = _retrieval.get()
761
+ if chunks is not None:
762
+ retrieval = {
763
+ "retrieved_chunks": list(chunks),
764
+ "similarity_scores": list(scores or []),
765
+ "retrieval_query": rquery,
766
+ }
767
+ diagnoser.submit(_Job(
768
+ captured=captured, output=output, usage=usage,
769
+ latency_ms=latency_ms, retrieval=retrieval,
770
+ context_window=context_window,
771
+ session_id=_session.get() or effective.session_id,
772
+ tool_calls=tool_calls,
773
+ correlation_id=uuid.uuid4().hex[:16],
774
+ ))
775
+ return resp
776
+
777
+ proxy = _PathProxy(client, adapter.create_path, instrumented)
778
+ object.__setattr__(proxy, "debugai", diagnoser)
779
+ return proxy
780
+
781
+
782
+ def _resolve(obj: Any, path: tuple[str, ...]) -> Any:
783
+ for seg in path:
784
+ obj = getattr(obj, seg)
785
+ return obj
786
+
787
+
788
+ # --------------------------------------------------------------------------- #
789
+ # CompletionResponse — normalized thin wrapper around any provider's response
790
+ # --------------------------------------------------------------------------- #
791
+ class _UsageInfo:
792
+ def __init__(self, prompt: int, completion: int):
793
+ self.prompt = prompt
794
+ self.completion = completion
795
+ self.total = prompt + completion
796
+
797
+ def __repr__(self):
798
+ return f"Usage(prompt={self.prompt}, completion={self.completion})"
799
+
800
+
801
+ class CompletionResponse:
802
+ """Normalized response from ``debugai.completion()`` / ``debugai.acompletion()``.
803
+
804
+ Attributes:
805
+ text — extracted output text (works regardless of provider)
806
+ usage — token counts (prompt / completion / total)
807
+ cost_usd — estimated cost from the built-in pricing table
808
+ latency_ms — end-to-end measured latency
809
+ model — model name as returned by the provider
810
+ raw — the original native provider response (pass-through)
811
+ """
812
+
813
+ def __init__(self, text: str, usage: _UsageInfo, cost_usd: float,
814
+ latency_ms: int, model: str, raw: Any):
815
+ self.text = text
816
+ self.usage = usage
817
+ self.cost_usd = cost_usd
818
+ self.latency_ms = latency_ms
819
+ self.model = model
820
+ self.raw = raw
821
+ self.fallback_attempts: list[tuple[str, str]] = []
822
+ self.correlation_id: str | None = None # B7
823
+ self.from_cache: bool = False # B5
824
+ self.retry_count: int = 0 # B6
825
+
826
+ def __repr__(self):
827
+ return (f"CompletionResponse(model={self.model!r}, "
828
+ f"tokens={self.usage.total}, cost=${self.cost_usd:.6f})")
829
+
830
+
831
+ # --------------------------------------------------------------------------- #
832
+ # Provider routing — maps model name prefix → (client factory, adapter)
833
+ # --------------------------------------------------------------------------- #
834
+ _PROVIDER_REGISTRY: list[tuple[Callable, "type[_OpenAIAdapter]", Callable]] = []
835
+ # Each entry: (matches_model_fn, adapter_class, client_factory_fn)
836
+
837
+
838
+ def _default_providers():
839
+ """Backward-compat shim used by tests that monkeypatch this function.
840
+ Real routing now goes through the PROVIDER_ROUTES table in providers.py."""
841
+ from debugai.providers import PROVIDER_ROUTES, _ADAPTER_MAP
842
+
843
+ entries = []
844
+ for route in PROVIDER_ROUTES:
845
+ r = route # capture for closure
846
+ adapter_cls = _ADAPTER_MAP.get(r.adapter, _OpenAICompatAdapter)
847
+ entries.append((
848
+ lambda m, pfx=r.prefix: m.lower().startswith(pfx.lower()),
849
+ adapter_cls,
850
+ lambda r=r: None, # unused in new path
851
+ ))
852
+ return entries
853
+
854
+
855
+ def register_provider(
856
+ matches: Callable[[str], bool],
857
+ adapter,
858
+ client_factory: Callable,
859
+ ) -> None:
860
+ """Register a custom provider so ``debugai.completion()`` can route to it.
861
+
862
+ debugai.register_provider(
863
+ matches=lambda m: m.startswith("my-model"),
864
+ adapter=MyAdapter,
865
+ client_factory=lambda: MyClient(...),
866
+ )
867
+ """
868
+ _PROVIDER_REGISTRY.insert(0, (matches, adapter, client_factory))
869
+
870
+
871
+ def _route_provider(model: str, config: "DebugAIConfig | None" = None):
872
+ """Return (adapter_class, client) for a model name.
873
+
874
+ Checks, in order:
875
+ 1. User-registered entries via register_provider()
876
+ 2. The built-in PROVIDER_ROUTES table in providers.py
877
+ """
878
+ # 1. User-registered overrides.
879
+ for matches, adapter, factory in _PROVIDER_REGISTRY:
880
+ if matches(model):
881
+ return adapter, factory()
882
+
883
+ # 2. Built-in routing table.
884
+ from debugai.providers import make_client, route_for
885
+ route = route_for(model)
886
+ if route is None:
887
+ raise ValueError(
888
+ f"No provider registered for model {model!r}. "
889
+ "Supported prefixes: gpt-, claude-, gemini-, groq/, together/, "
890
+ "mistral/, openrouter/, azure/, cohere/, ollama/, qwen*, llama*, "
891
+ "phi*, deepseek*, gemma*, mixtral*. "
892
+ "Or register your own: debugai.register_provider(...)."
893
+ )
894
+ from debugai.providers import _ADAPTER_MAP
895
+ adapter_cls = _ADAPTER_MAP.get(route.adapter, _OpenAICompatAdapter)
896
+ client = make_client(route, config or DebugAIConfig())
897
+ return adapter_cls, client
898
+
899
+
900
+ # Module-level default config — used by completion() when no config is passed.
901
+ _default_config: "DebugAIConfig | None" = None
902
+
903
+
904
+ def set_default_config(config: "DebugAIConfig") -> None:
905
+ """Set a module-level default config for all completion() calls."""
906
+ global _default_config
907
+ _default_config = config
908
+
909
+
910
+ def completion(model: str, messages: list, *, config: "DebugAIConfig | None" = None,
911
+ **kwargs) -> CompletionResponse:
912
+ """Universal LLM completion — works with any registered provider.
913
+
914
+ import debugai
915
+ resp = debugai.completion(model="gpt-4o", messages=[{"role":"user","content":"hi"}])
916
+ print(resp.text, resp.cost_usd, resp.latency_ms)
917
+ """
918
+ cfg = config or _default_config or DebugAIConfig()
919
+
920
+ # B4 — budget check BEFORE the call.
921
+ if cfg.budget_usd is not None and _global_metrics.cost_usd >= cfg.budget_usd:
922
+ if cfg.on_budget_exceeded:
923
+ cfg.on_budget_exceeded(_global_metrics.cost_usd)
924
+ # If callback didn't raise, we raise for safety.
925
+ raise BudgetExceededError(
926
+ f"Budget ${cfg.budget_usd:.4f} exceeded "
927
+ f"(spent ${_global_metrics.cost_usd:.4f})")
928
+
929
+ # B7 — correlation ID (unique per completion() call, threads through fallbacks).
930
+ correlation_id = uuid.uuid4().hex[:16]
931
+
932
+ # B5 — cache lookup.
933
+ cache_key = _cache_key(model, messages) if cfg.cache_ttl_seconds else None
934
+ if cache_key:
935
+ cached = _response_cache.get(cache_key)
936
+ if cached is not None:
937
+ log.debug("cache hit for model=%s", model)
938
+ cached.correlation_id = correlation_id
939
+ cached.from_cache = True
940
+ return cached
941
+
942
+ adapter_cls, client = _route_provider(model, cfg)
943
+
944
+ # Check for streaming — delegate to a different path if requested.
945
+ if kwargs.get("stream"):
946
+ return _stream_completion(model, messages, adapter_cls, client, cfg, kwargs)
947
+
948
+ # Fallback loop: try the primary model, then each fallback on error.
949
+ _fallbacks = list(cfg.fallbacks or [])
950
+ _attempted: list[tuple[str, str]] = [] # (model_name, error)
951
+ _model, _adapter, _client = model, adapter_cls, client
952
+ while True:
953
+ try:
954
+ resp = _call_provider(_model, messages, _adapter, _client, kwargs,
955
+ max_retries=cfg.max_retries,
956
+ backoff=cfg.retry_backoff_seconds)
957
+ break
958
+ except Exception as e:
959
+ _attempted.append((_model, str(e)))
960
+ log.warning("completion: %s failed (%s)", _model, e)
961
+ if not _fallbacks:
962
+ raise
963
+ fallback_model = _fallbacks.pop(0)
964
+ log.info("completion: trying fallback %s", fallback_model)
965
+ _adapter, _client = _route_provider(fallback_model, cfg)
966
+ _model = fallback_model
967
+
968
+ latency_ms = int(resp._latency_ms)
969
+ retry_count = getattr(resp, "_retry_count", 0)
970
+ text, usage_dict = _adapter.from_response(resp._raw)
971
+ # B3 — extract tool calls for the trace.
972
+ tool_calls = (getattr(_adapter, "extract_tool_calls", lambda r: [])(resp._raw))
973
+ from debugai.tracing import estimate_cost
974
+ usage = _UsageInfo(usage_dict.get("prompt", 0), usage_dict.get("completion", 0))
975
+ cost = estimate_cost(_model, usage.prompt, usage.completion, cfg.model_prices)
976
+ captured = _adapter.from_request({"model": _model, "messages": messages, **kwargs})
977
+
978
+ # Background observability.
979
+ if cfg.enable_diagnosis or cfg.enable_traces:
980
+ diagnoser = _Diagnoser(cfg)
981
+ diagnoser.submit(_Job(
982
+ captured=captured, output=text, usage=usage_dict,
983
+ latency_ms=latency_ms, retrieval=_retrieval.get(),
984
+ context_window=None,
985
+ session_id=_session.get() or cfg.session_id,
986
+ tool_calls=tool_calls,
987
+ correlation_id=correlation_id,
988
+ retry_count=retry_count,
989
+ ))
990
+
991
+ result = CompletionResponse(text=text, usage=usage, cost_usd=cost,
992
+ latency_ms=latency_ms, model=_model, raw=resp._raw)
993
+ if _attempted:
994
+ result.fallback_attempts = _attempted
995
+ result.correlation_id = correlation_id
996
+ result.retry_count = retry_count
997
+
998
+ # B5 — populate cache.
999
+ if cache_key:
1000
+ _response_cache.set(cache_key, result, cfg.cache_ttl_seconds)
1001
+
1002
+ return result
1003
+
1004
+
1005
+ class _RawResp:
1006
+ """Carries raw response, latency, and retry count out of _call_provider."""
1007
+ def __init__(self, raw, latency_ms: float):
1008
+ self._raw = raw
1009
+ self._latency_ms = latency_ms
1010
+ self._retry_count: int = 0
1011
+
1012
+
1013
+ def _call_provider(model: str, messages: list, adapter_cls, client, kwargs: dict,
1014
+ max_retries: int = 2, backoff: float = 1.0) -> "_RawResp":
1015
+ """Single provider call with retry (B6). Returns _RawResp(raw, latency_ms, retry_count)."""
1016
+ import time as _t
1017
+ kw = dict(kwargs)
1018
+
1019
+ _RETRYABLE = (408, 429, 500, 502, 503, 504)
1020
+
1021
+ for attempt in range(max(1, max_retries + 1)):
1022
+ start = _t.perf_counter()
1023
+ try:
1024
+ if adapter_cls is _AnthropicAdapter:
1025
+ kw_a = dict(kw)
1026
+ raw = _resolve(client, adapter_cls.create_path)(
1027
+ model=model, messages=messages,
1028
+ max_tokens=kw_a.pop("max_tokens", 1024), **kw_a)
1029
+ else:
1030
+ raw = _resolve(client, adapter_cls.create_path)(
1031
+ model=model, messages=messages, **kw)
1032
+ latency = (_t.perf_counter() - start) * 1000
1033
+ resp = _RawResp(raw, latency)
1034
+ resp._retry_count = attempt
1035
+ return resp
1036
+ except Exception as e:
1037
+ # Check if retryable.
1038
+ status = getattr(e, "status_code", getattr(e, "status", None))
1039
+ retry_after = None
1040
+ try:
1041
+ retry_after = float(e.response.headers.get("Retry-After", 0))
1042
+ except Exception:
1043
+ pass
1044
+ if status in _RETRYABLE and attempt < max_retries:
1045
+ wait = retry_after or (backoff * (2 ** attempt))
1046
+ log.warning("provider %s failed with %s (attempt %d/%d), retrying in %.1fs",
1047
+ model, status, attempt + 1, max_retries, wait)
1048
+ _t.sleep(wait)
1049
+ else:
1050
+ raise
1051
+
1052
+
1053
+ def _stream_completion(model, messages, adapter_cls, client, cfg, kwargs):
1054
+ """Sync streaming: wrap the iterator so chunks pass through + diagnose at end."""
1055
+ create = _resolve(client, adapter_cls.create_path)
1056
+ if adapter_cls is _AnthropicAdapter:
1057
+ stream = create(model=model, messages=messages,
1058
+ max_tokens=kwargs.pop("max_tokens", 1024), stream=True, **kwargs)
1059
+ else:
1060
+ stream = create(model=model, messages=messages, stream=True, **kwargs)
1061
+ return _StreamWrapper(stream, model, adapter_cls, cfg)
1062
+
1063
+
1064
+ async def acompletion(model: str, messages: list, *, config: "DebugAIConfig | None" = None,
1065
+ **kwargs) -> CompletionResponse:
1066
+ """Async variant of ``completion()``. Requires an async provider client."""
1067
+ import asyncio
1068
+ cfg = config or _default_config or DebugAIConfig()
1069
+ adapter_cls, _ = _route_provider(model, cfg)
1070
+
1071
+ # Build an async client.
1072
+ if adapter_cls is _AnthropicAdapter:
1073
+ try:
1074
+ from anthropic import AsyncAnthropic
1075
+ client = AsyncAnthropic(timeout=60.0)
1076
+ acreate = client.messages.create
1077
+ except Exception:
1078
+ raise ImportError("anthropic[async] required for acompletion with Anthropic models.")
1079
+ else:
1080
+ try:
1081
+ from openai import AsyncOpenAI
1082
+ client = AsyncOpenAI(timeout=60.0)
1083
+ acreate = client.chat.completions.create
1084
+ except Exception:
1085
+ raise ImportError("openai package required for acompletion with OpenAI models.")
1086
+
1087
+ captured = adapter_cls.from_request({"model": model, "messages": messages, **kwargs})
1088
+ start = time.perf_counter()
1089
+ if adapter_cls is _AnthropicAdapter:
1090
+ resp = await acreate(model=model, messages=messages,
1091
+ max_tokens=kwargs.pop("max_tokens", 1024), **kwargs)
1092
+ else:
1093
+ resp = await acreate(model=model, messages=messages, **kwargs)
1094
+ latency_ms = int((time.perf_counter() - start) * 1000)
1095
+
1096
+ text, usage_dict = adapter_cls.from_response(resp)
1097
+ from debugai.tracing import estimate_cost
1098
+ usage = _UsageInfo(usage_dict.get("prompt", 0), usage_dict.get("completion", 0))
1099
+ cost = estimate_cost(model, usage.prompt, usage.completion)
1100
+
1101
+ if cfg.enable_diagnosis or cfg.enable_traces:
1102
+ diagnoser = _Diagnoser(cfg)
1103
+ diagnoser.submit(_Job(
1104
+ captured=captured, output=text, usage=usage_dict,
1105
+ latency_ms=latency_ms, retrieval=_retrieval.get(),
1106
+ context_window=None,
1107
+ session_id=_session.get() or cfg.session_id,
1108
+ ))
1109
+
1110
+ return CompletionResponse(text=text, usage=usage, cost_usd=cost,
1111
+ latency_ms=latency_ms, model=model, raw=resp)
1112
+
1113
+
1114
+ # --------------------------------------------------------------------------- #
1115
+ # Streaming wrapper
1116
+ # --------------------------------------------------------------------------- #
1117
+ class _StreamWrapper:
1118
+ """Passes streaming chunks through unchanged, accumulates text, and fires
1119
+ a background diagnosis job after the last chunk is consumed."""
1120
+
1121
+ def __init__(self, stream, model: str, adapter_cls, cfg: "DebugAIConfig"):
1122
+ self._stream = stream
1123
+ self._model = model
1124
+ self._adapter_cls = adapter_cls
1125
+ self._cfg = cfg
1126
+ self._buffer: list[str] = []
1127
+ self._usage: dict = {}
1128
+
1129
+ def _extract_chunk_text(self, chunk) -> str:
1130
+ # OpenAI delta pattern
1131
+ try:
1132
+ return chunk.choices[0].delta.content or ""
1133
+ except Exception:
1134
+ pass
1135
+ # Anthropic content_block_delta pattern
1136
+ try:
1137
+ if getattr(chunk, "type", None) == "content_block_delta":
1138
+ return getattr(chunk.delta, "text", "") or ""
1139
+ except Exception:
1140
+ pass
1141
+ return ""
1142
+
1143
+ def __iter__(self):
1144
+ return self
1145
+
1146
+ def __next__(self):
1147
+ try:
1148
+ chunk = next(self._stream)
1149
+ self._buffer.append(self._extract_chunk_text(chunk))
1150
+ return chunk
1151
+ except StopIteration:
1152
+ self._finalize()
1153
+ raise
1154
+
1155
+ def _finalize(self):
1156
+ if not (self._cfg.enable_diagnosis or self._cfg.enable_traces):
1157
+ return
1158
+ text = "".join(self._buffer)
1159
+ from debugai.signals import _approx_token_count
1160
+ completion_tokens = _approx_token_count(text)
1161
+ diagnoser = _Diagnoser(self._cfg)
1162
+ diagnoser.submit(_Job(
1163
+ captured=_Captured(user_prompt="(streamed)", model_name=self._model),
1164
+ output=text,
1165
+ usage={"prompt": 0, "completion": completion_tokens,
1166
+ "total": completion_tokens},
1167
+ latency_ms=0,
1168
+ retrieval=_retrieval.get(),
1169
+ context_window=None,
1170
+ session_id=_session.get() or self._cfg.session_id,
1171
+ ))
1172
+
1173
+ def __enter__(self):
1174
+ return self
1175
+
1176
+ def __exit__(self, *args):
1177
+ pass
1178
+
1179
+
1180
+ def awrap_llm(
1181
+ async_client: Any,
1182
+ *,
1183
+ config: "DebugAIConfig | None" = None,
1184
+ on_diagnosis: Callable | None = None,
1185
+ on_trace: Callable | None = None,
1186
+ session_id: str | None = None,
1187
+ context_window: int | None = None,
1188
+ thresholds: Thresholds = DEFAULT_THRESHOLDS,
1189
+ sample_rate: float = 1.0,
1190
+ ) -> Any:
1191
+ """Wrap an async OpenAI/Anthropic client (``AsyncOpenAI``, ``AsyncAnthropic``).
1192
+
1193
+ from openai import AsyncOpenAI
1194
+ client = awrap_llm(AsyncOpenAI(), config=DebugAIConfig(sample_rate=0.5))
1195
+ resp = await client.chat.completions.create(model="gpt-4o", messages=[...])
1196
+ """
1197
+ effective = config or DebugAIConfig(
1198
+ on_diagnosis=on_diagnosis,
1199
+ on_trace=on_trace,
1200
+ session_id=session_id,
1201
+ thresholds=thresholds,
1202
+ sample_rate=sample_rate,
1203
+ )
1204
+ adapter = _detect_adapter(async_client)
1205
+ diagnoser = _Diagnoser(effective)
1206
+ real_create = _resolve(async_client, adapter.create_path)
1207
+ _rate = effective.sample_rate
1208
+ counter = {"n": 0}
1209
+
1210
+ async def async_instrumented(*args, **kwargs):
1211
+ chunks = kwargs.pop("debugai_chunks", None)
1212
+ scores = kwargs.pop("debugai_similarity_scores", None)
1213
+ rquery = kwargs.pop("debugai_retrieval_query", None)
1214
+
1215
+ captured = adapter.from_request(kwargs)
1216
+ start = time.perf_counter()
1217
+ resp = await real_create(*args, **kwargs)
1218
+ latency_ms = int((time.perf_counter() - start) * 1000)
1219
+
1220
+ counter["n"] += 1
1221
+ sampled = _rate >= 1.0 or (counter["n"] * _rate) % 1 < _rate
1222
+ if sampled:
1223
+ output, usage = adapter.from_response(resp)
1224
+ tool_calls = (getattr(adapter, "extract_tool_calls", lambda r: [])(resp))
1225
+ retrieval = _retrieval.get()
1226
+ if chunks is not None:
1227
+ retrieval = {
1228
+ "retrieved_chunks": list(chunks),
1229
+ "similarity_scores": list(scores or []),
1230
+ "retrieval_query": rquery,
1231
+ }
1232
+ diagnoser.submit(_Job(
1233
+ captured=captured, output=output, usage=usage,
1234
+ latency_ms=latency_ms, retrieval=retrieval,
1235
+ context_window=context_window,
1236
+ session_id=_session.get() or effective.session_id,
1237
+ tool_calls=tool_calls,
1238
+ correlation_id=uuid.uuid4().hex[:16],
1239
+ ))
1240
+ return resp
1241
+
1242
+ proxy = _PathProxy(async_client, adapter.create_path, async_instrumented)
1243
+ object.__setattr__(proxy, "debugai", diagnoser)
1244
+ return proxy
1245
+
1246
+
1247
+ def http_trace_sink(url: str, token: str | None = None, timeout: float = 5.0) -> Callable:
1248
+ """An ``on_trace`` sink that POSTs each trace to a DebugAI server.
1249
+
1250
+ client = wrap_llm(OpenAI(), on_trace=http_trace_sink(
1251
+ "http://localhost:8000/api/traces", token="dbg_..."))
1252
+
1253
+ ``token`` is a per-account API token (Account → API tokens). Failures are
1254
+ logged, never raised, so tracing never breaks the app. Uses stdlib only.
1255
+ """
1256
+ import json as _json
1257
+ import urllib.request
1258
+
1259
+ def sink(trace) -> None:
1260
+ payload = trace.to_dict() if hasattr(trace, "to_dict") else trace
1261
+ headers = {"Content-Type": "application/json"}
1262
+ if token:
1263
+ headers["X-API-Key"] = token
1264
+ req = urllib.request.Request(url, data=_json.dumps(payload).encode(),
1265
+ headers=headers, method="POST")
1266
+ try:
1267
+ urllib.request.urlopen(req, timeout=timeout).read()
1268
+ except Exception as e: # pragma: no cover - network dependent
1269
+ log.warning("http_trace_sink: failed to post trace (%s)", e)
1270
+
1271
+ return sink