openhack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. openhack/__init__.py +2 -0
  2. openhack/__main__.py +225 -0
  3. openhack/agents/__init__.py +30 -0
  4. openhack/agents/base.py +230 -0
  5. openhack/agents/browser_verifier.py +679 -0
  6. openhack/agents/browser_verifier_swarm.py +256 -0
  7. openhack/agents/checkpoint.py +89 -0
  8. openhack/agents/context_manager.py +356 -0
  9. openhack/agents/coordinator.py +1105 -0
  10. openhack/agents/endpoint_analyst.py +307 -0
  11. openhack/agents/feature_hunter.py +93 -0
  12. openhack/agents/hunter.py +481 -0
  13. openhack/agents/hunter_swarm.py +385 -0
  14. openhack/agents/llm.py +334 -0
  15. openhack/agents/recon.py +19 -0
  16. openhack/agents/sandbox_verifier.py +396 -0
  17. openhack/agents/sandbox_verifier_swarm.py +250 -0
  18. openhack/agents/session.py +286 -0
  19. openhack/agents/validator.py +217 -0
  20. openhack/agents/validator_swarm.py +106 -0
  21. openhack/auth.py +175 -0
  22. openhack/browser/__init__.py +12 -0
  23. openhack/browser/runner.py +385 -0
  24. openhack/categories.py +130 -0
  25. openhack/config.py +201 -0
  26. openhack/deterministic_recon.py +464 -0
  27. openhack/entry_points.py +745 -0
  28. openhack/framework_classifier.py +515 -0
  29. openhack/framework_detection.py +269 -0
  30. openhack/headless_scan.py +179 -0
  31. openhack/prompts/__init__.py +108 -0
  32. openhack/prompts/browser_verifier.py +171 -0
  33. openhack/prompts/coordinator.py +31 -0
  34. openhack/prompts/django/__init__.py +32 -0
  35. openhack/prompts/django/auth_bypass.py +76 -0
  36. openhack/prompts/django/csrf.py +62 -0
  37. openhack/prompts/django/data_exposure.py +67 -0
  38. openhack/prompts/django/idor.py +74 -0
  39. openhack/prompts/django/injection.py +67 -0
  40. openhack/prompts/django/misconfiguration.py +70 -0
  41. openhack/prompts/django/ssrf.py +64 -0
  42. openhack/prompts/endpoint_analyst.py +122 -0
  43. openhack/prompts/express/__init__.py +29 -0
  44. openhack/prompts/express/auth_bypass.py +71 -0
  45. openhack/prompts/express/data_exposure.py +77 -0
  46. openhack/prompts/express/idor.py +69 -0
  47. openhack/prompts/express/injection.py +75 -0
  48. openhack/prompts/express/misconfiguration.py +72 -0
  49. openhack/prompts/express/ssrf.py +63 -0
  50. openhack/prompts/feature_hunter.py +140 -0
  51. openhack/prompts/flask/__init__.py +29 -0
  52. openhack/prompts/flask/auth_bypass.py +86 -0
  53. openhack/prompts/flask/data_exposure.py +78 -0
  54. openhack/prompts/flask/idor.py +83 -0
  55. openhack/prompts/flask/injection.py +77 -0
  56. openhack/prompts/flask/misconfiguration.py +73 -0
  57. openhack/prompts/flask/ssrf.py +65 -0
  58. openhack/prompts/hunter.py +362 -0
  59. openhack/prompts/hunter_continuation_loop.py +12 -0
  60. openhack/prompts/hunter_continuation_no_findings.py +19 -0
  61. openhack/prompts/hunter_continuation_no_progress.py +22 -0
  62. openhack/prompts/hunter_tool_instructions.py +55 -0
  63. openhack/prompts/nextjs/__init__.py +42 -0
  64. openhack/prompts/nextjs/auth_bypass.py +80 -0
  65. openhack/prompts/nextjs/csrf.py +71 -0
  66. openhack/prompts/nextjs/data_exposure.py +88 -0
  67. openhack/prompts/nextjs/idor.py +64 -0
  68. openhack/prompts/nextjs/injection.py +65 -0
  69. openhack/prompts/nextjs/middleware_bypass.py +75 -0
  70. openhack/prompts/nextjs/misconfiguration.py +92 -0
  71. openhack/prompts/nextjs/server_actions.py +97 -0
  72. openhack/prompts/nextjs/ssrf.py +66 -0
  73. openhack/prompts/nextjs/xss.py +69 -0
  74. openhack/prompts/pr_analysis_system.py +80 -0
  75. openhack/prompts/pr_analysis_user.py +11 -0
  76. openhack/prompts/project_context.py +89 -0
  77. openhack/prompts/recon.py +199 -0
  78. openhack/prompts/reporter.py +88 -0
  79. openhack/prompts/researchers.py +434 -0
  80. openhack/prompts/sandbox_verifier.py +128 -0
  81. openhack/prompts/supabase/__init__.py +39 -0
  82. openhack/prompts/supabase/auth_tokens.py +131 -0
  83. openhack/prompts/supabase/edge_functions.py +150 -0
  84. openhack/prompts/supabase/graphql.py +102 -0
  85. openhack/prompts/supabase/postgrest.py +99 -0
  86. openhack/prompts/supabase/realtime.py +93 -0
  87. openhack/prompts/supabase/rls.py +110 -0
  88. openhack/prompts/supabase/rpc_functions.py +127 -0
  89. openhack/prompts/supabase/storage.py +110 -0
  90. openhack/prompts/supabase/tenant_isolation.py +118 -0
  91. openhack/prompts/validator.py +319 -0
  92. openhack/prompts/validator_continuation_incomplete.py +12 -0
  93. openhack/prompts/validator_tool_instructions.py +29 -0
  94. openhack/quality.py +231 -0
  95. openhack/sandbox/__init__.py +12 -0
  96. openhack/sandbox/orchestrator.py +517 -0
  97. openhack/sandbox/runner.py +177 -0
  98. openhack/scan_session.py +245 -0
  99. openhack/setup.py +452 -0
  100. openhack/static_validator.py +612 -0
  101. openhack/tools/__init__.py +1 -0
  102. openhack/tools/ast_tools.py +307 -0
  103. openhack/tools/coverage.py +1078 -0
  104. openhack/tools/filesystem.py +404 -0
  105. openhack/tools/nextjs.py +258 -0
  106. openhack/tools/registry.py +52 -0
  107. openhack/tui.py +3450 -0
  108. openhack/updates.py +170 -0
  109. openhack-0.1.0.dist-info/METADATA +189 -0
  110. openhack-0.1.0.dist-info/RECORD +113 -0
  111. openhack-0.1.0.dist-info/WHEEL +4 -0
  112. openhack-0.1.0.dist-info/entry_points.txt +2 -0
  113. openhack-0.1.0.dist-info/licenses/LICENSE +661 -0
@@ -0,0 +1,286 @@
1
+ """
2
+ Session management for vulnerability scanning.
3
+ """
4
+
5
+ import threading
6
+ import time
7
+ from typing import Any, Optional
8
+ from uuid import uuid4
9
+ from enum import Enum
10
+ from dataclasses import dataclass, field
11
+
12
+
13
+ class SessionStatus(str, Enum):
14
+ RUNNING = "running"
15
+ COMPLETED = "completed"
16
+ FAILED = "failed"
17
+ PAUSED = "paused"
18
+
19
+
20
+ @dataclass
21
+ class Finding:
22
+ """Represents a single vulnerability finding."""
23
+ category: str
24
+ severity: str
25
+ title: str
26
+ description: str
27
+ file_path: str
28
+ id: str = field(default_factory=lambda: str(uuid4()))
29
+ line_number: Optional[int] = None
30
+ code_snippet: Optional[str] = None
31
+ poc: Optional[str] = None
32
+ fix: Optional[str] = None
33
+ cvss_score: Optional[float] = None
34
+ confidence: str = "medium"
35
+ validated: bool = False
36
+ source: Optional[str] = None
37
+
38
+ def fingerprint(self) -> str:
39
+ import hashlib
40
+ file_norm = (self.file_path or "").strip().lower().split(":")[0]
41
+ cat_norm = self.category.strip().lower()
42
+ raw = f"{cat_norm}::{file_norm}"
43
+ return hashlib.sha256(raw.encode()).hexdigest()[:16]
44
+
45
+ def to_dict(self) -> dict:
46
+ d = {
47
+ "id": self.id,
48
+ "category": self.category,
49
+ "severity": self.severity,
50
+ "title": self.title,
51
+ "description": self.description,
52
+ "filePath": self.file_path,
53
+ "lineNumber": self.line_number,
54
+ "relevantCode": self.code_snippet,
55
+ "poc": self.poc,
56
+ "recommendation": self.fix,
57
+ "cvssScore": self.cvss_score,
58
+ "confidence": self.confidence,
59
+ "validated": self.validated,
60
+ "vulnerabilityType": self._generate_vulnerability_type(),
61
+ "fingerprint": self.fingerprint(),
62
+ }
63
+ if self.source:
64
+ d["verificationSource"] = self.source
65
+ return d
66
+
67
+ def _generate_vulnerability_type(self) -> str:
68
+ category_lower = self.category.lower().replace(" ", "_").replace("-", "_")
69
+ if "xss" in category_lower:
70
+ if "dangerouslysetinnerhtml" in self.description.lower():
71
+ return "xss_dangerously_set_html"
72
+ elif "innerhtml" in self.description.lower():
73
+ return "xss_innerhtml"
74
+ elif "document.write" in self.description.lower():
75
+ return "xss_document_write"
76
+ return f"xss_{category_lower}"
77
+ elif "sql" in category_lower or "injection" in category_lower:
78
+ if "raw" in self.description.lower():
79
+ return "sql_injection_raw_query"
80
+ return "sql_injection"
81
+ elif "idor" in category_lower:
82
+ return "idor_direct_object_reference"
83
+ elif "ssrf" in category_lower:
84
+ return "ssrf_server_side_request"
85
+ elif "csrf" in category_lower:
86
+ return "csrf_missing_token"
87
+ elif "auth" in category_lower:
88
+ return "auth_bypass"
89
+ return category_lower
90
+
91
+
92
+ @dataclass
93
+ class TraceEntry:
94
+ """A single trace entry for debugging/logging."""
95
+ timestamp: float
96
+ agent: str
97
+ event_type: str
98
+ content: Any
99
+ tool_name: Optional[str] = None
100
+ tool_input: Optional[dict] = None
101
+ tool_output: Optional[Any] = None
102
+
103
+
104
+ class Session:
105
+ """Session to track scan state (in-memory)."""
106
+
107
+ def __init__(
108
+ self,
109
+ target_dir: str,
110
+ scan_id: Optional[str] = None,
111
+ project_context: Optional[dict] = None,
112
+ trace_id: Optional[str] = None,
113
+ on_trace: Optional[Any] = None,
114
+ ):
115
+ self.id = scan_id or str(uuid4())
116
+ self.trace_id = trace_id or (self.id[:8] if self.id else str(uuid4())[:8])
117
+ self.target_dir = target_dir
118
+ self.project_context = project_context
119
+ self.status = SessionStatus.RUNNING
120
+ self.created_at = time.time()
121
+ self.updated_at = time.time()
122
+ self.current_agent: Optional[str] = None
123
+ self.current_step: Optional[str] = None
124
+ self.findings: list[Finding] = []
125
+ self.trace: list[TraceEntry] = []
126
+ self.context: dict = {}
127
+ self.total_cost: float = 0.0
128
+ self.total_tokens: int = 0
129
+ self.total_input_tokens: int = 0
130
+ self.total_output_tokens: int = 0
131
+ self.step_costs: dict[str, float] = {}
132
+ self.step_tokens: dict[str, int] = {}
133
+ self.step_input_tokens: dict[str, int] = {}
134
+ self.step_output_tokens: dict[str, int] = {}
135
+ self._on_trace = on_trace
136
+ self._user_instructions: list[str] = []
137
+ self._instructions_lock = threading.Lock()
138
+ self._instructions_version: int = 0
139
+ self.cancelled: bool = False
140
+ # Pause control: an asyncio.Event that's *set* when running (default)
141
+ # and *cleared* when paused. Agents call `await wait_if_paused()`
142
+ # between iterations, which blocks while the event is cleared.
143
+ # Lazily created on first access because Session may be instantiated
144
+ # outside an event loop (e.g. in serialization tests).
145
+ self._pause_event: Optional[Any] = None
146
+
147
+ def _ensure_pause_event(self) -> Any:
148
+ if self._pause_event is None:
149
+ import asyncio
150
+ self._pause_event = asyncio.Event()
151
+ self._pause_event.set() # default: not paused
152
+ return self._pause_event
153
+
154
+ @property
155
+ def paused(self) -> bool:
156
+ return self._pause_event is not None and not self._pause_event.is_set()
157
+
158
+ def pause(self) -> None:
159
+ """Block agent loops at their next safe checkpoint."""
160
+ self._ensure_pause_event().clear()
161
+
162
+ def resume(self) -> None:
163
+ """Unblock paused agent loops."""
164
+ self._ensure_pause_event().set()
165
+
166
+ async def wait_if_paused(self) -> None:
167
+ """If the session is paused, await until resumed. No-op when not paused."""
168
+ await self._ensure_pause_event().wait()
169
+
170
+ def cancel(self) -> None:
171
+ """Set the cancellation flag so all agents break out of their loops.
172
+
173
+ Also resumes the pause event so paused agents wake up and see the
174
+ cancellation flag instead of blocking forever.
175
+ """
176
+ self.cancelled = True
177
+ if self._pause_event is not None:
178
+ self._pause_event.set()
179
+
180
+ def add_user_instruction(self, text: str) -> None:
181
+ """Thread-safe: queue an instruction from the user during a running scan."""
182
+ with self._instructions_lock:
183
+ self._user_instructions.append(text)
184
+ self._instructions_version += 1
185
+ self.add_trace(agent="user", event_type="user_instruction", content=text)
186
+
187
+ def get_new_instructions(self, seen_version: int) -> tuple[list[str], int]:
188
+ """Thread-safe: return instructions added since *seen_version*.
189
+
190
+ Unlike drain, this does NOT clear instructions — every agent that calls
191
+ this will independently see every instruction added after its own
192
+ watermark, which is critical for the swarm pattern where many agents
193
+ run concurrently.
194
+
195
+ Returns (new_instructions, current_version).
196
+ """
197
+ with self._instructions_lock:
198
+ current = self._instructions_version
199
+ if seen_version >= current:
200
+ return [], current
201
+ new = self._user_instructions[seen_version:]
202
+ return list(new), current
203
+
204
+ def get_all_instructions(self) -> list[str]:
205
+ """Thread-safe: return a snapshot of all accumulated instructions."""
206
+ with self._instructions_lock:
207
+ return list(self._user_instructions)
208
+
209
+ def add_trace(
210
+ self,
211
+ agent: str,
212
+ event_type: str,
213
+ content: Any,
214
+ tool_name: Optional[str] = None,
215
+ tool_input: Optional[dict] = None,
216
+ tool_output: Optional[Any] = None,
217
+ ) -> TraceEntry:
218
+ entry = TraceEntry(
219
+ timestamp=time.time(),
220
+ agent=agent,
221
+ event_type=event_type,
222
+ content=content,
223
+ tool_name=tool_name,
224
+ tool_input=tool_input,
225
+ tool_output=tool_output,
226
+ )
227
+ self.trace.append(entry)
228
+ self.updated_at = time.time()
229
+ if self._on_trace:
230
+ self._on_trace(entry)
231
+ return entry
232
+
233
+ def add_finding(self, finding: Finding) -> None:
234
+ self.findings.append(finding)
235
+ self.updated_at = time.time()
236
+
237
+ def get_findings_dict(self) -> list[dict]:
238
+ return [f.to_dict() for f in self.findings]
239
+
240
+ def record_step_cost(
241
+ self,
242
+ step_name: str,
243
+ cost: float,
244
+ tokens: int,
245
+ input_tokens: int = 0,
246
+ output_tokens: int = 0,
247
+ ) -> None:
248
+ self.step_costs[step_name] = cost
249
+ self.step_tokens[step_name] = tokens
250
+ self.step_input_tokens[step_name] = input_tokens
251
+ self.step_output_tokens[step_name] = output_tokens
252
+ self.total_input_tokens += input_tokens
253
+ self.total_output_tokens += output_tokens
254
+ self.updated_at = time.time()
255
+
256
+ def restore_from_checkpoint(self, checkpoint_data: dict) -> None:
257
+ """Restore session state from checkpoint data."""
258
+ self.total_cost = checkpoint_data.get("total_cost", 0.0)
259
+ self.total_tokens = checkpoint_data.get("total_tokens", 0)
260
+ self.total_input_tokens = checkpoint_data.get("total_input_tokens", 0)
261
+ self.total_output_tokens = checkpoint_data.get("total_output_tokens", 0)
262
+ for step, cost in checkpoint_data.get("step_costs", {}).items():
263
+ self.step_costs[step] = cost
264
+ for step, tokens in checkpoint_data.get("step_tokens", {}).items():
265
+ self.step_tokens[step] = tokens
266
+ for step, tokens in checkpoint_data.get("step_input_tokens", {}).items():
267
+ self.step_input_tokens[step] = tokens
268
+ for step, tokens in checkpoint_data.get("step_output_tokens", {}).items():
269
+ self.step_output_tokens[step] = tokens
270
+
271
+ def get_cost_breakdown(self) -> dict:
272
+ return {
273
+ "total_cost": self.total_cost,
274
+ "total_tokens": self.total_tokens,
275
+ "total_input_tokens": self.total_input_tokens,
276
+ "total_output_tokens": self.total_output_tokens,
277
+ "steps": {
278
+ step: {
279
+ "cost": self.step_costs.get(step, 0),
280
+ "tokens": self.step_tokens.get(step, 0),
281
+ "input_tokens": self.step_input_tokens.get(step, 0),
282
+ "output_tokens": self.step_output_tokens.get(step, 0),
283
+ }
284
+ for step in self.step_costs
285
+ }
286
+ }
@@ -0,0 +1,217 @@
1
+ """
2
+ Validator agent for confirming vulnerabilities.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ from typing import Optional
8
+
9
+ from .base import BaseAgent
10
+ from .llm import Message, ToolResult
11
+ from openhack.prompts import VALIDATOR_PROMPT, VALIDATOR_TOOL_INSTRUCTIONS, format_project_context
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ VALIDATE_FINDING_TOOL = {
17
+ "name": "validate_finding",
18
+ "description": "Report the validation result for the potential vulnerability.",
19
+ "parameters": {
20
+ "type": "object",
21
+ "properties": {
22
+ "finding_index": {"type": "integer", "description": "Index (1-based) of the finding"},
23
+ "status": {"type": "string", "enum": ["confirmed", "false_positive", "needs_more_info"]},
24
+ "confidence": {"type": "string", "enum": ["high", "medium", "low"]},
25
+ "cvss_score": {"type": "number", "description": "CVSS 3.1 score (0.0 - 10.0)"},
26
+ "evidence": {"type": "string", "description": "Evidence supporting the validation"},
27
+ "poc": {"type": "string", "description": "Proof of concept"},
28
+ "fix": {"type": "string", "description": "Recommended fix"}
29
+ },
30
+ "required": ["finding_index", "status", "confidence"]
31
+ }
32
+ }
33
+
34
+ FINISH_VALIDATION_TOOL = {
35
+ "name": "finish_validation",
36
+ "description": "Call after validating all findings. Signals validation completion.",
37
+ "parameters": {
38
+ "type": "object",
39
+ "properties": {
40
+ "summary": {"type": "string"},
41
+ "total_confirmed": {"type": "integer"},
42
+ "total_false_positives": {"type": "integer"}
43
+ },
44
+ "required": ["summary", "total_confirmed", "total_false_positives"]
45
+ }
46
+ }
47
+
48
+
49
+ class ValidatorAgent(BaseAgent):
50
+ name = "validator"
51
+ description = "Validating and confirming vulnerabilities"
52
+
53
+ def __init__(self, *args, original_finding_index=None, **kwargs):
54
+ super().__init__(*args, **kwargs)
55
+ self.validated_findings: list[dict] = []
56
+ self.false_positives: list[dict] = []
57
+ self.original_finding_index = original_finding_index
58
+
59
+ if original_finding_index is not None:
60
+ self.name = f"validator:finding_{original_finding_index}"
61
+ self.description = f"Validating finding {original_finding_index}"
62
+
63
+ def get_system_prompt(self, context: dict) -> str:
64
+ findings = context.get("hunter", {}).get("findings", [])
65
+ project_context = context.get("project_context", {})
66
+ project_context_str = format_project_context(project_context)
67
+
68
+ findings_text = ""
69
+ for i, f in enumerate(findings, 1):
70
+ findings_text += f"""
71
+ ### Finding {i}
72
+ - **Category**: {f.get('category', 'Unknown')}
73
+ - **Severity**: {f.get('severity', 'Unknown')}
74
+ - **File**: {f.get('file_path', 'Unknown')}
75
+ - **Line**: {f.get('line_number', 'Unknown')}
76
+ - **Description**: {f.get('description', 'No description')}
77
+ - **Code**:
78
+ ```
79
+ {f.get('code_snippet', 'No code snippet')}
80
+ ```
81
+ """
82
+
83
+ base_prompt = VALIDATOR_PROMPT.format(
84
+ findings=findings_text or "No findings to validate",
85
+ project_context=project_context_str
86
+ )
87
+ base_prompt += VALIDATOR_TOOL_INSTRUCTIONS
88
+ return base_prompt
89
+
90
+ def get_tools(self) -> list[dict]:
91
+ return super().get_tools() + [VALIDATE_FINDING_TOOL, FINISH_VALIDATION_TOOL]
92
+
93
+ def _handle_validate_finding(self, args: dict) -> dict:
94
+ status = args.get("status", "").lower()
95
+ if self.original_finding_index is not None:
96
+ original_index = self.original_finding_index
97
+ else:
98
+ original_index = args.get("finding_index", 1) - 1
99
+
100
+ validation = {
101
+ "original_index": original_index,
102
+ "status": status,
103
+ "confidence": args.get("confidence", "medium").lower(),
104
+ "cvss_score": args.get("cvss_score"),
105
+ "evidence": args.get("evidence", ""),
106
+ "poc": args.get("poc"),
107
+ "fix": args.get("fix"),
108
+ }
109
+
110
+ if status == "confirmed":
111
+ self.validated_findings.append(validation)
112
+ else:
113
+ self.false_positives.append(validation)
114
+
115
+ return {"status": "recorded", "finding_index": args.get("finding_index")}
116
+
117
+ def _handle_finish_validation(self, args: dict) -> dict:
118
+ return {
119
+ "status": "validation_complete",
120
+ "confirmed": len(self.validated_findings),
121
+ "false_positives": len(self.false_positives),
122
+ }
123
+
124
+ async def run(self, task: str, context: Optional[dict] = None) -> dict:
125
+ context = context or {}
126
+ self.session.current_agent = self.name
127
+ self.validated_findings = []
128
+ self.false_positives = []
129
+
130
+ system_prompt = self.get_system_prompt(context)
131
+ self.messages = [Message(role="user", content=task)]
132
+ self._seed_existing_instructions()
133
+
134
+ max_iterations = 15 if self.original_finding_index is not None else 50
135
+ iteration = 0
136
+
137
+ while iteration < max_iterations:
138
+ if self.session.cancelled:
139
+ break
140
+ iteration += 1
141
+
142
+ self._inject_pending_instructions()
143
+
144
+ response = await self.llm.chat(
145
+ messages=self.messages, tools=self.get_tools(), system=system_prompt,
146
+ )
147
+
148
+ self.session.total_cost += response.cost
149
+ if response.usage:
150
+ self.session.total_tokens += response.usage.get("total_tokens", 0)
151
+ self.context_manager.update_usage(response.usage.get("input_tokens", 0))
152
+
153
+ if response.content:
154
+ self.session.add_trace(agent=self.name, event_type="thinking", content=response.content)
155
+
156
+ if not response.tool_calls:
157
+ return self._build_result(response.content or "")
158
+
159
+ assistant_msg = Message(
160
+ role="assistant", content=response.content,
161
+ tool_calls=[
162
+ {"id": tc.id, "type": "function", "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}}
163
+ for tc in response.tool_calls
164
+ ],
165
+ reasoning_content=getattr(response, 'reasoning_content', None),
166
+ )
167
+ self.messages.append(assistant_msg)
168
+
169
+ should_finish = False
170
+ for tool_call in response.tool_calls:
171
+ self.session.add_trace(
172
+ agent=self.name, event_type="tool_call",
173
+ content=f"Calling {tool_call.name}",
174
+ tool_name=tool_call.name, tool_input=tool_call.arguments,
175
+ )
176
+
177
+ if tool_call.name == "validate_finding":
178
+ result = self._handle_validate_finding(tool_call.arguments)
179
+ elif tool_call.name == "finish_validation":
180
+ result = self._handle_finish_validation(tool_call.arguments)
181
+ should_finish = True
182
+ else:
183
+ result = self.tools.execute_tool(tool_call.name, tool_call.arguments)
184
+
185
+ self.session.add_trace(
186
+ agent=self.name, event_type="tool_result",
187
+ content=f"Result from {tool_call.name}",
188
+ tool_name=tool_call.name, tool_output=result,
189
+ )
190
+
191
+ raw_content = json.dumps(result) if isinstance(result, dict) else str(result)
192
+ truncated_content = self.context_manager.truncate_tool_result(tool_call.name, raw_content)
193
+ tool_result = ToolResult(
194
+ tool_call_id=tool_call.id,
195
+ content=truncated_content,
196
+ )
197
+ self.messages.append(tool_result.to_message())
198
+
199
+ if should_finish:
200
+ return self._build_result(response.content or "")
201
+
202
+ if self.context_manager.needs_compaction():
203
+ self.messages = self.context_manager.compact_messages(self.messages)
204
+ logger.info(f"[{self.name}] Compacted message history ({self.context_manager.last_input_tokens} input tokens)")
205
+
206
+ return self._build_result("Max iterations reached")
207
+
208
+ def _build_result(self, summary: str) -> dict:
209
+ return {
210
+ "raw_output": summary,
211
+ "validated_findings": self.validated_findings,
212
+ "false_positives": self.false_positives,
213
+ "type": "validation_complete",
214
+ }
215
+
216
+ def _parse_final_response(self, content: str) -> dict:
217
+ return self._build_result(content)
@@ -0,0 +1,106 @@
1
+ """
2
+ Validator swarm agent that spawns one sub-validator per finding.
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ from typing import Optional
8
+
9
+ from .validator import ValidatorAgent
10
+ from .llm import LLMClient
11
+ from .session import Session
12
+ from openhack.tools.registry import ToolRegistry
13
+ from openhack.config import settings
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ValidatorSwarmAgent:
19
+ name = "validator_swarm"
20
+ description = "Validator swarm coordinator"
21
+
22
+ def __init__(self, llm: LLMClient, tools: ToolRegistry, session: Session):
23
+ self.llm = llm
24
+ self.tools = tools
25
+ self.session = session
26
+ self.total_cost: float = 0.0
27
+ self.total_tokens: int = 0
28
+ self.total_input_tokens: int = 0
29
+ self.total_output_tokens: int = 0
30
+
31
+ def _create_llm_for_sub_validator(self) -> LLMClient:
32
+ model = settings.validator_model_id or self.llm.model
33
+ return LLMClient(model=model, temperature=0.0, max_tokens=8192, provider=self.llm.provider, prompt_cache_key=self.llm.prompt_cache_key)
34
+
35
+ def _build_sub_context(self, finding: dict, full_context: dict) -> dict:
36
+ return {
37
+ "hunter": {"findings": [finding]},
38
+ "recon": full_context.get("recon", {}),
39
+ "project_context": full_context.get("project_context", {}),
40
+ }
41
+
42
+ async def run(self, task: str, context: Optional[dict] = None) -> dict:
43
+ context = context or {}
44
+ findings = context.get("hunter", {}).get("findings", [])
45
+
46
+ if not findings:
47
+ return {"raw_output": "No findings to validate", "validated_findings": [], "false_positives": [], "type": "validation_complete"}
48
+
49
+ self.session.add_trace(agent=self.name, event_type="swarm_start", content={"findings_count": len(findings)})
50
+
51
+ sub_validators: list[tuple[int, ValidatorAgent, dict]] = []
52
+ for idx, finding in enumerate(findings):
53
+ llm = self._create_llm_for_sub_validator()
54
+ validator = ValidatorAgent(llm, self.tools, self.session, original_finding_index=idx)
55
+ sub_context = self._build_sub_context(finding, context)
56
+ sub_validators.append((idx, validator, sub_context))
57
+
58
+ semaphore = asyncio.Semaphore(settings.max_concurrent_validators)
59
+
60
+ async def run_sub_validator(finding_idx, validator, sub_context):
61
+ async with semaphore:
62
+ try:
63
+ sub_task = "Validate this potential vulnerability. Confirm whether it is real, generate a PoC, and suggest a fix."
64
+ result = await validator.run(sub_task, sub_context)
65
+ return finding_idx, result
66
+ except Exception as e:
67
+ logger.error(f"Sub-validator for finding {finding_idx} failed: {e}")
68
+ return finding_idx, {"validated_findings": [], "false_positives": [], "type": "validation_failed"}
69
+
70
+ tasks = [
71
+ asyncio.create_task(run_sub_validator(idx, validator, sub_ctx))
72
+ for idx, validator, sub_ctx in sub_validators
73
+ ]
74
+ try:
75
+ results = await asyncio.gather(*tasks)
76
+ except asyncio.CancelledError:
77
+ for t in tasks:
78
+ t.cancel()
79
+ await asyncio.gather(*tasks, return_exceptions=True)
80
+ raise
81
+
82
+ all_validated: list[dict] = []
83
+ all_false_positives: list[dict] = []
84
+
85
+ for finding_idx, result in results:
86
+ all_validated.extend(result.get("validated_findings", []))
87
+ all_false_positives.extend(result.get("false_positives", []))
88
+
89
+ for _, validator, _ in sub_validators:
90
+ self.total_cost += validator.llm.total_cost
91
+ self.total_tokens += validator.llm.total_tokens
92
+ self.total_input_tokens += validator.llm.total_input_tokens
93
+ self.total_output_tokens += validator.llm.total_output_tokens
94
+
95
+ self.session.add_trace(
96
+ agent=self.name, event_type="swarm_complete",
97
+ content={"total_confirmed": len(all_validated), "total_false_positives": len(all_false_positives),
98
+ "total_cost": self.total_cost, "total_tokens": self.total_tokens},
99
+ )
100
+
101
+ return {
102
+ "raw_output": f"Validated {len(findings)} findings: {len(all_validated)} confirmed, {len(all_false_positives)} false positives",
103
+ "validated_findings": all_validated,
104
+ "false_positives": all_false_positives,
105
+ "type": "validation_complete",
106
+ }