strix-agent 0.4.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strix/agents/StrixAgent/strix_agent.py +3 -3
- strix/agents/StrixAgent/system_prompt.jinja +30 -26
- strix/agents/base_agent.py +159 -75
- strix/agents/state.py +5 -2
- strix/config/__init__.py +12 -0
- strix/config/config.py +172 -0
- strix/interface/assets/tui_styles.tcss +195 -230
- strix/interface/cli.py +16 -41
- strix/interface/main.py +151 -74
- strix/interface/streaming_parser.py +119 -0
- strix/interface/tool_components/__init__.py +4 -0
- strix/interface/tool_components/agent_message_renderer.py +190 -0
- strix/interface/tool_components/agents_graph_renderer.py +54 -38
- strix/interface/tool_components/base_renderer.py +68 -36
- strix/interface/tool_components/browser_renderer.py +106 -91
- strix/interface/tool_components/file_edit_renderer.py +117 -36
- strix/interface/tool_components/finish_renderer.py +43 -10
- strix/interface/tool_components/notes_renderer.py +63 -38
- strix/interface/tool_components/proxy_renderer.py +133 -92
- strix/interface/tool_components/python_renderer.py +121 -8
- strix/interface/tool_components/registry.py +19 -12
- strix/interface/tool_components/reporting_renderer.py +196 -28
- strix/interface/tool_components/scan_info_renderer.py +22 -19
- strix/interface/tool_components/terminal_renderer.py +270 -90
- strix/interface/tool_components/thinking_renderer.py +8 -6
- strix/interface/tool_components/todo_renderer.py +225 -0
- strix/interface/tool_components/user_message_renderer.py +26 -19
- strix/interface/tool_components/web_search_renderer.py +7 -6
- strix/interface/tui.py +907 -262
- strix/interface/utils.py +236 -4
- strix/llm/__init__.py +6 -2
- strix/llm/config.py +8 -5
- strix/llm/dedupe.py +217 -0
- strix/llm/llm.py +209 -356
- strix/llm/memory_compressor.py +6 -5
- strix/llm/utils.py +17 -8
- strix/runtime/__init__.py +12 -3
- strix/runtime/docker_runtime.py +121 -202
- strix/runtime/tool_server.py +55 -95
- strix/skills/README.md +64 -0
- strix/skills/__init__.py +110 -0
- strix/{prompts → skills}/frameworks/nextjs.jinja +26 -0
- strix/skills/scan_modes/deep.jinja +145 -0
- strix/skills/scan_modes/quick.jinja +63 -0
- strix/skills/scan_modes/standard.jinja +91 -0
- strix/telemetry/README.md +38 -0
- strix/telemetry/__init__.py +7 -1
- strix/telemetry/posthog.py +137 -0
- strix/telemetry/tracer.py +194 -54
- strix/tools/__init__.py +11 -4
- strix/tools/agents_graph/agents_graph_actions.py +20 -21
- strix/tools/agents_graph/agents_graph_actions_schema.xml +8 -8
- strix/tools/browser/browser_actions.py +10 -6
- strix/tools/browser/browser_actions_schema.xml +6 -1
- strix/tools/browser/browser_instance.py +96 -48
- strix/tools/browser/tab_manager.py +121 -102
- strix/tools/context.py +12 -0
- strix/tools/executor.py +63 -4
- strix/tools/file_edit/file_edit_actions.py +6 -3
- strix/tools/file_edit/file_edit_actions_schema.xml +45 -3
- strix/tools/finish/finish_actions.py +80 -105
- strix/tools/finish/finish_actions_schema.xml +121 -14
- strix/tools/notes/notes_actions.py +6 -33
- strix/tools/notes/notes_actions_schema.xml +50 -46
- strix/tools/proxy/proxy_actions.py +14 -2
- strix/tools/proxy/proxy_actions_schema.xml +0 -1
- strix/tools/proxy/proxy_manager.py +28 -16
- strix/tools/python/python_actions.py +2 -2
- strix/tools/python/python_actions_schema.xml +9 -1
- strix/tools/python/python_instance.py +39 -37
- strix/tools/python/python_manager.py +43 -31
- strix/tools/registry.py +73 -12
- strix/tools/reporting/reporting_actions.py +218 -31
- strix/tools/reporting/reporting_actions_schema.xml +256 -8
- strix/tools/terminal/terminal_actions.py +2 -2
- strix/tools/terminal/terminal_actions_schema.xml +6 -0
- strix/tools/terminal/terminal_manager.py +41 -30
- strix/tools/thinking/thinking_actions_schema.xml +27 -25
- strix/tools/todo/__init__.py +18 -0
- strix/tools/todo/todo_actions.py +568 -0
- strix/tools/todo/todo_actions_schema.xml +225 -0
- strix/utils/__init__.py +0 -0
- strix/utils/resource_paths.py +13 -0
- {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info}/METADATA +90 -65
- strix_agent-0.6.2.dist-info/RECORD +134 -0
- {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info}/WHEEL +1 -1
- strix/llm/request_queue.py +0 -87
- strix/prompts/README.md +0 -64
- strix/prompts/__init__.py +0 -109
- strix_agent-0.4.0.dist-info/RECORD +0 -118
- /strix/{prompts → skills}/cloud/.gitkeep +0 -0
- /strix/{prompts → skills}/coordination/root_agent.jinja +0 -0
- /strix/{prompts → skills}/custom/.gitkeep +0 -0
- /strix/{prompts → skills}/frameworks/fastapi.jinja +0 -0
- /strix/{prompts → skills}/protocols/graphql.jinja +0 -0
- /strix/{prompts → skills}/reconnaissance/.gitkeep +0 -0
- /strix/{prompts → skills}/technologies/firebase_firestore.jinja +0 -0
- /strix/{prompts → skills}/technologies/supabase.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/authentication_jwt.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/broken_function_level_authorization.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/business_logic.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/csrf.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/idor.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/information_disclosure.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/insecure_file_uploads.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/mass_assignment.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/open_redirect.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/path_traversal_lfi_rfi.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/race_conditions.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/rce.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/sql_injection.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/ssrf.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/subdomain_takeover.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/xss.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/xxe.jinja +0 -0
- {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info}/entry_points.txt +0 -0
- {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info/licenses}/LICENSE +0 -0
strix/interface/utils.py
CHANGED
|
@@ -38,6 +38,165 @@ def get_severity_color(severity: str) -> str:
|
|
|
38
38
|
return severity_colors.get(severity, "#6b7280")
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
def get_cvss_color(cvss_score: float) -> str:
|
|
42
|
+
if cvss_score >= 9.0:
|
|
43
|
+
return "#dc2626"
|
|
44
|
+
if cvss_score >= 7.0:
|
|
45
|
+
return "#ea580c"
|
|
46
|
+
if cvss_score >= 4.0:
|
|
47
|
+
return "#d97706"
|
|
48
|
+
if cvss_score >= 0.1:
|
|
49
|
+
return "#65a30d"
|
|
50
|
+
return "#6b7280"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def format_vulnerability_report(report: dict[str, Any]) -> Text: # noqa: PLR0912, PLR0915
|
|
54
|
+
"""Format a vulnerability report for CLI display with all rich fields."""
|
|
55
|
+
field_style = "bold #4ade80"
|
|
56
|
+
|
|
57
|
+
text = Text()
|
|
58
|
+
|
|
59
|
+
title = report.get("title", "")
|
|
60
|
+
if title:
|
|
61
|
+
text.append("Vulnerability Report", style="bold #ea580c")
|
|
62
|
+
text.append("\n\n")
|
|
63
|
+
text.append("Title: ", style=field_style)
|
|
64
|
+
text.append(title)
|
|
65
|
+
|
|
66
|
+
severity = report.get("severity", "")
|
|
67
|
+
if severity:
|
|
68
|
+
text.append("\n\n")
|
|
69
|
+
text.append("Severity: ", style=field_style)
|
|
70
|
+
severity_color = get_severity_color(severity.lower())
|
|
71
|
+
text.append(severity.upper(), style=f"bold {severity_color}")
|
|
72
|
+
|
|
73
|
+
cvss = report.get("cvss")
|
|
74
|
+
if cvss is not None:
|
|
75
|
+
text.append("\n\n")
|
|
76
|
+
text.append("CVSS Score: ", style=field_style)
|
|
77
|
+
cvss_color = get_cvss_color(cvss)
|
|
78
|
+
text.append(f"{cvss:.1f}", style=f"bold {cvss_color}")
|
|
79
|
+
|
|
80
|
+
target = report.get("target")
|
|
81
|
+
if target:
|
|
82
|
+
text.append("\n\n")
|
|
83
|
+
text.append("Target: ", style=field_style)
|
|
84
|
+
text.append(target)
|
|
85
|
+
|
|
86
|
+
endpoint = report.get("endpoint")
|
|
87
|
+
if endpoint:
|
|
88
|
+
text.append("\n\n")
|
|
89
|
+
text.append("Endpoint: ", style=field_style)
|
|
90
|
+
text.append(endpoint)
|
|
91
|
+
|
|
92
|
+
method = report.get("method")
|
|
93
|
+
if method:
|
|
94
|
+
text.append("\n\n")
|
|
95
|
+
text.append("Method: ", style=field_style)
|
|
96
|
+
text.append(method)
|
|
97
|
+
|
|
98
|
+
cve = report.get("cve")
|
|
99
|
+
if cve:
|
|
100
|
+
text.append("\n\n")
|
|
101
|
+
text.append("CVE: ", style=field_style)
|
|
102
|
+
text.append(cve)
|
|
103
|
+
|
|
104
|
+
cvss_breakdown = report.get("cvss_breakdown", {})
|
|
105
|
+
if cvss_breakdown:
|
|
106
|
+
text.append("\n\n")
|
|
107
|
+
cvss_parts = []
|
|
108
|
+
if cvss_breakdown.get("attack_vector"):
|
|
109
|
+
cvss_parts.append(f"AV:{cvss_breakdown['attack_vector']}")
|
|
110
|
+
if cvss_breakdown.get("attack_complexity"):
|
|
111
|
+
cvss_parts.append(f"AC:{cvss_breakdown['attack_complexity']}")
|
|
112
|
+
if cvss_breakdown.get("privileges_required"):
|
|
113
|
+
cvss_parts.append(f"PR:{cvss_breakdown['privileges_required']}")
|
|
114
|
+
if cvss_breakdown.get("user_interaction"):
|
|
115
|
+
cvss_parts.append(f"UI:{cvss_breakdown['user_interaction']}")
|
|
116
|
+
if cvss_breakdown.get("scope"):
|
|
117
|
+
cvss_parts.append(f"S:{cvss_breakdown['scope']}")
|
|
118
|
+
if cvss_breakdown.get("confidentiality"):
|
|
119
|
+
cvss_parts.append(f"C:{cvss_breakdown['confidentiality']}")
|
|
120
|
+
if cvss_breakdown.get("integrity"):
|
|
121
|
+
cvss_parts.append(f"I:{cvss_breakdown['integrity']}")
|
|
122
|
+
if cvss_breakdown.get("availability"):
|
|
123
|
+
cvss_parts.append(f"A:{cvss_breakdown['availability']}")
|
|
124
|
+
if cvss_parts:
|
|
125
|
+
text.append("CVSS Vector: ", style=field_style)
|
|
126
|
+
text.append("/".join(cvss_parts), style="dim")
|
|
127
|
+
|
|
128
|
+
description = report.get("description")
|
|
129
|
+
if description:
|
|
130
|
+
text.append("\n\n")
|
|
131
|
+
text.append("Description", style=field_style)
|
|
132
|
+
text.append("\n")
|
|
133
|
+
text.append(description)
|
|
134
|
+
|
|
135
|
+
impact = report.get("impact")
|
|
136
|
+
if impact:
|
|
137
|
+
text.append("\n\n")
|
|
138
|
+
text.append("Impact", style=field_style)
|
|
139
|
+
text.append("\n")
|
|
140
|
+
text.append(impact)
|
|
141
|
+
|
|
142
|
+
technical_analysis = report.get("technical_analysis")
|
|
143
|
+
if technical_analysis:
|
|
144
|
+
text.append("\n\n")
|
|
145
|
+
text.append("Technical Analysis", style=field_style)
|
|
146
|
+
text.append("\n")
|
|
147
|
+
text.append(technical_analysis)
|
|
148
|
+
|
|
149
|
+
poc_description = report.get("poc_description")
|
|
150
|
+
if poc_description:
|
|
151
|
+
text.append("\n\n")
|
|
152
|
+
text.append("PoC Description", style=field_style)
|
|
153
|
+
text.append("\n")
|
|
154
|
+
text.append(poc_description)
|
|
155
|
+
|
|
156
|
+
poc_script_code = report.get("poc_script_code")
|
|
157
|
+
if poc_script_code:
|
|
158
|
+
text.append("\n\n")
|
|
159
|
+
text.append("PoC Code", style=field_style)
|
|
160
|
+
text.append("\n")
|
|
161
|
+
text.append(poc_script_code, style="dim")
|
|
162
|
+
|
|
163
|
+
code_file = report.get("code_file")
|
|
164
|
+
if code_file:
|
|
165
|
+
text.append("\n\n")
|
|
166
|
+
text.append("Code File: ", style=field_style)
|
|
167
|
+
text.append(code_file)
|
|
168
|
+
|
|
169
|
+
code_before = report.get("code_before")
|
|
170
|
+
if code_before:
|
|
171
|
+
text.append("\n\n")
|
|
172
|
+
text.append("Code Before", style=field_style)
|
|
173
|
+
text.append("\n")
|
|
174
|
+
text.append(code_before, style="dim")
|
|
175
|
+
|
|
176
|
+
code_after = report.get("code_after")
|
|
177
|
+
if code_after:
|
|
178
|
+
text.append("\n\n")
|
|
179
|
+
text.append("Code After", style=field_style)
|
|
180
|
+
text.append("\n")
|
|
181
|
+
text.append(code_after, style="dim")
|
|
182
|
+
|
|
183
|
+
code_diff = report.get("code_diff")
|
|
184
|
+
if code_diff:
|
|
185
|
+
text.append("\n\n")
|
|
186
|
+
text.append("Code Diff", style=field_style)
|
|
187
|
+
text.append("\n")
|
|
188
|
+
text.append(code_diff, style="dim")
|
|
189
|
+
|
|
190
|
+
remediation_steps = report.get("remediation_steps")
|
|
191
|
+
if remediation_steps:
|
|
192
|
+
text.append("\n\n")
|
|
193
|
+
text.append("Remediation", style=field_style)
|
|
194
|
+
text.append("\n")
|
|
195
|
+
text.append(remediation_steps)
|
|
196
|
+
|
|
197
|
+
return text
|
|
198
|
+
|
|
199
|
+
|
|
41
200
|
def _build_vulnerability_stats(stats_text: Text, tracer: Any) -> None:
|
|
42
201
|
"""Build vulnerability section of stats text."""
|
|
43
202
|
vuln_count = len(tracer.vulnerability_reports)
|
|
@@ -129,11 +288,17 @@ def build_final_stats_text(tracer: Any) -> Text:
|
|
|
129
288
|
return stats_text
|
|
130
289
|
|
|
131
290
|
|
|
132
|
-
def build_live_stats_text(tracer: Any) -> Text:
|
|
291
|
+
def build_live_stats_text(tracer: Any, agent_config: dict[str, Any] | None = None) -> Text:
|
|
133
292
|
stats_text = Text()
|
|
134
293
|
if not tracer:
|
|
135
294
|
return stats_text
|
|
136
295
|
|
|
296
|
+
if agent_config:
|
|
297
|
+
llm_config = agent_config["llm_config"]
|
|
298
|
+
model = getattr(llm_config, "model_name", "Unknown")
|
|
299
|
+
stats_text.append(f"🧠 Model: {model}")
|
|
300
|
+
stats_text.append("\n")
|
|
301
|
+
|
|
137
302
|
vuln_count = len(tracer.vulnerability_reports)
|
|
138
303
|
tool_count = tracer.get_real_tool_count()
|
|
139
304
|
agent_count = len(tracer.agents)
|
|
@@ -196,6 +361,31 @@ def build_live_stats_text(tracer: Any) -> Text:
|
|
|
196
361
|
return stats_text
|
|
197
362
|
|
|
198
363
|
|
|
364
|
+
def build_tui_stats_text(tracer: Any, agent_config: dict[str, Any] | None = None) -> Text:
|
|
365
|
+
stats_text = Text()
|
|
366
|
+
if not tracer:
|
|
367
|
+
return stats_text
|
|
368
|
+
|
|
369
|
+
if agent_config:
|
|
370
|
+
llm_config = agent_config["llm_config"]
|
|
371
|
+
model = getattr(llm_config, "model_name", "Unknown")
|
|
372
|
+
stats_text.append(model, style="dim")
|
|
373
|
+
|
|
374
|
+
llm_stats = tracer.get_total_llm_stats()
|
|
375
|
+
total_stats = llm_stats["total"]
|
|
376
|
+
|
|
377
|
+
total_tokens = total_stats["input_tokens"] + total_stats["output_tokens"]
|
|
378
|
+
if total_tokens > 0:
|
|
379
|
+
stats_text.append("\n")
|
|
380
|
+
stats_text.append(f"{format_token_count(total_tokens)} tokens", style="dim")
|
|
381
|
+
|
|
382
|
+
if total_stats["cost"] > 0:
|
|
383
|
+
stats_text.append("\n")
|
|
384
|
+
stats_text.append(f"${total_stats['cost']:.2f} spent", style="dim")
|
|
385
|
+
|
|
386
|
+
return stats_text
|
|
387
|
+
|
|
388
|
+
|
|
199
389
|
# Name generation utilities
|
|
200
390
|
|
|
201
391
|
|
|
@@ -398,6 +588,47 @@ def collect_local_sources(targets_info: list[dict[str, Any]]) -> list[dict[str,
|
|
|
398
588
|
return local_sources
|
|
399
589
|
|
|
400
590
|
|
|
591
|
+
def _is_localhost_host(host: str) -> bool:
|
|
592
|
+
host_lower = host.lower().strip("[]")
|
|
593
|
+
|
|
594
|
+
if host_lower in ("localhost", "0.0.0.0", "::1"): # nosec B104
|
|
595
|
+
return True
|
|
596
|
+
|
|
597
|
+
try:
|
|
598
|
+
ip = ipaddress.ip_address(host_lower)
|
|
599
|
+
if isinstance(ip, ipaddress.IPv4Address):
|
|
600
|
+
return ip.is_loopback # 127.0.0.0/8
|
|
601
|
+
if isinstance(ip, ipaddress.IPv6Address):
|
|
602
|
+
return ip.is_loopback # ::1
|
|
603
|
+
except ValueError:
|
|
604
|
+
pass
|
|
605
|
+
|
|
606
|
+
return False
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
def rewrite_localhost_targets(targets_info: list[dict[str, Any]], host_gateway: str) -> None:
|
|
610
|
+
from yarl import URL # type: ignore[import-not-found]
|
|
611
|
+
|
|
612
|
+
for target_info in targets_info:
|
|
613
|
+
target_type = target_info.get("type")
|
|
614
|
+
details = target_info.get("details", {})
|
|
615
|
+
|
|
616
|
+
if target_type == "web_application":
|
|
617
|
+
target_url = details.get("target_url", "")
|
|
618
|
+
try:
|
|
619
|
+
url = URL(target_url)
|
|
620
|
+
except (ValueError, TypeError):
|
|
621
|
+
continue
|
|
622
|
+
|
|
623
|
+
if url.host and _is_localhost_host(url.host):
|
|
624
|
+
details["target_url"] = str(url.with_host(host_gateway))
|
|
625
|
+
|
|
626
|
+
elif target_type == "ip_address":
|
|
627
|
+
target_ip = details.get("target_ip", "")
|
|
628
|
+
if target_ip and _is_localhost_host(target_ip):
|
|
629
|
+
details["target_ip"] = host_gateway
|
|
630
|
+
|
|
631
|
+
|
|
401
632
|
# Repository utilities
|
|
402
633
|
def clone_repository(repo_url: str, run_name: str, dest_name: str | None = None) -> str:
|
|
403
634
|
console = Console()
|
|
@@ -488,9 +719,10 @@ def check_docker_connection() -> Any:
|
|
|
488
719
|
error_text.append("DOCKER NOT AVAILABLE", style="bold red")
|
|
489
720
|
error_text.append("\n\n", style="white")
|
|
490
721
|
error_text.append("Cannot connect to Docker daemon.\n", style="white")
|
|
491
|
-
error_text.append(
|
|
492
|
-
|
|
493
|
-
|
|
722
|
+
error_text.append(
|
|
723
|
+
"Please ensure Docker Desktop is installed and running, and try running strix again.\n",
|
|
724
|
+
style="white",
|
|
725
|
+
)
|
|
494
726
|
|
|
495
727
|
panel = Panel(
|
|
496
728
|
error_text,
|
strix/llm/__init__.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
1
4
|
import litellm
|
|
2
5
|
|
|
3
6
|
from .config import LLMConfig
|
|
@@ -11,5 +14,6 @@ __all__ = [
|
|
|
11
14
|
]
|
|
12
15
|
|
|
13
16
|
litellm._logging._disable_debugging()
|
|
14
|
-
|
|
15
|
-
|
|
17
|
+
logging.getLogger("asyncio").setLevel(logging.CRITICAL)
|
|
18
|
+
logging.getLogger("asyncio").propagate = False
|
|
19
|
+
warnings.filterwarnings("ignore", category=RuntimeWarning, module="asyncio")
|
strix/llm/config.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import
|
|
1
|
+
from strix.config import Config
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class LLMConfig:
|
|
@@ -6,15 +6,18 @@ class LLMConfig:
|
|
|
6
6
|
self,
|
|
7
7
|
model_name: str | None = None,
|
|
8
8
|
enable_prompt_caching: bool = True,
|
|
9
|
-
|
|
9
|
+
skills: list[str] | None = None,
|
|
10
10
|
timeout: int | None = None,
|
|
11
|
+
scan_mode: str = "deep",
|
|
11
12
|
):
|
|
12
|
-
self.model_name = model_name or
|
|
13
|
+
self.model_name = model_name or Config.get("strix_llm")
|
|
13
14
|
|
|
14
15
|
if not self.model_name:
|
|
15
16
|
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
|
16
17
|
|
|
17
18
|
self.enable_prompt_caching = enable_prompt_caching
|
|
18
|
-
self.
|
|
19
|
+
self.skills = skills or []
|
|
19
20
|
|
|
20
|
-
self.timeout = timeout or int(
|
|
21
|
+
self.timeout = timeout or int(Config.get("llm_timeout") or "300")
|
|
22
|
+
|
|
23
|
+
self.scan_mode = scan_mode if scan_mode in ["quick", "standard", "deep"] else "deep"
|
strix/llm/dedupe.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import litellm
|
|
7
|
+
|
|
8
|
+
from strix.config import Config
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
DEDUPE_SYSTEM_PROMPT = """You are an expert vulnerability report deduplication judge.
|
|
14
|
+
Your task is to determine if a candidate vulnerability report describes the SAME vulnerability
|
|
15
|
+
as any existing report.
|
|
16
|
+
|
|
17
|
+
CRITICAL DEDUPLICATION RULES:
|
|
18
|
+
|
|
19
|
+
1. SAME VULNERABILITY means:
|
|
20
|
+
- Same root cause (e.g., "missing input validation" not just "SQL injection")
|
|
21
|
+
- Same affected component/endpoint/file (exact match or clear overlap)
|
|
22
|
+
- Same exploitation method or attack vector
|
|
23
|
+
- Would be fixed by the same code change/patch
|
|
24
|
+
|
|
25
|
+
2. NOT DUPLICATES if:
|
|
26
|
+
- Different endpoints even with same vulnerability type (e.g., SQLi in /login vs /search)
|
|
27
|
+
- Different parameters in same endpoint (e.g., XSS in 'name' vs 'comment' field)
|
|
28
|
+
- Different root causes (e.g., stored XSS vs reflected XSS in same field)
|
|
29
|
+
- Different severity levels due to different impact
|
|
30
|
+
- One is authenticated, other is unauthenticated
|
|
31
|
+
|
|
32
|
+
3. ARE DUPLICATES even if:
|
|
33
|
+
- Titles are worded differently
|
|
34
|
+
- Descriptions have different level of detail
|
|
35
|
+
- PoC uses different payloads but exploits same issue
|
|
36
|
+
- One report is more thorough than another
|
|
37
|
+
- Minor variations in technical analysis
|
|
38
|
+
|
|
39
|
+
COMPARISON GUIDELINES:
|
|
40
|
+
- Focus on the technical root cause, not surface-level similarities
|
|
41
|
+
- Same vulnerability type (SQLi, XSS) doesn't mean duplicate - location matters
|
|
42
|
+
- Consider the fix: would fixing one also fix the other?
|
|
43
|
+
- When uncertain, lean towards NOT duplicate
|
|
44
|
+
|
|
45
|
+
FIELDS TO ANALYZE:
|
|
46
|
+
- title, description: General vulnerability info
|
|
47
|
+
- target, endpoint, method: Exact location of vulnerability
|
|
48
|
+
- technical_analysis: Root cause details
|
|
49
|
+
- poc_description: How it's exploited
|
|
50
|
+
- impact: What damage it can cause
|
|
51
|
+
|
|
52
|
+
YOU MUST RESPOND WITH EXACTLY THIS XML FORMAT AND NOTHING ELSE:
|
|
53
|
+
|
|
54
|
+
<dedupe_result>
|
|
55
|
+
<is_duplicate>true</is_duplicate>
|
|
56
|
+
<duplicate_id>vuln-0001</duplicate_id>
|
|
57
|
+
<confidence>0.95</confidence>
|
|
58
|
+
<reason>Both reports describe SQL injection in /api/login via the username parameter</reason>
|
|
59
|
+
</dedupe_result>
|
|
60
|
+
|
|
61
|
+
OR if not a duplicate:
|
|
62
|
+
|
|
63
|
+
<dedupe_result>
|
|
64
|
+
<is_duplicate>false</is_duplicate>
|
|
65
|
+
<duplicate_id></duplicate_id>
|
|
66
|
+
<confidence>0.90</confidence>
|
|
67
|
+
<reason>Different endpoints: candidate is /api/search, existing is /api/login</reason>
|
|
68
|
+
</dedupe_result>
|
|
69
|
+
|
|
70
|
+
RULES:
|
|
71
|
+
- is_duplicate MUST be exactly "true" or "false" (lowercase)
|
|
72
|
+
- duplicate_id MUST be the exact ID from existing reports or empty if not duplicate
|
|
73
|
+
- confidence MUST be a decimal (your confidence level in the decision)
|
|
74
|
+
- reason MUST be a specific explanation mentioning endpoint/parameter/root cause
|
|
75
|
+
- DO NOT include any text outside the <dedupe_result> tags"""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _prepare_report_for_comparison(report: dict[str, Any]) -> dict[str, Any]:
|
|
79
|
+
relevant_fields = [
|
|
80
|
+
"id",
|
|
81
|
+
"title",
|
|
82
|
+
"description",
|
|
83
|
+
"impact",
|
|
84
|
+
"target",
|
|
85
|
+
"technical_analysis",
|
|
86
|
+
"poc_description",
|
|
87
|
+
"endpoint",
|
|
88
|
+
"method",
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
cleaned = {}
|
|
92
|
+
for field in relevant_fields:
|
|
93
|
+
if report.get(field):
|
|
94
|
+
value = report[field]
|
|
95
|
+
if isinstance(value, str) and len(value) > 8000:
|
|
96
|
+
value = value[:8000] + "...[truncated]"
|
|
97
|
+
cleaned[field] = value
|
|
98
|
+
|
|
99
|
+
return cleaned
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _extract_xml_field(content: str, field: str) -> str:
|
|
103
|
+
pattern = rf"<{field}>(.*?)</{field}>"
|
|
104
|
+
match = re.search(pattern, content, re.DOTALL | re.IGNORECASE)
|
|
105
|
+
if match:
|
|
106
|
+
return match.group(1).strip()
|
|
107
|
+
return ""
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _parse_dedupe_response(content: str) -> dict[str, Any]:
|
|
111
|
+
result_match = re.search(
|
|
112
|
+
r"<dedupe_result>(.*?)</dedupe_result>", content, re.DOTALL | re.IGNORECASE
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if not result_match:
|
|
116
|
+
logger.warning(f"No <dedupe_result> block found in response: {content[:500]}")
|
|
117
|
+
raise ValueError("No <dedupe_result> block found in response")
|
|
118
|
+
|
|
119
|
+
result_content = result_match.group(1)
|
|
120
|
+
|
|
121
|
+
is_duplicate_str = _extract_xml_field(result_content, "is_duplicate")
|
|
122
|
+
duplicate_id = _extract_xml_field(result_content, "duplicate_id")
|
|
123
|
+
confidence_str = _extract_xml_field(result_content, "confidence")
|
|
124
|
+
reason = _extract_xml_field(result_content, "reason")
|
|
125
|
+
|
|
126
|
+
is_duplicate = is_duplicate_str.lower() == "true"
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
confidence = float(confidence_str) if confidence_str else 0.0
|
|
130
|
+
except ValueError:
|
|
131
|
+
confidence = 0.0
|
|
132
|
+
|
|
133
|
+
return {
|
|
134
|
+
"is_duplicate": is_duplicate,
|
|
135
|
+
"duplicate_id": duplicate_id[:64] if duplicate_id else "",
|
|
136
|
+
"confidence": confidence,
|
|
137
|
+
"reason": reason[:500] if reason else "",
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def check_duplicate(
|
|
142
|
+
candidate: dict[str, Any], existing_reports: list[dict[str, Any]]
|
|
143
|
+
) -> dict[str, Any]:
|
|
144
|
+
if not existing_reports:
|
|
145
|
+
return {
|
|
146
|
+
"is_duplicate": False,
|
|
147
|
+
"duplicate_id": "",
|
|
148
|
+
"confidence": 1.0,
|
|
149
|
+
"reason": "No existing reports to compare against",
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
candidate_cleaned = _prepare_report_for_comparison(candidate)
|
|
154
|
+
existing_cleaned = [_prepare_report_for_comparison(r) for r in existing_reports]
|
|
155
|
+
|
|
156
|
+
comparison_data = {"candidate": candidate_cleaned, "existing_reports": existing_cleaned}
|
|
157
|
+
|
|
158
|
+
model_name = Config.get("strix_llm")
|
|
159
|
+
api_key = Config.get("llm_api_key")
|
|
160
|
+
api_base = (
|
|
161
|
+
Config.get("llm_api_base")
|
|
162
|
+
or Config.get("openai_api_base")
|
|
163
|
+
or Config.get("litellm_base_url")
|
|
164
|
+
or Config.get("ollama_api_base")
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
messages = [
|
|
168
|
+
{"role": "system", "content": DEDUPE_SYSTEM_PROMPT},
|
|
169
|
+
{
|
|
170
|
+
"role": "user",
|
|
171
|
+
"content": (
|
|
172
|
+
f"Compare this candidate vulnerability against existing reports:\n\n"
|
|
173
|
+
f"{json.dumps(comparison_data, indent=2)}\n\n"
|
|
174
|
+
f"Respond with ONLY the <dedupe_result> XML block."
|
|
175
|
+
),
|
|
176
|
+
},
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
completion_kwargs: dict[str, Any] = {
|
|
180
|
+
"model": model_name,
|
|
181
|
+
"messages": messages,
|
|
182
|
+
"timeout": 120,
|
|
183
|
+
}
|
|
184
|
+
if api_key:
|
|
185
|
+
completion_kwargs["api_key"] = api_key
|
|
186
|
+
if api_base:
|
|
187
|
+
completion_kwargs["api_base"] = api_base
|
|
188
|
+
|
|
189
|
+
response = litellm.completion(**completion_kwargs)
|
|
190
|
+
|
|
191
|
+
content = response.choices[0].message.content
|
|
192
|
+
if not content:
|
|
193
|
+
return {
|
|
194
|
+
"is_duplicate": False,
|
|
195
|
+
"duplicate_id": "",
|
|
196
|
+
"confidence": 0.0,
|
|
197
|
+
"reason": "Empty response from LLM",
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
result = _parse_dedupe_response(content)
|
|
201
|
+
|
|
202
|
+
logger.info(
|
|
203
|
+
f"Deduplication check: is_duplicate={result['is_duplicate']}, "
|
|
204
|
+
f"confidence={result['confidence']}, reason={result['reason'][:100]}"
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
except Exception as e:
|
|
208
|
+
logger.exception("Error during vulnerability deduplication check")
|
|
209
|
+
return {
|
|
210
|
+
"is_duplicate": False,
|
|
211
|
+
"duplicate_id": "",
|
|
212
|
+
"confidence": 0.0,
|
|
213
|
+
"reason": f"Deduplication check failed: {e}",
|
|
214
|
+
"error": str(e),
|
|
215
|
+
}
|
|
216
|
+
else:
|
|
217
|
+
return result
|