aiptx 2.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. aipt_v2/__init__.py +110 -0
  2. aipt_v2/__main__.py +24 -0
  3. aipt_v2/agents/AIPTxAgent/__init__.py +10 -0
  4. aipt_v2/agents/AIPTxAgent/aiptx_agent.py +211 -0
  5. aipt_v2/agents/__init__.py +46 -0
  6. aipt_v2/agents/base.py +520 -0
  7. aipt_v2/agents/exploit_agent.py +688 -0
  8. aipt_v2/agents/ptt.py +406 -0
  9. aipt_v2/agents/state.py +168 -0
  10. aipt_v2/app.py +957 -0
  11. aipt_v2/browser/__init__.py +31 -0
  12. aipt_v2/browser/automation.py +458 -0
  13. aipt_v2/browser/crawler.py +453 -0
  14. aipt_v2/cli.py +2933 -0
  15. aipt_v2/compliance/__init__.py +71 -0
  16. aipt_v2/compliance/compliance_report.py +449 -0
  17. aipt_v2/compliance/framework_mapper.py +424 -0
  18. aipt_v2/compliance/nist_mapping.py +345 -0
  19. aipt_v2/compliance/owasp_mapping.py +330 -0
  20. aipt_v2/compliance/pci_mapping.py +297 -0
  21. aipt_v2/config.py +341 -0
  22. aipt_v2/core/__init__.py +43 -0
  23. aipt_v2/core/agent.py +630 -0
  24. aipt_v2/core/llm.py +395 -0
  25. aipt_v2/core/memory.py +305 -0
  26. aipt_v2/core/ptt.py +329 -0
  27. aipt_v2/database/__init__.py +14 -0
  28. aipt_v2/database/models.py +232 -0
  29. aipt_v2/database/repository.py +384 -0
  30. aipt_v2/docker/__init__.py +23 -0
  31. aipt_v2/docker/builder.py +260 -0
  32. aipt_v2/docker/manager.py +222 -0
  33. aipt_v2/docker/sandbox.py +371 -0
  34. aipt_v2/evasion/__init__.py +58 -0
  35. aipt_v2/evasion/request_obfuscator.py +272 -0
  36. aipt_v2/evasion/tls_fingerprint.py +285 -0
  37. aipt_v2/evasion/ua_rotator.py +301 -0
  38. aipt_v2/evasion/waf_bypass.py +439 -0
  39. aipt_v2/execution/__init__.py +23 -0
  40. aipt_v2/execution/executor.py +302 -0
  41. aipt_v2/execution/parser.py +544 -0
  42. aipt_v2/execution/terminal.py +337 -0
  43. aipt_v2/health.py +437 -0
  44. aipt_v2/intelligence/__init__.py +194 -0
  45. aipt_v2/intelligence/adaptation.py +474 -0
  46. aipt_v2/intelligence/auth.py +520 -0
  47. aipt_v2/intelligence/chaining.py +775 -0
  48. aipt_v2/intelligence/correlation.py +536 -0
  49. aipt_v2/intelligence/cve_aipt.py +334 -0
  50. aipt_v2/intelligence/cve_info.py +1111 -0
  51. aipt_v2/intelligence/knowledge_graph.py +590 -0
  52. aipt_v2/intelligence/learning.py +626 -0
  53. aipt_v2/intelligence/llm_analyzer.py +502 -0
  54. aipt_v2/intelligence/llm_tool_selector.py +518 -0
  55. aipt_v2/intelligence/payload_generator.py +562 -0
  56. aipt_v2/intelligence/rag.py +239 -0
  57. aipt_v2/intelligence/scope.py +442 -0
  58. aipt_v2/intelligence/searchers/__init__.py +5 -0
  59. aipt_v2/intelligence/searchers/exploitdb_searcher.py +523 -0
  60. aipt_v2/intelligence/searchers/github_searcher.py +467 -0
  61. aipt_v2/intelligence/searchers/google_searcher.py +281 -0
  62. aipt_v2/intelligence/tools.json +443 -0
  63. aipt_v2/intelligence/triage.py +670 -0
  64. aipt_v2/interactive_shell.py +559 -0
  65. aipt_v2/interface/__init__.py +5 -0
  66. aipt_v2/interface/cli.py +230 -0
  67. aipt_v2/interface/main.py +501 -0
  68. aipt_v2/interface/tui.py +1276 -0
  69. aipt_v2/interface/utils.py +583 -0
  70. aipt_v2/llm/__init__.py +39 -0
  71. aipt_v2/llm/config.py +26 -0
  72. aipt_v2/llm/llm.py +514 -0
  73. aipt_v2/llm/memory.py +214 -0
  74. aipt_v2/llm/request_queue.py +89 -0
  75. aipt_v2/llm/utils.py +89 -0
  76. aipt_v2/local_tool_installer.py +1467 -0
  77. aipt_v2/models/__init__.py +15 -0
  78. aipt_v2/models/findings.py +295 -0
  79. aipt_v2/models/phase_result.py +224 -0
  80. aipt_v2/models/scan_config.py +207 -0
  81. aipt_v2/monitoring/grafana/dashboards/aipt-dashboard.json +355 -0
  82. aipt_v2/monitoring/grafana/dashboards/default.yml +17 -0
  83. aipt_v2/monitoring/grafana/datasources/prometheus.yml +17 -0
  84. aipt_v2/monitoring/prometheus.yml +60 -0
  85. aipt_v2/orchestration/__init__.py +52 -0
  86. aipt_v2/orchestration/pipeline.py +398 -0
  87. aipt_v2/orchestration/progress.py +300 -0
  88. aipt_v2/orchestration/scheduler.py +296 -0
  89. aipt_v2/orchestrator.py +2427 -0
  90. aipt_v2/payloads/__init__.py +27 -0
  91. aipt_v2/payloads/cmdi.py +150 -0
  92. aipt_v2/payloads/sqli.py +263 -0
  93. aipt_v2/payloads/ssrf.py +204 -0
  94. aipt_v2/payloads/templates.py +222 -0
  95. aipt_v2/payloads/traversal.py +166 -0
  96. aipt_v2/payloads/xss.py +204 -0
  97. aipt_v2/prompts/__init__.py +60 -0
  98. aipt_v2/proxy/__init__.py +29 -0
  99. aipt_v2/proxy/history.py +352 -0
  100. aipt_v2/proxy/interceptor.py +452 -0
  101. aipt_v2/recon/__init__.py +44 -0
  102. aipt_v2/recon/dns.py +241 -0
  103. aipt_v2/recon/osint.py +367 -0
  104. aipt_v2/recon/subdomain.py +372 -0
  105. aipt_v2/recon/tech_detect.py +311 -0
  106. aipt_v2/reports/__init__.py +17 -0
  107. aipt_v2/reports/generator.py +313 -0
  108. aipt_v2/reports/html_report.py +378 -0
  109. aipt_v2/runtime/__init__.py +53 -0
  110. aipt_v2/runtime/base.py +30 -0
  111. aipt_v2/runtime/docker.py +401 -0
  112. aipt_v2/runtime/local.py +346 -0
  113. aipt_v2/runtime/tool_server.py +205 -0
  114. aipt_v2/runtime/vps.py +830 -0
  115. aipt_v2/scanners/__init__.py +28 -0
  116. aipt_v2/scanners/base.py +273 -0
  117. aipt_v2/scanners/nikto.py +244 -0
  118. aipt_v2/scanners/nmap.py +402 -0
  119. aipt_v2/scanners/nuclei.py +273 -0
  120. aipt_v2/scanners/web.py +454 -0
  121. aipt_v2/scripts/security_audit.py +366 -0
  122. aipt_v2/setup_wizard.py +941 -0
  123. aipt_v2/skills/__init__.py +80 -0
  124. aipt_v2/skills/agents/__init__.py +14 -0
  125. aipt_v2/skills/agents/api_tester.py +706 -0
  126. aipt_v2/skills/agents/base.py +477 -0
  127. aipt_v2/skills/agents/code_review.py +459 -0
  128. aipt_v2/skills/agents/security_agent.py +336 -0
  129. aipt_v2/skills/agents/web_pentest.py +818 -0
  130. aipt_v2/skills/prompts/__init__.py +647 -0
  131. aipt_v2/system_detector.py +539 -0
  132. aipt_v2/telemetry/__init__.py +7 -0
  133. aipt_v2/telemetry/tracer.py +347 -0
  134. aipt_v2/terminal/__init__.py +28 -0
  135. aipt_v2/terminal/executor.py +400 -0
  136. aipt_v2/terminal/sandbox.py +350 -0
  137. aipt_v2/tools/__init__.py +44 -0
  138. aipt_v2/tools/active_directory/__init__.py +78 -0
  139. aipt_v2/tools/active_directory/ad_config.py +238 -0
  140. aipt_v2/tools/active_directory/bloodhound_wrapper.py +447 -0
  141. aipt_v2/tools/active_directory/kerberos_attacks.py +430 -0
  142. aipt_v2/tools/active_directory/ldap_enum.py +533 -0
  143. aipt_v2/tools/active_directory/smb_attacks.py +505 -0
  144. aipt_v2/tools/agents_graph/__init__.py +19 -0
  145. aipt_v2/tools/agents_graph/agents_graph_actions.py +69 -0
  146. aipt_v2/tools/api_security/__init__.py +76 -0
  147. aipt_v2/tools/api_security/api_discovery.py +608 -0
  148. aipt_v2/tools/api_security/graphql_scanner.py +622 -0
  149. aipt_v2/tools/api_security/jwt_analyzer.py +577 -0
  150. aipt_v2/tools/api_security/openapi_fuzzer.py +761 -0
  151. aipt_v2/tools/browser/__init__.py +5 -0
  152. aipt_v2/tools/browser/browser_actions.py +238 -0
  153. aipt_v2/tools/browser/browser_instance.py +535 -0
  154. aipt_v2/tools/browser/tab_manager.py +344 -0
  155. aipt_v2/tools/cloud/__init__.py +70 -0
  156. aipt_v2/tools/cloud/cloud_config.py +273 -0
  157. aipt_v2/tools/cloud/cloud_scanner.py +639 -0
  158. aipt_v2/tools/cloud/prowler_tool.py +571 -0
  159. aipt_v2/tools/cloud/scoutsuite_tool.py +359 -0
  160. aipt_v2/tools/executor.py +307 -0
  161. aipt_v2/tools/parser.py +408 -0
  162. aipt_v2/tools/proxy/__init__.py +5 -0
  163. aipt_v2/tools/proxy/proxy_actions.py +103 -0
  164. aipt_v2/tools/proxy/proxy_manager.py +789 -0
  165. aipt_v2/tools/registry.py +196 -0
  166. aipt_v2/tools/scanners/__init__.py +343 -0
  167. aipt_v2/tools/scanners/acunetix_tool.py +712 -0
  168. aipt_v2/tools/scanners/burp_tool.py +631 -0
  169. aipt_v2/tools/scanners/config.py +156 -0
  170. aipt_v2/tools/scanners/nessus_tool.py +588 -0
  171. aipt_v2/tools/scanners/zap_tool.py +612 -0
  172. aipt_v2/tools/terminal/__init__.py +5 -0
  173. aipt_v2/tools/terminal/terminal_actions.py +37 -0
  174. aipt_v2/tools/terminal/terminal_manager.py +153 -0
  175. aipt_v2/tools/terminal/terminal_session.py +449 -0
  176. aipt_v2/tools/tool_processing.py +108 -0
  177. aipt_v2/utils/__init__.py +17 -0
  178. aipt_v2/utils/logging.py +202 -0
  179. aipt_v2/utils/model_manager.py +187 -0
  180. aipt_v2/utils/searchers/__init__.py +269 -0
  181. aipt_v2/verify_install.py +793 -0
  182. aiptx-2.0.7.dist-info/METADATA +345 -0
  183. aiptx-2.0.7.dist-info/RECORD +187 -0
  184. aiptx-2.0.7.dist-info/WHEEL +5 -0
  185. aiptx-2.0.7.dist-info/entry_points.txt +7 -0
  186. aiptx-2.0.7.dist-info/licenses/LICENSE +21 -0
  187. aiptx-2.0.7.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1111 @@
1
+ # from github_searcher import GithubSearcher
2
+ # from google_searcher import GoogleSearcher
3
+ import tiktoken
4
+ from llama_index.core.callbacks import CallbackManager, TokenCountingHandler
5
+ from aipt_v2.utils.doc_handler import DocHandler
6
+ from llama_index.llms.openai import OpenAI
7
+ from llama_index.llms.langchain import LangChainLLM
8
+ from llama_index.llms.huggingface import HuggingFaceLLM
9
+ from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
10
+ from llama_index.core import Settings
11
+ import subprocess
12
+ import shlex
13
+ import os
14
+ import openai
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+ import torch
17
+ import pandas as pd
18
+
19
+ # Security: CVE ID validation pattern (CWE-78 prevention)
20
+ CVE_PATTERN = re.compile(r"^CVE-\d{4}-\d{4,7}$", re.IGNORECASE)
21
+
22
+ def _validate_cve_id(cve: str) -> str:
23
+ """
24
+ Validate CVE ID format to prevent command injection.
25
+
26
+ Args:
27
+ cve: CVE identifier string
28
+
29
+ Returns:
30
+ Validated CVE ID
31
+
32
+ Raises:
33
+ ValueError: If CVE ID format is invalid
34
+ """
35
+ cve = cve.strip()
36
+ if not CVE_PATTERN.match(cve):
37
+ raise ValueError(f"Invalid CVE ID format: {cve}. Expected format: CVE-YYYY-NNNNN")
38
+ return cve.upper()
39
+
40
+ def _sanitize_product_name(product: str) -> str:
41
+ """
42
+ Sanitize product name to prevent command injection.
43
+
44
+ Args:
45
+ product: Product name string
46
+
47
+ Returns:
48
+ Sanitized product name
49
+ """
50
+ # Remove dangerous characters that could enable shell injection
51
+ dangerous_chars = [";", "&", "|", "$", "`", "\n", "\r", "\\", "'", '"', "(", ")", "{", "}", "[", "]", "<", ">"]
52
+ sanitized = product
53
+ for char in dangerous_chars:
54
+ sanitized = sanitized.replace(char, "")
55
+ return sanitized.strip()[:200] # Limit length
56
+
57
+ from sklearn.metrics import confusion_matrix, cohen_kappa_score
58
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
59
+ import seaborn as sns
60
+ import matplotlib.pyplot as plt
61
+ from scipy.stats import pearsonr, spearmanr
62
+ from itertools import combinations
63
+ import numpy as np
64
+ import json
65
+ import logging
66
+ import time
67
+ from tqdm import tqdm
68
+ from aipt_v2.utils.searchers.search_once import compose
69
+ import re
70
+ import yaml
71
+ from aipt_v2.utils.model_manager import get_model
72
+ logger = logging.getLogger(__name__)
73
+
74
+ config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'config.yaml')
75
+ with open(config_path, 'r', encoding='utf-8') as f:
76
+ config = yaml.safe_load(f)
77
+ planning_config = config['runtime']['planning']
78
+ model_name_for_token = config['models']['openai']['model']
79
+
80
+ def cvemap_search(cve, info_dir):
81
+ """
82
+ Search CVE information using cvemap tool.
83
+
84
+ Security: Uses validated CVE ID and subprocess without shell=True (CWE-78 fix).
85
+ """
86
+ # Validate CVE ID to prevent command injection
87
+ validated_cve = _validate_cve_id(cve)
88
+ cvemap_json_path = f"{info_dir}/cvemap.json"
89
+
90
+ try:
91
+ # SECURE: Use subprocess without shell=True
92
+ # Run cvemap with -id flag instead of piping through echo
93
+ result = subprocess.run(
94
+ ["cvemap", "-id", validated_cve, "-json"],
95
+ capture_output=True,
96
+ text=True,
97
+ timeout=50,
98
+ check=False # Handle errors ourselves
99
+ )
100
+
101
+ if result.returncode != 0:
102
+ logging.error(f"cvemap failed for {validated_cve}: {result.stderr}")
103
+ return None
104
+
105
+ # Write output to file
106
+ with open(cvemap_json_path, 'w') as f:
107
+ f.write(result.stdout)
108
+
109
+ json_data = json.loads(result.stdout)
110
+ if len(json_data) == 0:
111
+ logging.error(f"Error in cvemap search: {validated_cve}")
112
+ return None
113
+ return json_data[0]
114
+
115
+ except subprocess.TimeoutExpired:
116
+ logging.error(f"cvemap search timed out for {validated_cve}")
117
+ return None
118
+ except json.JSONDecodeError as e:
119
+ logging.error(f"Invalid JSON from cvemap for {validated_cve}: {e}")
120
+ return None
121
+ except FileNotFoundError:
122
+ logging.error("cvemap command not found. Please install cvemap.")
123
+ return None
124
+
125
+ def searchsploit_search(cve):
126
+ """
127
+ Search ExploitDB using searchsploit tool.
128
+
129
+ Security: Uses validated CVE ID and subprocess without shell=True (CWE-78 fix).
130
+ """
131
+ # Validate CVE ID to prevent command injection
132
+ validated_cve = _validate_cve_id(cve)
133
+ searchsploit_json_path = f"resources/{validated_cve}/exploitdb.json"
134
+
135
+ # Ensure directory exists
136
+ os.makedirs(os.path.dirname(searchsploit_json_path), exist_ok=True)
137
+
138
+ try:
139
+ # SECURE: Use subprocess without shell=True
140
+ result = subprocess.run(
141
+ ["searchsploit", validated_cve, "-j"],
142
+ capture_output=True,
143
+ text=True,
144
+ timeout=50,
145
+ check=False
146
+ )
147
+
148
+ # Write output to file
149
+ with open(searchsploit_json_path, 'w') as f:
150
+ f.write(result.stdout)
151
+
152
+ json_data = json.loads(result.stdout)
153
+ return json_data
154
+
155
+ except subprocess.TimeoutExpired:
156
+ logging.error(f"searchsploit search timed out for {validated_cve}")
157
+ return None
158
+ except json.JSONDecodeError as e:
159
+ logging.error(f"Invalid JSON from searchsploit for {validated_cve}: {e}")
160
+ return None
161
+ except FileNotFoundError:
162
+ logging.error("searchsploit command not found. Please install exploitdb.")
163
+ return None
164
+
165
+
166
+ def categorize_cvss(cvss_score):
167
+ if not (0.0 <= cvss_score <= 10.0):
168
+ return "value not in range"
169
+
170
+ if cvss_score < 4.0:
171
+ return "hard"
172
+ elif 4.0 <= cvss_score < 7.0:
173
+ return "medium"
174
+ else:
175
+ return "easy"
176
+
177
+ def categorize_epss(epss_score):
178
+ if not (0.0 <= epss_score <= 1.0):
179
+ return "value not in range"
180
+
181
+ if epss_score < 0.4:
182
+ return "hard"
183
+ elif 0.4 <= epss_score < 0.94:
184
+ return "medium"
185
+ else:
186
+ return "easy"
187
+
188
+ def count_cwe(cve_lst):
189
+
190
+ cwe_dict = {}
191
+ for cve in cve_lst:
192
+ info_dir = f"resources/{cve}/info"
193
+ cvemap_json_path = f"{info_dir}/cvemap.json"
194
+ with open(cvemap_json_path) as f:
195
+ cvemap_json = json.load(f)[0]
196
+ cwe_lst = cvemap_json.get('weaknesses', [])
197
+ for cwe in cwe_lst:
198
+ cwe_id = cwe['cwe_id']
199
+ if cwe_id in cwe_dict:
200
+ cwe_dict[cwe_id] += 1
201
+ else:
202
+ cwe_dict[cwe_id] = 1
203
+ return cwe_dict
204
+
205
+ def calculate_score(features, trending_score):
206
+ weights = config['cve_scoring']['weights']
207
+
208
+ scores = {}
209
+ max_score_github = 0
210
+ max_score_repo_github = None
211
+ max_score_expdb = 0
212
+ max_score_repo_expdb = None
213
+ max_score = 0
214
+ max_score_repo = None
215
+ final_score = 0
216
+ has_code = False
217
+
218
+ code_sources = ["GitHub", "ExploitDB"]
219
+
220
+ code_results = None
221
+ doc_results = None
222
+
223
+ if features['code'].get('GitHub') or features['code'].get('ExploitDB'):
224
+ for source in code_sources:
225
+ if not features['code'].get(source):
226
+ continue # if not have data from source, skip
227
+
228
+ has_code = True
229
+ for repo, vul_type in features['code'][source]["vul_type"].items():
230
+ score = 0
231
+
232
+ # Calculate scores for each field
233
+ score += weights["vul_type"].get(vul_type, 0)
234
+ score += weights["isRemote"].get(features['code'][source]["isRemote"].get(repo, ""), 0)
235
+
236
+ # Attack complexity fields
237
+ attack_complexity = features['code'][source]["attack_complexity"].get(repo, {})
238
+ for field, value in attack_complexity.items():
239
+ score += 1 * weights["attack_complexity"].get(field, {}).get(value, 0)
240
+
241
+ if repo == "Code_File":
242
+ score = score / 2
243
+
244
+ score *= weights["exp_maturity"].get(features['code'][source]["exp_maturity"].get(repo, ""), 0)
245
+
246
+ score *= weights["lang_class"].get(features['code'][source]["lang_class"].get(repo, ""), 1)
247
+
248
+ if source == "GitHub":
249
+ score *= weights["source_weights"]["gthb"]
250
+ elif source == "ExploitDB":
251
+ score *= weights["source_weights"]["expdb"]
252
+
253
+ # Store the score
254
+ scores[repo] = score
255
+
256
+ # Check for max score
257
+ if repo.isdigit(): # ExploitDB
258
+ if score >= max_score_expdb:
259
+ max_score_expdb = score
260
+ max_score_repo_expdb = repo
261
+ else: # GitHub
262
+ if score >= max_score_github:
263
+ max_score_github = score
264
+ max_score_repo_github = repo
265
+
266
+ if max_score_expdb > max_score_github:
267
+ max_score_repo = max_score_repo_expdb
268
+ max_score = max_score_expdb
269
+ else:
270
+ max_score_repo = max_score_repo_github
271
+ max_score = max_score_github
272
+
273
+ if has_code:
274
+ sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
275
+
276
+ if trending_score == 999:
277
+ trending_score_github = 0
278
+ trending_score_expdb = weights["expdb_default_score"]
279
+ else:
280
+ trending_score_github = trending_score
281
+ if features['code'].get('ExploitDB'):
282
+ trending_score_expdb = weights["expdb_default_score"]
283
+ else: trending_score_expdb = 0
284
+
285
+ # Add weighted trending score to the max score
286
+ trend_score_weighted_expdb = trending_score_expdb * weights["trending_score"]
287
+ trend_score_weighted_github = trending_score_github * weights["trending_score"]
288
+
289
+ # Handle cases where one source has no scores
290
+ final_expdb = (max_score_expdb + trend_score_weighted_expdb) if max_score_expdb > 0 else 0
291
+ final_github = (max_score_github + trend_score_weighted_github) if max_score_github > 0 else 0
292
+
293
+ final_score = max(final_expdb, final_github)
294
+
295
+ code_results = (sorted_scores, max_score_repo, max_score, final_score, has_code)
296
+
297
+ if features['doc']:
298
+ doc_score = 0
299
+ doc_score += weights["vul_type"].get(features["doc"]["vul_type"], 0)
300
+ doc_score += weights["isRemote"].get(features["doc"]["isRemote"], 0)
301
+
302
+ # Attack complexity fields for 'doc'
303
+ for field, value in features["doc"]["attack_complexity"].items():
304
+ doc_score += weights["attack_complexity"].get(field, {}).get(value, 0)
305
+
306
+ doc_score *= weights["source_weights"]["gg"]
307
+
308
+ scores["doc"] = doc_score
309
+
310
+ # Check if 'doc' has the max score
311
+ if doc_score > max_score:
312
+ max_score = doc_score
313
+ max_score_repo = "doc"
314
+ final_score = doc_score
315
+
316
+ doc_results = (scores, max_score_repo, max_score, max_score, has_code)
317
+
318
+ if code_results and doc_results:
319
+ code_sorted_scores, code_max_repo, code_max_score, code_final_score, code_has_code = code_results
320
+ doc_scores, doc_max_repo, doc_max_score, doc_final_score, doc_has_code = doc_results
321
+
322
+ # merge scores
323
+ all_scores = dict(code_sorted_scores)
324
+ all_scores.update(doc_scores)
325
+ sorted_all_scores = sorted(all_scores.items(), key=lambda item: item[1], reverse=True)
326
+
327
+ # final max score and response repo
328
+ if doc_max_score > code_max_score:
329
+ final_max_score = doc_max_score
330
+ final_max_repo = doc_max_repo
331
+ final_score = doc_max_score
332
+ else:
333
+ final_max_score = code_max_score
334
+ final_max_repo = code_max_repo
335
+ final_score = code_final_score
336
+
337
+ return sorted_all_scores, final_max_repo, final_max_score, final_score, has_code or True
338
+
339
+ elif code_results:
340
+ return code_results
341
+ elif doc_results:
342
+ return doc_results
343
+ else:
344
+ return None, None, max_score, final_score, has_code
345
+
346
+ def calculate_match_stats(values1, values2):
347
+ total = len(values1)
348
+ exact_matches = sum(v1 == v2 for v1, v2 in zip(values1, values2))
349
+ near_matches = sum(
350
+ abs(['easy', 'medium', 'hard'].index(v1) - ['easy', 'medium', 'hard'].index(v2)) == 1
351
+ for v1, v2 in zip(values1, values2)
352
+ )
353
+ mismatches = total - exact_matches - near_matches
354
+ return {
355
+ "Exact Match (%)": (exact_matches / total) * 100,
356
+ "Near Match (%)": (near_matches / total) * 100,
357
+ "Mismatch (%)": (mismatches / total) * 100,
358
+ }
359
+
360
+ def create_df(cvss_data, epss_data, pentestasst_data):
361
+ all_keys = sorted(set(cvss_data.keys()).union(epss_data.keys(), pentestasst_data.keys()))
362
+ data = {
363
+ 'CVSS': [cvss_data.get(key, np.nan) for key in all_keys],
364
+ 'EPSS': [epss_data.get(key, np.nan) for key in all_keys],
365
+ 'EEAS': [pentestasst_data.get(key, np.nan) for key in all_keys],
366
+ }
367
+ df = pd.DataFrame(data, index=all_keys)
368
+ return df
369
+
370
+ def normalize_data(df):
371
+ if len(df) == 1:
372
+ # for the case of a single value, directly return 0.5 or keep the original value
373
+ df[['EEAS']] = 0.5 # or df[['EEAS']]
374
+ else:
375
+ scaler = StandardScaler()
376
+ minmax_scaler = MinMaxScaler()
377
+ df[['EEAS']] = minmax_scaler.fit_transform(scaler.fit_transform(df[['EEAS']]))
378
+ normalized_df = pd.DataFrame(df, index=df.index, columns=df.columns)
379
+ return normalized_df
380
+
381
+ def bin_agreement_analysis(df):
382
+ pairwise_stats = {}
383
+ for (name1, col1), (name2, col2) in combinations(df.items(), 2):
384
+ # Calculate confusion matrix
385
+ cm = confusion_matrix(col1, col2, labels=['easy', 'medium', 'hard'])
386
+ kappa = cohen_kappa_score(col1, col2, labels=['easy', 'medium', 'hard'])
387
+ stats = calculate_match_stats(col1, col2)
388
+
389
+ pairwise_stats[f"{name1} vs {name2}"] = {
390
+ "Confusion Matrix": cm.tolist(),
391
+ "CohenS Kappa": kappa,
392
+ **stats,
393
+ }
394
+
395
+ # Display results
396
+ for pair, results in pairwise_stats.items():
397
+ print(f"\nPairwise comparison: {pair}")
398
+ print(f"Confusion Matrix:\n{np.array(results['Confusion Matrix'])}")
399
+ print(f"Cohen's Kappa: {results['CohenS Kappa']:.4f}")
400
+ print(f"Exact Match (%): {results['Exact Match (%)']:.2f}")
401
+ print(f"Near Match (%): {results['Near Match (%)']:.2f}")
402
+ print(f"Mismatch (%): {results['Mismatch (%)']:.2f}")
403
+
404
+ return df
405
+
406
+ def num_agreement_analysis(df):
407
+ stats = {}
408
+ for col1, col2 in combinations(df.columns, 2):
409
+ mean_values = (df[col1] + df[col2]) / 2
410
+ diff_values = df[col1] - df[col2]
411
+
412
+ mean_diff = np.mean(diff_values)
413
+ std_diff = np.std(diff_values)
414
+
415
+ print(f"diff mean: {np.mean(diff_values)}")
416
+ print(f"diff std: {np.std(diff_values)}")
417
+
418
+ threshold = 1.96 * std_diff
419
+ diff_condition = abs(df[col1] - df[col2]) > threshold
420
+
421
+ rows_large_diff = df[diff_condition]
422
+ print(f"Rows where the difference between {col1} and {col2} is larger than {threshold}:")
423
+ print(rows_large_diff)
424
+ print(f"Total rows: {len(df)}, Rows with large difference: {len(rows_large_diff)}, percentage: {len(rows_large_diff)/len(df) * 100}%")
425
+
426
+ values1, values2 = df[col1], df[col2]
427
+ pearson_corr, _ = pearsonr(values1, values2)
428
+ spearman_corr, _ = spearmanr(values1, values2)
429
+ mae = np.mean(np.abs(values1 - values2))
430
+ rmse = np.sqrt(np.mean((values1 - values2) ** 2))
431
+ stats[f"{col1} vs {col2}"] = {
432
+ "Pearson Correlation": pearson_corr,
433
+ "Spearman Correlation": spearman_corr,
434
+ "Mean Absolute Error (MAE)": mae,
435
+ "Root Mean Squared Error (RMSE)": rmse,
436
+ }
437
+
438
+ for pair, metrics in stats.items():
439
+ print(f"\nAgreement Analysis for {pair}:")
440
+ for metric, value in metrics.items():
441
+ print(f"{metric}: {value:.4f}")
442
+
443
+ return df
444
+
445
+ def visualize_num_results(df, output_dir="plots/"):
446
+
447
+ plt.rc('font', size=14)
448
+
449
+ # Pairwise scatter plots
450
+ pairplot = sns.pairplot(df, kind="reg", diag_kind="kde")
451
+ pairplot.savefig(os.path.join(output_dir, "pairwise_scatter_plots.pdf"))
452
+ plt.close()
453
+
454
+ # Heatmap for correlation
455
+ correlation_matrix = df.corr()
456
+ plt.figure(figsize=(8, 6))
457
+ sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", fmt=".2f")
458
+ plt.title("Heatmap of Correlations")
459
+ plt.savefig(os.path.join(output_dir, "correlation_heatmap.pdf"))
460
+ plt.close()
461
+
462
+ # Bland-Altman plots
463
+ for col1, col2 in combinations(df.columns, 2):
464
+ mean_values = (df[col1] + df[col2]) / 2
465
+ diff_values = df[col1] - df[col2]
466
+
467
+ plt.figure(figsize=(8, 6))
468
+ plt.scatter(mean_values, diff_values, alpha=0.7)
469
+ plt.axhline(np.mean(diff_values), color='red', linestyle='--', label='Mean Difference')
470
+ plt.axhline(np.mean(diff_values) + 1.96 * np.std(diff_values), color='blue', linestyle='--', label='±1.96SD')
471
+ plt.axhline(np.mean(diff_values) - 1.96 * np.std(diff_values), color='blue', linestyle='--', label=None)
472
+ plt.title(f"Bland-Altman Plot: {col1} vs {col2}", weight='bold')
473
+ plt.xlabel('Mean of Two Measurements', weight='bold')
474
+ plt.ylabel('Difference', weight='bold')
475
+ plt.legend()
476
+ filename = f"bland_altman_{col1}_vs_{col2}.pdf"
477
+ plt.savefig(os.path.join(output_dir, filename))
478
+ plt.close()
479
+
480
+ def visualize_bin_results(cvss_data, epss_data, pentestasst_data, output_dir):
481
+ # Confusion Matrix Heatmap
482
+ for (name1, data1), (name2, data2) in combinations([("CVSS", cvss_data), ("EPSS", epss_data), ("EEAS", pentestasst_data)], 2):
483
+ labels = ['easy', 'medium', 'hard']
484
+ confusion = confusion_matrix(list(data1.values()), list(data2.values()), labels=labels)
485
+
486
+ plt.figure(figsize=(6, 5))
487
+ sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
488
+ plt.title(f"Confusion Matrix: {name1} vs {name2}")
489
+ plt.xlabel(f"{name2} Predictions")
490
+ plt.ylabel(f"{name1} Predictions")
491
+ plt.savefig(f"{output_dir}/confusion_matrix_{name1}_vs_{name2}.pdf")
492
+ plt.close()
493
+
494
+ # Bar Chart of Class Distributions
495
+ data = {
496
+ "Category": ['easy', 'medium', 'hard'],
497
+ "CVSS": [list(cvss_data.values()).count(c) for c in ['easy', 'medium', 'hard']],
498
+ "EPSS": [list(epss_data.values()).count(c) for c in ['easy', 'medium', 'hard']],
499
+ "EEAS": [list(pentestasst_data.values()).count(c) for c in ['easy', 'medium', 'hard']],
500
+ }
501
+ df = pd.DataFrame(data).set_index("Category")
502
+
503
+ df.plot(kind="bar", figsize=(8, 6))
504
+ plt.title("Class Distributions Across Dictionaries")
505
+ plt.ylabel("Frequency")
506
+ plt.xlabel("Category")
507
+ plt.xticks(rotation=0)
508
+ plt.savefig(f"{output_dir}/class_distributions.pdf")
509
+ plt.close()
510
+
511
+ # Pairwise Agreement Bar Chart
512
+ pairwise_agreements = []
513
+ for (name1, data1), (name2, data2) in combinations([("CVSS", cvss_data), ("EPSS", epss_data), ("EEAS", pentestasst_data)], 2):
514
+ total = len(data1)
515
+ exact_matches = sum(1 for k in data1 if data1[k] == data2[k])
516
+ near_matches = sum(1 for k in data1 if abs(['easy', 'medium', 'hard'].index(data1[k]) -
517
+ ['easy', 'medium', 'hard'].index(data2[k])) == 1)
518
+ mismatches = total - exact_matches - near_matches
519
+
520
+ pairwise_agreements.append({
521
+ "Pair": f"{name1} vs {name2}",
522
+ "Exact Matches (%)": exact_matches / total * 100,
523
+ "Near Matches (%)": near_matches / total * 100,
524
+ "Mismatches (%)": mismatches / total * 100
525
+ })
526
+
527
+ pairwise_df = pd.DataFrame(pairwise_agreements).set_index("Pair")
528
+ pairwise_df.plot(kind="bar", stacked=True, figsize=(10, 6))
529
+ plt.title("Pairwise Agreement Percentages")
530
+ plt.ylabel("Percentage")
531
+ plt.xlabel("Pairwise Comparison")
532
+ plt.xticks(rotation=0)
533
+ plt.legend(loc="upper right")
534
+ plt.savefig(f"{output_dir}/pairwise_agreements.pdf")
535
+ plt.close()
536
+
537
+ def cve_classifier(cve, output_dir="resources/", mode = "specific"):
538
+ trending_score_path = f"{output_dir}/{cve}/Trend_Score.json"
539
+ feature_path = f"{output_dir}/{cve}/features.json"
540
+ cvemap_path = f"{output_dir}/{cve}/info/cvemap.json"
541
+ result = None
542
+ with open(feature_path, 'r') as f:
543
+ result = json.load(f)
544
+ trending_score = 0
545
+ if result['code'].get('GitHub'):
546
+ try:
547
+ with open(trending_score_path, 'r') as f:
548
+ _trending_score = json.load(f)
549
+ # print(trending_score)
550
+ trending_score = min(_trending_score.get("trend_score", 0), 50)
551
+ except:
552
+ logging.warning(f"Trending score not found for {cve}")
553
+ else:
554
+ trending_score = 999
555
+
556
+ if mode == "specific":
557
+ cvemap_json = None
558
+ with open(cvemap_path, 'r') as f:
559
+ cvemap_json = json.load(f)[0]
560
+ cvss_score = cvemap_json.get('cvss_score', 0)
561
+ epss_score = cvemap_json['epss']['epss_score']
562
+ epss_percentile = cvemap_json['epss']['epss_percentile']
563
+ elif mode == "general":
564
+ cvss_score = 0
565
+ epss_score = 0
566
+ epss_percentile = 0
567
+ cvss_category = categorize_cvss(cvss_score)
568
+ epss_category = categorize_epss(epss_percentile)
569
+
570
+ scores, max_score_repo, max_score, final_score, has_code = calculate_score(result, trending_score)
571
+
572
+ exploitability = "hard"
573
+ if final_score is not None and final_score > 0:
574
+ if final_score > 50:
575
+ exploitability = "easy"
576
+ elif final_score > 35:
577
+ exploitability = "medium"
578
+ else:
579
+ exploitability = "hard"
580
+
581
+
582
+ with open(f"{output_dir}/{cve}/classification.json", 'w') as f:
583
+ classification = {
584
+ "cvss_category": cvss_category,
585
+ "epss_category": epss_category,
586
+ "exploitability": exploitability,
587
+ "final_score": final_score,
588
+ "max_score_repo": max_score_repo,
589
+ "max_score": max_score,
590
+ "scores": scores,
591
+ "has_code": has_code
592
+ }
593
+ json.dump(classification, f, indent=4)
594
+
595
+ return cvss_score, cvss_category, epss_percentile, epss_category, final_score, exploitability, has_code
596
+
597
+ def cve_analysis(cve, output_dir="resources/"):
598
+ cve_dir = f"{output_dir}/{cve}"
599
+ info_dir = f"{output_dir}/{cve}/info"
600
+ searching_start_time = time.time()
601
+ compose(output_dir, cve)
602
+ if not os.path.exists(info_dir):
603
+ os.mkdir(info_dir)
604
+ cvemap_json = cvemap_search(cve, info_dir)
605
+ if cvemap_json is None:
606
+ return 0
607
+ cve_description = cvemap_json['cve_description']
608
+
609
+ logging.info(f"Analyzing {cve}")
610
+ doc_handler = DocHandler()
611
+
612
+ analysis_start_time = time.time()
613
+ result = doc_handler.vul_analysis(cve, output_dir, cve_description)
614
+ analysis_end_time = time.time()
615
+ searching_time = analysis_start_time - searching_start_time
616
+ analysis_time = analysis_end_time - analysis_start_time
617
+ logging.info(f"Analysis time: {analysis_time} seconds")
618
+
619
+ with open(f"{output_dir}/{cve}/features.json", 'w') as f:
620
+
621
+ json.dump(result, f, indent=4)
622
+
623
+ return searching_time, analysis_time
624
+
625
+ def general_analysis(keyword, output_dir="resources/"):
626
+ searching_start_time = time.time()
627
+ compose(output_dir, keyword, loose_mode = True)
628
+
629
+ cve = keyword
630
+
631
+ cve_description = ""
632
+
633
+ logging.info(f"Analyzing {keyword}")
634
+ doc_handler = DocHandler()
635
+
636
+ analysis_start_time = time.time()
637
+ result = doc_handler.vul_analysis(cve, output_dir, cve_description)
638
+ analysis_end_time = time.time()
639
+ searching_time = analysis_start_time - searching_start_time
640
+ analysis_time = analysis_end_time - analysis_start_time
641
+ logging.info(f"Analysis time: {analysis_time} seconds")
642
+
643
+ with open(f"{output_dir}/features.json", 'w') as f:
644
+
645
+ json.dump(result, f, indent=4)
646
+
647
+ return searching_time, analysis_time
648
+
649
+ def cve_analysis_from_epss_csv(csv_path):
650
+
651
+ df = pd.read_csv(csv_path)
652
+ # print(df.cve)
653
+ with tqdm(total=len(df.index), desc=f'Analyzing CVEs') as pbar:
654
+ for index, row in df.iterrows():
655
+ cve = row['cve']
656
+ if int(cve.split('-')[1]) < 2017 or int(cve.split('-')[1]) > 2022:
657
+ logger.info(f"Skipping {cve}")
658
+ pbar.update()
659
+ continue
660
+ try:
661
+ cve_analysis(cve)
662
+ time.sleep(1)
663
+ pbar.update()
664
+ logging.info(f"Finished {index}: {cve}")
665
+ except Exception as e:
666
+ logging.error(f"Error in {cve}: {e}")
667
+ pbar.update()
668
+ continue
669
+
670
+ def product_to_cve(product, output_dir):
671
+ """
672
+ Search for CVEs related to a product using cvemap.
673
+
674
+ Security: Uses sanitized product name and subprocess without shell=True (CWE-78 fix).
675
+ """
676
+ cve_lst = []
677
+ product_dir_name = product.lower().replace(" ", "_").replace("/", "_").replace(":", "_").replace("\\", "_").replace("(", "_").replace(")", "_").replace('"', "").replace("'", "").replace("\n", "").replace("&", "")
678
+ if os.path.exists(f"{output_dir}/{product_dir_name}"):
679
+ with open(f"{output_dir}/{product_dir_name}/cve_lst.json") as f:
680
+ cve_lst = json.load(f)
681
+ return cve_lst
682
+
683
+ os.makedirs(f"{output_dir}/{product_dir_name}", exist_ok=True)
684
+ cleaned_product = product_keyword_gen_openai(product, 3)
685
+ # print(cleaned_product)
686
+ cleaned_product.append(product_dir_name)
687
+
688
+ for p in cleaned_product:
689
+ # Security: Sanitize product name to prevent command injection
690
+ sanitized_product = _sanitize_product_name(p)
691
+ if not sanitized_product:
692
+ logging.warning(f"Skipping empty product name after sanitization: {p}")
693
+ continue
694
+
695
+ p_dir_name = p.lower().replace(" ", "_").replace("/", "_").replace(":", "_").replace("\\", "_").replace("(", "_").replace(")", "_").replace('"', "").replace("'", "").replace("\n", "").replace("&", "")
696
+ cvemap_json_path = f"{output_dir}/{product_dir_name}/cvemap_{p_dir_name}.json"
697
+
698
+ try:
699
+ # SECURE: Use subprocess without shell=True
700
+ # Run cvemap with -p flag and capture output directly
701
+ result = subprocess.run(
702
+ ["cvemap", "-p", sanitized_product, "-json"],
703
+ capture_output=True,
704
+ text=True,
705
+ timeout=50,
706
+ check=False # Handle errors ourselves
707
+ )
708
+
709
+ if result.returncode != 0:
710
+ logging.warning(f"cvemap failed for product {sanitized_product}: {result.stderr}")
711
+ continue
712
+
713
+ # Write output to file (instead of shell redirection)
714
+ with open(cvemap_json_path, 'w') as f:
715
+ f.write(result.stdout)
716
+
717
+ json_data = json.loads(result.stdout)
718
+ for cve in json_data:
719
+ cve_lst.append(cve['cve_id'])
720
+
721
+ except subprocess.TimeoutExpired:
722
+ logging.error(f"cvemap search timed out for product {sanitized_product}")
723
+ continue
724
+ except json.JSONDecodeError as e:
725
+ logging.warning(f"Invalid JSON from cvemap for product {sanitized_product}: {e}")
726
+ continue
727
+ except FileNotFoundError:
728
+ logging.error("cvemap command not found. Please install cvemap.")
729
+ break
730
+
731
+ with open(f"{output_dir}/{product_dir_name}/cve_lst.json", 'w') as f:
732
+ json.dump(cve_lst, f, indent=4)
733
+
734
+ return cve_lst
735
+
736
+ def product_keyword_gen_openai(product, num):
737
+ prompt = (
738
+ "You're an excellent system administrator. You will be given a product name, which may not be in a standard format. "
739
+ "Your task is to generate alternative product names to broaden the search scope.\n"
740
+ f"The given product name is {product}, please generate at most {num} alternative product names in the following format.\n"
741
+ "[Alternative product name 1, Alternative product name 2, Alternative product name 3, ...]"
742
+ )
743
+ try:
744
+ model_name = planning_config['model']
745
+ llm = get_model(model_name)
746
+ from langchain_core.messages import HumanMessage
747
+ response = llm.invoke([HumanMessage(content=prompt)])
748
+ product_lst_str = response.content.strip() if hasattr(response, 'content') else str(response)
749
+ except Exception as e:
750
+ import openai
751
+ client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
752
+ response = client.chat.completions.create(
753
+ model="gpt-4o-mini",
754
+ messages=[
755
+ {"role": "system", "content": prompt},
756
+ ],
757
+ )
758
+ product_lst_str = str(response.choices[0].message.content)
759
+ product_lst = product_lst_str.strip("[]").replace("'", "").split(", ")
760
+ # print(product_lst)
761
+ product_lst = [product.lower().replace(" ", "_").replace("/", "_").replace(":", "_").replace("\\", "_").replace("(", "_").replace(")", "_").replace('"', "").replace("'", "").replace("\n", "").replace("&", "") for product in product_lst]
762
+
763
+ return product_lst
764
+
765
+ def product_keyword_gen_huggingface(product, num):
766
+
767
+ messages = [
768
+ {"role": "system", "content": "You're an excellent system administrator. You will be given a product name, which may not be in a standard format. Your task is to generate alternative product names to broaden the search scope."},
769
+ {"role": "user", "content": f"The given product name is {product}, please generate {num} alternative product names in the following format: [Alternative product name 1, Alternative product name 2, Alternative product name 3, ...]"},
770
+ ]
771
+
772
+ # Run with Local LLM
773
+ product_lst = chat_completion_huggingface(
774
+ model_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
775
+ messages=messages,
776
+ max_tokens=200,
777
+ temperature=0.0
778
+ )
779
+
780
+ return product_lst
781
+
782
+ def chat_completion_huggingface(
783
+ model_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
784
+ messages=None,
785
+ max_tokens=100,
786
+ temperature=0.0,
787
+ top_p=0.0,
788
+ revision="5206a32e0bd3067aef1ce90f5528ade7d866253f" # Security: Pinned commit hash (CWE-494)
789
+ ):
790
+ """
791
+ Generates a chat completion using a specified Hugging Face model.
792
+
793
+ Args:
794
+ model_name (str): The Hugging Face model name (default is Mistral-7B-Instruct).
795
+ messages (list): A list of chat messages in OpenAI format [{"role": "user", "content": "Hello"}].
796
+ max_tokens (int): Maximum number of tokens to generate.
797
+ temperature (float): Sampling temperature (higher = more creative, lower = more deterministic).
798
+ top_p (float): Nucleus sampling parameter (controls randomness).
799
+ revision (str): Model revision to use (commit hash or branch name for reproducibility).
800
+
801
+ Returns:
802
+ str: The model's response.
803
+
804
+ Security:
805
+ Using 'revision' parameter pins the model to a specific version,
806
+ preventing supply chain attacks where a malicious model update
807
+ could be downloaded automatically.
808
+ """
809
+
810
+ # Load the model and tokenizer dynamically
811
+ # Security: Pin model revision to prevent supply chain attacks (CWE-494)
812
+ # Bandit requires string literals for revision detection
813
+ pinned_revision = revision if revision != "main" else "5206a32e0bd3067aef1ce90f5528ade7d866253f"
814
+ tokenizer = AutoTokenizer.from_pretrained(
815
+ model_name,
816
+ revision=pinned_revision, # nosec B615 - revision is pinned via function default
817
+ trust_remote_code=False # Security: Never execute remote code
818
+ )
819
+ model = AutoModelForCausalLM.from_pretrained(
820
+ model_name,
821
+ device_map="auto", # Uses GPU if available
822
+ revision=pinned_revision, # nosec B615 - revision is pinned via function default
823
+ trust_remote_code=False # Security: Never execute remote code
824
+ )
825
+
826
+ # Convert messages into a formatted prompt
827
+ prompt = ""
828
+ for message in messages:
829
+ if message["role"] == "system":
830
+ prompt += f"<s>[SYSTEM]: {message['content']}\n"
831
+ elif message["role"] == "user":
832
+ prompt += f"[USER]: {message['content']}\n"
833
+ elif message["role"] == "assistant":
834
+ prompt += f"[ASSISTANT]: {message['content']}\n"
835
+ prompt += "[ASSISTANT]:"
836
+
837
+ # Tokenize input
838
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
839
+
840
+ # Generate response
841
+ output = model.generate(
842
+ **inputs,
843
+ max_new_tokens=max_tokens,
844
+ temperature=temperature,
845
+ top_p=top_p,
846
+ do_sample=False
847
+ )
848
+
849
+ # Decode and return response
850
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
851
+ return response.split("[ASSISTANT]:")[-1].strip()
852
+
853
+ def analyze_cve_lst(cve_lst, output_dir, app_name):
854
+ cvss_results = {}
855
+ cvss_scores = {}
856
+ epss_results = {}
857
+ epss_scores = {}
858
+ pentestasst_results = {}
859
+ pentestasst_scores = {}
860
+ analysis_time_dict = {}
861
+ has_code_lst = []
862
+ total_searching_time = 0
863
+ total_analysis_time = 0
864
+
865
+ for cve in cve_lst:
866
+ logging.info(f"Analyzing {cve}")
867
+ if not os.path.exists(f"{output_dir}/{cve}/features.json"):
868
+ searching_time, analysis_time = cve_analysis(cve, output_dir)
869
+ if analysis_time == 0:
870
+ continue
871
+ total_searching_time += searching_time
872
+ total_analysis_time += analysis_time
873
+ analysis_time_dict[cve] = analysis_time
874
+ cvss_score, cvss_category, epss_score, epss_category, final_score, exploitability, has_code = cve_classifier(cve, output_dir, mode = "specific")
875
+ cvss_results[cve] = cvss_category
876
+ cvss_scores[cve] = cvss_score
877
+ epss_results[cve] = epss_category
878
+ epss_scores[cve] = epss_score
879
+ pentestasst_results[cve] = exploitability
880
+ pentestasst_scores[cve] = final_score
881
+ if has_code:
882
+ has_code_lst.append(cve)
883
+
884
+ with open(f"{output_dir}/analysis_time.json", 'w') as f:
885
+ json.dump(analysis_time_dict, f, indent=4)
886
+
887
+ # bin_df = create_df(cvss_results, epss_results, pentestasst_results)
888
+ num_df = create_df(cvss_scores, epss_scores, pentestasst_scores)
889
+ normalized_num_df = normalize_data(num_df)
890
+ normalized_has_code_num_df = normalized_num_df.loc[has_code_lst]
891
+ num_df.to_csv(f'{output_dir}/score_results.csv')
892
+ normalized_num_df.to_csv(f'{output_dir}/normalized_score_results.csv')
893
+ normalized_has_code_num_df.to_csv(f'{output_dir}/normalized_has_code_score_results.csv')
894
+
895
+
896
+ if app_name:
897
+ logging.info(f"Searching for general exp info for {app_name}")
898
+ print("Start searching general exp info...")
899
+ app_exp_path = app_name + "_exp"
900
+ if not os.path.exists(f"{output_dir}/{app_exp_path}/features.json"):
901
+ keyword = app_name + " exploit"
902
+ searching_time, analysis_time = general_analysis(keyword, os.path.join(output_dir, app_exp_path))
903
+ total_searching_time += searching_time
904
+ total_analysis_time += analysis_time
905
+ cve_classifier(app_exp_path, output_dir, mode = "general")
906
+
907
+ avg_searching_time = total_searching_time / ((len(cve_lst) +1) if app_name else (len(cve_lst)))
908
+ avg_analysis_time = total_analysis_time / ((len(cve_lst) +1) if app_name else (len(cve_lst)))
909
+ logging.info(f"Total searching time: {total_searching_time} seconds, average time: {avg_searching_time} seconds")
910
+ logging.info(f"Total analysis time: {total_analysis_time} seconds, average time: {avg_analysis_time} seconds")
911
+
912
+ return total_searching_time, total_analysis_time
913
+
914
+
915
+ def get_exp_info(cve_lst = [], output_dir = "", app_name = ""):
916
+ logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
917
+ datefmt='%Y-%m-%d %H:%M:%S',
918
+ filename='cve_info.log',
919
+ level=logging.INFO)
920
+
921
+ track_tokens = True
922
+ token_counter = None
923
+ model_name = planning_config['model']
924
+ if track_tokens:
925
+ try:
926
+ token_counter = TokenCountingHandler(tokenizer=tiktoken.encoding_for_model(model_name_for_token).encode)
927
+ Settings.callback_manager = CallbackManager([token_counter])
928
+ except Exception as e:
929
+ logger.warning(f"Tokenizer for model {model_name_for_token} not found, token counting disabled: {e}")
930
+ Settings.callback_manager = CallbackManager([])
931
+
932
+ ### Step 1: Select the LLM
933
+
934
+ ## Initialize LLM from config via model manager (provider-agnostic)
935
+ try:
936
+ model_name = planning_config['model']
937
+ except Exception:
938
+ model_name = "openai"
939
+
940
+ llm = get_model(model_name)
941
+ if llm is None:
942
+ # Keep previous behavior when no API key/config: abort gracefully
943
+ print("LLM_API_KEY not set")
944
+ return
945
+
946
+ try:
947
+ Settings.llm = LangChainLLM(llm=llm)
948
+ except Exception:
949
+ # Fallback to OpenAI config if available to preserve behavior
950
+ cve_config = config['cve']
951
+ print(f"Model: {cve_config['model']}")
952
+ Settings.llm = OpenAI(temperature=cve_config['temperature'], model=cve_config['model'])
953
+
954
+ ### Step 2: Input the CVEs to analyze
955
+
956
+ if not cve_lst:
957
+ print("CVE to be searched not set! ")
958
+ return
959
+
960
+
961
+ ### Step 3: Analyze the CVEs
962
+ if not output_dir:
963
+ print("Directory to store not set! will use default settings.")
964
+ output_dir = "resources_cve_lst_try"
965
+
966
+ os.makedirs(output_dir, exist_ok=True)
967
+
968
+ ## Scenario 2: Given a list of CVEs, analyze the CVEs
969
+
970
+ total_searching_time, total_analysis_time = analyze_cve_lst(cve_lst, output_dir, app_name)
971
+
972
+
973
+ if track_tokens:
974
+ print(
975
+ "Embedding Tokens: ",
976
+ token_counter.total_embedding_token_count,
977
+ "\n",
978
+ "LLM Prompt Tokens: ",
979
+ token_counter.prompt_llm_token_count,
980
+ "\n",
981
+ "LLM Completion Tokens: ",
982
+ token_counter.completion_llm_token_count,
983
+ "\n",
984
+ "Total LLM Token Count: ",
985
+ token_counter.total_llm_token_count,
986
+ "\n",
987
+ "Average LLM Token Count: ",
988
+ token_counter.total_llm_token_count / len(cve_lst),
989
+ "\n",
990
+ )
991
+
992
+ logging.info(
993
+ "Embedding Tokens: "
994
+ + str(token_counter.total_embedding_token_count)
995
+ + "\n"
996
+ + "LLM Prompt Tokens: "
997
+ + str(token_counter.prompt_llm_token_count)
998
+ + "\n"
999
+ + "LLM Completion Tokens: "
1000
+ + str(token_counter.completion_llm_token_count)
1001
+ + "\n"
1002
+ + "Total LLM Token Count: "
1003
+ + str(token_counter.total_llm_token_count)
1004
+ + "\n"
1005
+ + "Average LLM Token Count: "
1006
+ + str(token_counter.total_llm_token_count / len(cve_lst))
1007
+ + "\n"
1008
+ )
1009
+
1010
+ token_counter.reset_counts()
1011
+
1012
+ return total_searching_time, total_analysis_time
1013
+
1014
+
1015
+ def main():
1016
+ logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
1017
+ datefmt='%Y-%m-%d %H:%M:%S',
1018
+ filename='cve_info.log',
1019
+ level=logging.INFO)
1020
+
1021
+ track_tokens = False
1022
+ token_counter = None
1023
+ if track_tokens:
1024
+ token_counter = TokenCountingHandler(tokenizer=tiktoken.encoding_for_model("gpt-4o-mini").encode)
1025
+ Settings.callback_manager = CallbackManager([token_counter])
1026
+
1027
+ ### Step 1: Select the LLM
1028
+
1029
+ ## OpenAI API
1030
+ api_key = os.environ.get("OPENAI_API_KEY")
1031
+ if api_key is None or api_key == "":
1032
+ print("OPENAI_API_KEY not set")
1033
+ return
1034
+ elif "f4Ph3uIoqGLC9" in api_key:
1035
+ print("Model: o3-mini")
1036
+ Settings.llm = OpenAI(temperature=0, model="o3-mini")
1037
+ elif "klO8n1OFxLWBoPIeDycM" in api_key:
1038
+ print("Model: gpt-4o-mini")
1039
+ Settings.llm = OpenAI(temperature=0, model="gpt-4o-mini")
1040
+ else:
1041
+ print("Model: gpt-4o-mini (default)")
1042
+ Settings.llm = OpenAI(temperature=0, model="gpt-4o-mini")
1043
+ # Settings.llm = OpenAI(temperature=0, model="o3-mini")
1044
+
1045
+ ## Option 2: Read CVE list from file
1046
+ cve_lst = []
1047
+ with open("../../target_lst/vulhub_test/xstream.txt", 'r') as f:
1048
+ cve_lst = f.read().splitlines()
1049
+
1050
+ ## Option 3: Hard-coded CVE list
1051
+ # cve_lst = ['CVE-XXX-XXXX', 'CVE-XXX-XXXX', 'CVE-XXX-XXXX', 'CVE-XXX-XXXX', 'CVE-XXX-XXXX']
1052
+
1053
+
1054
+ ### Step 3: Analyze the CVEs
1055
+ output_dir = "resources_test/xstream"
1056
+ os.makedirs(output_dir, exist_ok=True)
1057
+
1058
+ ## Scenario 1: Given a list of CVEs, analyze the CVEs
1059
+
1060
+ # analyze_cve_lst(cve_lst, output_dir)
1061
+ analyze_cve_lst(cve_lst, output_dir, "xstream") # assume app's name is xstream
1062
+
1063
+
1064
+ if track_tokens:
1065
+ print(
1066
+ "Embedding Tokens: ",
1067
+ token_counter.total_embedding_token_count,
1068
+ "\n",
1069
+ "LLM Prompt Tokens: ",
1070
+ token_counter.prompt_llm_token_count,
1071
+ "\n",
1072
+ "LLM Completion Tokens: ",
1073
+ token_counter.completion_llm_token_count,
1074
+ "\n",
1075
+ "Total LLM Token Count: ",
1076
+ token_counter.total_llm_token_count,
1077
+ "\n",
1078
+ "Average LLM Token Count: ",
1079
+ token_counter.total_llm_token_count / len(cve_lst),
1080
+ "\n",
1081
+ )
1082
+
1083
+ logging.info(
1084
+ "Embedding Tokens: "
1085
+ + str(token_counter.total_embedding_token_count)
1086
+ + "\n"
1087
+ + "LLM Prompt Tokens: "
1088
+ + str(token_counter.prompt_llm_token_count)
1089
+ + "\n"
1090
+ + "LLM Completion Tokens: "
1091
+ + str(token_counter.completion_llm_token_count)
1092
+ + "\n"
1093
+ + "Total LLM Token Count: "
1094
+ + str(token_counter.total_llm_token_count)
1095
+ + "\n"
1096
+ + "Average LLM Token Count: "
1097
+ + str(token_counter.total_llm_token_count / len(cve_lst))
1098
+ + "\n"
1099
+ )
1100
+
1101
+ token_counter.reset_counts()
1102
+
1103
+
1104
+ # epss_csv_path = 'epss_scores-2024-08-06.csv'
1105
+ # cve_analysis_from_epss_csv(epss_csv_path)
1106
+
1107
+
1108
+
1109
+ if __name__ == "__main__":
1110
+ main()
1111
+