aiptx 2.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. aipt_v2/__init__.py +110 -0
  2. aipt_v2/__main__.py +24 -0
  3. aipt_v2/agents/AIPTxAgent/__init__.py +10 -0
  4. aipt_v2/agents/AIPTxAgent/aiptx_agent.py +211 -0
  5. aipt_v2/agents/__init__.py +46 -0
  6. aipt_v2/agents/base.py +520 -0
  7. aipt_v2/agents/exploit_agent.py +688 -0
  8. aipt_v2/agents/ptt.py +406 -0
  9. aipt_v2/agents/state.py +168 -0
  10. aipt_v2/app.py +957 -0
  11. aipt_v2/browser/__init__.py +31 -0
  12. aipt_v2/browser/automation.py +458 -0
  13. aipt_v2/browser/crawler.py +453 -0
  14. aipt_v2/cli.py +2933 -0
  15. aipt_v2/compliance/__init__.py +71 -0
  16. aipt_v2/compliance/compliance_report.py +449 -0
  17. aipt_v2/compliance/framework_mapper.py +424 -0
  18. aipt_v2/compliance/nist_mapping.py +345 -0
  19. aipt_v2/compliance/owasp_mapping.py +330 -0
  20. aipt_v2/compliance/pci_mapping.py +297 -0
  21. aipt_v2/config.py +341 -0
  22. aipt_v2/core/__init__.py +43 -0
  23. aipt_v2/core/agent.py +630 -0
  24. aipt_v2/core/llm.py +395 -0
  25. aipt_v2/core/memory.py +305 -0
  26. aipt_v2/core/ptt.py +329 -0
  27. aipt_v2/database/__init__.py +14 -0
  28. aipt_v2/database/models.py +232 -0
  29. aipt_v2/database/repository.py +384 -0
  30. aipt_v2/docker/__init__.py +23 -0
  31. aipt_v2/docker/builder.py +260 -0
  32. aipt_v2/docker/manager.py +222 -0
  33. aipt_v2/docker/sandbox.py +371 -0
  34. aipt_v2/evasion/__init__.py +58 -0
  35. aipt_v2/evasion/request_obfuscator.py +272 -0
  36. aipt_v2/evasion/tls_fingerprint.py +285 -0
  37. aipt_v2/evasion/ua_rotator.py +301 -0
  38. aipt_v2/evasion/waf_bypass.py +439 -0
  39. aipt_v2/execution/__init__.py +23 -0
  40. aipt_v2/execution/executor.py +302 -0
  41. aipt_v2/execution/parser.py +544 -0
  42. aipt_v2/execution/terminal.py +337 -0
  43. aipt_v2/health.py +437 -0
  44. aipt_v2/intelligence/__init__.py +194 -0
  45. aipt_v2/intelligence/adaptation.py +474 -0
  46. aipt_v2/intelligence/auth.py +520 -0
  47. aipt_v2/intelligence/chaining.py +775 -0
  48. aipt_v2/intelligence/correlation.py +536 -0
  49. aipt_v2/intelligence/cve_aipt.py +334 -0
  50. aipt_v2/intelligence/cve_info.py +1111 -0
  51. aipt_v2/intelligence/knowledge_graph.py +590 -0
  52. aipt_v2/intelligence/learning.py +626 -0
  53. aipt_v2/intelligence/llm_analyzer.py +502 -0
  54. aipt_v2/intelligence/llm_tool_selector.py +518 -0
  55. aipt_v2/intelligence/payload_generator.py +562 -0
  56. aipt_v2/intelligence/rag.py +239 -0
  57. aipt_v2/intelligence/scope.py +442 -0
  58. aipt_v2/intelligence/searchers/__init__.py +5 -0
  59. aipt_v2/intelligence/searchers/exploitdb_searcher.py +523 -0
  60. aipt_v2/intelligence/searchers/github_searcher.py +467 -0
  61. aipt_v2/intelligence/searchers/google_searcher.py +281 -0
  62. aipt_v2/intelligence/tools.json +443 -0
  63. aipt_v2/intelligence/triage.py +670 -0
  64. aipt_v2/interactive_shell.py +559 -0
  65. aipt_v2/interface/__init__.py +5 -0
  66. aipt_v2/interface/cli.py +230 -0
  67. aipt_v2/interface/main.py +501 -0
  68. aipt_v2/interface/tui.py +1276 -0
  69. aipt_v2/interface/utils.py +583 -0
  70. aipt_v2/llm/__init__.py +39 -0
  71. aipt_v2/llm/config.py +26 -0
  72. aipt_v2/llm/llm.py +514 -0
  73. aipt_v2/llm/memory.py +214 -0
  74. aipt_v2/llm/request_queue.py +89 -0
  75. aipt_v2/llm/utils.py +89 -0
  76. aipt_v2/local_tool_installer.py +1467 -0
  77. aipt_v2/models/__init__.py +15 -0
  78. aipt_v2/models/findings.py +295 -0
  79. aipt_v2/models/phase_result.py +224 -0
  80. aipt_v2/models/scan_config.py +207 -0
  81. aipt_v2/monitoring/grafana/dashboards/aipt-dashboard.json +355 -0
  82. aipt_v2/monitoring/grafana/dashboards/default.yml +17 -0
  83. aipt_v2/monitoring/grafana/datasources/prometheus.yml +17 -0
  84. aipt_v2/monitoring/prometheus.yml +60 -0
  85. aipt_v2/orchestration/__init__.py +52 -0
  86. aipt_v2/orchestration/pipeline.py +398 -0
  87. aipt_v2/orchestration/progress.py +300 -0
  88. aipt_v2/orchestration/scheduler.py +296 -0
  89. aipt_v2/orchestrator.py +2427 -0
  90. aipt_v2/payloads/__init__.py +27 -0
  91. aipt_v2/payloads/cmdi.py +150 -0
  92. aipt_v2/payloads/sqli.py +263 -0
  93. aipt_v2/payloads/ssrf.py +204 -0
  94. aipt_v2/payloads/templates.py +222 -0
  95. aipt_v2/payloads/traversal.py +166 -0
  96. aipt_v2/payloads/xss.py +204 -0
  97. aipt_v2/prompts/__init__.py +60 -0
  98. aipt_v2/proxy/__init__.py +29 -0
  99. aipt_v2/proxy/history.py +352 -0
  100. aipt_v2/proxy/interceptor.py +452 -0
  101. aipt_v2/recon/__init__.py +44 -0
  102. aipt_v2/recon/dns.py +241 -0
  103. aipt_v2/recon/osint.py +367 -0
  104. aipt_v2/recon/subdomain.py +372 -0
  105. aipt_v2/recon/tech_detect.py +311 -0
  106. aipt_v2/reports/__init__.py +17 -0
  107. aipt_v2/reports/generator.py +313 -0
  108. aipt_v2/reports/html_report.py +378 -0
  109. aipt_v2/runtime/__init__.py +53 -0
  110. aipt_v2/runtime/base.py +30 -0
  111. aipt_v2/runtime/docker.py +401 -0
  112. aipt_v2/runtime/local.py +346 -0
  113. aipt_v2/runtime/tool_server.py +205 -0
  114. aipt_v2/runtime/vps.py +830 -0
  115. aipt_v2/scanners/__init__.py +28 -0
  116. aipt_v2/scanners/base.py +273 -0
  117. aipt_v2/scanners/nikto.py +244 -0
  118. aipt_v2/scanners/nmap.py +402 -0
  119. aipt_v2/scanners/nuclei.py +273 -0
  120. aipt_v2/scanners/web.py +454 -0
  121. aipt_v2/scripts/security_audit.py +366 -0
  122. aipt_v2/setup_wizard.py +941 -0
  123. aipt_v2/skills/__init__.py +80 -0
  124. aipt_v2/skills/agents/__init__.py +14 -0
  125. aipt_v2/skills/agents/api_tester.py +706 -0
  126. aipt_v2/skills/agents/base.py +477 -0
  127. aipt_v2/skills/agents/code_review.py +459 -0
  128. aipt_v2/skills/agents/security_agent.py +336 -0
  129. aipt_v2/skills/agents/web_pentest.py +818 -0
  130. aipt_v2/skills/prompts/__init__.py +647 -0
  131. aipt_v2/system_detector.py +539 -0
  132. aipt_v2/telemetry/__init__.py +7 -0
  133. aipt_v2/telemetry/tracer.py +347 -0
  134. aipt_v2/terminal/__init__.py +28 -0
  135. aipt_v2/terminal/executor.py +400 -0
  136. aipt_v2/terminal/sandbox.py +350 -0
  137. aipt_v2/tools/__init__.py +44 -0
  138. aipt_v2/tools/active_directory/__init__.py +78 -0
  139. aipt_v2/tools/active_directory/ad_config.py +238 -0
  140. aipt_v2/tools/active_directory/bloodhound_wrapper.py +447 -0
  141. aipt_v2/tools/active_directory/kerberos_attacks.py +430 -0
  142. aipt_v2/tools/active_directory/ldap_enum.py +533 -0
  143. aipt_v2/tools/active_directory/smb_attacks.py +505 -0
  144. aipt_v2/tools/agents_graph/__init__.py +19 -0
  145. aipt_v2/tools/agents_graph/agents_graph_actions.py +69 -0
  146. aipt_v2/tools/api_security/__init__.py +76 -0
  147. aipt_v2/tools/api_security/api_discovery.py +608 -0
  148. aipt_v2/tools/api_security/graphql_scanner.py +622 -0
  149. aipt_v2/tools/api_security/jwt_analyzer.py +577 -0
  150. aipt_v2/tools/api_security/openapi_fuzzer.py +761 -0
  151. aipt_v2/tools/browser/__init__.py +5 -0
  152. aipt_v2/tools/browser/browser_actions.py +238 -0
  153. aipt_v2/tools/browser/browser_instance.py +535 -0
  154. aipt_v2/tools/browser/tab_manager.py +344 -0
  155. aipt_v2/tools/cloud/__init__.py +70 -0
  156. aipt_v2/tools/cloud/cloud_config.py +273 -0
  157. aipt_v2/tools/cloud/cloud_scanner.py +639 -0
  158. aipt_v2/tools/cloud/prowler_tool.py +571 -0
  159. aipt_v2/tools/cloud/scoutsuite_tool.py +359 -0
  160. aipt_v2/tools/executor.py +307 -0
  161. aipt_v2/tools/parser.py +408 -0
  162. aipt_v2/tools/proxy/__init__.py +5 -0
  163. aipt_v2/tools/proxy/proxy_actions.py +103 -0
  164. aipt_v2/tools/proxy/proxy_manager.py +789 -0
  165. aipt_v2/tools/registry.py +196 -0
  166. aipt_v2/tools/scanners/__init__.py +343 -0
  167. aipt_v2/tools/scanners/acunetix_tool.py +712 -0
  168. aipt_v2/tools/scanners/burp_tool.py +631 -0
  169. aipt_v2/tools/scanners/config.py +156 -0
  170. aipt_v2/tools/scanners/nessus_tool.py +588 -0
  171. aipt_v2/tools/scanners/zap_tool.py +612 -0
  172. aipt_v2/tools/terminal/__init__.py +5 -0
  173. aipt_v2/tools/terminal/terminal_actions.py +37 -0
  174. aipt_v2/tools/terminal/terminal_manager.py +153 -0
  175. aipt_v2/tools/terminal/terminal_session.py +449 -0
  176. aipt_v2/tools/tool_processing.py +108 -0
  177. aipt_v2/utils/__init__.py +17 -0
  178. aipt_v2/utils/logging.py +202 -0
  179. aipt_v2/utils/model_manager.py +187 -0
  180. aipt_v2/utils/searchers/__init__.py +269 -0
  181. aipt_v2/verify_install.py +793 -0
  182. aiptx-2.0.7.dist-info/METADATA +345 -0
  183. aiptx-2.0.7.dist-info/RECORD +187 -0
  184. aiptx-2.0.7.dist-info/WHEEL +5 -0
  185. aiptx-2.0.7.dist-info/entry_points.txt +7 -0
  186. aiptx-2.0.7.dist-info/licenses/LICENSE +21 -0
  187. aiptx-2.0.7.dist-info/top_level.txt +1 -0
aipt_v2/app.py ADDED
@@ -0,0 +1,957 @@
1
+ """
2
+ AIPT REST API - FastAPI application
3
+ Provides REST endpoints for AIPT operations.
4
+
5
+ Endpoints:
6
+ - /projects - Project management
7
+ - /sessions - Session management
8
+ - /findings - Finding management
9
+ - /scan - Run scans
10
+ - /tools - List available tools
11
+ - /auth - Authentication endpoints
12
+
13
+ Security:
14
+ - JWT-based authentication (optional but recommended)
15
+ - CORS restricted to configured origins
16
+ - Rate limiting per client IP
17
+ - Input validation on all endpoints
18
+ """
19
+
20
+ import os
21
+ import re
22
+ import secrets
23
+ from pathlib import Path
24
+ from typing import Optional, List
25
+ from datetime import datetime, timedelta, timezone
26
+ from urllib.parse import urlparse
27
+
28
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends
29
+ from fastapi.middleware.cors import CORSMiddleware
30
+ from fastapi.responses import JSONResponse
31
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
32
+ from starlette.middleware.base import BaseHTTPMiddleware
33
+ from starlette.responses import Response
34
+ from pydantic import BaseModel, Field, field_validator
35
+ from slowapi import Limiter, _rate_limit_exceeded_handler
36
+ from slowapi.util import get_remote_address
37
+ from slowapi.errors import RateLimitExceeded
38
+
39
+ # JWT imports with fallback
40
+ try:
41
+ import jwt
42
+ JWT_AVAILABLE = True
43
+ except ImportError:
44
+ JWT_AVAILABLE = False
45
+
46
+ # Import AIPT v2 components
47
+ from aipt_v2.database.repository import Repository
48
+ from aipt_v2.intelligence import ToolRAG, CVEIntelligence
49
+ from aipt_v2.tools.tool_processing import process_tool_invocations
50
+ from aipt_v2.utils.logging import logger
51
+ from .health import health_router, record_scan, record_tool_invocation
52
+
53
+ # Rate limiter instance
54
+ limiter = Limiter(key_func=get_remote_address)
55
+
56
+ # Security constants
57
+ ALLOWED_SCAN_PROTOCOLS = {"http", "https"}
58
+ MAX_TARGET_LENGTH = 2048
59
+ CVE_PATTERN = re.compile(r"^CVE-\d{4}-\d{4,}$", re.IGNORECASE)
60
+
61
+
62
+ # ============== Pydantic Models ==============
63
+
64
+ class ProjectCreate(BaseModel):
65
+ name: str = Field(..., min_length=1, max_length=255, description="Project name")
66
+ target: str = Field(..., min_length=1, max_length=MAX_TARGET_LENGTH, description="Target URL or domain")
67
+ description: Optional[str] = Field(None, max_length=2000, description="Project description")
68
+ scope: Optional[List[str]] = Field(None, max_length=100, description="In-scope domains/IPs")
69
+
70
+ @field_validator("target")
71
+ @classmethod
72
+ def validate_target(cls, v: str) -> str:
73
+ """Validate target is a valid URL or domain."""
74
+ v = v.strip()
75
+ if not v:
76
+ raise ValueError("Target cannot be empty")
77
+
78
+ # If it looks like a URL, validate it
79
+ if v.startswith(("http://", "https://")):
80
+ parsed = urlparse(v)
81
+ if parsed.scheme not in ALLOWED_SCAN_PROTOCOLS:
82
+ raise ValueError(f"Protocol must be http or https, got: {parsed.scheme}")
83
+ if not parsed.netloc:
84
+ raise ValueError("Invalid URL: missing hostname")
85
+ else:
86
+ # Validate as domain - basic check for dangerous characters
87
+ if any(c in v for c in [";", "&", "|", "$", "`", "\n", "\r"]):
88
+ raise ValueError("Target contains invalid characters")
89
+ return v
90
+
91
+ @field_validator("name")
92
+ @classmethod
93
+ def validate_name(cls, v: str) -> str:
94
+ """Validate project name."""
95
+ v = v.strip()
96
+ if not v:
97
+ raise ValueError("Name cannot be empty")
98
+ # Prevent path traversal in name
99
+ if ".." in v or "/" in v or "\\" in v:
100
+ raise ValueError("Name contains invalid characters")
101
+ return v
102
+
103
+
104
+ class ProjectResponse(BaseModel):
105
+ id: int
106
+ name: str
107
+ target: str
108
+ description: Optional[str]
109
+ scope: List[str]
110
+ status: str
111
+ created_at: datetime
112
+
113
+ class Config:
114
+ from_attributes = True
115
+
116
+
117
+ class SessionCreate(BaseModel):
118
+ name: Optional[str] = None
119
+ phase: str = "recon"
120
+ max_iterations: int = 100
121
+
122
+
123
+ class SessionResponse(BaseModel):
124
+ id: int
125
+ project_id: int
126
+ name: Optional[str]
127
+ phase: str
128
+ status: str
129
+ iteration: int
130
+ started_at: datetime
131
+
132
+ class Config:
133
+ from_attributes = True
134
+
135
+
136
+ class FindingResponse(BaseModel):
137
+ id: int
138
+ type: str
139
+ value: str
140
+ description: Optional[str]
141
+ severity: str
142
+ phase: Optional[str]
143
+ tool: Optional[str]
144
+ verified: bool
145
+ discovered_at: datetime
146
+
147
+ class Config:
148
+ from_attributes = True
149
+
150
+
151
+ class ScanRequest(BaseModel):
152
+ target: str = Field(..., min_length=1, max_length=MAX_TARGET_LENGTH, description="Target URL or domain")
153
+ tools: Optional[List[str]] = Field(None, max_length=20, description="Tools to run")
154
+ phase: str = Field(default="recon", description="Scan phase")
155
+
156
+ @field_validator("target")
157
+ @classmethod
158
+ def validate_target(cls, v: str) -> str:
159
+ """Validate and sanitize target."""
160
+ v = v.strip()
161
+ if not v:
162
+ raise ValueError("Target cannot be empty")
163
+
164
+ # Validate URL format
165
+ if v.startswith(("http://", "https://")):
166
+ parsed = urlparse(v)
167
+ if parsed.scheme not in ALLOWED_SCAN_PROTOCOLS:
168
+ raise ValueError(f"Protocol must be http or https")
169
+ if not parsed.netloc:
170
+ raise ValueError("Invalid URL: missing hostname")
171
+ else:
172
+ # Check for command injection characters
173
+ dangerous_chars = [";", "&", "|", "$", "`", "\n", "\r", "'", '"', "(", ")", "{", "}", "<", ">"]
174
+ if any(c in v for c in dangerous_chars):
175
+ raise ValueError("Target contains invalid characters")
176
+ return v
177
+
178
+ @field_validator("phase")
179
+ @classmethod
180
+ def validate_phase(cls, v: str) -> str:
181
+ """Validate phase is allowed."""
182
+ allowed_phases = {"recon", "scan", "exploit", "report"}
183
+ v = v.lower().strip()
184
+ if v not in allowed_phases:
185
+ raise ValueError(f"Phase must be one of: {', '.join(allowed_phases)}")
186
+ return v
187
+
188
+ @field_validator("tools")
189
+ @classmethod
190
+ def validate_tools(cls, v: Optional[List[str]]) -> Optional[List[str]]:
191
+ """Validate tool names."""
192
+ if v is None:
193
+ return v
194
+ # Sanitize tool names - only allow alphanumeric and underscore/hyphen
195
+ clean_tools = []
196
+ for tool in v:
197
+ tool = tool.strip().lower()
198
+ if not re.match(r"^[a-z0-9_-]+$", tool):
199
+ raise ValueError(f"Invalid tool name: {tool}")
200
+ clean_tools.append(tool)
201
+ return clean_tools
202
+
203
+
204
+ class ScanResponse(BaseModel):
205
+ status: str
206
+ message: str
207
+ findings: List[dict] = []
208
+
209
+
210
+ class ToolInfo(BaseModel):
211
+ name: str
212
+ description: str
213
+ phase: str
214
+ keywords: List[str]
215
+
216
+
217
+ class CVERequest(BaseModel):
218
+ cve_id: str = Field(..., min_length=9, max_length=20, description="CVE identifier")
219
+
220
+ @field_validator("cve_id")
221
+ @classmethod
222
+ def validate_cve_id(cls, v: str) -> str:
223
+ """Validate CVE ID format."""
224
+ v = v.strip().upper()
225
+ if not CVE_PATTERN.match(v):
226
+ raise ValueError("Invalid CVE ID format. Expected: CVE-YYYY-NNNNN")
227
+ return v
228
+
229
+
230
+ class CVEResponse(BaseModel):
231
+ cve_id: str
232
+ cvss: float
233
+ epss: float
234
+ priority_score: float
235
+ has_poc: bool
236
+ description: str
237
+
238
+
239
+ # ============== Authentication Models ==============
240
+
241
+ class TokenRequest(BaseModel):
242
+ """Request model for token generation."""
243
+ api_key: str = Field(..., min_length=32, max_length=128, description="API key")
244
+
245
+
246
+ class TokenResponse(BaseModel):
247
+ """Response model for token generation."""
248
+ access_token: str
249
+ token_type: str = "bearer"
250
+ expires_in: int
251
+
252
+
253
+ class UserInfo(BaseModel):
254
+ """User information extracted from JWT."""
255
+ sub: str # Subject (user identifier)
256
+ exp: datetime
257
+ iat: datetime
258
+ is_authenticated: bool = True
259
+
260
+
261
+ # ============== JWT Security ==============
262
+
263
+ # Security scheme for OpenAPI docs
264
+ security_scheme = HTTPBearer(auto_error=False)
265
+
266
+
267
+ class JWTAuth:
268
+ """
269
+ JWT Authentication handler.
270
+
271
+ Security features:
272
+ - HS256 algorithm with secure secret
273
+ - Token expiration (default 24h)
274
+ - API key validation before token issuance
275
+ """
276
+
277
+ def __init__(self, secret_key: Optional[str] = None, expires_hours: int = 24):
278
+ """
279
+ Initialize JWT auth handler.
280
+
281
+ Args:
282
+ secret_key: Secret for signing tokens. If not provided, uses
283
+ AIPT_JWT_SECRET env var or generates a random one.
284
+ expires_hours: Token expiration time in hours.
285
+ """
286
+ self.secret_key = secret_key or os.getenv("AIPT_JWT_SECRET")
287
+ if not self.secret_key:
288
+ # Generate a secure random secret if not configured
289
+ # Note: This means tokens won't survive server restarts
290
+ self.secret_key = secrets.token_urlsafe(32)
291
+ logger.warning(
292
+ "JWT secret not configured (AIPT_JWT_SECRET). "
293
+ "Using random secret - tokens will be invalid after restart."
294
+ )
295
+ self.expires_hours = expires_hours
296
+ self.algorithm = "HS256"
297
+
298
+ # Valid API keys - in production, load from secure storage
299
+ self._valid_api_keys = set()
300
+ env_keys = os.getenv("AIPT_API_KEYS", "")
301
+ if env_keys:
302
+ self._valid_api_keys = {k.strip() for k in env_keys.split(",") if k.strip()}
303
+
304
+ def validate_api_key(self, api_key: str) -> bool:
305
+ """Validate an API key."""
306
+ if not self._valid_api_keys:
307
+ # If no keys configured, accept any non-empty key (dev mode)
308
+ logger.warning("No API keys configured (AIPT_API_KEYS). Accepting any key.")
309
+ return bool(api_key and len(api_key) >= 32)
310
+ return api_key in self._valid_api_keys
311
+
312
+ def create_token(self, subject: str) -> tuple[str, int]:
313
+ """
314
+ Create a new JWT token.
315
+
316
+ Args:
317
+ subject: User/client identifier
318
+
319
+ Returns:
320
+ Tuple of (token, expires_in_seconds)
321
+ """
322
+ if not JWT_AVAILABLE:
323
+ raise HTTPException(
324
+ status_code=503,
325
+ detail="JWT support not available. Install PyJWT: pip install PyJWT"
326
+ )
327
+
328
+ now = datetime.now(timezone.utc)
329
+ expires = now + timedelta(hours=self.expires_hours)
330
+
331
+ payload = {
332
+ "sub": subject,
333
+ "iat": now,
334
+ "exp": expires,
335
+ "iss": "aipt-api",
336
+ }
337
+
338
+ token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
339
+ expires_in = int((expires - now).total_seconds())
340
+
341
+ return token, expires_in
342
+
343
+ def verify_token(self, token: str) -> Optional[UserInfo]:
344
+ """
345
+ Verify and decode a JWT token.
346
+
347
+ Args:
348
+ token: JWT token string
349
+
350
+ Returns:
351
+ UserInfo if valid, None otherwise
352
+ """
353
+ if not JWT_AVAILABLE:
354
+ return None
355
+
356
+ try:
357
+ payload = jwt.decode(
358
+ token,
359
+ self.secret_key,
360
+ algorithms=[self.algorithm],
361
+ options={"require": ["exp", "iat", "sub"]}
362
+ )
363
+ return UserInfo(
364
+ sub=payload["sub"],
365
+ exp=datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
366
+ iat=datetime.fromtimestamp(payload["iat"], tz=timezone.utc),
367
+ )
368
+ except jwt.ExpiredSignatureError:
369
+ logger.debug("Token expired")
370
+ return None
371
+ except jwt.InvalidTokenError as e:
372
+ logger.debug(f"Invalid token: {e}")
373
+ return None
374
+
375
+
376
+ # Global JWT auth instance
377
+ jwt_auth: Optional[JWTAuth] = None
378
+
379
+
380
+ # ============== Security Headers Middleware ==============
381
+
382
+ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
383
+ """
384
+ Middleware to add security headers to all responses.
385
+
386
+ Implements OWASP recommended security headers:
387
+ - X-Content-Type-Options: Prevents MIME type sniffing
388
+ - X-Frame-Options: Prevents clickjacking
389
+ - X-XSS-Protection: Legacy XSS protection
390
+ - Strict-Transport-Security: Enforces HTTPS
391
+ - Content-Security-Policy: Restricts resource loading
392
+ - Referrer-Policy: Controls referrer information
393
+ - Permissions-Policy: Restricts browser features
394
+ """
395
+
396
+ async def dispatch(self, request: Request, call_next) -> Response:
397
+ response = await call_next(request)
398
+
399
+ # Security headers
400
+ response.headers["X-Content-Type-Options"] = "nosniff"
401
+ response.headers["X-Frame-Options"] = "DENY"
402
+ response.headers["X-XSS-Protection"] = "1; mode=block"
403
+ response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
404
+ response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
405
+
406
+ # HSTS - Only enable if using HTTPS in production
407
+ # response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
408
+
409
+ # CSP - Restrictive policy for API responses
410
+ response.headers["Content-Security-Policy"] = (
411
+ "default-src 'none'; "
412
+ "frame-ancestors 'none'; "
413
+ "base-uri 'none'; "
414
+ "form-action 'none'"
415
+ )
416
+
417
+ return response
418
+
419
+
420
+ # ============== WAF Middleware ==============
421
+
422
+ class WAFMiddleware(BaseHTTPMiddleware):
423
+ """
424
+ Simple Web Application Firewall middleware.
425
+
426
+ Provides basic protection against common attacks:
427
+ - SQL Injection patterns
428
+ - XSS patterns
429
+ - Path traversal attempts
430
+ - Command injection patterns
431
+ """
432
+
433
+ # Suspicious patterns that may indicate attacks
434
+ SUSPICIOUS_PATTERNS = [
435
+ # SQL Injection
436
+ r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|ALTER)\b.*\b(FROM|INTO|TABLE|DATABASE)\b)",
437
+ r"(--|#|/\*|\*/|;)",
438
+ r"(\bOR\b\s+\d+\s*=\s*\d+)",
439
+ # XSS
440
+ r"(<script|javascript:|on\w+\s*=)",
441
+ # Path traversal
442
+ r"(\.\./|\.\.\\|%2e%2e)",
443
+ # Command injection
444
+ r"(;|\||&&|\$\(|`)",
445
+ ]
446
+
447
+ def __init__(self, app, enabled: bool = True):
448
+ super().__init__(app)
449
+ self.enabled = enabled
450
+ self._patterns = None
451
+
452
+ @property
453
+ def patterns(self):
454
+ if self._patterns is None:
455
+ self._patterns = [
456
+ re.compile(p, re.IGNORECASE)
457
+ for p in self.SUSPICIOUS_PATTERNS
458
+ ]
459
+ return self._patterns
460
+
461
+ def _is_suspicious(self, value: str) -> bool:
462
+ """Check if a value contains suspicious patterns."""
463
+ if not value:
464
+ return False
465
+ for pattern in self.patterns:
466
+ if pattern.search(value):
467
+ return True
468
+ return False
469
+
470
+ async def dispatch(self, request: Request, call_next) -> Response:
471
+ if not self.enabled:
472
+ return await call_next(request)
473
+
474
+ # Check query parameters
475
+ for key, value in request.query_params.items():
476
+ if self._is_suspicious(key) or self._is_suspicious(value):
477
+ logger.warning(
478
+ "WAF blocked suspicious request",
479
+ path=request.url.path,
480
+ param=key,
481
+ client=request.client.host if request.client else "unknown",
482
+ )
483
+ return JSONResponse(
484
+ status_code=403,
485
+ content={"detail": "Request blocked by security policy"}
486
+ )
487
+
488
+ # Check path
489
+ if self._is_suspicious(request.url.path):
490
+ logger.warning(
491
+ "WAF blocked suspicious path",
492
+ path=request.url.path,
493
+ client=request.client.host if request.client else "unknown",
494
+ )
495
+ return JSONResponse(
496
+ status_code=403,
497
+ content={"detail": "Request blocked by security policy"}
498
+ )
499
+
500
+ return await call_next(request)
501
+
502
+
503
+ async def get_current_user(
504
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_scheme)
505
+ ) -> Optional[UserInfo]:
506
+ """
507
+ Dependency to get current authenticated user.
508
+
509
+ Returns UserInfo if authenticated, None otherwise.
510
+ Authentication is optional - endpoints can work without auth.
511
+ """
512
+ if not credentials or not jwt_auth:
513
+ return None
514
+
515
+ return jwt_auth.verify_token(credentials.credentials)
516
+
517
+
518
+ async def require_auth(
519
+ user: Optional[UserInfo] = Depends(get_current_user)
520
+ ) -> UserInfo:
521
+ """
522
+ Dependency to require authentication.
523
+
524
+ Raises HTTPException 401 if not authenticated.
525
+ """
526
+ if not user:
527
+ raise HTTPException(
528
+ status_code=401,
529
+ detail="Authentication required",
530
+ headers={"WWW-Authenticate": "Bearer"},
531
+ )
532
+ return user
533
+
534
+
535
+ # ============== FastAPI App ==============
536
+
537
+ # Default CORS origins - restricted for security
538
+ DEFAULT_CORS_ORIGINS = [
539
+ "http://localhost:3000",
540
+ "http://localhost:8080",
541
+ "http://127.0.0.1:3000",
542
+ "http://127.0.0.1:8080",
543
+ ]
544
+
545
+
546
+ def create_app(
547
+ db_url: str = "sqlite:///~/.aipt/aipt.db",
548
+ title: str = "AIPT API",
549
+ cors_origins: Optional[List[str]] = None,
550
+ rate_limit: str = "100/minute",
551
+ ) -> FastAPI:
552
+ """
553
+ Create FastAPI application with security middleware.
554
+
555
+ Args:
556
+ db_url: Database connection URL
557
+ title: API title
558
+ cors_origins: Allowed CORS origins (defaults to localhost only)
559
+ rate_limit: Rate limit string (e.g., "100/minute")
560
+ """
561
+ # Get CORS origins from env or use defaults
562
+ if cors_origins is None:
563
+ env_origins = os.getenv("AIPT_CORS_ORIGINS", "")
564
+ if env_origins:
565
+ cors_origins = [o.strip() for o in env_origins.split(",") if o.strip()]
566
+ else:
567
+ cors_origins = DEFAULT_CORS_ORIGINS
568
+
569
+ app = FastAPI(
570
+ title=title,
571
+ description="AI-Powered Penetration Testing Framework API",
572
+ version="0.2.0",
573
+ )
574
+
575
+ # Rate limiting
576
+ app.state.limiter = limiter
577
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
578
+
579
+ # CORS - Restricted to configured origins only
580
+ app.add_middleware(
581
+ CORSMiddleware,
582
+ allow_origins=cors_origins,
583
+ allow_credentials=True,
584
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
585
+ allow_headers=["Authorization", "Content-Type", "X-Requested-With"],
586
+ expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset"],
587
+ )
588
+
589
+ # Security Headers Middleware - OWASP recommended headers
590
+ app.add_middleware(SecurityHeadersMiddleware)
591
+
592
+ # WAF Middleware - Basic attack protection
593
+ # Can be disabled via AIPT_WAF_ENABLED=false for testing
594
+ waf_enabled = os.getenv("AIPT_WAF_ENABLED", "true").lower() == "true"
595
+ app.add_middleware(WAFMiddleware, enabled=waf_enabled)
596
+
597
+ # Log security configuration
598
+ logger.info(
599
+ f"API security configured: cors_origins={cors_origins}, "
600
+ f"rate_limit={rate_limit}, waf_enabled={waf_enabled}, security_headers=enabled"
601
+ )
602
+
603
+ # Initialize components
604
+ repo = Repository(db_url)
605
+ tools_path = Path(__file__).parent / "intelligence" / "tools.json"
606
+ tools_rag = ToolRAG(tools_path=str(tools_path), lazy_load=True)
607
+ cve_intel = CVEIntelligence()
608
+
609
+ # Initialize JWT authentication
610
+ global jwt_auth
611
+ jwt_auth = JWTAuth()
612
+ app.state.jwt_auth = jwt_auth
613
+
614
+ # Store in app state
615
+ app.state.repo = repo
616
+ app.state.tools_rag = tools_rag
617
+ app.state.cve_intel = cve_intel
618
+ app.state.rate_limit = rate_limit
619
+
620
+ # ============== Authentication ==============
621
+
622
+ @app.post("/auth/token", response_model=TokenResponse, tags=["Authentication"])
623
+ @limiter.limit("10/minute") # Strict rate limit on token requests
624
+ async def get_token(request: Request, token_request: TokenRequest):
625
+ """
626
+ Get JWT access token using API key.
627
+
628
+ Security:
629
+ - Rate limited to 10 requests per minute
630
+ - API key must be at least 32 characters
631
+ - Token expires in 24 hours
632
+
633
+ Usage:
634
+ 1. Set AIPT_API_KEYS env var with comma-separated valid API keys
635
+ 2. POST to /auth/token with your API key
636
+ 3. Use returned token in Authorization header: Bearer <token>
637
+ """
638
+ if not jwt_auth.validate_api_key(token_request.api_key):
639
+ raise HTTPException(
640
+ status_code=401,
641
+ detail="Invalid API key"
642
+ )
643
+
644
+ # Create token with API key hash as subject (don't expose full key)
645
+ import hashlib
646
+ subject = hashlib.sha256(token_request.api_key.encode()).hexdigest()[:16]
647
+ token, expires_in = jwt_auth.create_token(subject)
648
+
649
+ return TokenResponse(
650
+ access_token=token,
651
+ expires_in=expires_in
652
+ )
653
+
654
+ @app.get("/auth/me", tags=["Authentication"])
655
+ async def get_current_user_info(user: UserInfo = Depends(require_auth)):
656
+ """
657
+ Get current authenticated user information.
658
+
659
+ Requires valid JWT token in Authorization header.
660
+ """
661
+ return {
662
+ "sub": user.sub,
663
+ "authenticated": True,
664
+ "expires": user.exp.isoformat(),
665
+ }
666
+
667
+ @app.get("/auth/status", tags=["Authentication"])
668
+ async def auth_status(user: Optional[UserInfo] = Depends(get_current_user)):
669
+ """
670
+ Check authentication status.
671
+
672
+ Returns whether JWT auth is available and if current request is authenticated.
673
+ """
674
+ return {
675
+ "jwt_available": JWT_AVAILABLE,
676
+ "authenticated": user is not None,
677
+ "user": user.sub if user else None,
678
+ }
679
+
680
+ # ============== Health & Metrics ==============
681
+ # Include comprehensive health check router with:
682
+ # - /health - Basic liveness probe
683
+ # - /health/live - Kubernetes liveness probe
684
+ # - /health/ready - Readiness probe with dependency checks
685
+ # - /metrics - Prometheus-compatible metrics
686
+ # - /health/info - Service information
687
+ app.include_router(health_router)
688
+
689
+ # ============== Projects ==============
690
+
691
+ @app.post("/projects", response_model=ProjectResponse)
692
+ async def create_project(project: ProjectCreate):
693
+ """Create a new project"""
694
+ db_project = repo.create_project(
695
+ name=project.name,
696
+ target=project.target,
697
+ description=project.description,
698
+ scope=project.scope,
699
+ )
700
+ return db_project
701
+
702
+ @app.get("/projects", response_model=List[ProjectResponse])
703
+ async def list_projects(status: Optional[str] = None):
704
+ """List all projects"""
705
+ return repo.list_projects(status=status)
706
+
707
+ @app.get("/projects/{project_id}", response_model=ProjectResponse)
708
+ async def get_project(project_id: int):
709
+ """Get project by ID"""
710
+ project = repo.get_project(project_id)
711
+ if not project:
712
+ raise HTTPException(status_code=404, detail="Project not found")
713
+ return project
714
+
715
+ @app.delete("/projects/{project_id}")
716
+ async def delete_project(project_id: int):
717
+ """Delete a project"""
718
+ if not repo.delete_project(project_id):
719
+ raise HTTPException(status_code=404, detail="Project not found")
720
+ return {"status": "deleted"}
721
+
722
+ # ============== Sessions ==============
723
+
724
+ @app.post("/projects/{project_id}/sessions", response_model=SessionResponse)
725
+ async def create_session(project_id: int, session: SessionCreate):
726
+ """Create a new session"""
727
+ project = repo.get_project(project_id)
728
+ if not project:
729
+ raise HTTPException(status_code=404, detail="Project not found")
730
+
731
+ from aipt_v2.database.models import PhaseType
732
+ db_session = repo.create_session(
733
+ project_id=project_id,
734
+ name=session.name,
735
+ phase=PhaseType(session.phase),
736
+ max_iterations=session.max_iterations,
737
+ )
738
+ return db_session
739
+
740
+ @app.get("/projects/{project_id}/sessions", response_model=List[SessionResponse])
741
+ async def list_sessions(project_id: int):
742
+ """List all sessions for a project"""
743
+ return repo.list_sessions(project_id)
744
+
745
+ # ============== Findings ==============
746
+
747
+ @app.get("/projects/{project_id}/findings", response_model=List[FindingResponse])
748
+ async def get_findings(
749
+ project_id: int,
750
+ type: Optional[str] = None,
751
+ severity: Optional[str] = None,
752
+ phase: Optional[str] = None,
753
+ ):
754
+ """Get findings for a project"""
755
+ return repo.get_findings(
756
+ project_id=project_id,
757
+ type=type,
758
+ severity=severity,
759
+ phase=phase,
760
+ )
761
+
762
+ @app.get("/projects/{project_id}/findings/summary")
763
+ async def get_findings_summary(project_id: int):
764
+ """Get findings summary"""
765
+ return repo.get_findings_summary(project_id)
766
+
767
+ @app.post("/findings/{finding_id}/verify")
768
+ async def verify_finding(finding_id: int, notes: Optional[str] = None):
769
+ """Mark finding as verified"""
770
+ repo.verify_finding(finding_id, verified=True, notes=notes)
771
+ return {"status": "verified"}
772
+
773
+ @app.post("/findings/{finding_id}/false-positive")
774
+ async def mark_false_positive(finding_id: int, notes: Optional[str] = None):
775
+ """Mark finding as false positive"""
776
+ repo.mark_false_positive(finding_id, notes=notes)
777
+ return {"status": "marked as false positive"}
778
+
779
+ # ============== Scanning ==============
780
+
781
+ @app.post("/scan/quick", response_model=ScanResponse)
782
+ @limiter.limit("10/minute") # Limit scan requests
783
+ async def quick_scan(request: Request, scan_request: ScanRequest):
784
+ """Run a quick scan on target (rate limited: 10/minute)"""
785
+ import asyncio
786
+ import shutil
787
+ from aipt_v2.tools.parser import OutputParser
788
+
789
+ findings = []
790
+
791
+ # Check if nmap is available
792
+ nmap_path = shutil.which("nmap")
793
+ if nmap_path:
794
+ try:
795
+ proc = await asyncio.create_subprocess_shell(
796
+ f"nmap -F {scan_request.target}",
797
+ stdout=asyncio.subprocess.PIPE,
798
+ stderr=asyncio.subprocess.PIPE,
799
+ )
800
+ stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=60)
801
+ output = stdout.decode() if stdout else ""
802
+
803
+ if proc.returncode == 0:
804
+ parser = OutputParser()
805
+ parsed = parser.parse(output, "nmap")
806
+ findings.extend([{
807
+ "type": f.type,
808
+ "value": f.value,
809
+ "description": f.description,
810
+ "severity": f.severity,
811
+ } for f in parsed])
812
+ except asyncio.TimeoutError:
813
+ logger.warning("Tool execution timed out", tool="nmap", target=scan_request.target)
814
+ except Exception as e:
815
+ logger.error("Tool execution failed", tool="nmap", error=str(e))
816
+
817
+ return ScanResponse(
818
+ status="completed",
819
+ message=f"Quick scan completed on {scan_request.target}",
820
+ findings=findings,
821
+ )
822
+
823
+ @app.post("/scan/tool")
824
+ @limiter.limit("5/minute") # Stricter limit for tool execution
825
+ async def run_tool(request: Request, tool_name: str, target: str, options: Optional[str] = None):
826
+ """Run a specific tool (rate limited: 5/minute)"""
827
+ import asyncio
828
+ import time
829
+
830
+ # Validate tool_name - only alphanumeric and underscore/hyphen
831
+ if not re.match(r"^[a-zA-Z0-9_-]+$", tool_name):
832
+ raise HTTPException(status_code=400, detail="Invalid tool name")
833
+
834
+ # Validate target - check for command injection
835
+ dangerous_chars = [";", "&", "|", "$", "`", "\n", "\r", "'", '"']
836
+ if any(c in target for c in dangerous_chars):
837
+ raise HTTPException(status_code=400, detail="Invalid target: contains dangerous characters")
838
+
839
+ # Validate options if provided
840
+ if options:
841
+ if any(c in options for c in [";", "&", "|", "$", "`"]):
842
+ raise HTTPException(status_code=400, detail="Invalid options: contains dangerous characters")
843
+
844
+ # Get tool from RAG
845
+ tool = tools_rag.get_tool_by_name(tool_name)
846
+ if not tool:
847
+ raise HTTPException(status_code=404, detail=f"Tool {tool_name} not found")
848
+
849
+ # Build command
850
+ cmd = tool.get("cmd", "").replace("{target}", target)
851
+ if options:
852
+ cmd = f"{cmd} {options}"
853
+
854
+ # Execute
855
+ start_time = time.time()
856
+ try:
857
+ proc = await asyncio.create_subprocess_shell(
858
+ cmd,
859
+ stdout=asyncio.subprocess.PIPE,
860
+ stderr=asyncio.subprocess.PIPE,
861
+ )
862
+ stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=300)
863
+ output = (stdout.decode() if stdout else "") + (stderr.decode() if stderr else "")
864
+ return_code = proc.returncode
865
+ except asyncio.TimeoutError:
866
+ output = "Command timed out after 300 seconds"
867
+ return_code = -1
868
+ except Exception as e:
869
+ output = f"Command failed: {str(e)}"
870
+ return_code = -1
871
+
872
+ duration = time.time() - start_time
873
+
874
+ return {
875
+ "tool": tool_name,
876
+ "target": target,
877
+ "command": cmd,
878
+ "return_code": return_code,
879
+ "output": output[:10000], # Truncate
880
+ "duration": duration,
881
+ }
882
+
883
+ # ============== Tools ==============
884
+
885
+ @app.get("/tools", response_model=List[ToolInfo])
886
+ async def list_tools(phase: Optional[str] = None):
887
+ """List available tools"""
888
+ tools = tools_rag.tools
889
+ if phase:
890
+ tools = [t for t in tools if t.get("phase") == phase]
891
+
892
+ return [
893
+ ToolInfo(
894
+ name=t.get("name", ""),
895
+ description=t.get("description", ""),
896
+ phase=t.get("phase", ""),
897
+ keywords=t.get("keywords", []),
898
+ )
899
+ for t in tools
900
+ ]
901
+
902
+ @app.get("/tools/{tool_name}")
903
+ async def get_tool(tool_name: str):
904
+ """Get tool details"""
905
+ tool = tools_rag.get_tool_by_name(tool_name)
906
+ if not tool:
907
+ raise HTTPException(status_code=404, detail="Tool not found")
908
+ return tool
909
+
910
+ @app.get("/tools/search/{query}")
911
+ async def search_tools(query: str, top_k: int = 5):
912
+ """Search for tools by query"""
913
+ results = tools_rag.search(query, top_k=top_k)
914
+ return results
915
+
916
+ # ============== CVE ==============
917
+
918
+ @app.post("/cve/lookup", response_model=CVEResponse)
919
+ async def lookup_cve(request: CVERequest):
920
+ """Lookup CVE information"""
921
+ info = cve_intel.lookup(request.cve_id)
922
+ return CVEResponse(
923
+ cve_id=info.cve_id,
924
+ cvss=info.cvss,
925
+ epss=info.epss,
926
+ priority_score=info.priority_score,
927
+ has_poc=info.has_poc,
928
+ description=info.description[:500],
929
+ )
930
+
931
+ @app.post("/cve/prioritize")
932
+ async def prioritize_cves(cve_ids: List[str]):
933
+ """Prioritize CVEs by exploitability"""
934
+ results = cve_intel.prioritize(cve_ids)
935
+ return [
936
+ {
937
+ "cve_id": r.cve_id,
938
+ "cvss": r.cvss,
939
+ "epss": r.epss,
940
+ "priority_score": r.priority_score,
941
+ "has_poc": r.has_poc,
942
+ }
943
+ for r in results
944
+ ]
945
+
946
+ return app
947
+
948
+
949
+ # Default app instance
950
+ app = create_app()
951
+
952
+
953
+ if __name__ == "__main__":
954
+ import uvicorn
955
+ # Security: Default to localhost to prevent accidental network exposure
956
+ # Use --host 0.0.0.0 explicitly for production behind reverse proxy
957
+ uvicorn.run(app, host="127.0.0.1", port=8000)