agentwall-security 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentwall/__init__.py +5 -0
- agentwall/cli/__init__.py +0 -0
- agentwall/cli/main.py +340 -0
- agentwall/core/__init__.py +0 -0
- agentwall/core/config_manager.py +109 -0
- agentwall/core/event_manager.py +147 -0
- agentwall/core/session_manager.py +60 -0
- agentwall/core/types.py +81 -0
- agentwall/detectors/__init__.py +0 -0
- agentwall/evaluators/__init__.py +0 -0
- agentwall/inspector/__init__.py +0 -0
- agentwall/inspector/deps.py +30 -0
- agentwall/inspector/desktop.py +58 -0
- agentwall/inspector/event_bus.py +45 -0
- agentwall/inspector/routes/__init__.py +0 -0
- agentwall/inspector/routes/events.py +12 -0
- agentwall/inspector/routes/export.py +78 -0
- agentwall/inspector/routes/goals.py +14 -0
- agentwall/inspector/routes/health.py +9 -0
- agentwall/inspector/routes/overview.py +28 -0
- agentwall/inspector/routes/policies.py +78 -0
- agentwall/inspector/routes/providers.py +90 -0
- agentwall/inspector/routes/sessions.py +54 -0
- agentwall/inspector/routes/ws.py +28 -0
- agentwall/inspector/server.py +52 -0
- agentwall/inspector/ui/dist/assets/index-CdF3waHo.js +40 -0
- agentwall/inspector/ui/dist/assets/index-OKzO_O9l.css +1 -0
- agentwall/inspector/ui/dist/index.html +13 -0
- agentwall/integrations/__init__.py +9 -0
- agentwall/integrations/crewai.py +161 -0
- agentwall/integrations/langchain.py +173 -0
- agentwall/integrations/openai_agents.py +185 -0
- agentwall/interceptors/__init__.py +24 -0
- agentwall/interceptors/agent.py +104 -0
- agentwall/interceptors/base.py +21 -0
- agentwall/interceptors/tool.py +173 -0
- agentwall/models/__init__.py +0 -0
- agentwall/models/schemas.py +93 -0
- agentwall/policies/__init__.py +0 -0
- agentwall/providers/__init__.py +0 -0
- agentwall/providers/anthropic.py +44 -0
- agentwall/providers/base.py +105 -0
- agentwall/providers/chain.py +31 -0
- agentwall/providers/deepseek.py +46 -0
- agentwall/providers/groq.py +46 -0
- agentwall/providers/keyring.py +20 -0
- agentwall/providers/ollama.py +48 -0
- agentwall/providers/openai.py +44 -0
- agentwall/providers/registry.py +64 -0
- agentwall/security/__init__.py +5 -0
- agentwall/security/detectors.py +145 -0
- agentwall/security/engine.py +142 -0
- agentwall/security/exceptions.py +13 -0
- agentwall/security/goal_tracker.py +98 -0
- agentwall/security/policy_engine.py +154 -0
- agentwall/security/result_analyzer.py +169 -0
- agentwall/security/rules.py +152 -0
- agentwall/storage/__init__.py +0 -0
- agentwall/storage/database.py +81 -0
- agentwall/storage/models.py +97 -0
- agentwall_security-0.1.0.dist-info/METADATA +337 -0
- agentwall_security-0.1.0.dist-info/RECORD +65 -0
- agentwall_security-0.1.0.dist-info/WHEEL +4 -0
- agentwall_security-0.1.0.dist-info/entry_points.txt +2 -0
- agentwall_security-0.1.0.dist-info/licenses/LICENSE +21 -0
agentwall/__init__.py
ADDED
|
File without changes
|
agentwall/cli/main.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import threading
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Annotated
|
|
7
|
+
|
|
8
|
+
import typer
|
|
9
|
+
|
|
10
|
+
from agentwall import __version__
|
|
11
|
+
|
|
12
|
+
app = typer.Typer(name="agentwall", help="AgentWall - AI Agent Runtime Security.", add_completion=False)
|
|
13
|
+
|
|
14
|
+
_INSPECTOR_HOST = "127.0.0.1"
|
|
15
|
+
_INSPECTOR_PORT = 8080
|
|
16
|
+
|
|
17
|
+
_PROVIDER_NAMES = ["openai", "anthropic", "groq", "deepseek", "ollama"]
|
|
18
|
+
|
|
19
|
+
_PROVIDER_MODELS: dict[str, list[str]] = {
|
|
20
|
+
"openai": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"],
|
|
21
|
+
"anthropic": ["claude-opus-4-8", "claude-sonnet-4-6", "claude-haiku-4-5-20251001"],
|
|
22
|
+
"groq": ["llama-3.3-70b-versatile", "llama-3.1-8b-instant", "gemma2-9b-it"],
|
|
23
|
+
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
|
|
24
|
+
"ollama": ["llama3.2", "llama3.1", "mistral", "gemma2", "phi3"],
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
_NO_API_KEY_PROVIDERS = {"ollama"}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_db():
|
|
31
|
+
from agentwall.storage.database import Database
|
|
32
|
+
return Database()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@app.command()
|
|
36
|
+
def version() -> None:
|
|
37
|
+
"""Print version."""
|
|
38
|
+
typer.echo(f"agentwall {__version__}")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@app.command()
|
|
42
|
+
def doctor() -> None:
|
|
43
|
+
"""Check installation health."""
|
|
44
|
+
checks = [
|
|
45
|
+
("keyring", "API key storage (keyring)"),
|
|
46
|
+
("sqlalchemy", "ORM (sqlalchemy)"),
|
|
47
|
+
("fastapi", "Inspector API (fastapi)"),
|
|
48
|
+
("uvicorn", "Inspector server (uvicorn)"),
|
|
49
|
+
("pydantic", "Validation (pydantic)"),
|
|
50
|
+
("openai", "OpenAI/Groq/DeepSeek SDK (openai)"),
|
|
51
|
+
("anthropic", "Anthropic SDK (anthropic)"),
|
|
52
|
+
("webview", "Desktop Inspector (pywebview)"),
|
|
53
|
+
]
|
|
54
|
+
all_ok = True
|
|
55
|
+
for module, label in checks:
|
|
56
|
+
try:
|
|
57
|
+
importlib.import_module(module)
|
|
58
|
+
typer.echo(f" OK {label}")
|
|
59
|
+
except ImportError:
|
|
60
|
+
typer.echo(f" MISSING {label}")
|
|
61
|
+
all_ok = False
|
|
62
|
+
|
|
63
|
+
if all_ok:
|
|
64
|
+
typer.echo("\nAgentWall installation OK.")
|
|
65
|
+
else:
|
|
66
|
+
typer.echo("\nSome dependencies missing. Run: pip install agentwall")
|
|
67
|
+
raise typer.Exit(1)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@app.command()
|
|
71
|
+
def config(
|
|
72
|
+
provider: Annotated[str | None, typer.Option(help="Provider name")] = None,
|
|
73
|
+
model: Annotated[str | None, typer.Option(help="Model name")] = None,
|
|
74
|
+
priority: Annotated[int | None, typer.Option(help="Priority (1=primary)")] = None,
|
|
75
|
+
low_threshold: Annotated[float | None, typer.Option(help="Low risk threshold")] = None,
|
|
76
|
+
high_threshold: Annotated[float | None, typer.Option(help="High risk threshold")] = None,
|
|
77
|
+
status: Annotated[bool, typer.Option("--status", help="Show provider health")] = False,
|
|
78
|
+
) -> None:
|
|
79
|
+
"""Configure AgentWall. No flags = interactive wizard."""
|
|
80
|
+
from agentwall.core.config_manager import ConfigManager
|
|
81
|
+
|
|
82
|
+
db = _get_db()
|
|
83
|
+
mgr = ConfigManager(db)
|
|
84
|
+
|
|
85
|
+
if status:
|
|
86
|
+
_show_status(db)
|
|
87
|
+
db.close()
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
non_interactive = any(x is not None for x in [provider, model, priority, low_threshold, high_threshold])
|
|
91
|
+
|
|
92
|
+
if non_interactive:
|
|
93
|
+
_apply_config_flags(mgr, provider, model, priority, low_threshold, high_threshold)
|
|
94
|
+
else:
|
|
95
|
+
_run_wizard(mgr, db)
|
|
96
|
+
|
|
97
|
+
db.close()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _apply_config_flags(mgr, provider, model, priority, low_threshold, high_threshold):
|
|
101
|
+
if provider and model:
|
|
102
|
+
from agentwall.providers.keyring import store_api_key
|
|
103
|
+
p = priority or 0
|
|
104
|
+
if provider not in _NO_API_KEY_PROVIDERS:
|
|
105
|
+
api_key = typer.prompt(f"API key for {provider}", hide_input=True)
|
|
106
|
+
if api_key:
|
|
107
|
+
store_api_key(provider, api_key)
|
|
108
|
+
typer.echo(f"API key stored in OS keyring for {provider}.")
|
|
109
|
+
mgr.set_provider(provider, model, priority=p)
|
|
110
|
+
typer.echo(f"Set {provider} → {model} (priority={p})")
|
|
111
|
+
|
|
112
|
+
if low_threshold is not None or high_threshold is not None:
|
|
113
|
+
current = mgr.get_thresholds()
|
|
114
|
+
low = low_threshold if low_threshold is not None else current["low_threshold"]
|
|
115
|
+
high = high_threshold if high_threshold is not None else current["high_threshold"]
|
|
116
|
+
mgr.set_thresholds(low, high)
|
|
117
|
+
typer.echo(f"Thresholds: low={low} high={high}")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _run_wizard(mgr, db) -> None:
|
|
121
|
+
typer.echo("\nAgentWall Configuration Wizard")
|
|
122
|
+
typer.echo("=" * 32)
|
|
123
|
+
|
|
124
|
+
_print_current_config(mgr)
|
|
125
|
+
|
|
126
|
+
actions = [
|
|
127
|
+
"Add / update provider",
|
|
128
|
+
"Remove provider",
|
|
129
|
+
"Set risk thresholds",
|
|
130
|
+
"Test provider connections",
|
|
131
|
+
"Exit",
|
|
132
|
+
]
|
|
133
|
+
typer.echo("\nWhat would you like to do?")
|
|
134
|
+
for i, a in enumerate(actions, 1):
|
|
135
|
+
typer.echo(f" {i}. {a}")
|
|
136
|
+
|
|
137
|
+
choice = typer.prompt("Select", default="1")
|
|
138
|
+
try:
|
|
139
|
+
idx = int(choice) - 1
|
|
140
|
+
action = actions[idx]
|
|
141
|
+
except (ValueError, IndexError):
|
|
142
|
+
typer.echo("Invalid selection.")
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
if action == "Add / update provider":
|
|
146
|
+
_wizard_add_provider(mgr)
|
|
147
|
+
elif action == "Remove provider":
|
|
148
|
+
_wizard_remove_provider(mgr)
|
|
149
|
+
elif action == "Set risk thresholds":
|
|
150
|
+
_wizard_thresholds(mgr)
|
|
151
|
+
elif action == "Test provider connections":
|
|
152
|
+
_show_status(db)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _print_current_config(mgr) -> None:
|
|
156
|
+
providers = mgr.list_providers_ordered()
|
|
157
|
+
thresholds = mgr.get_thresholds()
|
|
158
|
+
|
|
159
|
+
typer.echo(f"\nRisk thresholds: low={thresholds['low_threshold']} high={thresholds['high_threshold']}")
|
|
160
|
+
if providers:
|
|
161
|
+
typer.echo("Providers (by priority):")
|
|
162
|
+
for p in providers:
|
|
163
|
+
status = "enabled" if p.enabled else "disabled"
|
|
164
|
+
typer.echo(f" [{p.priority}] {p.provider} → {p.model} ({status})")
|
|
165
|
+
else:
|
|
166
|
+
typer.echo("No providers configured.")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _wizard_add_provider(mgr) -> None:
|
|
170
|
+
typer.echo("\nAvailable providers:")
|
|
171
|
+
for i, name in enumerate(_PROVIDER_NAMES, 1):
|
|
172
|
+
typer.echo(f" {i}. {name}")
|
|
173
|
+
|
|
174
|
+
raw = typer.prompt("Select provider (name or number)", default="openai")
|
|
175
|
+
if raw.isdigit():
|
|
176
|
+
idx = int(raw) - 1
|
|
177
|
+
provider_name = _PROVIDER_NAMES[idx] if 0 <= idx < len(_PROVIDER_NAMES) else raw
|
|
178
|
+
else:
|
|
179
|
+
provider_name = raw.lower().strip()
|
|
180
|
+
|
|
181
|
+
if provider_name not in _PROVIDER_NAMES:
|
|
182
|
+
typer.echo(f"Unknown provider: {provider_name}")
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
# API key (skip for Ollama)
|
|
186
|
+
if provider_name not in _NO_API_KEY_PROVIDERS:
|
|
187
|
+
api_key = typer.prompt(f"API key for {provider_name}", hide_input=True)
|
|
188
|
+
if api_key:
|
|
189
|
+
from agentwall.providers.keyring import store_api_key
|
|
190
|
+
store_api_key(provider_name, api_key)
|
|
191
|
+
typer.echo(f"API key stored in OS keyring for {provider_name}.")
|
|
192
|
+
|
|
193
|
+
# Model selection
|
|
194
|
+
models = _PROVIDER_MODELS.get(provider_name, [])
|
|
195
|
+
typer.echo(f"\nAvailable models for {provider_name}:")
|
|
196
|
+
for i, m in enumerate(models, 1):
|
|
197
|
+
default_marker = " (default)" if i == 2 else ""
|
|
198
|
+
typer.echo(f" {i}. {m}{default_marker}")
|
|
199
|
+
model_raw = typer.prompt("Select model (name or number)", default="2" if len(models) >= 2 else "1")
|
|
200
|
+
if model_raw.isdigit():
|
|
201
|
+
idx = int(model_raw) - 1
|
|
202
|
+
model_name = models[idx] if 0 <= idx < len(models) else models[0]
|
|
203
|
+
else:
|
|
204
|
+
model_name = model_raw.strip()
|
|
205
|
+
|
|
206
|
+
# Priority
|
|
207
|
+
existing = mgr.list_providers_ordered()
|
|
208
|
+
next_priority = max((p.priority for p in existing), default=0) + 1
|
|
209
|
+
priority = typer.prompt(f"Priority (1=primary, higher=fallback)", default=str(next_priority), type=int)
|
|
210
|
+
|
|
211
|
+
# Test connection
|
|
212
|
+
typer.echo(f"\nTesting connection to {provider_name}...")
|
|
213
|
+
try:
|
|
214
|
+
cls = _load_evaluator_class(provider_name)
|
|
215
|
+
from agentwall.providers.keyring import get_api_key
|
|
216
|
+
api_key_val = get_api_key(provider_name) if provider_name not in _NO_API_KEY_PROVIDERS else None
|
|
217
|
+
kwargs = {"model": model_name}
|
|
218
|
+
if api_key_val:
|
|
219
|
+
kwargs["api_key"] = api_key_val
|
|
220
|
+
evaluator = cls(**kwargs)
|
|
221
|
+
result = evaluator.health_check()
|
|
222
|
+
if result.health.value == "healthy":
|
|
223
|
+
typer.echo(f"Connection OK ({result.latency_ms:.0f}ms)")
|
|
224
|
+
else:
|
|
225
|
+
typer.echo(f"Warning: {result.error}")
|
|
226
|
+
if not typer.confirm("Save anyway?"):
|
|
227
|
+
return
|
|
228
|
+
except Exception as e:
|
|
229
|
+
typer.echo(f"Connection failed: {e}")
|
|
230
|
+
if not typer.confirm("Save anyway?"):
|
|
231
|
+
return
|
|
232
|
+
|
|
233
|
+
mgr.set_provider(provider_name, model_name, priority=priority)
|
|
234
|
+
typer.echo(f"\nSaved: {provider_name} → {model_name} (priority={priority})")
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _wizard_remove_provider(mgr) -> None:
|
|
238
|
+
providers = mgr.list_providers_ordered()
|
|
239
|
+
if not providers:
|
|
240
|
+
typer.echo("No providers configured.")
|
|
241
|
+
return
|
|
242
|
+
typer.echo("\nConfigured providers:")
|
|
243
|
+
for i, p in enumerate(providers, 1):
|
|
244
|
+
typer.echo(f" {i}. {p.provider}")
|
|
245
|
+
raw = typer.prompt("Select provider to remove")
|
|
246
|
+
if raw.isdigit():
|
|
247
|
+
idx = int(raw) - 1
|
|
248
|
+
name = providers[idx].provider if 0 <= idx < len(providers) else raw
|
|
249
|
+
else:
|
|
250
|
+
name = raw.strip()
|
|
251
|
+
if typer.confirm(f"Remove {name}?"):
|
|
252
|
+
mgr.remove_provider(name)
|
|
253
|
+
from agentwall.providers.keyring import delete_api_key
|
|
254
|
+
delete_api_key(name)
|
|
255
|
+
typer.echo(f"Removed {name} (API key deleted from keyring).")
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _wizard_thresholds(mgr) -> None:
|
|
259
|
+
current = mgr.get_thresholds()
|
|
260
|
+
typer.echo(f"\nCurrent: low={current['low_threshold']} high={current['high_threshold']}")
|
|
261
|
+
low = typer.prompt("Low threshold (ALLOW below this)", default=str(current["low_threshold"]), type=float)
|
|
262
|
+
high = typer.prompt("High threshold (LLM eval above this)", default=str(current["high_threshold"]), type=float)
|
|
263
|
+
if low >= high:
|
|
264
|
+
typer.echo("Error: low_threshold must be less than high_threshold.")
|
|
265
|
+
return
|
|
266
|
+
mgr.set_thresholds(low, high)
|
|
267
|
+
typer.echo(f"Saved: low={low} high={high}")
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _show_status(db) -> None:
|
|
271
|
+
from agentwall.providers.registry import ProviderRegistry
|
|
272
|
+
from agentwall.providers.base import ProviderHealth
|
|
273
|
+
|
|
274
|
+
reg = ProviderRegistry(db)
|
|
275
|
+
statuses = reg.health_check_all()
|
|
276
|
+
|
|
277
|
+
if not statuses:
|
|
278
|
+
typer.echo("No providers configured. Run: agentwall config")
|
|
279
|
+
return
|
|
280
|
+
|
|
281
|
+
typer.echo("\nProvider Status:")
|
|
282
|
+
for s in statuses:
|
|
283
|
+
icon = "OK" if s.health == ProviderHealth.HEALTHY else ("WARN" if s.health == ProviderHealth.DEGRADED else "FAIL")
|
|
284
|
+
latency = f" {s.latency_ms:.0f}ms" if s.latency_ms is not None else ""
|
|
285
|
+
error = f" — {s.error}" if s.error else ""
|
|
286
|
+
typer.echo(f" [{icon}] {s.provider} ({s.model}){latency}{error}")
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _load_evaluator_class(provider_name: str):
|
|
290
|
+
from agentwall.providers.registry import EVALUATOR_CLASSES
|
|
291
|
+
cls = EVALUATOR_CLASSES.get(provider_name)
|
|
292
|
+
if cls is None:
|
|
293
|
+
raise ValueError(f"Unknown provider: {provider_name}")
|
|
294
|
+
return cls
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@app.command()
|
|
298
|
+
def inspect(
|
|
299
|
+
host: Annotated[str, typer.Option(help="Bind host")] = _INSPECTOR_HOST,
|
|
300
|
+
port: Annotated[int, typer.Option(help="Bind port")] = _INSPECTOR_PORT,
|
|
301
|
+
browser: Annotated[bool, typer.Option("--browser", help="Open in browser instead of desktop window")] = False,
|
|
302
|
+
) -> None:
|
|
303
|
+
"""Launch AgentWall Inspector as a desktop window."""
|
|
304
|
+
ui_dist = Path(__file__).parent.parent / "inspector" / "ui" / "dist"
|
|
305
|
+
if not ui_dist.exists():
|
|
306
|
+
typer.echo(
|
|
307
|
+
"UI not built. Run:\n"
|
|
308
|
+
" cd agentwall/inspector/ui && npm install && npm run build"
|
|
309
|
+
)
|
|
310
|
+
typer.echo("Starting API-only mode (docs at /api/docs)...")
|
|
311
|
+
|
|
312
|
+
if browser:
|
|
313
|
+
_launch_browser(host, port)
|
|
314
|
+
else:
|
|
315
|
+
from agentwall.inspector.desktop import launch_desktop
|
|
316
|
+
typer.echo("Starting AgentWall Inspector...")
|
|
317
|
+
launch_desktop(host, port)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _launch_browser(host: str, port: int) -> None:
|
|
321
|
+
"""Fallback: start uvicorn and open the system browser."""
|
|
322
|
+
import uvicorn
|
|
323
|
+
import webbrowser
|
|
324
|
+
|
|
325
|
+
url = f"http://{host}:{port}"
|
|
326
|
+
typer.echo(f"Starting AgentWall Inspector at {url}")
|
|
327
|
+
|
|
328
|
+
def _open():
|
|
329
|
+
import time
|
|
330
|
+
time.sleep(1.5)
|
|
331
|
+
webbrowser.open(url)
|
|
332
|
+
|
|
333
|
+
threading.Thread(target=_open, daemon=True).start()
|
|
334
|
+
|
|
335
|
+
uvicorn.run(
|
|
336
|
+
"agentwall.inspector.server:app",
|
|
337
|
+
host=host,
|
|
338
|
+
port=port,
|
|
339
|
+
log_level="warning",
|
|
340
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from ..storage.database import Database
|
|
6
|
+
from ..storage.models import Policy, ProviderSetting
|
|
7
|
+
|
|
8
|
+
_DEFAULT_THRESHOLDS = {"low_threshold": 30.0, "high_threshold": 70.0}
|
|
9
|
+
_THRESHOLDS_POLICY = "thresholds"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConfigManager:
|
|
13
|
+
def __init__(self, db: Database) -> None:
|
|
14
|
+
self._db = db
|
|
15
|
+
|
|
16
|
+
# --- Provider settings ---
|
|
17
|
+
|
|
18
|
+
def get_provider(self, provider: str) -> ProviderSetting | None:
|
|
19
|
+
with self._db.session() as db:
|
|
20
|
+
row = db.get(ProviderSetting, provider)
|
|
21
|
+
if row:
|
|
22
|
+
db.expunge(row)
|
|
23
|
+
return row
|
|
24
|
+
|
|
25
|
+
def set_provider(
|
|
26
|
+
self,
|
|
27
|
+
provider: str,
|
|
28
|
+
model: str,
|
|
29
|
+
priority: int = 0,
|
|
30
|
+
enabled: bool = True,
|
|
31
|
+
config: dict | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
with self._db.session() as db:
|
|
34
|
+
existing = db.get(ProviderSetting, provider)
|
|
35
|
+
if existing:
|
|
36
|
+
existing.model = model
|
|
37
|
+
existing.priority = priority
|
|
38
|
+
existing.enabled = enabled
|
|
39
|
+
existing.config = config or {}
|
|
40
|
+
else:
|
|
41
|
+
db.add(ProviderSetting(
|
|
42
|
+
provider=provider, model=model,
|
|
43
|
+
priority=priority, enabled=enabled, config=config or {},
|
|
44
|
+
))
|
|
45
|
+
db.commit()
|
|
46
|
+
|
|
47
|
+
def set_provider_enabled(self, provider: str, enabled: bool) -> None:
|
|
48
|
+
with self._db.session() as db:
|
|
49
|
+
row = db.get(ProviderSetting, provider)
|
|
50
|
+
if row:
|
|
51
|
+
row.enabled = enabled
|
|
52
|
+
db.commit()
|
|
53
|
+
|
|
54
|
+
def remove_provider(self, provider: str) -> None:
|
|
55
|
+
with self._db.session() as db:
|
|
56
|
+
row = db.get(ProviderSetting, provider)
|
|
57
|
+
if row:
|
|
58
|
+
db.delete(row)
|
|
59
|
+
db.commit()
|
|
60
|
+
|
|
61
|
+
def list_providers(self) -> list[ProviderSetting]:
|
|
62
|
+
with self._db.session() as db:
|
|
63
|
+
rows = db.query(ProviderSetting).all()
|
|
64
|
+
for r in rows:
|
|
65
|
+
db.expunge(r)
|
|
66
|
+
return rows
|
|
67
|
+
|
|
68
|
+
def list_providers_ordered(self) -> list[ProviderSetting]:
|
|
69
|
+
with self._db.session() as db:
|
|
70
|
+
rows = db.query(ProviderSetting).order_by(ProviderSetting.priority).all()
|
|
71
|
+
for r in rows:
|
|
72
|
+
db.expunge(r)
|
|
73
|
+
return rows
|
|
74
|
+
|
|
75
|
+
# --- Policies ---
|
|
76
|
+
|
|
77
|
+
def get_policy(self, name: str) -> Policy | None:
|
|
78
|
+
with self._db.session() as db:
|
|
79
|
+
row = db.query(Policy).filter(Policy.name == name).first()
|
|
80
|
+
if row:
|
|
81
|
+
db.expunge(row)
|
|
82
|
+
return row
|
|
83
|
+
|
|
84
|
+
def set_policy(self, name: str, config: dict) -> None:
|
|
85
|
+
with self._db.session() as db:
|
|
86
|
+
existing = db.query(Policy).filter(Policy.name == name).first()
|
|
87
|
+
if existing:
|
|
88
|
+
existing.config = config
|
|
89
|
+
else:
|
|
90
|
+
db.add(Policy(name=name, config=config, created_at=time.time()))
|
|
91
|
+
db.commit()
|
|
92
|
+
|
|
93
|
+
def list_policies(self) -> list[Policy]:
|
|
94
|
+
with self._db.session() as db:
|
|
95
|
+
rows = db.query(Policy).order_by(Policy.name).all()
|
|
96
|
+
for r in rows:
|
|
97
|
+
db.expunge(r)
|
|
98
|
+
return rows
|
|
99
|
+
|
|
100
|
+
# --- Risk thresholds (stored as policy) ---
|
|
101
|
+
|
|
102
|
+
def get_thresholds(self) -> dict:
|
|
103
|
+
policy = self.get_policy(_THRESHOLDS_POLICY)
|
|
104
|
+
if policy:
|
|
105
|
+
return {**_DEFAULT_THRESHOLDS, **policy.config}
|
|
106
|
+
return dict(_DEFAULT_THRESHOLDS)
|
|
107
|
+
|
|
108
|
+
def set_thresholds(self, low: float, high: float) -> None:
|
|
109
|
+
self.set_policy(_THRESHOLDS_POLICY, {"low_threshold": low, "high_threshold": high})
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from ..storage.database import Database
|
|
6
|
+
from ..storage.models import Evaluation, GoalSegment, ToolEvent
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EventManager:
|
|
10
|
+
def __init__(self, db: Database) -> None:
|
|
11
|
+
self._db = db
|
|
12
|
+
|
|
13
|
+
def record(
|
|
14
|
+
self,
|
|
15
|
+
session_id: str,
|
|
16
|
+
tool_name: str,
|
|
17
|
+
arguments: dict,
|
|
18
|
+
*,
|
|
19
|
+
tool_type: str | None = None,
|
|
20
|
+
action: str | None = None,
|
|
21
|
+
target: str | None = None,
|
|
22
|
+
resource_category: str | None = None,
|
|
23
|
+
) -> ToolEvent:
|
|
24
|
+
with self._db.session() as db:
|
|
25
|
+
row = ToolEvent(
|
|
26
|
+
session_id=session_id,
|
|
27
|
+
tool_name=tool_name,
|
|
28
|
+
arguments=arguments,
|
|
29
|
+
timestamp=time.time(),
|
|
30
|
+
tool_type=tool_type,
|
|
31
|
+
action=action,
|
|
32
|
+
target=target,
|
|
33
|
+
resource_category=resource_category,
|
|
34
|
+
)
|
|
35
|
+
db.add(row)
|
|
36
|
+
db.commit()
|
|
37
|
+
db.refresh(row)
|
|
38
|
+
db.expunge(row)
|
|
39
|
+
return row
|
|
40
|
+
|
|
41
|
+
def record_evaluation(
|
|
42
|
+
self,
|
|
43
|
+
event_id: int,
|
|
44
|
+
decision: str,
|
|
45
|
+
risk_score: float,
|
|
46
|
+
reason: str,
|
|
47
|
+
llm_used: bool = False,
|
|
48
|
+
alignment_score: float | None = None,
|
|
49
|
+
detector_hits: list[str] | None = None,
|
|
50
|
+
policy_matched: str | None = None,
|
|
51
|
+
) -> Evaluation:
|
|
52
|
+
with self._db.session() as db:
|
|
53
|
+
row = Evaluation(
|
|
54
|
+
event_id=event_id,
|
|
55
|
+
decision=decision,
|
|
56
|
+
risk_score=risk_score,
|
|
57
|
+
reason=reason,
|
|
58
|
+
llm_used=llm_used,
|
|
59
|
+
timestamp=time.time(),
|
|
60
|
+
alignment_score=alignment_score,
|
|
61
|
+
detector_hits=detector_hits,
|
|
62
|
+
policy_matched=policy_matched,
|
|
63
|
+
)
|
|
64
|
+
db.add(row)
|
|
65
|
+
db.commit()
|
|
66
|
+
db.refresh(row)
|
|
67
|
+
db.expunge(row)
|
|
68
|
+
return row
|
|
69
|
+
|
|
70
|
+
def update_evaluation_post(
|
|
71
|
+
self,
|
|
72
|
+
event_id: int,
|
|
73
|
+
post_execution_risk: float,
|
|
74
|
+
result_classification: str,
|
|
75
|
+
result_detector_hits: list[str],
|
|
76
|
+
result_metadata: dict,
|
|
77
|
+
) -> None:
|
|
78
|
+
with self._db.session() as db:
|
|
79
|
+
db.query(Evaluation).filter(Evaluation.event_id == event_id).update({
|
|
80
|
+
"post_execution_risk": post_execution_risk,
|
|
81
|
+
"result_classification": result_classification,
|
|
82
|
+
"result_detector_hits": result_detector_hits,
|
|
83
|
+
"result_metadata": result_metadata,
|
|
84
|
+
})
|
|
85
|
+
db.commit()
|
|
86
|
+
|
|
87
|
+
def get_events(self, session_id: str) -> list[ToolEvent]:
|
|
88
|
+
with self._db.session() as db:
|
|
89
|
+
rows = (
|
|
90
|
+
db.query(ToolEvent)
|
|
91
|
+
.filter(ToolEvent.session_id == session_id)
|
|
92
|
+
.order_by(ToolEvent.timestamp)
|
|
93
|
+
.all()
|
|
94
|
+
)
|
|
95
|
+
for r in rows:
|
|
96
|
+
db.expunge(r)
|
|
97
|
+
return rows
|
|
98
|
+
|
|
99
|
+
def get_events_with_evaluations(self, session_id: str) -> list[ToolEvent]:
|
|
100
|
+
from sqlalchemy.orm import joinedload
|
|
101
|
+
|
|
102
|
+
with self._db.session() as db:
|
|
103
|
+
rows = (
|
|
104
|
+
db.query(ToolEvent)
|
|
105
|
+
.options(joinedload(ToolEvent.evaluation))
|
|
106
|
+
.filter(ToolEvent.session_id == session_id)
|
|
107
|
+
.order_by(ToolEvent.timestamp)
|
|
108
|
+
.all()
|
|
109
|
+
)
|
|
110
|
+
for r in rows:
|
|
111
|
+
if r.evaluation:
|
|
112
|
+
db.expunge(r.evaluation)
|
|
113
|
+
db.expunge(r)
|
|
114
|
+
return rows
|
|
115
|
+
|
|
116
|
+
def create_goal_segment(self, session_id: str, goal: str, reason: str = "initial") -> str:
|
|
117
|
+
with self._db.session() as db:
|
|
118
|
+
seg = GoalSegment(
|
|
119
|
+
session_id=session_id,
|
|
120
|
+
goal_text=goal,
|
|
121
|
+
started_at=time.time(),
|
|
122
|
+
transition_reason=reason,
|
|
123
|
+
)
|
|
124
|
+
db.add(seg)
|
|
125
|
+
db.commit()
|
|
126
|
+
db.refresh(seg)
|
|
127
|
+
seg_id = seg.id
|
|
128
|
+
return seg_id
|
|
129
|
+
|
|
130
|
+
def close_goal_segment(self, segment_id: str) -> None:
|
|
131
|
+
with self._db.session() as db:
|
|
132
|
+
db.query(GoalSegment).filter(GoalSegment.id == segment_id).update(
|
|
133
|
+
{"ended_at": time.time()}
|
|
134
|
+
)
|
|
135
|
+
db.commit()
|
|
136
|
+
|
|
137
|
+
def get_goal_segments(self, session_id: str) -> list[GoalSegment]:
|
|
138
|
+
with self._db.session() as db:
|
|
139
|
+
rows = (
|
|
140
|
+
db.query(GoalSegment)
|
|
141
|
+
.filter(GoalSegment.session_id == session_id)
|
|
142
|
+
.order_by(GoalSegment.started_at)
|
|
143
|
+
.all()
|
|
144
|
+
)
|
|
145
|
+
for r in rows:
|
|
146
|
+
db.expunge(r)
|
|
147
|
+
return rows
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
|
|
6
|
+
from ..storage.database import Database
|
|
7
|
+
from ..storage.models import Session
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SessionManager:
|
|
11
|
+
def __init__(self, db: Database) -> None:
|
|
12
|
+
self._db = db
|
|
13
|
+
|
|
14
|
+
def create(self, user_goal: str, meta: dict | None = None) -> Session:
|
|
15
|
+
session_id = str(uuid.uuid4())
|
|
16
|
+
with self._db.session() as db:
|
|
17
|
+
row = Session(
|
|
18
|
+
id=session_id,
|
|
19
|
+
user_goal=user_goal,
|
|
20
|
+
created_at=time.time(),
|
|
21
|
+
meta=meta or {},
|
|
22
|
+
)
|
|
23
|
+
db.add(row)
|
|
24
|
+
db.commit()
|
|
25
|
+
db.refresh(row)
|
|
26
|
+
db.expunge(row)
|
|
27
|
+
return row
|
|
28
|
+
|
|
29
|
+
def get(self, session_id: str) -> Session | None:
|
|
30
|
+
with self._db.session() as db:
|
|
31
|
+
row = db.get(Session, session_id)
|
|
32
|
+
if row:
|
|
33
|
+
db.expunge(row)
|
|
34
|
+
return row
|
|
35
|
+
|
|
36
|
+
def end(self, session_id: str) -> None:
|
|
37
|
+
with self._db.session() as db:
|
|
38
|
+
row = db.get(Session, session_id)
|
|
39
|
+
if row:
|
|
40
|
+
row.ended_at = time.time()
|
|
41
|
+
db.commit()
|
|
42
|
+
|
|
43
|
+
def update_goal(self, session_id: str, user_goal: str) -> None:
|
|
44
|
+
with self._db.session() as db:
|
|
45
|
+
row = db.get(Session, session_id)
|
|
46
|
+
if row:
|
|
47
|
+
row.user_goal = user_goal
|
|
48
|
+
db.commit()
|
|
49
|
+
|
|
50
|
+
def list(self, limit: int = 100) -> list[Session]:
|
|
51
|
+
with self._db.session() as db:
|
|
52
|
+
rows = (
|
|
53
|
+
db.query(Session)
|
|
54
|
+
.order_by(Session.created_at.desc())
|
|
55
|
+
.limit(limit)
|
|
56
|
+
.all()
|
|
57
|
+
)
|
|
58
|
+
for r in rows:
|
|
59
|
+
db.expunge(r)
|
|
60
|
+
return rows
|