evalguard-python 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalguard/__init__.py +42 -0
- evalguard/anthropic.py +182 -0
- evalguard/bedrock.py +280 -0
- evalguard/client.py +516 -0
- evalguard/crewai.py +189 -0
- evalguard/fastapi.py +273 -0
- evalguard/guardrails.py +160 -0
- evalguard/langchain.py +218 -0
- evalguard/nemoclaw.py +251 -0
- evalguard/openai.py +194 -0
- evalguard/types.py +142 -0
- evalguard_python-1.1.0.dist-info/METADATA +362 -0
- evalguard_python-1.1.0.dist-info/RECORD +15 -0
- evalguard_python-1.1.0.dist-info/WHEEL +5 -0
- evalguard_python-1.1.0.dist-info/top_level.txt +1 -0
evalguard/langchain.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""LangChain callback handler for EvalGuard.
|
|
2
|
+
|
|
3
|
+
Usage::
|
|
4
|
+
|
|
5
|
+
from evalguard.langchain import EvalGuardCallback
|
|
6
|
+
from langchain_openai import ChatOpenAI
|
|
7
|
+
|
|
8
|
+
callback = EvalGuardCallback(api_key="eg_...", project_id="proj_...")
|
|
9
|
+
llm = ChatOpenAI(callbacks=[callback])
|
|
10
|
+
# Every LLM call is now guarded and traced
|
|
11
|
+
llm.invoke("Hello, world!")
|
|
12
|
+
|
|
13
|
+
Works with any LangChain LLM, chat model, or chain that supports callbacks.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import time
|
|
19
|
+
import uuid
|
|
20
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
21
|
+
|
|
22
|
+
from .guardrails import GuardrailClient, GuardrailViolation
|
|
23
|
+
|
|
24
|
+
# LangChain callback protocol: we implement the methods directly rather
|
|
25
|
+
# than inheriting from BaseCallbackHandler so the SDK has zero hard
|
|
26
|
+
# dependencies on LangChain. When LangChain is installed, the duck-typed
|
|
27
|
+
# interface is fully compatible.
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class EvalGuardCallback:
|
|
31
|
+
"""LangChain callback that guards every LLM call via EvalGuard.
|
|
32
|
+
|
|
33
|
+
This class implements the LangChain callback protocol without importing
|
|
34
|
+
LangChain, so it works with *any* version (0.1.x, 0.2.x, 0.3.x).
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
api_key:
|
|
39
|
+
EvalGuard API key.
|
|
40
|
+
project_id:
|
|
41
|
+
Optional project ID for trace grouping.
|
|
42
|
+
rules:
|
|
43
|
+
Guardrail rules for input checking.
|
|
44
|
+
block_on_violation:
|
|
45
|
+
Raise :class:`GuardrailViolation` when input is blocked.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
api_key: str,
|
|
51
|
+
project_id: Optional[str] = None,
|
|
52
|
+
base_url: str = "https://api.evalguard.ai",
|
|
53
|
+
rules: Optional[List[str]] = None,
|
|
54
|
+
block_on_violation: bool = True,
|
|
55
|
+
timeout: float = 5.0,
|
|
56
|
+
) -> None:
|
|
57
|
+
self._guard = GuardrailClient(
|
|
58
|
+
api_key=api_key,
|
|
59
|
+
base_url=base_url,
|
|
60
|
+
project_id=project_id,
|
|
61
|
+
timeout=timeout,
|
|
62
|
+
)
|
|
63
|
+
self._rules = rules
|
|
64
|
+
self._block = block_on_violation
|
|
65
|
+
# Per-run state keyed by run_id
|
|
66
|
+
self._runs: Dict[str, Dict[str, Any]] = {}
|
|
67
|
+
|
|
68
|
+
# ── LangChain callback protocol ──────────────────────────────────
|
|
69
|
+
|
|
70
|
+
def on_llm_start(
|
|
71
|
+
self,
|
|
72
|
+
serialized: Dict[str, Any],
|
|
73
|
+
prompts: List[str],
|
|
74
|
+
*,
|
|
75
|
+
run_id: Optional[Any] = None,
|
|
76
|
+
parent_run_id: Optional[Any] = None,
|
|
77
|
+
tags: Optional[List[str]] = None,
|
|
78
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
79
|
+
**kwargs: Any,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""Pre-LLM: check for prompt injection, PII, etc."""
|
|
82
|
+
rid = str(run_id or uuid.uuid4())
|
|
83
|
+
prompt_text = "\n".join(prompts)
|
|
84
|
+
model_name = serialized.get("name", serialized.get("id", ["unknown"])[-1] if isinstance(serialized.get("id"), list) else "unknown")
|
|
85
|
+
|
|
86
|
+
start = time.monotonic()
|
|
87
|
+
check = self._guard.check_input(
|
|
88
|
+
prompt_text,
|
|
89
|
+
rules=self._rules,
|
|
90
|
+
metadata={"model": model_name, "framework": "langchain"},
|
|
91
|
+
)
|
|
92
|
+
guard_ms = (time.monotonic() - start) * 1000
|
|
93
|
+
|
|
94
|
+
self._runs[rid] = {
|
|
95
|
+
"model": model_name,
|
|
96
|
+
"input": prompt_text,
|
|
97
|
+
"guard_ms": guard_ms,
|
|
98
|
+
"violations": check.get("violations", []),
|
|
99
|
+
"start": time.monotonic(),
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
if not check.get("allowed", True) and self._block:
|
|
103
|
+
raise GuardrailViolation(check.get("violations", []))
|
|
104
|
+
|
|
105
|
+
def on_chat_model_start(
|
|
106
|
+
self,
|
|
107
|
+
serialized: Dict[str, Any],
|
|
108
|
+
messages: List[Any],
|
|
109
|
+
*,
|
|
110
|
+
run_id: Optional[Any] = None,
|
|
111
|
+
parent_run_id: Optional[Any] = None,
|
|
112
|
+
tags: Optional[List[str]] = None,
|
|
113
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
114
|
+
**kwargs: Any,
|
|
115
|
+
) -> None:
|
|
116
|
+
"""Pre-chat-model: extract text from message objects and check."""
|
|
117
|
+
prompts = []
|
|
118
|
+
for message_group in messages:
|
|
119
|
+
for msg in (message_group if isinstance(message_group, list) else [message_group]):
|
|
120
|
+
content = getattr(msg, "content", "") if not isinstance(msg, dict) else msg.get("content", "")
|
|
121
|
+
if isinstance(content, str):
|
|
122
|
+
prompts.append(content)
|
|
123
|
+
self.on_llm_start(serialized, prompts, run_id=run_id, parent_run_id=parent_run_id, tags=tags, metadata=metadata, **kwargs)
|
|
124
|
+
|
|
125
|
+
def on_llm_end(
|
|
126
|
+
self,
|
|
127
|
+
response: Any,
|
|
128
|
+
*,
|
|
129
|
+
run_id: Optional[Any] = None,
|
|
130
|
+
parent_run_id: Optional[Any] = None,
|
|
131
|
+
**kwargs: Any,
|
|
132
|
+
) -> None:
|
|
133
|
+
"""Post-LLM: log the complete trace."""
|
|
134
|
+
rid = str(run_id or "")
|
|
135
|
+
run_data = self._runs.pop(rid, {})
|
|
136
|
+
llm_ms = (time.monotonic() - run_data.get("start", time.monotonic())) * 1000
|
|
137
|
+
|
|
138
|
+
output_text = _extract_lc_response(response)
|
|
139
|
+
self._guard.log_trace(
|
|
140
|
+
{
|
|
141
|
+
"provider": "langchain",
|
|
142
|
+
"model": run_data.get("model", "unknown"),
|
|
143
|
+
"input": run_data.get("input", ""),
|
|
144
|
+
"output": output_text,
|
|
145
|
+
"guard_latency_ms": round(run_data.get("guard_ms", 0), 2),
|
|
146
|
+
"llm_latency_ms": round(llm_ms, 2),
|
|
147
|
+
"violations": run_data.get("violations", []),
|
|
148
|
+
}
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def on_llm_error(
|
|
152
|
+
self,
|
|
153
|
+
error: BaseException,
|
|
154
|
+
*,
|
|
155
|
+
run_id: Optional[Any] = None,
|
|
156
|
+
parent_run_id: Optional[Any] = None,
|
|
157
|
+
**kwargs: Any,
|
|
158
|
+
) -> None:
|
|
159
|
+
"""Log error traces for failed LLM calls."""
|
|
160
|
+
rid = str(run_id or "")
|
|
161
|
+
run_data = self._runs.pop(rid, {})
|
|
162
|
+
self._guard.log_trace(
|
|
163
|
+
{
|
|
164
|
+
"provider": "langchain",
|
|
165
|
+
"model": run_data.get("model", "unknown"),
|
|
166
|
+
"input": run_data.get("input", ""),
|
|
167
|
+
"output": "",
|
|
168
|
+
"error": str(error),
|
|
169
|
+
"violations": run_data.get("violations", []),
|
|
170
|
+
}
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# ── Chain-level callbacks (no-ops, present for compatibility) ────
|
|
174
|
+
|
|
175
|
+
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
179
|
+
pass
|
|
180
|
+
|
|
181
|
+
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
182
|
+
pass
|
|
183
|
+
|
|
184
|
+
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> None:
|
|
185
|
+
pass
|
|
186
|
+
|
|
187
|
+
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
|
188
|
+
pass
|
|
189
|
+
|
|
190
|
+
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
def on_text(self, text: str, **kwargs: Any) -> None:
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
def on_retry(self, retry_state: Any, **kwargs: Any) -> None:
|
|
197
|
+
pass
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _extract_lc_response(response: Any) -> str:
|
|
201
|
+
"""Extract text from a LangChain LLMResult or ChatResult."""
|
|
202
|
+
try:
|
|
203
|
+
# LLMResult / ChatResult have .generations
|
|
204
|
+
generations = getattr(response, "generations", None)
|
|
205
|
+
if generations:
|
|
206
|
+
parts: list[str] = []
|
|
207
|
+
for gen_list in generations:
|
|
208
|
+
for gen in gen_list:
|
|
209
|
+
# ChatGeneration has .message.content; Generation has .text
|
|
210
|
+
msg = getattr(gen, "message", None)
|
|
211
|
+
if msg:
|
|
212
|
+
parts.append(getattr(msg, "content", str(msg)))
|
|
213
|
+
else:
|
|
214
|
+
parts.append(getattr(gen, "text", str(gen)))
|
|
215
|
+
return "\n".join(parts)
|
|
216
|
+
except Exception:
|
|
217
|
+
pass
|
|
218
|
+
return str(response) if response else ""
|
evalguard/nemoclaw.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""NVIDIA NeMo Guardrails / OpenClaw agent integration for EvalGuard.
|
|
2
|
+
|
|
3
|
+
Usage::
|
|
4
|
+
|
|
5
|
+
from evalguard.nemoclaw import EvalGuardAgent
|
|
6
|
+
|
|
7
|
+
agent = EvalGuardAgent(api_key="eg_...", agent_name="support-bot")
|
|
8
|
+
|
|
9
|
+
# Guard any LLM call regardless of provider
|
|
10
|
+
result = agent.guarded_call(
|
|
11
|
+
provider="openai",
|
|
12
|
+
messages=[{"role": "user", "content": "Hello"}],
|
|
13
|
+
llm_fn=lambda: openai_client.chat.completions.create(
|
|
14
|
+
model="gpt-4", messages=[{"role": "user", "content": "Hello"}]
|
|
15
|
+
),
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Or use as a context manager for multi-step agent workflows
|
|
19
|
+
with agent.session("ticket-123") as session:
|
|
20
|
+
session.check("User says: reset my password")
|
|
21
|
+
result = do_llm_call(...)
|
|
22
|
+
session.log_step("password_reset", input="...", output=str(result))
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import time
|
|
28
|
+
import uuid
|
|
29
|
+
from contextlib import contextmanager
|
|
30
|
+
from typing import Any, Callable, Dict, Generator, List, Optional
|
|
31
|
+
|
|
32
|
+
from .guardrails import GuardrailClient, GuardrailViolation
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class EvalGuardAgent:
|
|
36
|
+
"""Agent-level guardrail wrapper for NeMo/OpenClaw-style agent systems.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
api_key:
|
|
41
|
+
EvalGuard API key.
|
|
42
|
+
agent_name:
|
|
43
|
+
A human-readable name for this agent (used in traces).
|
|
44
|
+
project_id:
|
|
45
|
+
Optional project ID.
|
|
46
|
+
rules:
|
|
47
|
+
Default guardrail rules.
|
|
48
|
+
block_on_violation:
|
|
49
|
+
Raise on violation if *True*.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
api_key: str,
|
|
55
|
+
agent_name: str = "default",
|
|
56
|
+
project_id: Optional[str] = None,
|
|
57
|
+
base_url: str = "https://api.evalguard.ai",
|
|
58
|
+
rules: Optional[List[str]] = None,
|
|
59
|
+
block_on_violation: bool = True,
|
|
60
|
+
timeout: float = 5.0,
|
|
61
|
+
) -> None:
|
|
62
|
+
self._guard = GuardrailClient(
|
|
63
|
+
api_key=api_key,
|
|
64
|
+
base_url=base_url,
|
|
65
|
+
project_id=project_id,
|
|
66
|
+
timeout=timeout,
|
|
67
|
+
)
|
|
68
|
+
self._agent_name = agent_name
|
|
69
|
+
self._rules = rules
|
|
70
|
+
self._block = block_on_violation
|
|
71
|
+
|
|
72
|
+
def guarded_call(
|
|
73
|
+
self,
|
|
74
|
+
provider: str,
|
|
75
|
+
messages: List[Dict[str, str]],
|
|
76
|
+
llm_fn: Callable[[], Any],
|
|
77
|
+
*,
|
|
78
|
+
rules: Optional[List[str]] = None,
|
|
79
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
80
|
+
) -> Any:
|
|
81
|
+
"""Execute an LLM call with pre/post guardrail checks.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
provider:
|
|
86
|
+
LLM provider name (``"openai"``, ``"anthropic"``, etc.).
|
|
87
|
+
messages:
|
|
88
|
+
The messages being sent (for guardrail input checking).
|
|
89
|
+
llm_fn:
|
|
90
|
+
A zero-argument callable that performs the actual LLM call.
|
|
91
|
+
rules:
|
|
92
|
+
Override default rules for this call.
|
|
93
|
+
metadata:
|
|
94
|
+
Additional metadata for the trace.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
The result of ``llm_fn()``.
|
|
99
|
+
"""
|
|
100
|
+
prompt_text = "\n".join(
|
|
101
|
+
msg.get("content", "") for msg in messages if isinstance(msg, dict)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# ── Pre-LLM check ────────────────────────────────────────────
|
|
105
|
+
start = time.monotonic()
|
|
106
|
+
check = self._guard.check_input(
|
|
107
|
+
prompt_text,
|
|
108
|
+
rules=rules or self._rules,
|
|
109
|
+
metadata={
|
|
110
|
+
"agent": self._agent_name,
|
|
111
|
+
"provider": provider,
|
|
112
|
+
"framework": "nemoclaw",
|
|
113
|
+
**(metadata or {}),
|
|
114
|
+
},
|
|
115
|
+
)
|
|
116
|
+
guard_ms = (time.monotonic() - start) * 1000
|
|
117
|
+
|
|
118
|
+
if not check.get("allowed", True) and self._block:
|
|
119
|
+
raise GuardrailViolation(check.get("violations", []))
|
|
120
|
+
|
|
121
|
+
# ── Execute LLM call ─────────────────────────────────────────
|
|
122
|
+
start = time.monotonic()
|
|
123
|
+
result = llm_fn()
|
|
124
|
+
llm_ms = (time.monotonic() - start) * 1000
|
|
125
|
+
|
|
126
|
+
# ── Post-LLM trace ───────────────────────────────────────────
|
|
127
|
+
output_text = str(result)[:2000] if result else ""
|
|
128
|
+
self._guard.log_trace(
|
|
129
|
+
{
|
|
130
|
+
"provider": provider,
|
|
131
|
+
"agent": self._agent_name,
|
|
132
|
+
"framework": "nemoclaw",
|
|
133
|
+
"input": prompt_text,
|
|
134
|
+
"output": output_text,
|
|
135
|
+
"guard_latency_ms": round(guard_ms, 2),
|
|
136
|
+
"llm_latency_ms": round(llm_ms, 2),
|
|
137
|
+
"violations": check.get("violations", []),
|
|
138
|
+
**(metadata or {}),
|
|
139
|
+
}
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return result
|
|
143
|
+
|
|
144
|
+
@contextmanager
|
|
145
|
+
def session(self, session_id: Optional[str] = None) -> Generator["_AgentSession", None, None]:
|
|
146
|
+
"""Create a guarded session for multi-step agent workflows.
|
|
147
|
+
|
|
148
|
+
Usage::
|
|
149
|
+
|
|
150
|
+
with agent.session("ticket-123") as s:
|
|
151
|
+
s.check("user input here")
|
|
152
|
+
result = my_llm_call()
|
|
153
|
+
s.log_step("step_name", input="...", output="...")
|
|
154
|
+
"""
|
|
155
|
+
sid = session_id or str(uuid.uuid4())
|
|
156
|
+
sess = _AgentSession(self._guard, self._agent_name, sid, self._rules, self._block)
|
|
157
|
+
try:
|
|
158
|
+
yield sess
|
|
159
|
+
finally:
|
|
160
|
+
sess._finalize()
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class _AgentSession:
|
|
164
|
+
"""A multi-step guarded session within an agent."""
|
|
165
|
+
|
|
166
|
+
__slots__ = ("_guard", "_agent_name", "_session_id", "_rules", "_block", "_steps")
|
|
167
|
+
|
|
168
|
+
def __init__(
|
|
169
|
+
self,
|
|
170
|
+
guard: GuardrailClient,
|
|
171
|
+
agent_name: str,
|
|
172
|
+
session_id: str,
|
|
173
|
+
rules: Optional[List[str]],
|
|
174
|
+
block: bool,
|
|
175
|
+
) -> None:
|
|
176
|
+
self._guard = guard
|
|
177
|
+
self._agent_name = agent_name
|
|
178
|
+
self._session_id = session_id
|
|
179
|
+
self._rules = rules
|
|
180
|
+
self._block = block
|
|
181
|
+
self._steps: List[Dict[str, Any]] = []
|
|
182
|
+
|
|
183
|
+
def check(self, text: str, *, rules: Optional[List[str]] = None) -> Dict[str, Any]:
|
|
184
|
+
"""Run a guardrail check within this session."""
|
|
185
|
+
result = self._guard.check_input(
|
|
186
|
+
text,
|
|
187
|
+
rules=rules or self._rules,
|
|
188
|
+
metadata={
|
|
189
|
+
"agent": self._agent_name,
|
|
190
|
+
"session_id": self._session_id,
|
|
191
|
+
"framework": "nemoclaw",
|
|
192
|
+
},
|
|
193
|
+
)
|
|
194
|
+
if not result.get("allowed", True) and self._block:
|
|
195
|
+
raise GuardrailViolation(result.get("violations", []))
|
|
196
|
+
return result
|
|
197
|
+
|
|
198
|
+
def log_step(
|
|
199
|
+
self,
|
|
200
|
+
step_name: str,
|
|
201
|
+
*,
|
|
202
|
+
input: str = "",
|
|
203
|
+
output: str = "",
|
|
204
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
205
|
+
) -> None:
|
|
206
|
+
"""Log a single step in the agent workflow."""
|
|
207
|
+
step = {
|
|
208
|
+
"step": step_name,
|
|
209
|
+
"input": input[:2000],
|
|
210
|
+
"output": output[:2000],
|
|
211
|
+
**(metadata or {}),
|
|
212
|
+
}
|
|
213
|
+
self._steps.append(step)
|
|
214
|
+
self._guard.log_trace(
|
|
215
|
+
{
|
|
216
|
+
"provider": "nemoclaw",
|
|
217
|
+
"agent": self._agent_name,
|
|
218
|
+
"session_id": self._session_id,
|
|
219
|
+
"framework": "nemoclaw",
|
|
220
|
+
**step,
|
|
221
|
+
}
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def _finalize(self) -> None:
|
|
225
|
+
"""Log session summary on exit."""
|
|
226
|
+
self._guard.log_trace(
|
|
227
|
+
{
|
|
228
|
+
"provider": "nemoclaw",
|
|
229
|
+
"agent": self._agent_name,
|
|
230
|
+
"session_id": self._session_id,
|
|
231
|
+
"framework": "nemoclaw",
|
|
232
|
+
"event": "session_end",
|
|
233
|
+
"total_steps": len(self._steps),
|
|
234
|
+
}
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
# Convenience alias
|
|
239
|
+
def init(
|
|
240
|
+
api_key: str,
|
|
241
|
+
agent_name: str = "default",
|
|
242
|
+
**kwargs: Any,
|
|
243
|
+
) -> EvalGuardAgent:
|
|
244
|
+
"""Shorthand for creating an EvalGuardAgent.
|
|
245
|
+
|
|
246
|
+
Usage::
|
|
247
|
+
|
|
248
|
+
from evalguard.nemoclaw import init
|
|
249
|
+
agent = init(api_key="eg_...", agent_name="support-bot")
|
|
250
|
+
"""
|
|
251
|
+
return EvalGuardAgent(api_key=api_key, agent_name=agent_name, **kwargs)
|
evalguard/openai.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Drop-in OpenAI wrapper with EvalGuard guardrails.
|
|
2
|
+
|
|
3
|
+
Usage::
|
|
4
|
+
|
|
5
|
+
from evalguard.openai import wrap
|
|
6
|
+
from openai import OpenAI
|
|
7
|
+
|
|
8
|
+
client = wrap(OpenAI(), api_key="eg_...", project_id="proj_...")
|
|
9
|
+
# Use exactly like normal -- guardrails are automatic
|
|
10
|
+
response = client.chat.completions.create(
|
|
11
|
+
model="gpt-4",
|
|
12
|
+
messages=[{"role": "user", "content": "Hello"}],
|
|
13
|
+
)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import time
|
|
19
|
+
from typing import Any, List, Optional
|
|
20
|
+
|
|
21
|
+
from .guardrails import GuardrailClient, GuardrailViolation
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def wrap(
|
|
25
|
+
client: Any,
|
|
26
|
+
*,
|
|
27
|
+
api_key: str,
|
|
28
|
+
project_id: Optional[str] = None,
|
|
29
|
+
base_url: str = "https://api.evalguard.ai",
|
|
30
|
+
rules: Optional[List[str]] = None,
|
|
31
|
+
block_on_violation: bool = True,
|
|
32
|
+
timeout: float = 5.0,
|
|
33
|
+
) -> "_OpenAIProxy":
|
|
34
|
+
"""Wrap an ``openai.OpenAI`` (or ``AsyncOpenAI``) client with guardrails.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
client:
|
|
39
|
+
An instantiated ``openai.OpenAI`` or ``openai.AsyncOpenAI`` client.
|
|
40
|
+
api_key:
|
|
41
|
+
EvalGuard API key.
|
|
42
|
+
project_id:
|
|
43
|
+
Optional EvalGuard project ID for trace grouping.
|
|
44
|
+
rules:
|
|
45
|
+
Guardrail rules to apply. Defaults to prompt-injection + PII.
|
|
46
|
+
block_on_violation:
|
|
47
|
+
If *True*, raise :class:`GuardrailViolation` when input is blocked.
|
|
48
|
+
If *False*, violations are logged but the request proceeds.
|
|
49
|
+
"""
|
|
50
|
+
guard = GuardrailClient(
|
|
51
|
+
api_key=api_key,
|
|
52
|
+
base_url=base_url,
|
|
53
|
+
project_id=project_id,
|
|
54
|
+
timeout=timeout,
|
|
55
|
+
)
|
|
56
|
+
return _OpenAIProxy(client, guard, rules, block_on_violation)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class _OpenAIProxy:
|
|
60
|
+
"""Transparent proxy that intercepts ``chat.completions.create``."""
|
|
61
|
+
|
|
62
|
+
__slots__ = ("_client", "_guard", "_rules", "_block")
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
client: Any,
|
|
67
|
+
guard: GuardrailClient,
|
|
68
|
+
rules: Optional[List[str]],
|
|
69
|
+
block: bool,
|
|
70
|
+
) -> None:
|
|
71
|
+
self._client = client
|
|
72
|
+
self._guard = guard
|
|
73
|
+
self._rules = rules
|
|
74
|
+
self._block = block
|
|
75
|
+
|
|
76
|
+
def __getattr__(self, name: str) -> Any:
|
|
77
|
+
attr = getattr(self._client, name)
|
|
78
|
+
if name == "chat":
|
|
79
|
+
return _ChatProxy(attr, self._guard, self._rules, self._block)
|
|
80
|
+
return attr
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class _ChatProxy:
|
|
84
|
+
__slots__ = ("_chat", "_guard", "_rules", "_block")
|
|
85
|
+
|
|
86
|
+
def __init__(self, chat: Any, guard: GuardrailClient, rules: Optional[List[str]], block: bool) -> None:
|
|
87
|
+
self._chat = chat
|
|
88
|
+
self._guard = guard
|
|
89
|
+
self._rules = rules
|
|
90
|
+
self._block = block
|
|
91
|
+
|
|
92
|
+
def __getattr__(self, name: str) -> Any:
|
|
93
|
+
attr = getattr(self._chat, name)
|
|
94
|
+
if name == "completions":
|
|
95
|
+
return _CompletionsProxy(attr, self._guard, self._rules, self._block)
|
|
96
|
+
return attr
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class _CompletionsProxy:
|
|
100
|
+
__slots__ = ("_completions", "_guard", "_rules", "_block")
|
|
101
|
+
|
|
102
|
+
def __init__(self, completions: Any, guard: GuardrailClient, rules: Optional[List[str]], block: bool) -> None:
|
|
103
|
+
self._completions = completions
|
|
104
|
+
self._guard = guard
|
|
105
|
+
self._rules = rules
|
|
106
|
+
self._block = block
|
|
107
|
+
|
|
108
|
+
def __getattr__(self, name: str) -> Any:
|
|
109
|
+
if name == "create":
|
|
110
|
+
return self._guarded_create
|
|
111
|
+
return getattr(self._completions, name)
|
|
112
|
+
|
|
113
|
+
def _guarded_create(self, **kwargs: Any) -> Any:
|
|
114
|
+
messages = kwargs.get("messages", [])
|
|
115
|
+
prompt_text = _extract_prompt(messages)
|
|
116
|
+
model = kwargs.get("model", "unknown")
|
|
117
|
+
|
|
118
|
+
# ── Pre-LLM check ────────────────────────────────────────────
|
|
119
|
+
start = time.monotonic()
|
|
120
|
+
check = self._guard.check_input(
|
|
121
|
+
prompt_text,
|
|
122
|
+
rules=self._rules,
|
|
123
|
+
metadata={"model": model, "framework": "openai"},
|
|
124
|
+
)
|
|
125
|
+
guard_ms = (time.monotonic() - start) * 1000
|
|
126
|
+
|
|
127
|
+
if not check.get("allowed", True):
|
|
128
|
+
if self._block:
|
|
129
|
+
raise GuardrailViolation(check.get("violations", []))
|
|
130
|
+
# Non-blocking: log but continue
|
|
131
|
+
|
|
132
|
+
# ── Call OpenAI ───────────────────────────────────────────────
|
|
133
|
+
start = time.monotonic()
|
|
134
|
+
response = self._completions.create(**kwargs)
|
|
135
|
+
llm_ms = (time.monotonic() - start) * 1000
|
|
136
|
+
|
|
137
|
+
# ── Post-LLM trace ───────────────────────────────────────────
|
|
138
|
+
output_text = _extract_response(response)
|
|
139
|
+
self._guard.log_trace(
|
|
140
|
+
{
|
|
141
|
+
"provider": "openai",
|
|
142
|
+
"model": model,
|
|
143
|
+
"input": prompt_text,
|
|
144
|
+
"output": output_text,
|
|
145
|
+
"guard_latency_ms": round(guard_ms, 2),
|
|
146
|
+
"llm_latency_ms": round(llm_ms, 2),
|
|
147
|
+
"violations": check.get("violations", []),
|
|
148
|
+
"token_usage": _extract_usage(response),
|
|
149
|
+
}
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return response
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _extract_prompt(messages: list) -> str:
|
|
156
|
+
"""Join all message contents into a single string for guardrail checking."""
|
|
157
|
+
parts: list[str] = []
|
|
158
|
+
for msg in messages:
|
|
159
|
+
content = msg.get("content", "") if isinstance(msg, dict) else getattr(msg, "content", "")
|
|
160
|
+
if isinstance(content, str):
|
|
161
|
+
parts.append(content)
|
|
162
|
+
elif isinstance(content, list):
|
|
163
|
+
# Multi-modal messages: extract text parts
|
|
164
|
+
for part in content:
|
|
165
|
+
if isinstance(part, dict) and part.get("type") == "text":
|
|
166
|
+
parts.append(part.get("text", ""))
|
|
167
|
+
return "\n".join(parts)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _extract_response(response: Any) -> str:
|
|
171
|
+
"""Extract text from an OpenAI ChatCompletion response."""
|
|
172
|
+
try:
|
|
173
|
+
choices = response.choices if hasattr(response, "choices") else response.get("choices", [])
|
|
174
|
+
if choices:
|
|
175
|
+
msg = choices[0].message if hasattr(choices[0], "message") else choices[0].get("message", {})
|
|
176
|
+
return msg.content if hasattr(msg, "content") else msg.get("content", "")
|
|
177
|
+
except Exception:
|
|
178
|
+
pass
|
|
179
|
+
return ""
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _extract_usage(response: Any) -> Optional[dict]:
|
|
183
|
+
"""Extract token usage from response."""
|
|
184
|
+
try:
|
|
185
|
+
usage = response.usage if hasattr(response, "usage") else response.get("usage")
|
|
186
|
+
if usage:
|
|
187
|
+
return {
|
|
188
|
+
"prompt_tokens": getattr(usage, "prompt_tokens", None) or usage.get("prompt_tokens"),
|
|
189
|
+
"completion_tokens": getattr(usage, "completion_tokens", None) or usage.get("completion_tokens"),
|
|
190
|
+
"total_tokens": getattr(usage, "total_tokens", None) or usage.get("total_tokens"),
|
|
191
|
+
}
|
|
192
|
+
except Exception:
|
|
193
|
+
pass
|
|
194
|
+
return None
|