debugerai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
debugai/detectors.py ADDED
@@ -0,0 +1,206 @@
1
+ """Layer 2 — Failure Classification Rules (Architecture §5).
2
+
3
+ Five deterministic detectors. Each takes the signal vector + thresholds and
4
+ returns a DetectorResult. All detectors run (§5.2); results are ranked by
5
+ confidence into primary + secondary. Gate patterns prevent nonsensical
6
+ multi-classification.
7
+
8
+ Detector bases are tuned to the doc's worked example: Scenario A (similarity
9
+ 0.41, entity 0.17, overlap 0.12) → retrieval failure 0.95 = 0.70 base + 0.15
10
+ (entity) + 0.10 (overlap).
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from debugai.schema import CaptureRecord
18
+ from debugai.signals import SignalVector
19
+ from debugai.thresholds import Thresholds
20
+
21
+ # Failure type identifiers (also used by the fix-agent registry later).
22
+ CONTEXT_OVERFLOW = "context_overflow"
23
+ RETRIEVAL_FAILURE = "retrieval_failure"
24
+ ENTITY_GAP = "entity_gap"
25
+ HALLUCINATION = "hallucination"
26
+ PROMPT_BRITTLENESS = "prompt_brittleness"
27
+
28
+ SEVERITY = {
29
+ CONTEXT_OVERFLOW: "critical",
30
+ RETRIEVAL_FAILURE: "critical",
31
+ ENTITY_GAP: "warning",
32
+ HALLUCINATION: "critical",
33
+ PROMPT_BRITTLENESS: "warning",
34
+ }
35
+
36
+ _GATED_BASE = 0.70 # base confidence for a critical gated detector that fires
37
+
38
+
39
+ @dataclass
40
+ class DetectorResult:
41
+ failure: str
42
+ fired: bool
43
+ confidence: float
44
+ severity: str
45
+ root_cause: str = ""
46
+ fix: str = "" # deterministic fix hint (Layer-3 fallback)
47
+ evidence: dict = field(default_factory=dict)
48
+
49
+ def clamp(self) -> "DetectorResult":
50
+ self.confidence = round(max(0.0, min(self.confidence, 1.0)), 4)
51
+ return self
52
+
53
+
54
+ # --------------------------------------------------------------------------- #
55
+ # 1. Context overflow — Critical | checked 1st
56
+ # --------------------------------------------------------------------------- #
57
+ def detect_context_overflow(s: SignalVector, rec: CaptureRecord, t: Thresholds) -> DetectorResult:
58
+ fired = s.context_ratio > t.context_length_ratio_max
59
+ conf = _GATED_BASE
60
+ if s.token_ratio > t.token_usage_high:
61
+ conf += 0.15
62
+ if s.latency_ms > t.latency_high_ms:
63
+ conf += 0.10
64
+ if s.overlap < t.overlap_low:
65
+ conf += 0.10
66
+ return DetectorResult(
67
+ failure=CONTEXT_OVERFLOW,
68
+ fired=fired,
69
+ confidence=conf,
70
+ severity=SEVERITY[CONTEXT_OVERFLOW],
71
+ root_cause=(
72
+ f"Prompt fills {s.context_ratio:.0%} of the context window "
73
+ f"(> {t.context_length_ratio_max:.0%}); content is likely truncated."
74
+ ),
75
+ fix="Reduce retrieved chunks to the top-N most relevant, summarise prior "
76
+ "conversation history, or move to a larger-context model.",
77
+ evidence={"context_ratio": s.context_ratio, "token_ratio": s.token_ratio,
78
+ "latency_ms": s.latency_ms},
79
+ ).clamp()
80
+
81
+
82
+ # --------------------------------------------------------------------------- #
83
+ # 2. Retrieval failure — Critical | checked 2nd
84
+ # --------------------------------------------------------------------------- #
85
+ def detect_retrieval_failure(s: SignalVector, rec: CaptureRecord, t: Thresholds) -> DetectorResult:
86
+ fired = s.similarity < t.similarity_min
87
+ conf = _GATED_BASE
88
+ if s.entity_coverage < t.entity_coverage_min:
89
+ conf += 0.15
90
+ if s.overlap < t.overlap_very_low:
91
+ conf += 0.10
92
+ return DetectorResult(
93
+ failure=RETRIEVAL_FAILURE,
94
+ fired=fired,
95
+ confidence=conf,
96
+ severity=SEVERITY[RETRIEVAL_FAILURE],
97
+ root_cause=(
98
+ f"Mean retrieval similarity {s.similarity:.2f} is below "
99
+ f"{t.similarity_min:.2f}; the retriever returned irrelevant chunks."
100
+ ),
101
+ fix="Re-chunk source documents with an entity-aware strategy, tune the "
102
+ "retriever / embedding model, or expand the knowledge base.",
103
+ evidence={"similarity": s.similarity, "entity_coverage": s.entity_coverage,
104
+ "overlap": s.overlap},
105
+ ).clamp()
106
+
107
+
108
+ # --------------------------------------------------------------------------- #
109
+ # 3. Entity gap — Warning | checked 3rd
110
+ # --------------------------------------------------------------------------- #
111
+ def detect_entity_gap(s: SignalVector, rec: CaptureRecord, t: Thresholds) -> DetectorResult:
112
+ fired = (
113
+ s.similarity >= t.similarity_min
114
+ and s.entity_coverage < t.entity_coverage_min
115
+ and s.contradiction < t.contradiction_min
116
+ and s.variance <= t.variance_min
117
+ )
118
+ conf = 0.60 + min(0.08 * s.entities_missing, 0.30)
119
+ return DetectorResult(
120
+ failure=ENTITY_GAP,
121
+ fired=fired,
122
+ confidence=conf,
123
+ severity=SEVERITY[ENTITY_GAP],
124
+ root_cause=(
125
+ f"Retrieval is healthy but {s.entities_missing} entity(ies) in the "
126
+ "answer are absent from the retrieved context — a knowledge-base hole."
127
+ ),
128
+ fix="Check whether the missing entities exist elsewhere in the corpus with "
129
+ "different chunking; if not, flag the knowledge-base gap for human review.",
130
+ evidence={"entity_coverage": s.entity_coverage,
131
+ "entities_missing": s.entities_missing},
132
+ ).clamp()
133
+
134
+
135
+ # --------------------------------------------------------------------------- #
136
+ # 4. Hallucination — Critical | checked 4th
137
+ # --------------------------------------------------------------------------- #
138
+ def detect_hallucination(s: SignalVector, rec: CaptureRecord, t: Thresholds) -> DetectorResult:
139
+ score = 0.0
140
+ if s.similarity >= t.similarity_min: # gate
141
+ if s.entity_coverage < t.entity_coverage_hallucination:
142
+ score += 0.35
143
+ if s.contradiction > t.contradiction_min:
144
+ score += 0.30
145
+ if s.variance > t.variance_min:
146
+ score += 0.20
147
+ if t.overlap_very_low <= s.overlap <= 0.70:
148
+ score += 0.10
149
+ # High overlap (≥ 0.70) means the output largely repeats the context —
150
+ # strong grounding evidence that weighs against hallucination.
151
+ if s.overlap >= 0.70:
152
+ score -= 0.15
153
+ fired = s.similarity >= t.similarity_min and score >= t.hallucination_fire
154
+ return DetectorResult(
155
+ failure=HALLUCINATION,
156
+ fired=fired,
157
+ confidence=score,
158
+ severity=SEVERITY[HALLUCINATION],
159
+ root_cause=(
160
+ "Retrieval succeeded but the output is not grounded in it "
161
+ f"(fabrication score {score:.2f}): low entity coverage / contradiction "
162
+ "/ instability indicate invented content."
163
+ ),
164
+ fix="Add grounding constraints to the system prompt (answer only from "
165
+ "provided context; cite sources; say 'not found' when unsupported).",
166
+ evidence={"entity_coverage": s.entity_coverage, "contradiction": s.contradiction,
167
+ "variance": s.variance, "overlap": s.overlap},
168
+ ).clamp()
169
+
170
+
171
+ # --------------------------------------------------------------------------- #
172
+ # 5. Prompt brittleness — Warning | checked 5th
173
+ # --------------------------------------------------------------------------- #
174
+ def detect_prompt_brittleness(s: SignalVector, rec: CaptureRecord, t: Thresholds) -> DetectorResult:
175
+ gate = (
176
+ s.similarity >= t.similarity_min
177
+ and s.entity_coverage >= t.entity_coverage_min
178
+ and s.contradiction < t.contradiction_min
179
+ )
180
+ fired = gate and s.variance > t.variance_min
181
+ conf = 0.60
182
+ if rec.temperature is not None and rec.temperature > t.temperature_high:
183
+ conf += 0.15
184
+ return DetectorResult(
185
+ failure=PROMPT_BRITTLENESS,
186
+ fired=fired,
187
+ confidence=conf,
188
+ severity=SEVERITY[PROMPT_BRITTLENESS],
189
+ root_cause=(
190
+ f"All grounding signals are healthy but output variance {s.variance:.2f} "
191
+ "is high — the same prompt yields inconsistent answers."
192
+ ),
193
+ fix="Lower the sampling temperature, add an explicit output-format template, "
194
+ "and insert few-shot examples to pin down the expected response shape.",
195
+ evidence={"variance": s.variance, "temperature": rec.temperature},
196
+ ).clamp()
197
+
198
+
199
+ # Priority order matters (§5.2): earlier detectors gate later ones.
200
+ DETECTORS = [
201
+ detect_context_overflow,
202
+ detect_retrieval_failure,
203
+ detect_entity_gap,
204
+ detect_hallucination,
205
+ detect_prompt_brittleness,
206
+ ]
debugai/diagnosis.py ADDED
@@ -0,0 +1,64 @@
1
+ """Diagnosis pipeline (Architecture §5.2 / §7.3).
2
+
3
+ Runs all five detectors, ranks the ones that fired by confidence, and returns a
4
+ primary diagnosis plus secondary issues. Gate patterns in the detectors prevent
5
+ nonsensical combinations.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+
12
+ from debugai.detectors import DETECTORS, DetectorResult
13
+ from debugai.schema import CaptureRecord
14
+ from debugai.signals import SignalVector
15
+ from debugai.thresholds import DEFAULT_THRESHOLDS, Thresholds
16
+
17
+
18
+ @dataclass
19
+ class Diagnosis:
20
+ healthy: bool
21
+ primary: DetectorResult | None
22
+ secondary: list[DetectorResult] = field(default_factory=list)
23
+ signals: SignalVector | None = None
24
+
25
+ def to_dict(self) -> dict:
26
+ def fmt(r: DetectorResult | None) -> dict | None:
27
+ if r is None:
28
+ return None
29
+ return {
30
+ "failure": r.failure,
31
+ "confidence": r.confidence,
32
+ "severity": r.severity,
33
+ "root_cause": r.root_cause,
34
+ "fix": r.fix,
35
+ "evidence": r.evidence,
36
+ }
37
+
38
+ return {
39
+ "healthy": self.healthy,
40
+ "primary": fmt(self.primary),
41
+ "secondary": [fmt(r) for r in self.secondary],
42
+ "signals": self.signals.to_dict() if self.signals else None,
43
+ }
44
+
45
+
46
+ def diagnose(
47
+ signals: SignalVector,
48
+ rec: CaptureRecord,
49
+ thresholds: Thresholds = DEFAULT_THRESHOLDS,
50
+ ) -> Diagnosis:
51
+ """Classify a signal vector into primary + secondary failures."""
52
+ results = [detector(signals, rec, thresholds) for detector in DETECTORS]
53
+ fired = [r for r in results if r.fired]
54
+ # Rank by confidence; ties broken by detector priority (stable sort order).
55
+ fired.sort(key=lambda r: r.confidence, reverse=True)
56
+
57
+ if not fired:
58
+ return Diagnosis(healthy=True, primary=None, secondary=[], signals=signals)
59
+ return Diagnosis(
60
+ healthy=False,
61
+ primary=fired[0],
62
+ secondary=fired[1:],
63
+ signals=signals,
64
+ )
debugai/explainer.py ADDED
@@ -0,0 +1,105 @@
1
+ """Layer 3 — LLM Explainer (Architecture §2.2, §8.2).
2
+
3
+ The ONLY diagnosis-path layer that calls an LLM. It translates the structured,
4
+ deterministic diagnosis into a human-readable explanation + fix, calibrating
5
+ language to confidence and variance type (§2.3 step 5).
6
+
7
+ Fail-open design: if no API key is configured (or the SDK is missing), we fall
8
+ back to a deterministic template built from the detector's own root_cause / fix
9
+ strings. The deterministic system always has the final say (§8.2).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import logging
16
+ import os
17
+
18
+ from debugai.diagnosis import Diagnosis
19
+
20
+ log = logging.getLogger("debugai.explainer")
21
+
22
+ # Small, fast, cheap model is right for an advisory explanation layer.
23
+ DEFAULT_MODEL = os.environ.get("DEBUGAI_EXPLAINER_MODEL", "claude-haiku-4-5-20251001")
24
+
25
+ _SYSTEM = (
26
+ "You are DebugAI's explanation layer. A deterministic engine has already "
27
+ "diagnosed why an LLM application's output failed. Your ONLY job is to turn "
28
+ "the structured diagnosis into a crisp, developer-facing explanation and a "
29
+ "concrete fix. Rules: (1) Never contradict the diagnosis — it is ground "
30
+ "truth. (2) Calibrate certainty to the confidence score: state high-"
31
+ "confidence findings plainly, hedge low-confidence ones. (3) Be specific — "
32
+ "never say 'add more context'. (4) Keep it under 120 words. Respond as JSON: "
33
+ '{"explanation": "...", "fix": "..."}.'
34
+ )
35
+
36
+
37
+ def _deterministic(diag: Diagnosis) -> dict:
38
+ if diag.healthy or diag.primary is None:
39
+ return {
40
+ "explanation": "No failure detected — all signals are within healthy "
41
+ "ranges.",
42
+ "fix": "",
43
+ "model": "deterministic",
44
+ }
45
+ p = diag.primary
46
+ secondary = ", ".join(r.failure for r in diag.secondary)
47
+ explanation = p.root_cause
48
+ if secondary:
49
+ explanation += f" Secondary issues also detected: {secondary}."
50
+ return {"explanation": explanation, "fix": p.fix, "model": "deterministic"}
51
+
52
+
53
+ def _client():
54
+ """Return an Anthropic client, or None if unavailable (fail open)."""
55
+ if not os.environ.get("ANTHROPIC_API_KEY"):
56
+ return None
57
+ try:
58
+ import anthropic
59
+
60
+ return anthropic.Anthropic(timeout=30.0, max_retries=2)
61
+ except Exception as e: # pragma: no cover - environment dependent
62
+ log.warning("Anthropic client unavailable (%s); using deterministic explain", e)
63
+ return None
64
+
65
+
66
+ def explain(diag: Diagnosis, model: str = DEFAULT_MODEL) -> dict:
67
+ """Produce {explanation, fix, model} for a diagnosis."""
68
+ if diag.healthy or diag.primary is None:
69
+ return _deterministic(diag)
70
+
71
+ client = _client()
72
+ if client is None:
73
+ return _deterministic(diag)
74
+
75
+ payload = {
76
+ "primary": {
77
+ "failure": diag.primary.failure,
78
+ "confidence": diag.primary.confidence,
79
+ "severity": diag.primary.severity,
80
+ "root_cause": diag.primary.root_cause,
81
+ "deterministic_fix_hint": diag.primary.fix,
82
+ "evidence": diag.primary.evidence,
83
+ },
84
+ "secondary": [
85
+ {"failure": r.failure, "confidence": r.confidence} for r in diag.secondary
86
+ ],
87
+ "signals": diag.signals.to_dict() if diag.signals else {},
88
+ }
89
+ try:
90
+ msg = client.messages.create(
91
+ model=model,
92
+ max_tokens=400,
93
+ system=_SYSTEM,
94
+ messages=[{"role": "user", "content": json.dumps(payload)}],
95
+ )
96
+ text = "".join(b.text for b in msg.content if getattr(b, "type", "") == "text")
97
+ parsed = json.loads(text)
98
+ return {
99
+ "explanation": parsed.get("explanation", diag.primary.root_cause),
100
+ "fix": parsed.get("fix", diag.primary.fix),
101
+ "model": model,
102
+ }
103
+ except Exception as e: # pragma: no cover - network dependent
104
+ log.warning("LLM explain failed (%s); falling back to deterministic", e)
105
+ return _deterministic(diag)
@@ -0,0 +1,5 @@
1
+ """Framework integrations for DebugAI."""
2
+
3
+ from debugai.integrations.langchain import DebugAICallbackHandler
4
+
5
+ __all__ = ["DebugAICallbackHandler"]
@@ -0,0 +1,109 @@
1
+ """LangChain integration — diagnose a RAG/LLM chain automatically.
2
+
3
+ Attach the callback handler to any LangChain run; it captures the retrieved
4
+ documents and the LLM prompt/output, then runs ``analyze()`` on the result and
5
+ hands the diagnosis to your sink:
6
+
7
+ from debugai.integrations import DebugAICallbackHandler
8
+ handler = DebugAICallbackHandler(on_diagnosis=lambda d: print(d["primary"]))
9
+ chain.invoke(question, config={"callbacks": [handler]})
10
+
11
+ Works whether or not ``langchain`` is installed — if the LangChain base class is
12
+ importable we subclass it (so LangChain dispatches the events); otherwise the
13
+ handler is still a usable plain object you can drive directly (and in tests).
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ from typing import Any, Callable
20
+
21
+ from debugai.analyze import analyze
22
+
23
+ log = logging.getLogger("debugai.integrations.langchain")
24
+
25
+ # Subclass the real base when available so LangChain routes callbacks to us.
26
+ try: # langchain-core (current)
27
+ from langchain_core.callbacks import BaseCallbackHandler as _Base
28
+ except Exception: # pragma: no cover - optional dep
29
+ try: # older monolithic langchain
30
+ from langchain.callbacks.base import BaseCallbackHandler as _Base
31
+ except Exception:
32
+ _Base = object
33
+
34
+
35
+ def _doc_text(doc: Any) -> str:
36
+ return getattr(doc, "page_content", None) or (doc.get("page_content", "") if isinstance(doc, dict) else str(doc))
37
+
38
+
39
+ def _gen_text(response: Any) -> str:
40
+ """Pull the generated text out of a LangChain LLMResult (LLM or chat)."""
41
+ gens = getattr(response, "generations", None)
42
+ if not gens:
43
+ return ""
44
+ first = gens[0][0]
45
+ text = getattr(first, "text", "") or ""
46
+ if not text: # chat models carry it on .message.content
47
+ msg = getattr(first, "message", None)
48
+ content = getattr(msg, "content", "")
49
+ text = content if isinstance(content, str) else " ".join(
50
+ p.get("text", "") for p in content if isinstance(p, dict)
51
+ ) if isinstance(content, list) else str(content or "")
52
+ return text
53
+
54
+
55
+ class DebugAICallbackHandler(_Base):
56
+ """Captures retrieval + generation from a LangChain run and diagnoses it."""
57
+
58
+ def __init__(self, on_diagnosis: Callable[[dict], None] | None = None,
59
+ system_prompt: str = "", judge: bool = False,
60
+ explain_with_llm: bool = False):
61
+ super().__init__()
62
+ self._on_diagnosis = on_diagnosis
63
+ self._system_prompt = system_prompt
64
+ self._judge = judge
65
+ self._explain = explain_with_llm
66
+ self.last: dict | None = None # most recent diagnosis, for inspection
67
+ self._prompt: str = ""
68
+ self._chunks: list[str] = []
69
+
70
+ # --- LangChain callback hooks -----------------------------------------
71
+ def on_retriever_end(self, documents, **kwargs) -> None:
72
+ try:
73
+ self._chunks = [_doc_text(d) for d in (documents or [])]
74
+ except Exception as e: # never break the chain
75
+ log.warning("retriever capture failed: %s", e)
76
+
77
+ def on_llm_start(self, serialized, prompts, **kwargs) -> None:
78
+ if prompts:
79
+ self._prompt = prompts[-1]
80
+
81
+ def on_chat_model_start(self, serialized, messages, **kwargs) -> None:
82
+ # messages: list[list[BaseMessage]]; grab the last human message's text.
83
+ try:
84
+ flat = messages[-1] if messages else []
85
+ for m in reversed(flat):
86
+ content = getattr(m, "content", "")
87
+ if content:
88
+ self._prompt = content if isinstance(content, str) else str(content)
89
+ break
90
+ except Exception as e:
91
+ log.warning("chat-start capture failed: %s", e)
92
+
93
+ def on_llm_end(self, response, **kwargs) -> None:
94
+ output = _gen_text(response)
95
+ if not (self._prompt or output):
96
+ return
97
+ try:
98
+ self.last = analyze(
99
+ prompt=self._prompt or "(unknown)", output=output,
100
+ system_prompt=self._system_prompt,
101
+ chunks=self._chunks or None,
102
+ explain_with_llm=self._explain, judge=self._judge,
103
+ )
104
+ if self._on_diagnosis is not None:
105
+ self._on_diagnosis(self.last)
106
+ except Exception as e: # diagnosis must never break the user's chain
107
+ log.warning("DebugAI analyze failed: %s", e)
108
+ finally:
109
+ self._chunks = [] # reset retrieval for the next call
debugai/judge.py ADDED
@@ -0,0 +1,171 @@
1
+ """Instruction-adherence judge — diagnoses *behavioural* / prompt-following
2
+ failures that the deterministic grounding signals can't see.
3
+
4
+ Some failures aren't about retrieval or hallucination at all — e.g. a Socratic
5
+ tutor that reveals the answer in the first turn or re-asks the same guiding
6
+ question. These are violations of the system prompt's own rules. We detect them
7
+ with an LLM-as-judge (OpenAI by default), with a deterministic heuristic
8
+ fallback so the feature still works without an API key.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import logging
15
+ import os
16
+ import re
17
+ from dataclasses import asdict, dataclass, field
18
+
19
+ log = logging.getLogger("debugai.judge")
20
+
21
+ INSTRUCTION_VIOLATION = "instruction_violation"
22
+ DEFAULT_JUDGE_MODEL = os.environ.get("DEBUGAI_JUDGE_MODEL", "gpt-5.5")
23
+
24
+ _SENT_RE = re.compile(r"[.!?]+")
25
+ _WORD_RE = re.compile(r"[A-Za-z0-9']+")
26
+
27
+
28
+ @dataclass
29
+ class Violation:
30
+ rule: str
31
+ severity: str = "warning" # warning | critical
32
+ evidence: str = ""
33
+
34
+ def to_dict(self) -> dict:
35
+ return asdict(self)
36
+
37
+
38
+ @dataclass
39
+ class InstructionDiagnosis:
40
+ healthy: bool
41
+ confidence: float
42
+ violations: list[Violation] = field(default_factory=list)
43
+ model: str = "heuristic"
44
+
45
+ def to_dict(self) -> dict:
46
+ return {
47
+ "healthy": self.healthy,
48
+ "confidence": self.confidence,
49
+ "model": self.model,
50
+ "violations": [v.to_dict() for v in self.violations],
51
+ }
52
+
53
+
54
+ # --------------------------------------------------------------------------- #
55
+ # Public entry point
56
+ # --------------------------------------------------------------------------- #
57
+ def judge_instructions(system_prompt: str, user_prompt: str, output: str,
58
+ model: str | None = None) -> InstructionDiagnosis:
59
+ """Evaluate an assistant ``output`` against the rules in its ``system_prompt``.
60
+
61
+ Uses the OpenAI judge when ``OPENAI_API_KEY`` is set; otherwise a deterministic
62
+ heuristic check. Returns the violations found (empty → healthy)."""
63
+ if not (system_prompt or "").strip():
64
+ return InstructionDiagnosis(healthy=True, confidence=0.0, model="n/a")
65
+ model = model or DEFAULT_JUDGE_MODEL
66
+ if os.environ.get("OPENAI_API_KEY"):
67
+ try:
68
+ return _openai_judge(system_prompt, user_prompt, output, model)
69
+ except Exception as e: # pragma: no cover - network dependent
70
+ log.warning("OpenAI judge failed (%s); using heuristic fallback", e)
71
+ return _heuristic_judge(system_prompt, user_prompt, output)
72
+
73
+
74
+ # --------------------------------------------------------------------------- #
75
+ # LLM-as-judge (OpenAI)
76
+ # --------------------------------------------------------------------------- #
77
+ _JUDGE_SYSTEM = (
78
+ "You are a strict evaluator of an AI assistant's adherence to its own system "
79
+ "prompt. You are given the assistant's SYSTEM PROMPT (which contains its rules) "
80
+ "and the assistant's RESPONSE to a student. Identify every rule the response "
81
+ "violates — especially pedagogy rules such as revealing too much of the answer "
82
+ "early, asking more than one question, asking a question that merely restates a "
83
+ "previous one, or paraphrasing the student. Respond ONLY as JSON: "
84
+ '{"violations": [{"rule": "<short rule description>", "severity": '
85
+ '"critical|warning", "evidence": "<quote/why>"}], "confidence": <0..1>}. '
86
+ "If the response fully complies, return an empty violations list."
87
+ )
88
+
89
+
90
+ def _openai_judge(system_prompt: str, user_prompt: str, output: str,
91
+ model: str) -> InstructionDiagnosis:
92
+ from openai import OpenAI
93
+
94
+ client = OpenAI(timeout=30.0, max_retries=2)
95
+ payload = (
96
+ f"SYSTEM PROMPT (rules):\n{system_prompt}\n\n"
97
+ f"STUDENT MESSAGE:\n{user_prompt}\n\n"
98
+ f"ASSISTANT RESPONSE:\n{output}"
99
+ )
100
+ resp = client.chat.completions.create(
101
+ model=model,
102
+ response_format={"type": "json_object"},
103
+ messages=[{"role": "system", "content": _JUDGE_SYSTEM},
104
+ {"role": "user", "content": payload}],
105
+ )
106
+ data = json.loads(resp.choices[0].message.content or "{}")
107
+ violations = [
108
+ Violation(rule=v.get("rule", "rule"), severity=v.get("severity", "warning"),
109
+ evidence=v.get("evidence", ""))
110
+ for v in data.get("violations", [])
111
+ ]
112
+ conf = float(data.get("confidence", 0.8 if violations else 0.0))
113
+ return InstructionDiagnosis(
114
+ healthy=not violations, confidence=round(conf, 4),
115
+ violations=violations, model=f"openai:{model}",
116
+ )
117
+
118
+
119
+ # --------------------------------------------------------------------------- #
120
+ # Deterministic heuristic fallback (no API key)
121
+ # --------------------------------------------------------------------------- #
122
+ def _sentences(text: str) -> list[str]:
123
+ return [s.strip() for s in _SENT_RE.split(text or "") if s.strip()]
124
+
125
+
126
+ def _jaccard(a: str, b: str) -> float:
127
+ sa = {w.lower() for w in _WORD_RE.findall(a or "")}
128
+ sb = {w.lower() for w in _WORD_RE.findall(b or "")}
129
+ return len(sa & sb) / len(sa | sb) if (sa or sb) else 0.0
130
+
131
+
132
+ def _heuristic_judge(system_prompt: str, user_prompt: str, output: str) -> InstructionDiagnosis:
133
+ """Catches the most common Socratic-tutor violations without an LLM.
134
+
135
+ These map to typical rules: exactly one question, don't reveal the answer
136
+ early, don't open by paraphrasing the student."""
137
+ sysl = (system_prompt or "").lower()
138
+ violations: list[Violation] = []
139
+ qn = (output or "").count("?")
140
+ words = _WORD_RE.findall(output or "")
141
+
142
+ # Rule: exactly one leading question.
143
+ if "one question" in sysl or "socratic" in sysl or "leading question" in sysl:
144
+ if qn == 0:
145
+ violations.append(Violation("No leading question — the turn should advance with one question.",
146
+ "critical", "0 question marks in the response."))
147
+ elif qn > 1:
148
+ violations.append(Violation("More than one question in a single turn.",
149
+ "warning", f"{qn} question marks found."))
150
+
151
+ # Rule: don't reveal the full solution in the first response (length/declarative heuristic).
152
+ if "socratic" in sysl or "do not give away" in sysl or "not give away" in sysl or "leading question" in sysl:
153
+ declarative = [s for s in _sentences(output) if "?" not in s]
154
+ decl_words = sum(len(_WORD_RE.findall(s)) for s in declarative)
155
+ if len(words) > 90 or decl_words > 70:
156
+ violations.append(Violation("Reveals too much — long explanation given before the student reasons.",
157
+ "critical", f"{decl_words} words of explanation before the question."))
158
+
159
+ # Rule: never open by paraphrasing the student.
160
+ first = _sentences(output)[0] if _sentences(output) else ""
161
+ if first and _jaccard(first, user_prompt) > 0.5:
162
+ violations.append(Violation("Opens by paraphrasing the student's message.",
163
+ "warning", "First sentence closely mirrors the student input."))
164
+
165
+ # Confidence scales with count + severity.
166
+ if not violations:
167
+ return InstructionDiagnosis(healthy=True, confidence=0.0, model="heuristic")
168
+ crit = sum(1 for v in violations if v.severity == "critical")
169
+ conf = min(0.6 + 0.15 * len(violations) + 0.1 * crit, 0.95)
170
+ return InstructionDiagnosis(healthy=False, confidence=round(conf, 4),
171
+ violations=violations, model="heuristic")