agentrust-py 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agentrust_sdk/hooks.py ADDED
@@ -0,0 +1,461 @@
1
+ """
2
+ Auto-instrumentation hooks — monkey-patch AI framework entry points so AgentTrust
3
+ governance fires without any `@harness` decorator changes.
4
+
5
+ Supported targets
6
+ -----------------
7
+ - **OpenAI SDK** >= 1.0 — ``openai.chat.completions.create`` (sync + async)
8
+ - **LangChain Core (modern LCEL)** — ``langchain_core.runnables.base.Runnable.invoke``
9
+ - **LangChain legacy** — ``langchain.llms.base.BaseLLM.predict`` + ``Chain.__call__``
10
+ - **LangGraph** — ``langgraph.graph.state.CompiledStateGraph.invoke`` (auto-wrap compiled graphs)
11
+
12
+ Usage
13
+ -----
14
+ Call once at application startup (before frameworks are used)::
15
+
16
+ from agentrust_sdk.hooks import auto_instrument
17
+ auto_instrument() # patches all installed frameworks; silently skips missing ones
18
+
19
+ Disable all patches::
20
+
21
+ AGENTRUST_AUTO_INSTRUMENT=false # env var; takes effect before import
22
+
23
+ Rollback / uninstall
24
+ --------------------
25
+ ``remove_patches()`` restores **original callables** (not just clears the set).
26
+ A **process restart** is the safest rollback if the process already served traffic.
27
+ """
28
+ from __future__ import annotations
29
+
30
+ import logging
31
+ import os
32
+ import threading
33
+ from typing import Any
34
+
35
+ from .config import SDK_CONFIG
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ # Registry: name → list of (obj, attr, original_callable)
40
+ # Stored so remove_patches() can restore every original.
41
+ # Protected by _REGISTRY_LOCK — auto_instrument() may be called from multiple
42
+ # threads in e.g. gunicorn pre-fork workers.
43
+ _PATCH_REGISTRY: dict[str, list[tuple[Any, str, Any]]] = {}
44
+ _REGISTRY_LOCK = threading.Lock()
45
+
46
+
47
+ def _is_patched(name: str) -> bool:
48
+ with _REGISTRY_LOCK:
49
+ return name in _PATCH_REGISTRY
50
+
51
+
52
+ def auto_instrument(
53
+ *,
54
+ langchain: bool = True,
55
+ openai: bool = True,
56
+ langgraph: bool = True,
57
+ agent_id: str = "auto-instrumented",
58
+ gateway_url: str | None = None,
59
+ api_key: str | None = None,
60
+ ) -> list[str]:
61
+ """Patch installed AI frameworks to add AgentTrust governance.
62
+
63
+ Returns a list of framework names that were successfully patched.
64
+ Silently skips frameworks that are not installed.
65
+ """
66
+ if not SDK_CONFIG.enabled:
67
+ logger.debug("[AgentTrust] Auto-instrumentation skipped: SDK disabled")
68
+ return []
69
+
70
+ if os.environ.get("AGENTRUST_AUTO_INSTRUMENT", "true").lower() in ("0", "false", "no"):
71
+ logger.debug("[AgentTrust] Auto-instrumentation skipped: AGENTRUST_AUTO_INSTRUMENT=false")
72
+ return []
73
+
74
+ opts = {"agent_id": agent_id, "gateway_url": gateway_url, "api_key": api_key}
75
+ results = []
76
+
77
+ if openai and patch_openai(**opts):
78
+ results.append("openai")
79
+
80
+ if langchain and patch_langchain(**opts):
81
+ results.append("langchain")
82
+
83
+ if langgraph and patch_langgraph(**opts):
84
+ results.append("langgraph")
85
+
86
+ if results:
87
+ logger.info("[AgentTrust] Auto-instrumented: %s", ", ".join(results))
88
+ return results
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # Internal helpers
93
+ # ---------------------------------------------------------------------------
94
+
95
+ def _validate_output_sync(
96
+ output: Any,
97
+ input_text: str,
98
+ agent_id: str,
99
+ framework: str,
100
+ gateway_url: str | None,
101
+ api_key: str | None,
102
+ model: str = "unknown",
103
+ ) -> Any:
104
+ """Fire-and-forget validate call; swallows all errors (fail-open)."""
105
+ if not SDK_CONFIG.enabled:
106
+ return output
107
+ try:
108
+ from .client import AgentTrustClient
109
+ payload = output if isinstance(output, dict) else {"result": str(output)[:2000]}
110
+ with AgentTrustClient(
111
+ base_url=gateway_url or SDK_CONFIG.gateway_url,
112
+ api_key=api_key or SDK_CONFIG.api_key,
113
+ ) as client:
114
+ client.validate(
115
+ agent_id=agent_id,
116
+ user=framework.lower(),
117
+ input=str(input_text)[:500],
118
+ output=payload,
119
+ framework=framework,
120
+ model=model,
121
+ )
122
+ except Exception as exc:
123
+ logger.debug("[AgentTrust] %s hook skipped: %s", framework, exc)
124
+ return output
125
+
126
+
127
+ async def _validate_output_async(
128
+ output: Any,
129
+ input_text: str,
130
+ agent_id: str,
131
+ framework: str,
132
+ gateway_url: str | None,
133
+ api_key: str | None,
134
+ model: str = "unknown",
135
+ ) -> Any:
136
+ if not SDK_CONFIG.enabled:
137
+ return output
138
+ try:
139
+ from .client import AsyncAgentTrustClient
140
+ payload = output if isinstance(output, dict) else {"result": str(output)[:2000]}
141
+ async with AsyncAgentTrustClient(
142
+ base_url=gateway_url or SDK_CONFIG.gateway_url,
143
+ api_key=api_key or SDK_CONFIG.api_key,
144
+ ) as client:
145
+ await client.validate(
146
+ agent_id=agent_id,
147
+ user=framework.lower(),
148
+ input=str(input_text)[:500],
149
+ output=payload,
150
+ framework=framework,
151
+ model=model,
152
+ )
153
+ except Exception as exc:
154
+ logger.debug("[AgentTrust] %s async hook skipped: %s", framework, exc)
155
+ return output
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # OpenAI patch (SDK >= 1.0)
160
+ # ---------------------------------------------------------------------------
161
+
162
+ def patch_openai(
163
+ agent_id: str = "openai-auto",
164
+ gateway_url: str | None = None,
165
+ api_key: str | None = None,
166
+ ) -> bool:
167
+ """Patch ``openai.chat.completions.create`` (sync + async) to add governance."""
168
+ if _is_patched("openai"):
169
+ return True
170
+
171
+ try:
172
+ import openai as _openai
173
+ _sync_completions = _openai.chat.completions
174
+ _async_completions = _openai.AsyncOpenAI # check import exists
175
+ except (ImportError, AttributeError):
176
+ return False
177
+
178
+ _orig_sync = _sync_completions.create
179
+
180
+ def _patched_sync(*args: Any, **kwargs: Any) -> Any:
181
+ result = _orig_sync(*args, **kwargs)
182
+ if not SDK_CONFIG.enabled:
183
+ return result
184
+ try:
185
+ output_text = ""
186
+ if hasattr(result, "choices") and result.choices:
187
+ msg = result.choices[0].message
188
+ output_text = getattr(msg, "content", "") or ""
189
+ input_text = ""
190
+ for m in kwargs.get("messages", []):
191
+ if isinstance(m, dict) and m.get("role") == "user":
192
+ input_text = str(m.get("content", ""))[:500]
193
+ break
194
+ _validate_output_sync(
195
+ {"content": output_text[:2000]}, input_text, agent_id,
196
+ "OpenAI", gateway_url, api_key, model=kwargs.get("model", "unknown"),
197
+ )
198
+ except Exception as exc:
199
+ logger.debug("[AgentTrust] OpenAI sync hook error: %s", exc)
200
+ return result
201
+
202
+ _sync_completions.create = _patched_sync # type: ignore[method-assign]
203
+
204
+ # Also patch AsyncOpenAI client's chat completions create if accessible
205
+ try:
206
+ _async_client_cls = _openai.AsyncOpenAI
207
+ _orig_async_create = _async_client_cls.chat.completions.create
208
+
209
+ async def _patched_async(*args: Any, **kwargs: Any) -> Any:
210
+ result = await _orig_async_create(*args, **kwargs)
211
+ if not SDK_CONFIG.enabled:
212
+ return result
213
+ try:
214
+ output_text = ""
215
+ if hasattr(result, "choices") and result.choices:
216
+ msg = result.choices[0].message
217
+ output_text = getattr(msg, "content", "") or ""
218
+ input_text = ""
219
+ for m in kwargs.get("messages", []):
220
+ if isinstance(m, dict) and m.get("role") == "user":
221
+ input_text = str(m.get("content", ""))[:500]
222
+ break
223
+ await _validate_output_async(
224
+ {"content": output_text[:2000]}, input_text, agent_id,
225
+ "OpenAI", gateway_url, api_key, model=kwargs.get("model", "unknown"),
226
+ )
227
+ except Exception as exc:
228
+ logger.debug("[AgentTrust] OpenAI async hook error: %s", exc)
229
+ return result
230
+
231
+ _async_client_cls.chat.completions.create = _patched_async # type: ignore[method-assign]
232
+ with _REGISTRY_LOCK:
233
+ _PATCH_REGISTRY["openai"] = [
234
+ (_sync_completions, "create", _orig_sync),
235
+ (_async_client_cls.chat.completions, "create", _orig_async_create),
236
+ ]
237
+ except Exception:
238
+ with _REGISTRY_LOCK:
239
+ _PATCH_REGISTRY["openai"] = [(_sync_completions, "create", _orig_sync)]
240
+
241
+ logger.debug("[AgentTrust] OpenAI auto-instrumentation applied (sync + async)")
242
+ return True
243
+
244
+
245
+ # ---------------------------------------------------------------------------
246
+ # LangChain patch — modern LCEL (langchain_core) + legacy fallback
247
+ # ---------------------------------------------------------------------------
248
+
249
+ def patch_langchain(
250
+ agent_id: str = "langchain-auto",
251
+ gateway_url: str | None = None,
252
+ api_key: str | None = None,
253
+ ) -> bool:
254
+ """Patch LangChain to add governance on invoke calls.
255
+
256
+ Targets (in order of preference):
257
+ 1. ``langchain_core.runnables.base.Runnable.invoke`` — modern LCEL (>= 0.1)
258
+ 2. ``langchain.llms.base.BaseLLM.predict`` — legacy LLMs
259
+ 3. ``langchain.chains.base.Chain.__call__`` — legacy chains
260
+ """
261
+ if _is_patched("langchain"):
262
+ return True
263
+
264
+ targets: list[tuple[Any, str, Any]] = []
265
+
266
+ # ── 1. langchain_core LCEL (modern, preferred) ─────────────────────────
267
+ try:
268
+ import langchain_core.runnables.base as _runnables
269
+ _orig_invoke = _runnables.Runnable.invoke
270
+
271
+ def _patched_lcel_invoke(self: Any, input: Any, config: Any = None, **kwargs: Any) -> Any:
272
+ result = _orig_invoke(self, input, config, **kwargs)
273
+ if not SDK_CONFIG.enabled:
274
+ return result
275
+ input_text = str(input)[:500] if not isinstance(input, dict) else str(input.get("input", input))[:500]
276
+ _validate_output_sync(result, input_text, agent_id, "LangChain", gateway_url, api_key)
277
+ return result
278
+
279
+ _runnables.Runnable.invoke = _patched_lcel_invoke # type: ignore[method-assign]
280
+ targets.append((_runnables.Runnable, "invoke", _orig_invoke))
281
+ logger.debug("[AgentTrust] LangChain LCEL (langchain_core) patched")
282
+ except (ImportError, AttributeError):
283
+ pass
284
+
285
+ # ── 2. Legacy BaseLLM.predict ───────────────────────────────────────────
286
+ try:
287
+ import langchain.llms.base as _llm_base
288
+ _orig_predict = getattr(_llm_base.BaseLLM, "predict", None)
289
+ if _orig_predict:
290
+ def _patched_predict(self: Any, text: str, **kwargs: Any) -> str:
291
+ result = _orig_predict(self, text, **kwargs)
292
+ _validate_output_sync(result, text, agent_id, "LangChain", gateway_url, api_key)
293
+ return result
294
+ _llm_base.BaseLLM.predict = _patched_predict # type: ignore[method-assign]
295
+ targets.append((_llm_base.BaseLLM, "predict", _orig_predict))
296
+ except (ImportError, AttributeError):
297
+ pass
298
+
299
+ # ── 3. Legacy Chain.__call__ ────────────────────────────────────────────
300
+ try:
301
+ import langchain.chains.base as _chain_base
302
+ _orig_call = getattr(_chain_base.Chain, "__call__", None)
303
+ if _orig_call:
304
+ def _patched_chain_call(self: Any, inputs: Any, **kwargs: Any) -> Any:
305
+ result = _orig_call(self, inputs, **kwargs)
306
+ input_text = str(inputs) if not isinstance(inputs, dict) else str(inputs.get("input", inputs))
307
+ _validate_output_sync(result, input_text, agent_id, "LangChain", gateway_url, api_key)
308
+ return result
309
+ _chain_base.Chain.__call__ = _patched_chain_call # type: ignore[method-assign]
310
+ targets.append((_chain_base.Chain, "__call__", _orig_call))
311
+ except (ImportError, AttributeError):
312
+ pass
313
+
314
+ if not targets:
315
+ return False # no LangChain variant found
316
+
317
+ with _REGISTRY_LOCK:
318
+ _PATCH_REGISTRY["langchain"] = targets
319
+ return True
320
+
321
+
322
+ # ---------------------------------------------------------------------------
323
+ # LangGraph patch — CompiledStateGraph.invoke / ainvoke (P0-3)
324
+ # ---------------------------------------------------------------------------
325
+
326
+ def patch_langgraph(
327
+ agent_id: str = "langgraph-auto",
328
+ gateway_url: str | None = None,
329
+ api_key: str | None = None,
330
+ ) -> bool:
331
+ """Patch LangGraph ``CompiledStateGraph.invoke`` and ``ainvoke`` for governance.
332
+
333
+ This covers auto-instrumentation without needing ``AgentTrustNode`` wiring.
334
+ Any ``graph.invoke(state)`` call will be governed.
335
+ """
336
+ if _is_patched("langgraph"):
337
+ return True
338
+
339
+ try:
340
+ from langgraph.graph.state import CompiledStateGraph as _CSG
341
+ except ImportError:
342
+ try:
343
+ # fallback for older langgraph layouts
344
+ from langgraph.pregel import Pregel as _CSG # type: ignore[assignment,no-redef]
345
+ except ImportError:
346
+ return False
347
+
348
+ targets: list[tuple[Any, str, Any]] = []
349
+
350
+ _orig_invoke = getattr(_CSG, "invoke", None)
351
+ _orig_ainvoke = getattr(_CSG, "ainvoke", None)
352
+
353
+ if _orig_invoke:
354
+ def _patched_invoke(self: Any, input: Any, config: Any = None, **kwargs: Any) -> Any:
355
+ result = _orig_invoke(self, input, config, **kwargs)
356
+ if not SDK_CONFIG.enabled:
357
+ return result
358
+ input_text = str(input)[:500] if not isinstance(input, dict) else str(input.get("input", input))[:500]
359
+ _validate_output_sync(result, input_text, agent_id, "LangGraph", gateway_url, api_key)
360
+ return result
361
+
362
+ _CSG.invoke = _patched_invoke # type: ignore[method-assign]
363
+ targets.append((_CSG, "invoke", _orig_invoke))
364
+
365
+ if _orig_ainvoke:
366
+ async def _patched_ainvoke(self: Any, input: Any, config: Any = None, **kwargs: Any) -> Any:
367
+ result = await _orig_ainvoke(self, input, config, **kwargs)
368
+ if not SDK_CONFIG.enabled:
369
+ return result
370
+ input_text = str(input)[:500] if not isinstance(input, dict) else str(input.get("input", input))[:500]
371
+ await _validate_output_async(result, input_text, agent_id, "LangGraph", gateway_url, api_key)
372
+ return result
373
+
374
+ _CSG.ainvoke = _patched_ainvoke # type: ignore[method-assign]
375
+ targets.append((_CSG, "ainvoke", _orig_ainvoke))
376
+
377
+ if not targets:
378
+ return False
379
+
380
+ with _REGISTRY_LOCK:
381
+ _PATCH_REGISTRY["langgraph"] = targets
382
+ logger.debug("[AgentTrust] LangGraph auto-instrumentation applied")
383
+ return True
384
+
385
+
386
+ # ---------------------------------------------------------------------------
387
+ # LangGraph compile-time wrapper (P0-3)
388
+ # ---------------------------------------------------------------------------
389
+
390
+ def auto_wrap(compiled_graph: Any, agent_id: str = "langgraph-wrapped") -> Any:
391
+ """Wrap a compiled LangGraph graph so every invoke/ainvoke is governed.
392
+
393
+ Usage — zero-touch: pass your compiled graph through this helper once::
394
+
395
+ from agentrust_sdk.hooks import auto_wrap
396
+
397
+ graph = workflow.compile()
398
+ graph = auto_wrap(graph, agent_id="my-agent") # governance on every invoke
399
+
400
+ # No other code changes needed
401
+ result = graph.invoke({"input": "hello", "user": "alice"})
402
+
403
+ This is lighter than the global ``patch_langgraph()`` monkey-patch: it wraps
404
+ only the specific graph instance, not all CompiledStateGraph objects.
405
+ """
406
+ _orig_invoke = getattr(compiled_graph, "invoke", None)
407
+ _orig_ainvoke = getattr(compiled_graph, "ainvoke", None)
408
+ _orig_stream = getattr(compiled_graph, "stream", None)
409
+
410
+ opts = {"agent_id": agent_id, "gateway_url": None, "api_key": SDK_CONFIG.api_key}
411
+
412
+ if _orig_invoke:
413
+ def _wrapped_invoke(input: Any, config: Any = None, **kwargs: Any) -> Any:
414
+ result = _orig_invoke(input, config, **kwargs)
415
+ if SDK_CONFIG.enabled:
416
+ input_text = str(input)[:500] if not isinstance(input, dict) else str(input.get("input", input))[:500]
417
+ _validate_output_sync(result, input_text, agent_id, "LangGraph", None, SDK_CONFIG.api_key)
418
+ return result
419
+ compiled_graph.invoke = _wrapped_invoke
420
+
421
+ if _orig_ainvoke:
422
+ async def _wrapped_ainvoke(input: Any, config: Any = None, **kwargs: Any) -> Any:
423
+ result = await _orig_ainvoke(input, config, **kwargs)
424
+ if SDK_CONFIG.enabled:
425
+ input_text = str(input)[:500] if not isinstance(input, dict) else str(input.get("input", input))[:500]
426
+ await _validate_output_async(result, input_text, agent_id, "LangGraph", None, SDK_CONFIG.api_key)
427
+ return result
428
+ compiled_graph.ainvoke = _wrapped_ainvoke
429
+
430
+ logger.debug("[AgentTrust] auto_wrap applied to graph instance (agent_id=%s)", agent_id)
431
+ return compiled_graph
432
+
433
+
434
+ # ---------------------------------------------------------------------------
435
+ # Patch removal — restores original callables (P1-3 fix)
436
+ # ---------------------------------------------------------------------------
437
+
438
+ def remove_patches() -> list[str]:
439
+ """Restore all patched callables to their originals.
440
+
441
+ Returns the list of framework names that were restored.
442
+ Note: a **process restart** is still the safest rollback if the process
443
+ already served traffic through patched paths.
444
+ """
445
+ with _REGISTRY_LOCK:
446
+ snapshot = list(_PATCH_REGISTRY.items())
447
+
448
+ restored = []
449
+ for name, targets in snapshot:
450
+ for obj, attr, original in targets:
451
+ try:
452
+ setattr(obj, attr, original)
453
+ except Exception as exc:
454
+ logger.warning("[AgentTrust] Could not restore %s.%s: %s", obj, attr, exc)
455
+ restored.append(name)
456
+
457
+ with _REGISTRY_LOCK:
458
+ _PATCH_REGISTRY.clear()
459
+ if restored:
460
+ logger.info("[AgentTrust] Removed patches: %s", ", ".join(restored))
461
+ return restored
@@ -0,0 +1,81 @@
1
+ """SDK-facing request/response models."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class ToolCall(BaseModel):
9
+ name: str
10
+ arguments: dict[str, Any] = Field(default_factory=dict)
11
+ result: Any = None
12
+ latency_ms: float | None = None
13
+ error: str | None = None
14
+
15
+
16
+ class ValidationResult(BaseModel):
17
+ schema_score: float = 0.0
18
+ evidence_score: float = 0.0
19
+ tool_trust_score: float = 0.0
20
+ consistency_score: float = 0.0
21
+ policy_score: float = 0.0
22
+ judge_score: float | None = None
23
+ final_confidence: float = 0.0
24
+ failures: list[str] = Field(default_factory=list)
25
+
26
+
27
+ class RiskResult(BaseModel):
28
+ tier: str = "unknown" # "unknown" = not computed (tier too low)
29
+ score: float = 0.0
30
+ reason: str = ""
31
+
32
+
33
+ class DecisionResult(BaseModel):
34
+ outcome: str = "pending" # "pending" = not computed (tier too low)
35
+ reason: str = ""
36
+ policy_version: str = ""
37
+
38
+
39
+ class ValidateRequest(BaseModel):
40
+ agent_id: str
41
+ framework: str = "REST"
42
+ version: str = "1"
43
+ parent_envelope_id: str | None = None
44
+ user: str
45
+ input: str
46
+ output: dict[str, Any] = Field(default_factory=dict)
47
+ model: str = "unknown"
48
+ tools_called: list[ToolCall] = Field(default_factory=list)
49
+ latency_ms: float = 0.0
50
+ tokens: int = 0
51
+ session_id: str | None = None
52
+ metadata: dict[str, Any] = Field(default_factory=dict)
53
+
54
+
55
+ class ValidateResponse(BaseModel):
56
+ envelope_id: str
57
+ validation: ValidationResult
58
+ risk: RiskResult
59
+ decision: DecisionResult
60
+ latency_ms: float
61
+ trust_chain: dict | None = None
62
+ # Tier metadata — always populated
63
+ tier_info: str = "unknown"
64
+ upgrade_hint: str | None = None
65
+
66
+ @property
67
+ def approved(self) -> bool:
68
+ return self.decision.outcome == "approve"
69
+
70
+ @property
71
+ def blocked(self) -> bool:
72
+ return self.decision.outcome == "block"
73
+
74
+ @property
75
+ def needs_review(self) -> bool:
76
+ return self.decision.outcome in ("escalate", "request_evidence")
77
+
78
+ @property
79
+ def schema_valid(self) -> bool:
80
+ """Available on all tiers including OSS."""
81
+ return self.validation.schema_score >= 80.0 and not self.validation.failures
agentrust_sdk/py.typed ADDED
File without changes