iris-security-langchain 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/governed_agent.py +56 -0
- iris_langchain/__init__.py +48 -0
- iris_langchain/_governance.py +349 -0
- iris_langchain/agent.py +140 -0
- iris_langchain/callback.py +366 -0
- iris_langchain/tools.py +105 -0
- iris_security_langchain-0.1.1.dist-info/METADATA +39 -0
- iris_security_langchain-0.1.1.dist-info/RECORD +12 -0
- iris_security_langchain-0.1.1.dist-info/WHEEL +5 -0
- iris_security_langchain-0.1.1.dist-info/top_level.txt +3 -0
- tests/conftest.py +11 -0
- tests/test_langchain_integration.py +254 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Minimal LangChain agent with IRIS governance.
|
|
3
|
+
|
|
4
|
+
Requires optional deps: pip install iris-langchain[openai]
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from langchain.agents import AgentExecutor, create_tool_calling_agent
|
|
10
|
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
11
|
+
from langchain_core.tools import tool
|
|
12
|
+
from langchain_openai import ChatOpenAI
|
|
13
|
+
|
|
14
|
+
from iris import AgentPassport, ComplianceTag
|
|
15
|
+
from iris_langchain import IrisLangChainAgent
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@tool
|
|
19
|
+
def lookup_account(account_id: str, data_region: str = "us-east-1") -> str:
|
|
20
|
+
"""Look up a customer account by ID."""
|
|
21
|
+
return f"Account {account_id} in {data_region} is active."
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def main() -> None:
|
|
25
|
+
passport = AgentPassport(
|
|
26
|
+
name="research-agent",
|
|
27
|
+
owner="team@company.com",
|
|
28
|
+
compliance_tags=[ComplianceTag.COLORADO_AI_ACT],
|
|
29
|
+
is_high_risk_ai=True,
|
|
30
|
+
tool_permissions=[], # declare tools in passport for production
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
|
34
|
+
tools = [lookup_account]
|
|
35
|
+
prompt = ChatPromptTemplate.from_messages(
|
|
36
|
+
[
|
|
37
|
+
("system", "You are a helpful research agent."),
|
|
38
|
+
("human", "{input}"),
|
|
39
|
+
MessagesPlaceholder("agent_scratchpad"),
|
|
40
|
+
]
|
|
41
|
+
)
|
|
42
|
+
agent_runnable = create_tool_calling_agent(llm, tools, prompt)
|
|
43
|
+
base_executor = AgentExecutor(agent=agent_runnable, tools=tools, verbose=True)
|
|
44
|
+
|
|
45
|
+
agent = IrisLangChainAgent.from_agent(
|
|
46
|
+
base_executor,
|
|
47
|
+
passport,
|
|
48
|
+
compliance=["colorado-ai-act"],
|
|
49
|
+
)
|
|
50
|
+
result = agent.run("Research this topic and summarize findings")
|
|
51
|
+
print(result)
|
|
52
|
+
# IRIS evaluated every tool call. Violations blocked. Evidence logged.
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if __name__ == "__main__":
|
|
56
|
+
main()
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""
|
|
2
|
+
IRIS LangChain integration — governance in three lines of code.
|
|
3
|
+
|
|
4
|
+
Quickstart:
|
|
5
|
+
from iris_langchain import IrisLangChainAgent
|
|
6
|
+
from iris import AgentPassport, ComplianceTag
|
|
7
|
+
|
|
8
|
+
passport = AgentPassport(
|
|
9
|
+
name="support-agent",
|
|
10
|
+
owner="team@company.com",
|
|
11
|
+
compliance_tags=[ComplianceTag.COLORADO_AI_ACT],
|
|
12
|
+
)
|
|
13
|
+
agent = IrisLangChainAgent.from_agent(base_agent, passport)
|
|
14
|
+
result = agent.run("Help this customer")
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from iris import IrisViolationError
|
|
20
|
+
from iris_core.models.passport import (
|
|
21
|
+
AgentPassport,
|
|
22
|
+
ComplianceTag,
|
|
23
|
+
DataClassification,
|
|
24
|
+
Environment,
|
|
25
|
+
ToolPermission,
|
|
26
|
+
)
|
|
27
|
+
from iris_core.models.policy import PolicyResult, Severity, Violation
|
|
28
|
+
|
|
29
|
+
from iris_langchain.agent import IrisLangChainAgent
|
|
30
|
+
from iris_langchain.callback import IrisCallbackHandler
|
|
31
|
+
from iris_langchain.tools import iris_tool_guard
|
|
32
|
+
|
|
33
|
+
__version__ = "0.1.0"
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"IrisCallbackHandler",
|
|
37
|
+
"IrisLangChainAgent",
|
|
38
|
+
"iris_tool_guard",
|
|
39
|
+
"IrisViolationError",
|
|
40
|
+
"AgentPassport",
|
|
41
|
+
"ComplianceTag",
|
|
42
|
+
"DataClassification",
|
|
43
|
+
"Environment",
|
|
44
|
+
"ToolPermission",
|
|
45
|
+
"PolicyResult",
|
|
46
|
+
"Severity",
|
|
47
|
+
"Violation",
|
|
48
|
+
]
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""Shared IRIS evaluation helpers for LangChain integrations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
import sys
|
|
10
|
+
import threading
|
|
11
|
+
import uuid
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Dict, List, Optional
|
|
16
|
+
|
|
17
|
+
from iris import IrisViolationError
|
|
18
|
+
from iris_core.engine.cedar import CedarEngine, EvaluationContext
|
|
19
|
+
from iris_core.rbac.context import UserContext
|
|
20
|
+
from iris_core.evidence.vault import EvidenceVault
|
|
21
|
+
from iris_core.models.passport import AgentPassport, Environment
|
|
22
|
+
from iris_core.models.policy import PolicyResult, Severity, Violation
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger("iris.langchain")
|
|
25
|
+
|
|
26
|
+
_VAULT_LOCK = threading.Lock()
|
|
27
|
+
|
|
28
|
+
_SSN = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
|
|
29
|
+
_CREDIT_CARD = re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")
|
|
30
|
+
_DOB = re.compile(
|
|
31
|
+
r"\b(?:0[1-9]|1[0-2])[/-](?:0[1-9]|[12]\d|3[01])[/-](?:19|20)\d{2}\b"
|
|
32
|
+
)
|
|
33
|
+
_CROSS_REGION = re.compile(r"cn-north|china|beijing", re.IGNORECASE)
|
|
34
|
+
_HIGH_RISK_DOMAIN = re.compile(r"\b(loan|diagnosis|hiring)\b", re.IGNORECASE)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class RunSession:
|
|
39
|
+
"""Per-agent-run compliance tracking for Evidence Vault correlation."""
|
|
40
|
+
|
|
41
|
+
run_id: str
|
|
42
|
+
tool_calls: int = 0
|
|
43
|
+
violations: int = 0
|
|
44
|
+
permits: int = 0
|
|
45
|
+
warnings: int = 0
|
|
46
|
+
pii_output_violations: int = 0
|
|
47
|
+
finalized: bool = False
|
|
48
|
+
events: List[str] = field(default_factory=list)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def pass_rate(self) -> float:
|
|
52
|
+
if self.tool_calls == 0:
|
|
53
|
+
return 1.0
|
|
54
|
+
blocked = self.violations + self.pii_output_violations
|
|
55
|
+
return max(0.0, (self.tool_calls - blocked) / self.tool_calls)
|
|
56
|
+
|
|
57
|
+
def to_summary(self) -> dict:
|
|
58
|
+
return {
|
|
59
|
+
"run_id": self.run_id,
|
|
60
|
+
"total_tool_calls": self.tool_calls,
|
|
61
|
+
"violations": self.violations,
|
|
62
|
+
"pii_output_violations": self.pii_output_violations,
|
|
63
|
+
"permits": self.permits,
|
|
64
|
+
"warnings": self.warnings,
|
|
65
|
+
"pass_rate": round(self.pass_rate, 4),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def resolve_environment(env: Optional[Environment] = None) -> Environment:
|
|
70
|
+
if env is not None:
|
|
71
|
+
return env
|
|
72
|
+
return Environment(os.environ.get("IRIS_ENV", "dev"))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def has_policy_loaded(engine: CedarEngine, passport: AgentPassport) -> bool:
|
|
76
|
+
return bool(engine._policy_cache.get(passport.agent_id))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def load_passport_policy(engine: CedarEngine, passport: AgentPassport) -> None:
|
|
80
|
+
"""Load Cedar policy from passport.policy_ref when present on disk."""
|
|
81
|
+
if not passport.policy_ref:
|
|
82
|
+
return
|
|
83
|
+
policy_path = Path(passport.policy_ref)
|
|
84
|
+
if not policy_path.is_absolute():
|
|
85
|
+
policy_path = Path.cwd() / policy_path
|
|
86
|
+
if policy_path.exists():
|
|
87
|
+
engine.load_policy_file(passport.agent_id, policy_path)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def apply_no_policy_gate(
|
|
91
|
+
engine: CedarEngine,
|
|
92
|
+
passport: AgentPassport,
|
|
93
|
+
env: Environment,
|
|
94
|
+
result: PolicyResult,
|
|
95
|
+
) -> PolicyResult:
|
|
96
|
+
"""Fail open in dev/test when no policy is loaded; fail closed in staging/prod."""
|
|
97
|
+
if has_policy_loaded(engine, passport):
|
|
98
|
+
return result
|
|
99
|
+
if env in (Environment.DEV, Environment.TEST):
|
|
100
|
+
if result.decision == "DENY":
|
|
101
|
+
return PolicyResult(
|
|
102
|
+
decision="PERMIT_WITH_WARNINGS",
|
|
103
|
+
violations=result.violations,
|
|
104
|
+
agent_id=result.agent_id,
|
|
105
|
+
action=result.action,
|
|
106
|
+
resource=result.resource,
|
|
107
|
+
environment=result.environment,
|
|
108
|
+
)
|
|
109
|
+
return result
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def extract_regions(inputs: Optional[Dict[str, Any]]) -> tuple[Optional[str], Optional[str]]:
|
|
113
|
+
if not inputs:
|
|
114
|
+
return None, None
|
|
115
|
+
data_region = inputs.get("data_region")
|
|
116
|
+
destination_region = inputs.get("destination_region") or inputs.get("dest_region")
|
|
117
|
+
return (
|
|
118
|
+
str(data_region) if data_region is not None else None,
|
|
119
|
+
str(destination_region) if destination_region is not None else None,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def detect_pii(text: str) -> bool:
|
|
124
|
+
if not text:
|
|
125
|
+
return False
|
|
126
|
+
return bool(_SSN.search(text) or _CREDIT_CARD.search(text) or _DOB.search(text))
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def pii_output_violation(passport: AgentPassport, tool_name: str) -> Violation:
|
|
130
|
+
return Violation(
|
|
131
|
+
rule_id="IRIS-DATA-001",
|
|
132
|
+
severity=Severity.HIGH,
|
|
133
|
+
message=(
|
|
134
|
+
f"Tool '{tool_name}' output for agent '{passport.name}' may contain PII "
|
|
135
|
+
f"(SSN, payment card, or date-of-birth pattern)."
|
|
136
|
+
),
|
|
137
|
+
compliance_refs=[
|
|
138
|
+
"colorado-ai-act:impact-assessment",
|
|
139
|
+
"gdpr:data-minimization",
|
|
140
|
+
],
|
|
141
|
+
remediation=(
|
|
142
|
+
"Redact sensitive data from tool outputs or restrict tools that return PII "
|
|
143
|
+
"in production environments."
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def check_prompt_guardrails(prompt: str, passport: AgentPassport) -> List[Violation]:
|
|
149
|
+
"""Scan LLM prompts for cross-region and high-risk domain indicators."""
|
|
150
|
+
if not prompt:
|
|
151
|
+
return []
|
|
152
|
+
|
|
153
|
+
violations: List[Violation] = []
|
|
154
|
+
|
|
155
|
+
if detect_pii(prompt):
|
|
156
|
+
violations.append(
|
|
157
|
+
Violation(
|
|
158
|
+
rule_id="IRIS-DATA-001",
|
|
159
|
+
severity=Severity.HIGH,
|
|
160
|
+
message=(
|
|
161
|
+
f"Prompt for agent '{passport.name}' may contain PII "
|
|
162
|
+
f"(SSN, payment card, or date-of-birth pattern)."
|
|
163
|
+
),
|
|
164
|
+
compliance_refs=[
|
|
165
|
+
"colorado-ai-act:impact-assessment",
|
|
166
|
+
"gdpr:data-minimization",
|
|
167
|
+
],
|
|
168
|
+
remediation=(
|
|
169
|
+
"Remove sensitive identifiers from the prompt or update the agent "
|
|
170
|
+
"passport data_classification to match the data being processed."
|
|
171
|
+
),
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if _CROSS_REGION.search(prompt):
|
|
176
|
+
violations.append(
|
|
177
|
+
Violation(
|
|
178
|
+
rule_id="IRIS-XR-001",
|
|
179
|
+
severity=Severity.CRITICAL,
|
|
180
|
+
message=(
|
|
181
|
+
f"Prompt for agent '{passport.name}' references restricted "
|
|
182
|
+
f"cross-region geography (China / cn-north)."
|
|
183
|
+
),
|
|
184
|
+
compliance_refs=["china-pipl:cross-border-transfer"],
|
|
185
|
+
remediation=(
|
|
186
|
+
"Remove cross-region references from the prompt or document an "
|
|
187
|
+
"approved exception with your security engineer."
|
|
188
|
+
),
|
|
189
|
+
)
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if _HIGH_RISK_DOMAIN.search(prompt):
|
|
193
|
+
violations.append(
|
|
194
|
+
Violation(
|
|
195
|
+
rule_id="CO-004",
|
|
196
|
+
severity=Severity.HIGH,
|
|
197
|
+
message=(
|
|
198
|
+
f"Prompt for agent '{passport.name}' references a high-risk "
|
|
199
|
+
f"consequential domain (loan, diagnosis, or hiring)."
|
|
200
|
+
),
|
|
201
|
+
compliance_refs=["colorado-ai-act:sb-24-205:consumer-opt-out"],
|
|
202
|
+
remediation=(
|
|
203
|
+
"Set user_consent_logged=True in policy context for consequential "
|
|
204
|
+
f"actions, or run 'iris compliance assess --agent {passport.agent_id}'."
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return violations
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def merge_guardrail_violations(
|
|
213
|
+
base: PolicyResult,
|
|
214
|
+
extra_violations: List[Violation],
|
|
215
|
+
) -> PolicyResult:
|
|
216
|
+
if not extra_violations:
|
|
217
|
+
return base
|
|
218
|
+
|
|
219
|
+
violations = list(base.violations) + list(extra_violations)
|
|
220
|
+
critical = [v for v in violations if v.severity == Severity.CRITICAL]
|
|
221
|
+
if critical or base.decision == "DENY":
|
|
222
|
+
decision = "DENY" if critical else base.decision
|
|
223
|
+
elif violations and base.decision == "PERMIT":
|
|
224
|
+
decision = "PERMIT_WITH_WARNINGS"
|
|
225
|
+
else:
|
|
226
|
+
decision = base.decision
|
|
227
|
+
|
|
228
|
+
return PolicyResult(
|
|
229
|
+
decision=decision,
|
|
230
|
+
violations=violations,
|
|
231
|
+
agent_id=base.agent_id,
|
|
232
|
+
action=base.action,
|
|
233
|
+
resource=base.resource,
|
|
234
|
+
environment=base.environment,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def evaluate_and_record(
|
|
239
|
+
engine: CedarEngine,
|
|
240
|
+
vault: EvidenceVault,
|
|
241
|
+
passport: AgentPassport,
|
|
242
|
+
env: Environment,
|
|
243
|
+
*,
|
|
244
|
+
action: str,
|
|
245
|
+
resource: str,
|
|
246
|
+
resource_type: str = "tool",
|
|
247
|
+
data_region: Optional[str] = None,
|
|
248
|
+
destination_region: Optional[str] = None,
|
|
249
|
+
data_classification: Optional[str] = None,
|
|
250
|
+
user_consent_logged: bool = False,
|
|
251
|
+
run_id: Optional[str] = None,
|
|
252
|
+
extra_violations: Optional[List[Violation]] = None,
|
|
253
|
+
dlp_prompt_findings: Optional[list] = None,
|
|
254
|
+
user_email: Optional[str] = None,
|
|
255
|
+
user_role: Optional[str] = None,
|
|
256
|
+
) -> PolicyResult:
|
|
257
|
+
user_ctx = UserContext.from_params(user_email, user_role)
|
|
258
|
+
ctx = EvaluationContext(
|
|
259
|
+
agent_id=passport.agent_id,
|
|
260
|
+
action=action,
|
|
261
|
+
resource=resource,
|
|
262
|
+
resource_type=resource_type,
|
|
263
|
+
environment=env,
|
|
264
|
+
data_region=data_region,
|
|
265
|
+
destination_region=destination_region,
|
|
266
|
+
data_classification=data_classification,
|
|
267
|
+
user_consent_logged=user_consent_logged,
|
|
268
|
+
dlp_prompt_findings=dlp_prompt_findings,
|
|
269
|
+
additional={"run_id": run_id} if run_id else {},
|
|
270
|
+
**user_ctx.evaluation_fields(),
|
|
271
|
+
)
|
|
272
|
+
result = engine.evaluate(passport, ctx)
|
|
273
|
+
result = apply_no_policy_gate(engine, passport, env, result)
|
|
274
|
+
result = merge_guardrail_violations(result, extra_violations or [])
|
|
275
|
+
with _VAULT_LOCK:
|
|
276
|
+
event_id = vault.record(ctx, result)
|
|
277
|
+
return result
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def record_audit_event(
|
|
281
|
+
vault: EvidenceVault,
|
|
282
|
+
*,
|
|
283
|
+
run_id: str,
|
|
284
|
+
event_type: str,
|
|
285
|
+
resource: str,
|
|
286
|
+
details: Optional[Dict[str, Any]] = None,
|
|
287
|
+
violations: Optional[List[Violation]] = None,
|
|
288
|
+
decision: str = "AUDIT",
|
|
289
|
+
) -> str:
|
|
290
|
+
"""Record a non-evaluation audit event tagged with the agent run_id."""
|
|
291
|
+
event_id = str(uuid.uuid4())
|
|
292
|
+
entry = {
|
|
293
|
+
"event_id": event_id,
|
|
294
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
295
|
+
"agent_id": vault._agent_id,
|
|
296
|
+
"run_id": run_id,
|
|
297
|
+
"event_type": event_type,
|
|
298
|
+
"action": event_type,
|
|
299
|
+
"resource": resource,
|
|
300
|
+
"decision": decision,
|
|
301
|
+
"details": details or {},
|
|
302
|
+
"violations": [
|
|
303
|
+
{
|
|
304
|
+
"rule_id": v.rule_id,
|
|
305
|
+
"severity": v.severity.value,
|
|
306
|
+
"message": v.message,
|
|
307
|
+
"compliance_refs": v.compliance_refs,
|
|
308
|
+
}
|
|
309
|
+
for v in (violations or [])
|
|
310
|
+
],
|
|
311
|
+
}
|
|
312
|
+
with _VAULT_LOCK:
|
|
313
|
+
with open(vault._log_file, "a") as f:
|
|
314
|
+
f.write(json.dumps(entry) + "\n")
|
|
315
|
+
return event_id
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def track_result(session: Optional[RunSession], result: PolicyResult) -> None:
|
|
319
|
+
if session is None:
|
|
320
|
+
return
|
|
321
|
+
if result.decision == "PERMIT":
|
|
322
|
+
session.permits += 1
|
|
323
|
+
elif result.decision == "PERMIT_WITH_WARNINGS":
|
|
324
|
+
session.warnings += 1
|
|
325
|
+
session.violations += len(result.violations)
|
|
326
|
+
elif result.decision == "DENY":
|
|
327
|
+
session.violations += max(len(result.violations), 1)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def enforce_result(result: PolicyResult, env: Environment) -> None:
|
|
331
|
+
if result.decision == "DENY":
|
|
332
|
+
if env in (Environment.DEV, Environment.TEST):
|
|
333
|
+
for violation in result.violations:
|
|
334
|
+
msg = (
|
|
335
|
+
f"[IRIS WARNING] {violation.message} "
|
|
336
|
+
f"Remediation: {violation.remediation}"
|
|
337
|
+
)
|
|
338
|
+
logger.warning(msg)
|
|
339
|
+
print(msg, file=sys.stderr)
|
|
340
|
+
return
|
|
341
|
+
raise IrisViolationError(result)
|
|
342
|
+
if result.decision == "PERMIT_WITH_WARNINGS":
|
|
343
|
+
for violation in result.violations:
|
|
344
|
+
msg = (
|
|
345
|
+
f"[IRIS WARNING] {violation.message} "
|
|
346
|
+
f"Remediation: {violation.remediation}"
|
|
347
|
+
)
|
|
348
|
+
logger.warning(msg)
|
|
349
|
+
print(msg, file=sys.stderr)
|
iris_langchain/agent.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""IrisLangChainAgent — wrap any LangChain agent with IRIS governance."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
from iris_core.compliance.registry import ComplianceRegistry
|
|
9
|
+
from iris_core.models.passport import AgentPassport, ComplianceTag, Environment
|
|
10
|
+
from iris_core.models.policy import Severity, Violation
|
|
11
|
+
|
|
12
|
+
from iris_langchain.callback import IrisCallbackHandler
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class IrisLangChainAgent:
|
|
16
|
+
"""
|
|
17
|
+
Drop-in governance wrapper for LangChain AgentExecutor or Runnable agents.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
agent = IrisLangChainAgent.from_agent(base_agent, passport)
|
|
21
|
+
result = agent.run("Help this customer with their account")
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
agent: Any,
|
|
27
|
+
passport: AgentPassport,
|
|
28
|
+
handler: IrisCallbackHandler,
|
|
29
|
+
compliance_frameworks: Optional[List[str]] = None,
|
|
30
|
+
):
|
|
31
|
+
self._agent = agent
|
|
32
|
+
self.passport = passport
|
|
33
|
+
self._handler = handler
|
|
34
|
+
self._compliance_frameworks = compliance_frameworks or [
|
|
35
|
+
t.value for t in passport.compliance_tags
|
|
36
|
+
]
|
|
37
|
+
self._inject_callbacks()
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_agent(
|
|
41
|
+
cls,
|
|
42
|
+
agent: Any,
|
|
43
|
+
passport: AgentPassport,
|
|
44
|
+
compliance: Optional[List[str]] = None,
|
|
45
|
+
environment: Optional[str] = None,
|
|
46
|
+
user_email: Optional[str] = None,
|
|
47
|
+
user_role: Optional[str] = None,
|
|
48
|
+
) -> "IrisLangChainAgent":
|
|
49
|
+
from iris_core.dev_trust import print_dev_trust_message
|
|
50
|
+
|
|
51
|
+
print_dev_trust_message()
|
|
52
|
+
if compliance:
|
|
53
|
+
passport.compliance_tags = [ComplianceTag(c) for c in compliance]
|
|
54
|
+
env_name = environment or os.environ.get("IRIS_ENV", "dev")
|
|
55
|
+
env = Environment(env_name)
|
|
56
|
+
handler = IrisCallbackHandler(passport, env, user_email=user_email, user_role=user_role)
|
|
57
|
+
return cls(agent, passport, handler, compliance_frameworks=compliance)
|
|
58
|
+
|
|
59
|
+
def _inject_callbacks(self) -> None:
|
|
60
|
+
if hasattr(self._agent, "callbacks"):
|
|
61
|
+
existing = list(self._agent.callbacks or [])
|
|
62
|
+
if self._handler not in existing:
|
|
63
|
+
existing.append(self._handler)
|
|
64
|
+
self._agent.callbacks = existing
|
|
65
|
+
|
|
66
|
+
def _config_with_callbacks(self, config: Optional[dict] = None) -> dict:
|
|
67
|
+
config = dict(config or {})
|
|
68
|
+
callbacks = list(config.get("callbacks") or [])
|
|
69
|
+
if self._handler not in callbacks:
|
|
70
|
+
callbacks.append(self._handler)
|
|
71
|
+
config["callbacks"] = callbacks
|
|
72
|
+
return config
|
|
73
|
+
|
|
74
|
+
def run(self, input: Union[str, dict], **kwargs: Any) -> Any:
|
|
75
|
+
self._handler.begin_run()
|
|
76
|
+
try:
|
|
77
|
+
if hasattr(self._agent, "run"):
|
|
78
|
+
callbacks = list(kwargs.pop("callbacks", None) or [])
|
|
79
|
+
if self._handler not in callbacks:
|
|
80
|
+
callbacks.append(self._handler)
|
|
81
|
+
result = self._agent.run(
|
|
82
|
+
input,
|
|
83
|
+
callbacks=callbacks or [self._handler],
|
|
84
|
+
**kwargs,
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
result = self.invoke(input, **kwargs)
|
|
88
|
+
if not self._handler.current_run or not self._handler.current_run.finalized:
|
|
89
|
+
self._handler.finalize_run(output=result)
|
|
90
|
+
return result
|
|
91
|
+
except Exception:
|
|
92
|
+
if self._handler.current_run and not self._handler.current_run.finalized:
|
|
93
|
+
self._handler.finalize_run()
|
|
94
|
+
raise
|
|
95
|
+
|
|
96
|
+
async def ainvoke(self, input: Union[str, dict], **kwargs: Any) -> Any:
|
|
97
|
+
self._handler.begin_run()
|
|
98
|
+
try:
|
|
99
|
+
config = self._config_with_callbacks(kwargs.pop("config", None))
|
|
100
|
+
payload = input if isinstance(input, dict) else {"input": input}
|
|
101
|
+
if hasattr(self._agent, "ainvoke"):
|
|
102
|
+
result = await self._agent.ainvoke(payload, config=config, **kwargs)
|
|
103
|
+
else:
|
|
104
|
+
raise TypeError("Wrapped agent does not support ainvoke()")
|
|
105
|
+
if not self._handler.current_run or not self._handler.current_run.finalized:
|
|
106
|
+
self._handler.finalize_run(output=result)
|
|
107
|
+
return result
|
|
108
|
+
except Exception:
|
|
109
|
+
if self._handler.current_run and not self._handler.current_run.finalized:
|
|
110
|
+
self._handler.finalize_run()
|
|
111
|
+
raise
|
|
112
|
+
|
|
113
|
+
def invoke(self, input: Union[str, dict], **kwargs: Any) -> Any:
|
|
114
|
+
self._handler.begin_run()
|
|
115
|
+
try:
|
|
116
|
+
config = self._config_with_callbacks(kwargs.pop("config", None))
|
|
117
|
+
payload = input if isinstance(input, dict) else {"input": input}
|
|
118
|
+
if hasattr(self._agent, "invoke"):
|
|
119
|
+
result = self._agent.invoke(payload, config=config, **kwargs)
|
|
120
|
+
elif hasattr(self._agent, "run"):
|
|
121
|
+
return self.run(input, **kwargs)
|
|
122
|
+
else:
|
|
123
|
+
raise TypeError("Wrapped agent does not support invoke() or run()")
|
|
124
|
+
if not self._handler.current_run or not self._handler.current_run.finalized:
|
|
125
|
+
self._handler.finalize_run(output=result)
|
|
126
|
+
return result
|
|
127
|
+
except Exception:
|
|
128
|
+
if self._handler.current_run and not self._handler.current_run.finalized:
|
|
129
|
+
self._handler.finalize_run()
|
|
130
|
+
raise
|
|
131
|
+
|
|
132
|
+
def compliance_check(self) -> List[Violation]:
|
|
133
|
+
registry = ComplianceRegistry()
|
|
134
|
+
return registry.check_passport(self.passport, self._compliance_frameworks)
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def is_ready_for_production(self) -> bool:
|
|
138
|
+
violations = self.compliance_check()
|
|
139
|
+
critical = [v for v in violations if v.severity == Severity.CRITICAL]
|
|
140
|
+
return len(critical) == 0
|