aiptx 2.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aipt_v2/__init__.py +110 -0
- aipt_v2/__main__.py +24 -0
- aipt_v2/agents/AIPTxAgent/__init__.py +10 -0
- aipt_v2/agents/AIPTxAgent/aiptx_agent.py +211 -0
- aipt_v2/agents/__init__.py +46 -0
- aipt_v2/agents/base.py +520 -0
- aipt_v2/agents/exploit_agent.py +688 -0
- aipt_v2/agents/ptt.py +406 -0
- aipt_v2/agents/state.py +168 -0
- aipt_v2/app.py +957 -0
- aipt_v2/browser/__init__.py +31 -0
- aipt_v2/browser/automation.py +458 -0
- aipt_v2/browser/crawler.py +453 -0
- aipt_v2/cli.py +2933 -0
- aipt_v2/compliance/__init__.py +71 -0
- aipt_v2/compliance/compliance_report.py +449 -0
- aipt_v2/compliance/framework_mapper.py +424 -0
- aipt_v2/compliance/nist_mapping.py +345 -0
- aipt_v2/compliance/owasp_mapping.py +330 -0
- aipt_v2/compliance/pci_mapping.py +297 -0
- aipt_v2/config.py +341 -0
- aipt_v2/core/__init__.py +43 -0
- aipt_v2/core/agent.py +630 -0
- aipt_v2/core/llm.py +395 -0
- aipt_v2/core/memory.py +305 -0
- aipt_v2/core/ptt.py +329 -0
- aipt_v2/database/__init__.py +14 -0
- aipt_v2/database/models.py +232 -0
- aipt_v2/database/repository.py +384 -0
- aipt_v2/docker/__init__.py +23 -0
- aipt_v2/docker/builder.py +260 -0
- aipt_v2/docker/manager.py +222 -0
- aipt_v2/docker/sandbox.py +371 -0
- aipt_v2/evasion/__init__.py +58 -0
- aipt_v2/evasion/request_obfuscator.py +272 -0
- aipt_v2/evasion/tls_fingerprint.py +285 -0
- aipt_v2/evasion/ua_rotator.py +301 -0
- aipt_v2/evasion/waf_bypass.py +439 -0
- aipt_v2/execution/__init__.py +23 -0
- aipt_v2/execution/executor.py +302 -0
- aipt_v2/execution/parser.py +544 -0
- aipt_v2/execution/terminal.py +337 -0
- aipt_v2/health.py +437 -0
- aipt_v2/intelligence/__init__.py +194 -0
- aipt_v2/intelligence/adaptation.py +474 -0
- aipt_v2/intelligence/auth.py +520 -0
- aipt_v2/intelligence/chaining.py +775 -0
- aipt_v2/intelligence/correlation.py +536 -0
- aipt_v2/intelligence/cve_aipt.py +334 -0
- aipt_v2/intelligence/cve_info.py +1111 -0
- aipt_v2/intelligence/knowledge_graph.py +590 -0
- aipt_v2/intelligence/learning.py +626 -0
- aipt_v2/intelligence/llm_analyzer.py +502 -0
- aipt_v2/intelligence/llm_tool_selector.py +518 -0
- aipt_v2/intelligence/payload_generator.py +562 -0
- aipt_v2/intelligence/rag.py +239 -0
- aipt_v2/intelligence/scope.py +442 -0
- aipt_v2/intelligence/searchers/__init__.py +5 -0
- aipt_v2/intelligence/searchers/exploitdb_searcher.py +523 -0
- aipt_v2/intelligence/searchers/github_searcher.py +467 -0
- aipt_v2/intelligence/searchers/google_searcher.py +281 -0
- aipt_v2/intelligence/tools.json +443 -0
- aipt_v2/intelligence/triage.py +670 -0
- aipt_v2/interactive_shell.py +559 -0
- aipt_v2/interface/__init__.py +5 -0
- aipt_v2/interface/cli.py +230 -0
- aipt_v2/interface/main.py +501 -0
- aipt_v2/interface/tui.py +1276 -0
- aipt_v2/interface/utils.py +583 -0
- aipt_v2/llm/__init__.py +39 -0
- aipt_v2/llm/config.py +26 -0
- aipt_v2/llm/llm.py +514 -0
- aipt_v2/llm/memory.py +214 -0
- aipt_v2/llm/request_queue.py +89 -0
- aipt_v2/llm/utils.py +89 -0
- aipt_v2/local_tool_installer.py +1467 -0
- aipt_v2/models/__init__.py +15 -0
- aipt_v2/models/findings.py +295 -0
- aipt_v2/models/phase_result.py +224 -0
- aipt_v2/models/scan_config.py +207 -0
- aipt_v2/monitoring/grafana/dashboards/aipt-dashboard.json +355 -0
- aipt_v2/monitoring/grafana/dashboards/default.yml +17 -0
- aipt_v2/monitoring/grafana/datasources/prometheus.yml +17 -0
- aipt_v2/monitoring/prometheus.yml +60 -0
- aipt_v2/orchestration/__init__.py +52 -0
- aipt_v2/orchestration/pipeline.py +398 -0
- aipt_v2/orchestration/progress.py +300 -0
- aipt_v2/orchestration/scheduler.py +296 -0
- aipt_v2/orchestrator.py +2427 -0
- aipt_v2/payloads/__init__.py +27 -0
- aipt_v2/payloads/cmdi.py +150 -0
- aipt_v2/payloads/sqli.py +263 -0
- aipt_v2/payloads/ssrf.py +204 -0
- aipt_v2/payloads/templates.py +222 -0
- aipt_v2/payloads/traversal.py +166 -0
- aipt_v2/payloads/xss.py +204 -0
- aipt_v2/prompts/__init__.py +60 -0
- aipt_v2/proxy/__init__.py +29 -0
- aipt_v2/proxy/history.py +352 -0
- aipt_v2/proxy/interceptor.py +452 -0
- aipt_v2/recon/__init__.py +44 -0
- aipt_v2/recon/dns.py +241 -0
- aipt_v2/recon/osint.py +367 -0
- aipt_v2/recon/subdomain.py +372 -0
- aipt_v2/recon/tech_detect.py +311 -0
- aipt_v2/reports/__init__.py +17 -0
- aipt_v2/reports/generator.py +313 -0
- aipt_v2/reports/html_report.py +378 -0
- aipt_v2/runtime/__init__.py +53 -0
- aipt_v2/runtime/base.py +30 -0
- aipt_v2/runtime/docker.py +401 -0
- aipt_v2/runtime/local.py +346 -0
- aipt_v2/runtime/tool_server.py +205 -0
- aipt_v2/runtime/vps.py +830 -0
- aipt_v2/scanners/__init__.py +28 -0
- aipt_v2/scanners/base.py +273 -0
- aipt_v2/scanners/nikto.py +244 -0
- aipt_v2/scanners/nmap.py +402 -0
- aipt_v2/scanners/nuclei.py +273 -0
- aipt_v2/scanners/web.py +454 -0
- aipt_v2/scripts/security_audit.py +366 -0
- aipt_v2/setup_wizard.py +941 -0
- aipt_v2/skills/__init__.py +80 -0
- aipt_v2/skills/agents/__init__.py +14 -0
- aipt_v2/skills/agents/api_tester.py +706 -0
- aipt_v2/skills/agents/base.py +477 -0
- aipt_v2/skills/agents/code_review.py +459 -0
- aipt_v2/skills/agents/security_agent.py +336 -0
- aipt_v2/skills/agents/web_pentest.py +818 -0
- aipt_v2/skills/prompts/__init__.py +647 -0
- aipt_v2/system_detector.py +539 -0
- aipt_v2/telemetry/__init__.py +7 -0
- aipt_v2/telemetry/tracer.py +347 -0
- aipt_v2/terminal/__init__.py +28 -0
- aipt_v2/terminal/executor.py +400 -0
- aipt_v2/terminal/sandbox.py +350 -0
- aipt_v2/tools/__init__.py +44 -0
- aipt_v2/tools/active_directory/__init__.py +78 -0
- aipt_v2/tools/active_directory/ad_config.py +238 -0
- aipt_v2/tools/active_directory/bloodhound_wrapper.py +447 -0
- aipt_v2/tools/active_directory/kerberos_attacks.py +430 -0
- aipt_v2/tools/active_directory/ldap_enum.py +533 -0
- aipt_v2/tools/active_directory/smb_attacks.py +505 -0
- aipt_v2/tools/agents_graph/__init__.py +19 -0
- aipt_v2/tools/agents_graph/agents_graph_actions.py +69 -0
- aipt_v2/tools/api_security/__init__.py +76 -0
- aipt_v2/tools/api_security/api_discovery.py +608 -0
- aipt_v2/tools/api_security/graphql_scanner.py +622 -0
- aipt_v2/tools/api_security/jwt_analyzer.py +577 -0
- aipt_v2/tools/api_security/openapi_fuzzer.py +761 -0
- aipt_v2/tools/browser/__init__.py +5 -0
- aipt_v2/tools/browser/browser_actions.py +238 -0
- aipt_v2/tools/browser/browser_instance.py +535 -0
- aipt_v2/tools/browser/tab_manager.py +344 -0
- aipt_v2/tools/cloud/__init__.py +70 -0
- aipt_v2/tools/cloud/cloud_config.py +273 -0
- aipt_v2/tools/cloud/cloud_scanner.py +639 -0
- aipt_v2/tools/cloud/prowler_tool.py +571 -0
- aipt_v2/tools/cloud/scoutsuite_tool.py +359 -0
- aipt_v2/tools/executor.py +307 -0
- aipt_v2/tools/parser.py +408 -0
- aipt_v2/tools/proxy/__init__.py +5 -0
- aipt_v2/tools/proxy/proxy_actions.py +103 -0
- aipt_v2/tools/proxy/proxy_manager.py +789 -0
- aipt_v2/tools/registry.py +196 -0
- aipt_v2/tools/scanners/__init__.py +343 -0
- aipt_v2/tools/scanners/acunetix_tool.py +712 -0
- aipt_v2/tools/scanners/burp_tool.py +631 -0
- aipt_v2/tools/scanners/config.py +156 -0
- aipt_v2/tools/scanners/nessus_tool.py +588 -0
- aipt_v2/tools/scanners/zap_tool.py +612 -0
- aipt_v2/tools/terminal/__init__.py +5 -0
- aipt_v2/tools/terminal/terminal_actions.py +37 -0
- aipt_v2/tools/terminal/terminal_manager.py +153 -0
- aipt_v2/tools/terminal/terminal_session.py +449 -0
- aipt_v2/tools/tool_processing.py +108 -0
- aipt_v2/utils/__init__.py +17 -0
- aipt_v2/utils/logging.py +202 -0
- aipt_v2/utils/model_manager.py +187 -0
- aipt_v2/utils/searchers/__init__.py +269 -0
- aipt_v2/verify_install.py +793 -0
- aiptx-2.0.7.dist-info/METADATA +345 -0
- aiptx-2.0.7.dist-info/RECORD +187 -0
- aiptx-2.0.7.dist-info/WHEEL +5 -0
- aiptx-2.0.7.dist-info/entry_points.txt +7 -0
- aiptx-2.0.7.dist-info/licenses/LICENSE +21 -0
- aiptx-2.0.7.dist-info/top_level.txt +1 -0
aipt_v2/app.py
ADDED
|
@@ -0,0 +1,957 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AIPT REST API - FastAPI application
|
|
3
|
+
Provides REST endpoints for AIPT operations.
|
|
4
|
+
|
|
5
|
+
Endpoints:
|
|
6
|
+
- /projects - Project management
|
|
7
|
+
- /sessions - Session management
|
|
8
|
+
- /findings - Finding management
|
|
9
|
+
- /scan - Run scans
|
|
10
|
+
- /tools - List available tools
|
|
11
|
+
- /auth - Authentication endpoints
|
|
12
|
+
|
|
13
|
+
Security:
|
|
14
|
+
- JWT-based authentication (optional but recommended)
|
|
15
|
+
- CORS restricted to configured origins
|
|
16
|
+
- Rate limiting per client IP
|
|
17
|
+
- Input validation on all endpoints
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
import re
|
|
22
|
+
import secrets
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Optional, List
|
|
25
|
+
from datetime import datetime, timedelta, timezone
|
|
26
|
+
from urllib.parse import urlparse
|
|
27
|
+
|
|
28
|
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends
|
|
29
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
30
|
+
from fastapi.responses import JSONResponse
|
|
31
|
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
32
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
33
|
+
from starlette.responses import Response
|
|
34
|
+
from pydantic import BaseModel, Field, field_validator
|
|
35
|
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
36
|
+
from slowapi.util import get_remote_address
|
|
37
|
+
from slowapi.errors import RateLimitExceeded
|
|
38
|
+
|
|
39
|
+
# JWT imports with fallback
|
|
40
|
+
try:
|
|
41
|
+
import jwt
|
|
42
|
+
JWT_AVAILABLE = True
|
|
43
|
+
except ImportError:
|
|
44
|
+
JWT_AVAILABLE = False
|
|
45
|
+
|
|
46
|
+
# Import AIPT v2 components
|
|
47
|
+
from aipt_v2.database.repository import Repository
|
|
48
|
+
from aipt_v2.intelligence import ToolRAG, CVEIntelligence
|
|
49
|
+
from aipt_v2.tools.tool_processing import process_tool_invocations
|
|
50
|
+
from aipt_v2.utils.logging import logger
|
|
51
|
+
from .health import health_router, record_scan, record_tool_invocation
|
|
52
|
+
|
|
53
|
+
# Rate limiter instance
|
|
54
|
+
limiter = Limiter(key_func=get_remote_address)
|
|
55
|
+
|
|
56
|
+
# Security constants
|
|
57
|
+
ALLOWED_SCAN_PROTOCOLS = {"http", "https"}
|
|
58
|
+
MAX_TARGET_LENGTH = 2048
|
|
59
|
+
CVE_PATTERN = re.compile(r"^CVE-\d{4}-\d{4,}$", re.IGNORECASE)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# ============== Pydantic Models ==============
|
|
63
|
+
|
|
64
|
+
class ProjectCreate(BaseModel):
|
|
65
|
+
name: str = Field(..., min_length=1, max_length=255, description="Project name")
|
|
66
|
+
target: str = Field(..., min_length=1, max_length=MAX_TARGET_LENGTH, description="Target URL or domain")
|
|
67
|
+
description: Optional[str] = Field(None, max_length=2000, description="Project description")
|
|
68
|
+
scope: Optional[List[str]] = Field(None, max_length=100, description="In-scope domains/IPs")
|
|
69
|
+
|
|
70
|
+
@field_validator("target")
|
|
71
|
+
@classmethod
|
|
72
|
+
def validate_target(cls, v: str) -> str:
|
|
73
|
+
"""Validate target is a valid URL or domain."""
|
|
74
|
+
v = v.strip()
|
|
75
|
+
if not v:
|
|
76
|
+
raise ValueError("Target cannot be empty")
|
|
77
|
+
|
|
78
|
+
# If it looks like a URL, validate it
|
|
79
|
+
if v.startswith(("http://", "https://")):
|
|
80
|
+
parsed = urlparse(v)
|
|
81
|
+
if parsed.scheme not in ALLOWED_SCAN_PROTOCOLS:
|
|
82
|
+
raise ValueError(f"Protocol must be http or https, got: {parsed.scheme}")
|
|
83
|
+
if not parsed.netloc:
|
|
84
|
+
raise ValueError("Invalid URL: missing hostname")
|
|
85
|
+
else:
|
|
86
|
+
# Validate as domain - basic check for dangerous characters
|
|
87
|
+
if any(c in v for c in [";", "&", "|", "$", "`", "\n", "\r"]):
|
|
88
|
+
raise ValueError("Target contains invalid characters")
|
|
89
|
+
return v
|
|
90
|
+
|
|
91
|
+
@field_validator("name")
|
|
92
|
+
@classmethod
|
|
93
|
+
def validate_name(cls, v: str) -> str:
|
|
94
|
+
"""Validate project name."""
|
|
95
|
+
v = v.strip()
|
|
96
|
+
if not v:
|
|
97
|
+
raise ValueError("Name cannot be empty")
|
|
98
|
+
# Prevent path traversal in name
|
|
99
|
+
if ".." in v or "/" in v or "\\" in v:
|
|
100
|
+
raise ValueError("Name contains invalid characters")
|
|
101
|
+
return v
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ProjectResponse(BaseModel):
|
|
105
|
+
id: int
|
|
106
|
+
name: str
|
|
107
|
+
target: str
|
|
108
|
+
description: Optional[str]
|
|
109
|
+
scope: List[str]
|
|
110
|
+
status: str
|
|
111
|
+
created_at: datetime
|
|
112
|
+
|
|
113
|
+
class Config:
|
|
114
|
+
from_attributes = True
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class SessionCreate(BaseModel):
|
|
118
|
+
name: Optional[str] = None
|
|
119
|
+
phase: str = "recon"
|
|
120
|
+
max_iterations: int = 100
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class SessionResponse(BaseModel):
|
|
124
|
+
id: int
|
|
125
|
+
project_id: int
|
|
126
|
+
name: Optional[str]
|
|
127
|
+
phase: str
|
|
128
|
+
status: str
|
|
129
|
+
iteration: int
|
|
130
|
+
started_at: datetime
|
|
131
|
+
|
|
132
|
+
class Config:
|
|
133
|
+
from_attributes = True
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class FindingResponse(BaseModel):
|
|
137
|
+
id: int
|
|
138
|
+
type: str
|
|
139
|
+
value: str
|
|
140
|
+
description: Optional[str]
|
|
141
|
+
severity: str
|
|
142
|
+
phase: Optional[str]
|
|
143
|
+
tool: Optional[str]
|
|
144
|
+
verified: bool
|
|
145
|
+
discovered_at: datetime
|
|
146
|
+
|
|
147
|
+
class Config:
|
|
148
|
+
from_attributes = True
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class ScanRequest(BaseModel):
|
|
152
|
+
target: str = Field(..., min_length=1, max_length=MAX_TARGET_LENGTH, description="Target URL or domain")
|
|
153
|
+
tools: Optional[List[str]] = Field(None, max_length=20, description="Tools to run")
|
|
154
|
+
phase: str = Field(default="recon", description="Scan phase")
|
|
155
|
+
|
|
156
|
+
@field_validator("target")
|
|
157
|
+
@classmethod
|
|
158
|
+
def validate_target(cls, v: str) -> str:
|
|
159
|
+
"""Validate and sanitize target."""
|
|
160
|
+
v = v.strip()
|
|
161
|
+
if not v:
|
|
162
|
+
raise ValueError("Target cannot be empty")
|
|
163
|
+
|
|
164
|
+
# Validate URL format
|
|
165
|
+
if v.startswith(("http://", "https://")):
|
|
166
|
+
parsed = urlparse(v)
|
|
167
|
+
if parsed.scheme not in ALLOWED_SCAN_PROTOCOLS:
|
|
168
|
+
raise ValueError(f"Protocol must be http or https")
|
|
169
|
+
if not parsed.netloc:
|
|
170
|
+
raise ValueError("Invalid URL: missing hostname")
|
|
171
|
+
else:
|
|
172
|
+
# Check for command injection characters
|
|
173
|
+
dangerous_chars = [";", "&", "|", "$", "`", "\n", "\r", "'", '"', "(", ")", "{", "}", "<", ">"]
|
|
174
|
+
if any(c in v for c in dangerous_chars):
|
|
175
|
+
raise ValueError("Target contains invalid characters")
|
|
176
|
+
return v
|
|
177
|
+
|
|
178
|
+
@field_validator("phase")
|
|
179
|
+
@classmethod
|
|
180
|
+
def validate_phase(cls, v: str) -> str:
|
|
181
|
+
"""Validate phase is allowed."""
|
|
182
|
+
allowed_phases = {"recon", "scan", "exploit", "report"}
|
|
183
|
+
v = v.lower().strip()
|
|
184
|
+
if v not in allowed_phases:
|
|
185
|
+
raise ValueError(f"Phase must be one of: {', '.join(allowed_phases)}")
|
|
186
|
+
return v
|
|
187
|
+
|
|
188
|
+
@field_validator("tools")
|
|
189
|
+
@classmethod
|
|
190
|
+
def validate_tools(cls, v: Optional[List[str]]) -> Optional[List[str]]:
|
|
191
|
+
"""Validate tool names."""
|
|
192
|
+
if v is None:
|
|
193
|
+
return v
|
|
194
|
+
# Sanitize tool names - only allow alphanumeric and underscore/hyphen
|
|
195
|
+
clean_tools = []
|
|
196
|
+
for tool in v:
|
|
197
|
+
tool = tool.strip().lower()
|
|
198
|
+
if not re.match(r"^[a-z0-9_-]+$", tool):
|
|
199
|
+
raise ValueError(f"Invalid tool name: {tool}")
|
|
200
|
+
clean_tools.append(tool)
|
|
201
|
+
return clean_tools
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class ScanResponse(BaseModel):
|
|
205
|
+
status: str
|
|
206
|
+
message: str
|
|
207
|
+
findings: List[dict] = []
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class ToolInfo(BaseModel):
|
|
211
|
+
name: str
|
|
212
|
+
description: str
|
|
213
|
+
phase: str
|
|
214
|
+
keywords: List[str]
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class CVERequest(BaseModel):
|
|
218
|
+
cve_id: str = Field(..., min_length=9, max_length=20, description="CVE identifier")
|
|
219
|
+
|
|
220
|
+
@field_validator("cve_id")
|
|
221
|
+
@classmethod
|
|
222
|
+
def validate_cve_id(cls, v: str) -> str:
|
|
223
|
+
"""Validate CVE ID format."""
|
|
224
|
+
v = v.strip().upper()
|
|
225
|
+
if not CVE_PATTERN.match(v):
|
|
226
|
+
raise ValueError("Invalid CVE ID format. Expected: CVE-YYYY-NNNNN")
|
|
227
|
+
return v
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class CVEResponse(BaseModel):
|
|
231
|
+
cve_id: str
|
|
232
|
+
cvss: float
|
|
233
|
+
epss: float
|
|
234
|
+
priority_score: float
|
|
235
|
+
has_poc: bool
|
|
236
|
+
description: str
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# ============== Authentication Models ==============
|
|
240
|
+
|
|
241
|
+
class TokenRequest(BaseModel):
|
|
242
|
+
"""Request model for token generation."""
|
|
243
|
+
api_key: str = Field(..., min_length=32, max_length=128, description="API key")
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class TokenResponse(BaseModel):
|
|
247
|
+
"""Response model for token generation."""
|
|
248
|
+
access_token: str
|
|
249
|
+
token_type: str = "bearer"
|
|
250
|
+
expires_in: int
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class UserInfo(BaseModel):
|
|
254
|
+
"""User information extracted from JWT."""
|
|
255
|
+
sub: str # Subject (user identifier)
|
|
256
|
+
exp: datetime
|
|
257
|
+
iat: datetime
|
|
258
|
+
is_authenticated: bool = True
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
# ============== JWT Security ==============
|
|
262
|
+
|
|
263
|
+
# Security scheme for OpenAPI docs
|
|
264
|
+
security_scheme = HTTPBearer(auto_error=False)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class JWTAuth:
|
|
268
|
+
"""
|
|
269
|
+
JWT Authentication handler.
|
|
270
|
+
|
|
271
|
+
Security features:
|
|
272
|
+
- HS256 algorithm with secure secret
|
|
273
|
+
- Token expiration (default 24h)
|
|
274
|
+
- API key validation before token issuance
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
def __init__(self, secret_key: Optional[str] = None, expires_hours: int = 24):
|
|
278
|
+
"""
|
|
279
|
+
Initialize JWT auth handler.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
secret_key: Secret for signing tokens. If not provided, uses
|
|
283
|
+
AIPT_JWT_SECRET env var or generates a random one.
|
|
284
|
+
expires_hours: Token expiration time in hours.
|
|
285
|
+
"""
|
|
286
|
+
self.secret_key = secret_key or os.getenv("AIPT_JWT_SECRET")
|
|
287
|
+
if not self.secret_key:
|
|
288
|
+
# Generate a secure random secret if not configured
|
|
289
|
+
# Note: This means tokens won't survive server restarts
|
|
290
|
+
self.secret_key = secrets.token_urlsafe(32)
|
|
291
|
+
logger.warning(
|
|
292
|
+
"JWT secret not configured (AIPT_JWT_SECRET). "
|
|
293
|
+
"Using random secret - tokens will be invalid after restart."
|
|
294
|
+
)
|
|
295
|
+
self.expires_hours = expires_hours
|
|
296
|
+
self.algorithm = "HS256"
|
|
297
|
+
|
|
298
|
+
# Valid API keys - in production, load from secure storage
|
|
299
|
+
self._valid_api_keys = set()
|
|
300
|
+
env_keys = os.getenv("AIPT_API_KEYS", "")
|
|
301
|
+
if env_keys:
|
|
302
|
+
self._valid_api_keys = {k.strip() for k in env_keys.split(",") if k.strip()}
|
|
303
|
+
|
|
304
|
+
def validate_api_key(self, api_key: str) -> bool:
|
|
305
|
+
"""Validate an API key."""
|
|
306
|
+
if not self._valid_api_keys:
|
|
307
|
+
# If no keys configured, accept any non-empty key (dev mode)
|
|
308
|
+
logger.warning("No API keys configured (AIPT_API_KEYS). Accepting any key.")
|
|
309
|
+
return bool(api_key and len(api_key) >= 32)
|
|
310
|
+
return api_key in self._valid_api_keys
|
|
311
|
+
|
|
312
|
+
def create_token(self, subject: str) -> tuple[str, int]:
|
|
313
|
+
"""
|
|
314
|
+
Create a new JWT token.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
subject: User/client identifier
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
Tuple of (token, expires_in_seconds)
|
|
321
|
+
"""
|
|
322
|
+
if not JWT_AVAILABLE:
|
|
323
|
+
raise HTTPException(
|
|
324
|
+
status_code=503,
|
|
325
|
+
detail="JWT support not available. Install PyJWT: pip install PyJWT"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
now = datetime.now(timezone.utc)
|
|
329
|
+
expires = now + timedelta(hours=self.expires_hours)
|
|
330
|
+
|
|
331
|
+
payload = {
|
|
332
|
+
"sub": subject,
|
|
333
|
+
"iat": now,
|
|
334
|
+
"exp": expires,
|
|
335
|
+
"iss": "aipt-api",
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
|
339
|
+
expires_in = int((expires - now).total_seconds())
|
|
340
|
+
|
|
341
|
+
return token, expires_in
|
|
342
|
+
|
|
343
|
+
def verify_token(self, token: str) -> Optional[UserInfo]:
|
|
344
|
+
"""
|
|
345
|
+
Verify and decode a JWT token.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
token: JWT token string
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
UserInfo if valid, None otherwise
|
|
352
|
+
"""
|
|
353
|
+
if not JWT_AVAILABLE:
|
|
354
|
+
return None
|
|
355
|
+
|
|
356
|
+
try:
|
|
357
|
+
payload = jwt.decode(
|
|
358
|
+
token,
|
|
359
|
+
self.secret_key,
|
|
360
|
+
algorithms=[self.algorithm],
|
|
361
|
+
options={"require": ["exp", "iat", "sub"]}
|
|
362
|
+
)
|
|
363
|
+
return UserInfo(
|
|
364
|
+
sub=payload["sub"],
|
|
365
|
+
exp=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
|
366
|
+
iat=datetime.fromtimestamp(payload["iat"], tz=timezone.utc),
|
|
367
|
+
)
|
|
368
|
+
except jwt.ExpiredSignatureError:
|
|
369
|
+
logger.debug("Token expired")
|
|
370
|
+
return None
|
|
371
|
+
except jwt.InvalidTokenError as e:
|
|
372
|
+
logger.debug(f"Invalid token: {e}")
|
|
373
|
+
return None
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
# Global JWT auth instance
|
|
377
|
+
jwt_auth: Optional[JWTAuth] = None
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
# ============== Security Headers Middleware ==============
|
|
381
|
+
|
|
382
|
+
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
383
|
+
"""
|
|
384
|
+
Middleware to add security headers to all responses.
|
|
385
|
+
|
|
386
|
+
Implements OWASP recommended security headers:
|
|
387
|
+
- X-Content-Type-Options: Prevents MIME type sniffing
|
|
388
|
+
- X-Frame-Options: Prevents clickjacking
|
|
389
|
+
- X-XSS-Protection: Legacy XSS protection
|
|
390
|
+
- Strict-Transport-Security: Enforces HTTPS
|
|
391
|
+
- Content-Security-Policy: Restricts resource loading
|
|
392
|
+
- Referrer-Policy: Controls referrer information
|
|
393
|
+
- Permissions-Policy: Restricts browser features
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
async def dispatch(self, request: Request, call_next) -> Response:
|
|
397
|
+
response = await call_next(request)
|
|
398
|
+
|
|
399
|
+
# Security headers
|
|
400
|
+
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
401
|
+
response.headers["X-Frame-Options"] = "DENY"
|
|
402
|
+
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
403
|
+
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
404
|
+
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
|
405
|
+
|
|
406
|
+
# HSTS - Only enable if using HTTPS in production
|
|
407
|
+
# response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
|
408
|
+
|
|
409
|
+
# CSP - Restrictive policy for API responses
|
|
410
|
+
response.headers["Content-Security-Policy"] = (
|
|
411
|
+
"default-src 'none'; "
|
|
412
|
+
"frame-ancestors 'none'; "
|
|
413
|
+
"base-uri 'none'; "
|
|
414
|
+
"form-action 'none'"
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
return response
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
# ============== WAF Middleware ==============
|
|
421
|
+
|
|
422
|
+
class WAFMiddleware(BaseHTTPMiddleware):
|
|
423
|
+
"""
|
|
424
|
+
Simple Web Application Firewall middleware.
|
|
425
|
+
|
|
426
|
+
Provides basic protection against common attacks:
|
|
427
|
+
- SQL Injection patterns
|
|
428
|
+
- XSS patterns
|
|
429
|
+
- Path traversal attempts
|
|
430
|
+
- Command injection patterns
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
# Suspicious patterns that may indicate attacks
|
|
434
|
+
SUSPICIOUS_PATTERNS = [
|
|
435
|
+
# SQL Injection
|
|
436
|
+
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|ALTER)\b.*\b(FROM|INTO|TABLE|DATABASE)\b)",
|
|
437
|
+
r"(--|#|/\*|\*/|;)",
|
|
438
|
+
r"(\bOR\b\s+\d+\s*=\s*\d+)",
|
|
439
|
+
# XSS
|
|
440
|
+
r"(<script|javascript:|on\w+\s*=)",
|
|
441
|
+
# Path traversal
|
|
442
|
+
r"(\.\./|\.\.\\|%2e%2e)",
|
|
443
|
+
# Command injection
|
|
444
|
+
r"(;|\||&&|\$\(|`)",
|
|
445
|
+
]
|
|
446
|
+
|
|
447
|
+
def __init__(self, app, enabled: bool = True):
|
|
448
|
+
super().__init__(app)
|
|
449
|
+
self.enabled = enabled
|
|
450
|
+
self._patterns = None
|
|
451
|
+
|
|
452
|
+
@property
|
|
453
|
+
def patterns(self):
|
|
454
|
+
if self._patterns is None:
|
|
455
|
+
self._patterns = [
|
|
456
|
+
re.compile(p, re.IGNORECASE)
|
|
457
|
+
for p in self.SUSPICIOUS_PATTERNS
|
|
458
|
+
]
|
|
459
|
+
return self._patterns
|
|
460
|
+
|
|
461
|
+
def _is_suspicious(self, value: str) -> bool:
|
|
462
|
+
"""Check if a value contains suspicious patterns."""
|
|
463
|
+
if not value:
|
|
464
|
+
return False
|
|
465
|
+
for pattern in self.patterns:
|
|
466
|
+
if pattern.search(value):
|
|
467
|
+
return True
|
|
468
|
+
return False
|
|
469
|
+
|
|
470
|
+
async def dispatch(self, request: Request, call_next) -> Response:
|
|
471
|
+
if not self.enabled:
|
|
472
|
+
return await call_next(request)
|
|
473
|
+
|
|
474
|
+
# Check query parameters
|
|
475
|
+
for key, value in request.query_params.items():
|
|
476
|
+
if self._is_suspicious(key) or self._is_suspicious(value):
|
|
477
|
+
logger.warning(
|
|
478
|
+
"WAF blocked suspicious request",
|
|
479
|
+
path=request.url.path,
|
|
480
|
+
param=key,
|
|
481
|
+
client=request.client.host if request.client else "unknown",
|
|
482
|
+
)
|
|
483
|
+
return JSONResponse(
|
|
484
|
+
status_code=403,
|
|
485
|
+
content={"detail": "Request blocked by security policy"}
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
# Check path
|
|
489
|
+
if self._is_suspicious(request.url.path):
|
|
490
|
+
logger.warning(
|
|
491
|
+
"WAF blocked suspicious path",
|
|
492
|
+
path=request.url.path,
|
|
493
|
+
client=request.client.host if request.client else "unknown",
|
|
494
|
+
)
|
|
495
|
+
return JSONResponse(
|
|
496
|
+
status_code=403,
|
|
497
|
+
content={"detail": "Request blocked by security policy"}
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
return await call_next(request)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
async def get_current_user(
|
|
504
|
+
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_scheme)
|
|
505
|
+
) -> Optional[UserInfo]:
|
|
506
|
+
"""
|
|
507
|
+
Dependency to get current authenticated user.
|
|
508
|
+
|
|
509
|
+
Returns UserInfo if authenticated, None otherwise.
|
|
510
|
+
Authentication is optional - endpoints can work without auth.
|
|
511
|
+
"""
|
|
512
|
+
if not credentials or not jwt_auth:
|
|
513
|
+
return None
|
|
514
|
+
|
|
515
|
+
return jwt_auth.verify_token(credentials.credentials)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
async def require_auth(
|
|
519
|
+
user: Optional[UserInfo] = Depends(get_current_user)
|
|
520
|
+
) -> UserInfo:
|
|
521
|
+
"""
|
|
522
|
+
Dependency to require authentication.
|
|
523
|
+
|
|
524
|
+
Raises HTTPException 401 if not authenticated.
|
|
525
|
+
"""
|
|
526
|
+
if not user:
|
|
527
|
+
raise HTTPException(
|
|
528
|
+
status_code=401,
|
|
529
|
+
detail="Authentication required",
|
|
530
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
531
|
+
)
|
|
532
|
+
return user
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
# ============== FastAPI App ==============
|
|
536
|
+
|
|
537
|
+
# Default CORS origins - restricted for security
|
|
538
|
+
DEFAULT_CORS_ORIGINS = [
|
|
539
|
+
"http://localhost:3000",
|
|
540
|
+
"http://localhost:8080",
|
|
541
|
+
"http://127.0.0.1:3000",
|
|
542
|
+
"http://127.0.0.1:8080",
|
|
543
|
+
]
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def create_app(
|
|
547
|
+
db_url: str = "sqlite:///~/.aipt/aipt.db",
|
|
548
|
+
title: str = "AIPT API",
|
|
549
|
+
cors_origins: Optional[List[str]] = None,
|
|
550
|
+
rate_limit: str = "100/minute",
|
|
551
|
+
) -> FastAPI:
|
|
552
|
+
"""
|
|
553
|
+
Create FastAPI application with security middleware.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
db_url: Database connection URL
|
|
557
|
+
title: API title
|
|
558
|
+
cors_origins: Allowed CORS origins (defaults to localhost only)
|
|
559
|
+
rate_limit: Rate limit string (e.g., "100/minute")
|
|
560
|
+
"""
|
|
561
|
+
# Get CORS origins from env or use defaults
|
|
562
|
+
if cors_origins is None:
|
|
563
|
+
env_origins = os.getenv("AIPT_CORS_ORIGINS", "")
|
|
564
|
+
if env_origins:
|
|
565
|
+
cors_origins = [o.strip() for o in env_origins.split(",") if o.strip()]
|
|
566
|
+
else:
|
|
567
|
+
cors_origins = DEFAULT_CORS_ORIGINS
|
|
568
|
+
|
|
569
|
+
app = FastAPI(
|
|
570
|
+
title=title,
|
|
571
|
+
description="AI-Powered Penetration Testing Framework API",
|
|
572
|
+
version="0.2.0",
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# Rate limiting
|
|
576
|
+
app.state.limiter = limiter
|
|
577
|
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
578
|
+
|
|
579
|
+
# CORS - Restricted to configured origins only
|
|
580
|
+
app.add_middleware(
|
|
581
|
+
CORSMiddleware,
|
|
582
|
+
allow_origins=cors_origins,
|
|
583
|
+
allow_credentials=True,
|
|
584
|
+
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
|
585
|
+
allow_headers=["Authorization", "Content-Type", "X-Requested-With"],
|
|
586
|
+
expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset"],
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
# Security Headers Middleware - OWASP recommended headers
|
|
590
|
+
app.add_middleware(SecurityHeadersMiddleware)
|
|
591
|
+
|
|
592
|
+
# WAF Middleware - Basic attack protection
|
|
593
|
+
# Can be disabled via AIPT_WAF_ENABLED=false for testing
|
|
594
|
+
waf_enabled = os.getenv("AIPT_WAF_ENABLED", "true").lower() == "true"
|
|
595
|
+
app.add_middleware(WAFMiddleware, enabled=waf_enabled)
|
|
596
|
+
|
|
597
|
+
# Log security configuration
|
|
598
|
+
logger.info(
|
|
599
|
+
f"API security configured: cors_origins={cors_origins}, "
|
|
600
|
+
f"rate_limit={rate_limit}, waf_enabled={waf_enabled}, security_headers=enabled"
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
# Initialize components
|
|
604
|
+
repo = Repository(db_url)
|
|
605
|
+
tools_path = Path(__file__).parent / "intelligence" / "tools.json"
|
|
606
|
+
tools_rag = ToolRAG(tools_path=str(tools_path), lazy_load=True)
|
|
607
|
+
cve_intel = CVEIntelligence()
|
|
608
|
+
|
|
609
|
+
# Initialize JWT authentication
|
|
610
|
+
global jwt_auth
|
|
611
|
+
jwt_auth = JWTAuth()
|
|
612
|
+
app.state.jwt_auth = jwt_auth
|
|
613
|
+
|
|
614
|
+
# Store in app state
|
|
615
|
+
app.state.repo = repo
|
|
616
|
+
app.state.tools_rag = tools_rag
|
|
617
|
+
app.state.cve_intel = cve_intel
|
|
618
|
+
app.state.rate_limit = rate_limit
|
|
619
|
+
|
|
620
|
+
# ============== Authentication ==============
|
|
621
|
+
|
|
622
|
+
@app.post("/auth/token", response_model=TokenResponse, tags=["Authentication"])
|
|
623
|
+
@limiter.limit("10/minute") # Strict rate limit on token requests
|
|
624
|
+
async def get_token(request: Request, token_request: TokenRequest):
|
|
625
|
+
"""
|
|
626
|
+
Get JWT access token using API key.
|
|
627
|
+
|
|
628
|
+
Security:
|
|
629
|
+
- Rate limited to 10 requests per minute
|
|
630
|
+
- API key must be at least 32 characters
|
|
631
|
+
- Token expires in 24 hours
|
|
632
|
+
|
|
633
|
+
Usage:
|
|
634
|
+
1. Set AIPT_API_KEYS env var with comma-separated valid API keys
|
|
635
|
+
2. POST to /auth/token with your API key
|
|
636
|
+
3. Use returned token in Authorization header: Bearer <token>
|
|
637
|
+
"""
|
|
638
|
+
if not jwt_auth.validate_api_key(token_request.api_key):
|
|
639
|
+
raise HTTPException(
|
|
640
|
+
status_code=401,
|
|
641
|
+
detail="Invalid API key"
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
# Create token with API key hash as subject (don't expose full key)
|
|
645
|
+
import hashlib
|
|
646
|
+
subject = hashlib.sha256(token_request.api_key.encode()).hexdigest()[:16]
|
|
647
|
+
token, expires_in = jwt_auth.create_token(subject)
|
|
648
|
+
|
|
649
|
+
return TokenResponse(
|
|
650
|
+
access_token=token,
|
|
651
|
+
expires_in=expires_in
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
@app.get("/auth/me", tags=["Authentication"])
|
|
655
|
+
async def get_current_user_info(user: UserInfo = Depends(require_auth)):
|
|
656
|
+
"""
|
|
657
|
+
Get current authenticated user information.
|
|
658
|
+
|
|
659
|
+
Requires valid JWT token in Authorization header.
|
|
660
|
+
"""
|
|
661
|
+
return {
|
|
662
|
+
"sub": user.sub,
|
|
663
|
+
"authenticated": True,
|
|
664
|
+
"expires": user.exp.isoformat(),
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
@app.get("/auth/status", tags=["Authentication"])
|
|
668
|
+
async def auth_status(user: Optional[UserInfo] = Depends(get_current_user)):
|
|
669
|
+
"""
|
|
670
|
+
Check authentication status.
|
|
671
|
+
|
|
672
|
+
Returns whether JWT auth is available and if current request is authenticated.
|
|
673
|
+
"""
|
|
674
|
+
return {
|
|
675
|
+
"jwt_available": JWT_AVAILABLE,
|
|
676
|
+
"authenticated": user is not None,
|
|
677
|
+
"user": user.sub if user else None,
|
|
678
|
+
}
|
|
679
|
+
|
|
680
|
+
# ============== Health & Metrics ==============
|
|
681
|
+
# Include comprehensive health check router with:
|
|
682
|
+
# - /health - Basic liveness probe
|
|
683
|
+
# - /health/live - Kubernetes liveness probe
|
|
684
|
+
# - /health/ready - Readiness probe with dependency checks
|
|
685
|
+
# - /metrics - Prometheus-compatible metrics
|
|
686
|
+
# - /health/info - Service information
|
|
687
|
+
app.include_router(health_router)
|
|
688
|
+
|
|
689
|
+
# ============== Projects ==============
|
|
690
|
+
|
|
691
|
+
@app.post("/projects", response_model=ProjectResponse)
|
|
692
|
+
async def create_project(project: ProjectCreate):
|
|
693
|
+
"""Create a new project"""
|
|
694
|
+
db_project = repo.create_project(
|
|
695
|
+
name=project.name,
|
|
696
|
+
target=project.target,
|
|
697
|
+
description=project.description,
|
|
698
|
+
scope=project.scope,
|
|
699
|
+
)
|
|
700
|
+
return db_project
|
|
701
|
+
|
|
702
|
+
@app.get("/projects", response_model=List[ProjectResponse])
|
|
703
|
+
async def list_projects(status: Optional[str] = None):
|
|
704
|
+
"""List all projects"""
|
|
705
|
+
return repo.list_projects(status=status)
|
|
706
|
+
|
|
707
|
+
@app.get("/projects/{project_id}", response_model=ProjectResponse)
|
|
708
|
+
async def get_project(project_id: int):
|
|
709
|
+
"""Get project by ID"""
|
|
710
|
+
project = repo.get_project(project_id)
|
|
711
|
+
if not project:
|
|
712
|
+
raise HTTPException(status_code=404, detail="Project not found")
|
|
713
|
+
return project
|
|
714
|
+
|
|
715
|
+
@app.delete("/projects/{project_id}")
|
|
716
|
+
async def delete_project(project_id: int):
|
|
717
|
+
"""Delete a project"""
|
|
718
|
+
if not repo.delete_project(project_id):
|
|
719
|
+
raise HTTPException(status_code=404, detail="Project not found")
|
|
720
|
+
return {"status": "deleted"}
|
|
721
|
+
|
|
722
|
+
# ============== Sessions ==============
|
|
723
|
+
|
|
724
|
+
@app.post("/projects/{project_id}/sessions", response_model=SessionResponse)
|
|
725
|
+
async def create_session(project_id: int, session: SessionCreate):
|
|
726
|
+
"""Create a new session"""
|
|
727
|
+
project = repo.get_project(project_id)
|
|
728
|
+
if not project:
|
|
729
|
+
raise HTTPException(status_code=404, detail="Project not found")
|
|
730
|
+
|
|
731
|
+
from aipt_v2.database.models import PhaseType
|
|
732
|
+
db_session = repo.create_session(
|
|
733
|
+
project_id=project_id,
|
|
734
|
+
name=session.name,
|
|
735
|
+
phase=PhaseType(session.phase),
|
|
736
|
+
max_iterations=session.max_iterations,
|
|
737
|
+
)
|
|
738
|
+
return db_session
|
|
739
|
+
|
|
740
|
+
@app.get("/projects/{project_id}/sessions", response_model=List[SessionResponse])
|
|
741
|
+
async def list_sessions(project_id: int):
|
|
742
|
+
"""List all sessions for a project"""
|
|
743
|
+
return repo.list_sessions(project_id)
|
|
744
|
+
|
|
745
|
+
# ============== Findings ==============
|
|
746
|
+
|
|
747
|
+
@app.get("/projects/{project_id}/findings", response_model=List[FindingResponse])
|
|
748
|
+
async def get_findings(
|
|
749
|
+
project_id: int,
|
|
750
|
+
type: Optional[str] = None,
|
|
751
|
+
severity: Optional[str] = None,
|
|
752
|
+
phase: Optional[str] = None,
|
|
753
|
+
):
|
|
754
|
+
"""Get findings for a project"""
|
|
755
|
+
return repo.get_findings(
|
|
756
|
+
project_id=project_id,
|
|
757
|
+
type=type,
|
|
758
|
+
severity=severity,
|
|
759
|
+
phase=phase,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
@app.get("/projects/{project_id}/findings/summary")
|
|
763
|
+
async def get_findings_summary(project_id: int):
|
|
764
|
+
"""Get findings summary"""
|
|
765
|
+
return repo.get_findings_summary(project_id)
|
|
766
|
+
|
|
767
|
+
@app.post("/findings/{finding_id}/verify")
|
|
768
|
+
async def verify_finding(finding_id: int, notes: Optional[str] = None):
|
|
769
|
+
"""Mark finding as verified"""
|
|
770
|
+
repo.verify_finding(finding_id, verified=True, notes=notes)
|
|
771
|
+
return {"status": "verified"}
|
|
772
|
+
|
|
773
|
+
@app.post("/findings/{finding_id}/false-positive")
|
|
774
|
+
async def mark_false_positive(finding_id: int, notes: Optional[str] = None):
|
|
775
|
+
"""Mark finding as false positive"""
|
|
776
|
+
repo.mark_false_positive(finding_id, notes=notes)
|
|
777
|
+
return {"status": "marked as false positive"}
|
|
778
|
+
|
|
779
|
+
# ============== Scanning ==============
|
|
780
|
+
|
|
781
|
+
@app.post("/scan/quick", response_model=ScanResponse)
|
|
782
|
+
@limiter.limit("10/minute") # Limit scan requests
|
|
783
|
+
async def quick_scan(request: Request, scan_request: ScanRequest):
|
|
784
|
+
"""Run a quick scan on target (rate limited: 10/minute)"""
|
|
785
|
+
import asyncio
|
|
786
|
+
import shutil
|
|
787
|
+
from aipt_v2.tools.parser import OutputParser
|
|
788
|
+
|
|
789
|
+
findings = []
|
|
790
|
+
|
|
791
|
+
# Check if nmap is available
|
|
792
|
+
nmap_path = shutil.which("nmap")
|
|
793
|
+
if nmap_path:
|
|
794
|
+
try:
|
|
795
|
+
proc = await asyncio.create_subprocess_shell(
|
|
796
|
+
f"nmap -F {scan_request.target}",
|
|
797
|
+
stdout=asyncio.subprocess.PIPE,
|
|
798
|
+
stderr=asyncio.subprocess.PIPE,
|
|
799
|
+
)
|
|
800
|
+
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=60)
|
|
801
|
+
output = stdout.decode() if stdout else ""
|
|
802
|
+
|
|
803
|
+
if proc.returncode == 0:
|
|
804
|
+
parser = OutputParser()
|
|
805
|
+
parsed = parser.parse(output, "nmap")
|
|
806
|
+
findings.extend([{
|
|
807
|
+
"type": f.type,
|
|
808
|
+
"value": f.value,
|
|
809
|
+
"description": f.description,
|
|
810
|
+
"severity": f.severity,
|
|
811
|
+
} for f in parsed])
|
|
812
|
+
except asyncio.TimeoutError:
|
|
813
|
+
logger.warning("Tool execution timed out", tool="nmap", target=scan_request.target)
|
|
814
|
+
except Exception as e:
|
|
815
|
+
logger.error("Tool execution failed", tool="nmap", error=str(e))
|
|
816
|
+
|
|
817
|
+
return ScanResponse(
|
|
818
|
+
status="completed",
|
|
819
|
+
message=f"Quick scan completed on {scan_request.target}",
|
|
820
|
+
findings=findings,
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
@app.post("/scan/tool")
|
|
824
|
+
@limiter.limit("5/minute") # Stricter limit for tool execution
|
|
825
|
+
async def run_tool(request: Request, tool_name: str, target: str, options: Optional[str] = None):
|
|
826
|
+
"""Run a specific tool (rate limited: 5/minute)"""
|
|
827
|
+
import asyncio
|
|
828
|
+
import time
|
|
829
|
+
|
|
830
|
+
# Validate tool_name - only alphanumeric and underscore/hyphen
|
|
831
|
+
if not re.match(r"^[a-zA-Z0-9_-]+$", tool_name):
|
|
832
|
+
raise HTTPException(status_code=400, detail="Invalid tool name")
|
|
833
|
+
|
|
834
|
+
# Validate target - check for command injection
|
|
835
|
+
dangerous_chars = [";", "&", "|", "$", "`", "\n", "\r", "'", '"']
|
|
836
|
+
if any(c in target for c in dangerous_chars):
|
|
837
|
+
raise HTTPException(status_code=400, detail="Invalid target: contains dangerous characters")
|
|
838
|
+
|
|
839
|
+
# Validate options if provided
|
|
840
|
+
if options:
|
|
841
|
+
if any(c in options for c in [";", "&", "|", "$", "`"]):
|
|
842
|
+
raise HTTPException(status_code=400, detail="Invalid options: contains dangerous characters")
|
|
843
|
+
|
|
844
|
+
# Get tool from RAG
|
|
845
|
+
tool = tools_rag.get_tool_by_name(tool_name)
|
|
846
|
+
if not tool:
|
|
847
|
+
raise HTTPException(status_code=404, detail=f"Tool {tool_name} not found")
|
|
848
|
+
|
|
849
|
+
# Build command
|
|
850
|
+
cmd = tool.get("cmd", "").replace("{target}", target)
|
|
851
|
+
if options:
|
|
852
|
+
cmd = f"{cmd} {options}"
|
|
853
|
+
|
|
854
|
+
# Execute
|
|
855
|
+
start_time = time.time()
|
|
856
|
+
try:
|
|
857
|
+
proc = await asyncio.create_subprocess_shell(
|
|
858
|
+
cmd,
|
|
859
|
+
stdout=asyncio.subprocess.PIPE,
|
|
860
|
+
stderr=asyncio.subprocess.PIPE,
|
|
861
|
+
)
|
|
862
|
+
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=300)
|
|
863
|
+
output = (stdout.decode() if stdout else "") + (stderr.decode() if stderr else "")
|
|
864
|
+
return_code = proc.returncode
|
|
865
|
+
except asyncio.TimeoutError:
|
|
866
|
+
output = "Command timed out after 300 seconds"
|
|
867
|
+
return_code = -1
|
|
868
|
+
except Exception as e:
|
|
869
|
+
output = f"Command failed: {str(e)}"
|
|
870
|
+
return_code = -1
|
|
871
|
+
|
|
872
|
+
duration = time.time() - start_time
|
|
873
|
+
|
|
874
|
+
return {
|
|
875
|
+
"tool": tool_name,
|
|
876
|
+
"target": target,
|
|
877
|
+
"command": cmd,
|
|
878
|
+
"return_code": return_code,
|
|
879
|
+
"output": output[:10000], # Truncate
|
|
880
|
+
"duration": duration,
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
# ============== Tools ==============
|
|
884
|
+
|
|
885
|
+
@app.get("/tools", response_model=List[ToolInfo])
|
|
886
|
+
async def list_tools(phase: Optional[str] = None):
|
|
887
|
+
"""List available tools"""
|
|
888
|
+
tools = tools_rag.tools
|
|
889
|
+
if phase:
|
|
890
|
+
tools = [t for t in tools if t.get("phase") == phase]
|
|
891
|
+
|
|
892
|
+
return [
|
|
893
|
+
ToolInfo(
|
|
894
|
+
name=t.get("name", ""),
|
|
895
|
+
description=t.get("description", ""),
|
|
896
|
+
phase=t.get("phase", ""),
|
|
897
|
+
keywords=t.get("keywords", []),
|
|
898
|
+
)
|
|
899
|
+
for t in tools
|
|
900
|
+
]
|
|
901
|
+
|
|
902
|
+
@app.get("/tools/{tool_name}")
|
|
903
|
+
async def get_tool(tool_name: str):
|
|
904
|
+
"""Get tool details"""
|
|
905
|
+
tool = tools_rag.get_tool_by_name(tool_name)
|
|
906
|
+
if not tool:
|
|
907
|
+
raise HTTPException(status_code=404, detail="Tool not found")
|
|
908
|
+
return tool
|
|
909
|
+
|
|
910
|
+
@app.get("/tools/search/{query}")
|
|
911
|
+
async def search_tools(query: str, top_k: int = 5):
|
|
912
|
+
"""Search for tools by query"""
|
|
913
|
+
results = tools_rag.search(query, top_k=top_k)
|
|
914
|
+
return results
|
|
915
|
+
|
|
916
|
+
# ============== CVE ==============
|
|
917
|
+
|
|
918
|
+
@app.post("/cve/lookup", response_model=CVEResponse)
|
|
919
|
+
async def lookup_cve(request: CVERequest):
|
|
920
|
+
"""Lookup CVE information"""
|
|
921
|
+
info = cve_intel.lookup(request.cve_id)
|
|
922
|
+
return CVEResponse(
|
|
923
|
+
cve_id=info.cve_id,
|
|
924
|
+
cvss=info.cvss,
|
|
925
|
+
epss=info.epss,
|
|
926
|
+
priority_score=info.priority_score,
|
|
927
|
+
has_poc=info.has_poc,
|
|
928
|
+
description=info.description[:500],
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
@app.post("/cve/prioritize")
|
|
932
|
+
async def prioritize_cves(cve_ids: List[str]):
|
|
933
|
+
"""Prioritize CVEs by exploitability"""
|
|
934
|
+
results = cve_intel.prioritize(cve_ids)
|
|
935
|
+
return [
|
|
936
|
+
{
|
|
937
|
+
"cve_id": r.cve_id,
|
|
938
|
+
"cvss": r.cvss,
|
|
939
|
+
"epss": r.epss,
|
|
940
|
+
"priority_score": r.priority_score,
|
|
941
|
+
"has_poc": r.has_poc,
|
|
942
|
+
}
|
|
943
|
+
for r in results
|
|
944
|
+
]
|
|
945
|
+
|
|
946
|
+
return app
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
# Default app instance
|
|
950
|
+
app = create_app()
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
if __name__ == "__main__":
|
|
954
|
+
import uvicorn
|
|
955
|
+
# Security: Default to localhost to prevent accidental network exposure
|
|
956
|
+
# Use --host 0.0.0.0 explicitly for production behind reverse proxy
|
|
957
|
+
uvicorn.run(app, host="127.0.0.1", port=8000)
|