agentwall-security 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. agentwall/__init__.py +5 -0
  2. agentwall/cli/__init__.py +0 -0
  3. agentwall/cli/main.py +340 -0
  4. agentwall/core/__init__.py +0 -0
  5. agentwall/core/config_manager.py +109 -0
  6. agentwall/core/event_manager.py +147 -0
  7. agentwall/core/session_manager.py +60 -0
  8. agentwall/core/types.py +81 -0
  9. agentwall/detectors/__init__.py +0 -0
  10. agentwall/evaluators/__init__.py +0 -0
  11. agentwall/inspector/__init__.py +0 -0
  12. agentwall/inspector/deps.py +30 -0
  13. agentwall/inspector/desktop.py +58 -0
  14. agentwall/inspector/event_bus.py +45 -0
  15. agentwall/inspector/routes/__init__.py +0 -0
  16. agentwall/inspector/routes/events.py +12 -0
  17. agentwall/inspector/routes/export.py +78 -0
  18. agentwall/inspector/routes/goals.py +14 -0
  19. agentwall/inspector/routes/health.py +9 -0
  20. agentwall/inspector/routes/overview.py +28 -0
  21. agentwall/inspector/routes/policies.py +78 -0
  22. agentwall/inspector/routes/providers.py +90 -0
  23. agentwall/inspector/routes/sessions.py +54 -0
  24. agentwall/inspector/routes/ws.py +28 -0
  25. agentwall/inspector/server.py +52 -0
  26. agentwall/inspector/ui/dist/assets/index-CdF3waHo.js +40 -0
  27. agentwall/inspector/ui/dist/assets/index-OKzO_O9l.css +1 -0
  28. agentwall/inspector/ui/dist/index.html +13 -0
  29. agentwall/integrations/__init__.py +9 -0
  30. agentwall/integrations/crewai.py +161 -0
  31. agentwall/integrations/langchain.py +173 -0
  32. agentwall/integrations/openai_agents.py +185 -0
  33. agentwall/interceptors/__init__.py +24 -0
  34. agentwall/interceptors/agent.py +104 -0
  35. agentwall/interceptors/base.py +21 -0
  36. agentwall/interceptors/tool.py +173 -0
  37. agentwall/models/__init__.py +0 -0
  38. agentwall/models/schemas.py +93 -0
  39. agentwall/policies/__init__.py +0 -0
  40. agentwall/providers/__init__.py +0 -0
  41. agentwall/providers/anthropic.py +44 -0
  42. agentwall/providers/base.py +105 -0
  43. agentwall/providers/chain.py +31 -0
  44. agentwall/providers/deepseek.py +46 -0
  45. agentwall/providers/groq.py +46 -0
  46. agentwall/providers/keyring.py +20 -0
  47. agentwall/providers/ollama.py +48 -0
  48. agentwall/providers/openai.py +44 -0
  49. agentwall/providers/registry.py +64 -0
  50. agentwall/security/__init__.py +5 -0
  51. agentwall/security/detectors.py +145 -0
  52. agentwall/security/engine.py +142 -0
  53. agentwall/security/exceptions.py +13 -0
  54. agentwall/security/goal_tracker.py +98 -0
  55. agentwall/security/policy_engine.py +154 -0
  56. agentwall/security/result_analyzer.py +169 -0
  57. agentwall/security/rules.py +152 -0
  58. agentwall/storage/__init__.py +0 -0
  59. agentwall/storage/database.py +81 -0
  60. agentwall/storage/models.py +97 -0
  61. agentwall_security-0.1.0.dist-info/METADATA +337 -0
  62. agentwall_security-0.1.0.dist-info/RECORD +65 -0
  63. agentwall_security-0.1.0.dist-info/WHEEL +4 -0
  64. agentwall_security-0.1.0.dist-info/entry_points.txt +2 -0
  65. agentwall_security-0.1.0.dist-info/licenses/LICENSE +21 -0
agentwall/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ __version__ = "0.1.0"
2
+
3
+ from agentwall.interceptors import protect_agent, protect_tool
4
+
5
+ __all__ = ["protect_agent", "protect_tool"]
File without changes
agentwall/cli/main.py ADDED
@@ -0,0 +1,340 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import threading
5
+ from pathlib import Path
6
+ from typing import Annotated
7
+
8
+ import typer
9
+
10
+ from agentwall import __version__
11
+
12
+ app = typer.Typer(name="agentwall", help="AgentWall - AI Agent Runtime Security.", add_completion=False)
13
+
14
+ _INSPECTOR_HOST = "127.0.0.1"
15
+ _INSPECTOR_PORT = 8080
16
+
17
+ _PROVIDER_NAMES = ["openai", "anthropic", "groq", "deepseek", "ollama"]
18
+
19
+ _PROVIDER_MODELS: dict[str, list[str]] = {
20
+ "openai": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"],
21
+ "anthropic": ["claude-opus-4-8", "claude-sonnet-4-6", "claude-haiku-4-5-20251001"],
22
+ "groq": ["llama-3.3-70b-versatile", "llama-3.1-8b-instant", "gemma2-9b-it"],
23
+ "deepseek": ["deepseek-chat", "deepseek-reasoner"],
24
+ "ollama": ["llama3.2", "llama3.1", "mistral", "gemma2", "phi3"],
25
+ }
26
+
27
+ _NO_API_KEY_PROVIDERS = {"ollama"}
28
+
29
+
30
+ def _get_db():
31
+ from agentwall.storage.database import Database
32
+ return Database()
33
+
34
+
35
+ @app.command()
36
+ def version() -> None:
37
+ """Print version."""
38
+ typer.echo(f"agentwall {__version__}")
39
+
40
+
41
+ @app.command()
42
+ def doctor() -> None:
43
+ """Check installation health."""
44
+ checks = [
45
+ ("keyring", "API key storage (keyring)"),
46
+ ("sqlalchemy", "ORM (sqlalchemy)"),
47
+ ("fastapi", "Inspector API (fastapi)"),
48
+ ("uvicorn", "Inspector server (uvicorn)"),
49
+ ("pydantic", "Validation (pydantic)"),
50
+ ("openai", "OpenAI/Groq/DeepSeek SDK (openai)"),
51
+ ("anthropic", "Anthropic SDK (anthropic)"),
52
+ ("webview", "Desktop Inspector (pywebview)"),
53
+ ]
54
+ all_ok = True
55
+ for module, label in checks:
56
+ try:
57
+ importlib.import_module(module)
58
+ typer.echo(f" OK {label}")
59
+ except ImportError:
60
+ typer.echo(f" MISSING {label}")
61
+ all_ok = False
62
+
63
+ if all_ok:
64
+ typer.echo("\nAgentWall installation OK.")
65
+ else:
66
+ typer.echo("\nSome dependencies missing. Run: pip install agentwall")
67
+ raise typer.Exit(1)
68
+
69
+
70
+ @app.command()
71
+ def config(
72
+ provider: Annotated[str | None, typer.Option(help="Provider name")] = None,
73
+ model: Annotated[str | None, typer.Option(help="Model name")] = None,
74
+ priority: Annotated[int | None, typer.Option(help="Priority (1=primary)")] = None,
75
+ low_threshold: Annotated[float | None, typer.Option(help="Low risk threshold")] = None,
76
+ high_threshold: Annotated[float | None, typer.Option(help="High risk threshold")] = None,
77
+ status: Annotated[bool, typer.Option("--status", help="Show provider health")] = False,
78
+ ) -> None:
79
+ """Configure AgentWall. No flags = interactive wizard."""
80
+ from agentwall.core.config_manager import ConfigManager
81
+
82
+ db = _get_db()
83
+ mgr = ConfigManager(db)
84
+
85
+ if status:
86
+ _show_status(db)
87
+ db.close()
88
+ return
89
+
90
+ non_interactive = any(x is not None for x in [provider, model, priority, low_threshold, high_threshold])
91
+
92
+ if non_interactive:
93
+ _apply_config_flags(mgr, provider, model, priority, low_threshold, high_threshold)
94
+ else:
95
+ _run_wizard(mgr, db)
96
+
97
+ db.close()
98
+
99
+
100
+ def _apply_config_flags(mgr, provider, model, priority, low_threshold, high_threshold):
101
+ if provider and model:
102
+ from agentwall.providers.keyring import store_api_key
103
+ p = priority or 0
104
+ if provider not in _NO_API_KEY_PROVIDERS:
105
+ api_key = typer.prompt(f"API key for {provider}", hide_input=True)
106
+ if api_key:
107
+ store_api_key(provider, api_key)
108
+ typer.echo(f"API key stored in OS keyring for {provider}.")
109
+ mgr.set_provider(provider, model, priority=p)
110
+ typer.echo(f"Set {provider} → {model} (priority={p})")
111
+
112
+ if low_threshold is not None or high_threshold is not None:
113
+ current = mgr.get_thresholds()
114
+ low = low_threshold if low_threshold is not None else current["low_threshold"]
115
+ high = high_threshold if high_threshold is not None else current["high_threshold"]
116
+ mgr.set_thresholds(low, high)
117
+ typer.echo(f"Thresholds: low={low} high={high}")
118
+
119
+
120
+ def _run_wizard(mgr, db) -> None:
121
+ typer.echo("\nAgentWall Configuration Wizard")
122
+ typer.echo("=" * 32)
123
+
124
+ _print_current_config(mgr)
125
+
126
+ actions = [
127
+ "Add / update provider",
128
+ "Remove provider",
129
+ "Set risk thresholds",
130
+ "Test provider connections",
131
+ "Exit",
132
+ ]
133
+ typer.echo("\nWhat would you like to do?")
134
+ for i, a in enumerate(actions, 1):
135
+ typer.echo(f" {i}. {a}")
136
+
137
+ choice = typer.prompt("Select", default="1")
138
+ try:
139
+ idx = int(choice) - 1
140
+ action = actions[idx]
141
+ except (ValueError, IndexError):
142
+ typer.echo("Invalid selection.")
143
+ return
144
+
145
+ if action == "Add / update provider":
146
+ _wizard_add_provider(mgr)
147
+ elif action == "Remove provider":
148
+ _wizard_remove_provider(mgr)
149
+ elif action == "Set risk thresholds":
150
+ _wizard_thresholds(mgr)
151
+ elif action == "Test provider connections":
152
+ _show_status(db)
153
+
154
+
155
+ def _print_current_config(mgr) -> None:
156
+ providers = mgr.list_providers_ordered()
157
+ thresholds = mgr.get_thresholds()
158
+
159
+ typer.echo(f"\nRisk thresholds: low={thresholds['low_threshold']} high={thresholds['high_threshold']}")
160
+ if providers:
161
+ typer.echo("Providers (by priority):")
162
+ for p in providers:
163
+ status = "enabled" if p.enabled else "disabled"
164
+ typer.echo(f" [{p.priority}] {p.provider} → {p.model} ({status})")
165
+ else:
166
+ typer.echo("No providers configured.")
167
+
168
+
169
+ def _wizard_add_provider(mgr) -> None:
170
+ typer.echo("\nAvailable providers:")
171
+ for i, name in enumerate(_PROVIDER_NAMES, 1):
172
+ typer.echo(f" {i}. {name}")
173
+
174
+ raw = typer.prompt("Select provider (name or number)", default="openai")
175
+ if raw.isdigit():
176
+ idx = int(raw) - 1
177
+ provider_name = _PROVIDER_NAMES[idx] if 0 <= idx < len(_PROVIDER_NAMES) else raw
178
+ else:
179
+ provider_name = raw.lower().strip()
180
+
181
+ if provider_name not in _PROVIDER_NAMES:
182
+ typer.echo(f"Unknown provider: {provider_name}")
183
+ return
184
+
185
+ # API key (skip for Ollama)
186
+ if provider_name not in _NO_API_KEY_PROVIDERS:
187
+ api_key = typer.prompt(f"API key for {provider_name}", hide_input=True)
188
+ if api_key:
189
+ from agentwall.providers.keyring import store_api_key
190
+ store_api_key(provider_name, api_key)
191
+ typer.echo(f"API key stored in OS keyring for {provider_name}.")
192
+
193
+ # Model selection
194
+ models = _PROVIDER_MODELS.get(provider_name, [])
195
+ typer.echo(f"\nAvailable models for {provider_name}:")
196
+ for i, m in enumerate(models, 1):
197
+ default_marker = " (default)" if i == 2 else ""
198
+ typer.echo(f" {i}. {m}{default_marker}")
199
+ model_raw = typer.prompt("Select model (name or number)", default="2" if len(models) >= 2 else "1")
200
+ if model_raw.isdigit():
201
+ idx = int(model_raw) - 1
202
+ model_name = models[idx] if 0 <= idx < len(models) else models[0]
203
+ else:
204
+ model_name = model_raw.strip()
205
+
206
+ # Priority
207
+ existing = mgr.list_providers_ordered()
208
+ next_priority = max((p.priority for p in existing), default=0) + 1
209
+ priority = typer.prompt(f"Priority (1=primary, higher=fallback)", default=str(next_priority), type=int)
210
+
211
+ # Test connection
212
+ typer.echo(f"\nTesting connection to {provider_name}...")
213
+ try:
214
+ cls = _load_evaluator_class(provider_name)
215
+ from agentwall.providers.keyring import get_api_key
216
+ api_key_val = get_api_key(provider_name) if provider_name not in _NO_API_KEY_PROVIDERS else None
217
+ kwargs = {"model": model_name}
218
+ if api_key_val:
219
+ kwargs["api_key"] = api_key_val
220
+ evaluator = cls(**kwargs)
221
+ result = evaluator.health_check()
222
+ if result.health.value == "healthy":
223
+ typer.echo(f"Connection OK ({result.latency_ms:.0f}ms)")
224
+ else:
225
+ typer.echo(f"Warning: {result.error}")
226
+ if not typer.confirm("Save anyway?"):
227
+ return
228
+ except Exception as e:
229
+ typer.echo(f"Connection failed: {e}")
230
+ if not typer.confirm("Save anyway?"):
231
+ return
232
+
233
+ mgr.set_provider(provider_name, model_name, priority=priority)
234
+ typer.echo(f"\nSaved: {provider_name} → {model_name} (priority={priority})")
235
+
236
+
237
+ def _wizard_remove_provider(mgr) -> None:
238
+ providers = mgr.list_providers_ordered()
239
+ if not providers:
240
+ typer.echo("No providers configured.")
241
+ return
242
+ typer.echo("\nConfigured providers:")
243
+ for i, p in enumerate(providers, 1):
244
+ typer.echo(f" {i}. {p.provider}")
245
+ raw = typer.prompt("Select provider to remove")
246
+ if raw.isdigit():
247
+ idx = int(raw) - 1
248
+ name = providers[idx].provider if 0 <= idx < len(providers) else raw
249
+ else:
250
+ name = raw.strip()
251
+ if typer.confirm(f"Remove {name}?"):
252
+ mgr.remove_provider(name)
253
+ from agentwall.providers.keyring import delete_api_key
254
+ delete_api_key(name)
255
+ typer.echo(f"Removed {name} (API key deleted from keyring).")
256
+
257
+
258
+ def _wizard_thresholds(mgr) -> None:
259
+ current = mgr.get_thresholds()
260
+ typer.echo(f"\nCurrent: low={current['low_threshold']} high={current['high_threshold']}")
261
+ low = typer.prompt("Low threshold (ALLOW below this)", default=str(current["low_threshold"]), type=float)
262
+ high = typer.prompt("High threshold (LLM eval above this)", default=str(current["high_threshold"]), type=float)
263
+ if low >= high:
264
+ typer.echo("Error: low_threshold must be less than high_threshold.")
265
+ return
266
+ mgr.set_thresholds(low, high)
267
+ typer.echo(f"Saved: low={low} high={high}")
268
+
269
+
270
+ def _show_status(db) -> None:
271
+ from agentwall.providers.registry import ProviderRegistry
272
+ from agentwall.providers.base import ProviderHealth
273
+
274
+ reg = ProviderRegistry(db)
275
+ statuses = reg.health_check_all()
276
+
277
+ if not statuses:
278
+ typer.echo("No providers configured. Run: agentwall config")
279
+ return
280
+
281
+ typer.echo("\nProvider Status:")
282
+ for s in statuses:
283
+ icon = "OK" if s.health == ProviderHealth.HEALTHY else ("WARN" if s.health == ProviderHealth.DEGRADED else "FAIL")
284
+ latency = f" {s.latency_ms:.0f}ms" if s.latency_ms is not None else ""
285
+ error = f" — {s.error}" if s.error else ""
286
+ typer.echo(f" [{icon}] {s.provider} ({s.model}){latency}{error}")
287
+
288
+
289
+ def _load_evaluator_class(provider_name: str):
290
+ from agentwall.providers.registry import EVALUATOR_CLASSES
291
+ cls = EVALUATOR_CLASSES.get(provider_name)
292
+ if cls is None:
293
+ raise ValueError(f"Unknown provider: {provider_name}")
294
+ return cls
295
+
296
+
297
+ @app.command()
298
+ def inspect(
299
+ host: Annotated[str, typer.Option(help="Bind host")] = _INSPECTOR_HOST,
300
+ port: Annotated[int, typer.Option(help="Bind port")] = _INSPECTOR_PORT,
301
+ browser: Annotated[bool, typer.Option("--browser", help="Open in browser instead of desktop window")] = False,
302
+ ) -> None:
303
+ """Launch AgentWall Inspector as a desktop window."""
304
+ ui_dist = Path(__file__).parent.parent / "inspector" / "ui" / "dist"
305
+ if not ui_dist.exists():
306
+ typer.echo(
307
+ "UI not built. Run:\n"
308
+ " cd agentwall/inspector/ui && npm install && npm run build"
309
+ )
310
+ typer.echo("Starting API-only mode (docs at /api/docs)...")
311
+
312
+ if browser:
313
+ _launch_browser(host, port)
314
+ else:
315
+ from agentwall.inspector.desktop import launch_desktop
316
+ typer.echo("Starting AgentWall Inspector...")
317
+ launch_desktop(host, port)
318
+
319
+
320
+ def _launch_browser(host: str, port: int) -> None:
321
+ """Fallback: start uvicorn and open the system browser."""
322
+ import uvicorn
323
+ import webbrowser
324
+
325
+ url = f"http://{host}:{port}"
326
+ typer.echo(f"Starting AgentWall Inspector at {url}")
327
+
328
+ def _open():
329
+ import time
330
+ time.sleep(1.5)
331
+ webbrowser.open(url)
332
+
333
+ threading.Thread(target=_open, daemon=True).start()
334
+
335
+ uvicorn.run(
336
+ "agentwall.inspector.server:app",
337
+ host=host,
338
+ port=port,
339
+ log_level="warning",
340
+ )
File without changes
@@ -0,0 +1,109 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+
5
+ from ..storage.database import Database
6
+ from ..storage.models import Policy, ProviderSetting
7
+
8
+ _DEFAULT_THRESHOLDS = {"low_threshold": 30.0, "high_threshold": 70.0}
9
+ _THRESHOLDS_POLICY = "thresholds"
10
+
11
+
12
+ class ConfigManager:
13
+ def __init__(self, db: Database) -> None:
14
+ self._db = db
15
+
16
+ # --- Provider settings ---
17
+
18
+ def get_provider(self, provider: str) -> ProviderSetting | None:
19
+ with self._db.session() as db:
20
+ row = db.get(ProviderSetting, provider)
21
+ if row:
22
+ db.expunge(row)
23
+ return row
24
+
25
+ def set_provider(
26
+ self,
27
+ provider: str,
28
+ model: str,
29
+ priority: int = 0,
30
+ enabled: bool = True,
31
+ config: dict | None = None,
32
+ ) -> None:
33
+ with self._db.session() as db:
34
+ existing = db.get(ProviderSetting, provider)
35
+ if existing:
36
+ existing.model = model
37
+ existing.priority = priority
38
+ existing.enabled = enabled
39
+ existing.config = config or {}
40
+ else:
41
+ db.add(ProviderSetting(
42
+ provider=provider, model=model,
43
+ priority=priority, enabled=enabled, config=config or {},
44
+ ))
45
+ db.commit()
46
+
47
+ def set_provider_enabled(self, provider: str, enabled: bool) -> None:
48
+ with self._db.session() as db:
49
+ row = db.get(ProviderSetting, provider)
50
+ if row:
51
+ row.enabled = enabled
52
+ db.commit()
53
+
54
+ def remove_provider(self, provider: str) -> None:
55
+ with self._db.session() as db:
56
+ row = db.get(ProviderSetting, provider)
57
+ if row:
58
+ db.delete(row)
59
+ db.commit()
60
+
61
+ def list_providers(self) -> list[ProviderSetting]:
62
+ with self._db.session() as db:
63
+ rows = db.query(ProviderSetting).all()
64
+ for r in rows:
65
+ db.expunge(r)
66
+ return rows
67
+
68
+ def list_providers_ordered(self) -> list[ProviderSetting]:
69
+ with self._db.session() as db:
70
+ rows = db.query(ProviderSetting).order_by(ProviderSetting.priority).all()
71
+ for r in rows:
72
+ db.expunge(r)
73
+ return rows
74
+
75
+ # --- Policies ---
76
+
77
+ def get_policy(self, name: str) -> Policy | None:
78
+ with self._db.session() as db:
79
+ row = db.query(Policy).filter(Policy.name == name).first()
80
+ if row:
81
+ db.expunge(row)
82
+ return row
83
+
84
+ def set_policy(self, name: str, config: dict) -> None:
85
+ with self._db.session() as db:
86
+ existing = db.query(Policy).filter(Policy.name == name).first()
87
+ if existing:
88
+ existing.config = config
89
+ else:
90
+ db.add(Policy(name=name, config=config, created_at=time.time()))
91
+ db.commit()
92
+
93
+ def list_policies(self) -> list[Policy]:
94
+ with self._db.session() as db:
95
+ rows = db.query(Policy).order_by(Policy.name).all()
96
+ for r in rows:
97
+ db.expunge(r)
98
+ return rows
99
+
100
+ # --- Risk thresholds (stored as policy) ---
101
+
102
+ def get_thresholds(self) -> dict:
103
+ policy = self.get_policy(_THRESHOLDS_POLICY)
104
+ if policy:
105
+ return {**_DEFAULT_THRESHOLDS, **policy.config}
106
+ return dict(_DEFAULT_THRESHOLDS)
107
+
108
+ def set_thresholds(self, low: float, high: float) -> None:
109
+ self.set_policy(_THRESHOLDS_POLICY, {"low_threshold": low, "high_threshold": high})
@@ -0,0 +1,147 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+
5
+ from ..storage.database import Database
6
+ from ..storage.models import Evaluation, GoalSegment, ToolEvent
7
+
8
+
9
+ class EventManager:
10
+ def __init__(self, db: Database) -> None:
11
+ self._db = db
12
+
13
+ def record(
14
+ self,
15
+ session_id: str,
16
+ tool_name: str,
17
+ arguments: dict,
18
+ *,
19
+ tool_type: str | None = None,
20
+ action: str | None = None,
21
+ target: str | None = None,
22
+ resource_category: str | None = None,
23
+ ) -> ToolEvent:
24
+ with self._db.session() as db:
25
+ row = ToolEvent(
26
+ session_id=session_id,
27
+ tool_name=tool_name,
28
+ arguments=arguments,
29
+ timestamp=time.time(),
30
+ tool_type=tool_type,
31
+ action=action,
32
+ target=target,
33
+ resource_category=resource_category,
34
+ )
35
+ db.add(row)
36
+ db.commit()
37
+ db.refresh(row)
38
+ db.expunge(row)
39
+ return row
40
+
41
+ def record_evaluation(
42
+ self,
43
+ event_id: int,
44
+ decision: str,
45
+ risk_score: float,
46
+ reason: str,
47
+ llm_used: bool = False,
48
+ alignment_score: float | None = None,
49
+ detector_hits: list[str] | None = None,
50
+ policy_matched: str | None = None,
51
+ ) -> Evaluation:
52
+ with self._db.session() as db:
53
+ row = Evaluation(
54
+ event_id=event_id,
55
+ decision=decision,
56
+ risk_score=risk_score,
57
+ reason=reason,
58
+ llm_used=llm_used,
59
+ timestamp=time.time(),
60
+ alignment_score=alignment_score,
61
+ detector_hits=detector_hits,
62
+ policy_matched=policy_matched,
63
+ )
64
+ db.add(row)
65
+ db.commit()
66
+ db.refresh(row)
67
+ db.expunge(row)
68
+ return row
69
+
70
+ def update_evaluation_post(
71
+ self,
72
+ event_id: int,
73
+ post_execution_risk: float,
74
+ result_classification: str,
75
+ result_detector_hits: list[str],
76
+ result_metadata: dict,
77
+ ) -> None:
78
+ with self._db.session() as db:
79
+ db.query(Evaluation).filter(Evaluation.event_id == event_id).update({
80
+ "post_execution_risk": post_execution_risk,
81
+ "result_classification": result_classification,
82
+ "result_detector_hits": result_detector_hits,
83
+ "result_metadata": result_metadata,
84
+ })
85
+ db.commit()
86
+
87
+ def get_events(self, session_id: str) -> list[ToolEvent]:
88
+ with self._db.session() as db:
89
+ rows = (
90
+ db.query(ToolEvent)
91
+ .filter(ToolEvent.session_id == session_id)
92
+ .order_by(ToolEvent.timestamp)
93
+ .all()
94
+ )
95
+ for r in rows:
96
+ db.expunge(r)
97
+ return rows
98
+
99
+ def get_events_with_evaluations(self, session_id: str) -> list[ToolEvent]:
100
+ from sqlalchemy.orm import joinedload
101
+
102
+ with self._db.session() as db:
103
+ rows = (
104
+ db.query(ToolEvent)
105
+ .options(joinedload(ToolEvent.evaluation))
106
+ .filter(ToolEvent.session_id == session_id)
107
+ .order_by(ToolEvent.timestamp)
108
+ .all()
109
+ )
110
+ for r in rows:
111
+ if r.evaluation:
112
+ db.expunge(r.evaluation)
113
+ db.expunge(r)
114
+ return rows
115
+
116
+ def create_goal_segment(self, session_id: str, goal: str, reason: str = "initial") -> str:
117
+ with self._db.session() as db:
118
+ seg = GoalSegment(
119
+ session_id=session_id,
120
+ goal_text=goal,
121
+ started_at=time.time(),
122
+ transition_reason=reason,
123
+ )
124
+ db.add(seg)
125
+ db.commit()
126
+ db.refresh(seg)
127
+ seg_id = seg.id
128
+ return seg_id
129
+
130
+ def close_goal_segment(self, segment_id: str) -> None:
131
+ with self._db.session() as db:
132
+ db.query(GoalSegment).filter(GoalSegment.id == segment_id).update(
133
+ {"ended_at": time.time()}
134
+ )
135
+ db.commit()
136
+
137
+ def get_goal_segments(self, session_id: str) -> list[GoalSegment]:
138
+ with self._db.session() as db:
139
+ rows = (
140
+ db.query(GoalSegment)
141
+ .filter(GoalSegment.session_id == session_id)
142
+ .order_by(GoalSegment.started_at)
143
+ .all()
144
+ )
145
+ for r in rows:
146
+ db.expunge(r)
147
+ return rows
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ import uuid
5
+
6
+ from ..storage.database import Database
7
+ from ..storage.models import Session
8
+
9
+
10
+ class SessionManager:
11
+ def __init__(self, db: Database) -> None:
12
+ self._db = db
13
+
14
+ def create(self, user_goal: str, meta: dict | None = None) -> Session:
15
+ session_id = str(uuid.uuid4())
16
+ with self._db.session() as db:
17
+ row = Session(
18
+ id=session_id,
19
+ user_goal=user_goal,
20
+ created_at=time.time(),
21
+ meta=meta or {},
22
+ )
23
+ db.add(row)
24
+ db.commit()
25
+ db.refresh(row)
26
+ db.expunge(row)
27
+ return row
28
+
29
+ def get(self, session_id: str) -> Session | None:
30
+ with self._db.session() as db:
31
+ row = db.get(Session, session_id)
32
+ if row:
33
+ db.expunge(row)
34
+ return row
35
+
36
+ def end(self, session_id: str) -> None:
37
+ with self._db.session() as db:
38
+ row = db.get(Session, session_id)
39
+ if row:
40
+ row.ended_at = time.time()
41
+ db.commit()
42
+
43
+ def update_goal(self, session_id: str, user_goal: str) -> None:
44
+ with self._db.session() as db:
45
+ row = db.get(Session, session_id)
46
+ if row:
47
+ row.user_goal = user_goal
48
+ db.commit()
49
+
50
+ def list(self, limit: int = 100) -> list[Session]:
51
+ with self._db.session() as db:
52
+ rows = (
53
+ db.query(Session)
54
+ .order_by(Session.created_at.desc())
55
+ .limit(limit)
56
+ .all()
57
+ )
58
+ for r in rows:
59
+ db.expunge(r)
60
+ return rows