aiptx 2.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aipt_v2/__init__.py +110 -0
- aipt_v2/__main__.py +24 -0
- aipt_v2/agents/AIPTxAgent/__init__.py +10 -0
- aipt_v2/agents/AIPTxAgent/aiptx_agent.py +211 -0
- aipt_v2/agents/__init__.py +46 -0
- aipt_v2/agents/base.py +520 -0
- aipt_v2/agents/exploit_agent.py +688 -0
- aipt_v2/agents/ptt.py +406 -0
- aipt_v2/agents/state.py +168 -0
- aipt_v2/app.py +957 -0
- aipt_v2/browser/__init__.py +31 -0
- aipt_v2/browser/automation.py +458 -0
- aipt_v2/browser/crawler.py +453 -0
- aipt_v2/cli.py +2933 -0
- aipt_v2/compliance/__init__.py +71 -0
- aipt_v2/compliance/compliance_report.py +449 -0
- aipt_v2/compliance/framework_mapper.py +424 -0
- aipt_v2/compliance/nist_mapping.py +345 -0
- aipt_v2/compliance/owasp_mapping.py +330 -0
- aipt_v2/compliance/pci_mapping.py +297 -0
- aipt_v2/config.py +341 -0
- aipt_v2/core/__init__.py +43 -0
- aipt_v2/core/agent.py +630 -0
- aipt_v2/core/llm.py +395 -0
- aipt_v2/core/memory.py +305 -0
- aipt_v2/core/ptt.py +329 -0
- aipt_v2/database/__init__.py +14 -0
- aipt_v2/database/models.py +232 -0
- aipt_v2/database/repository.py +384 -0
- aipt_v2/docker/__init__.py +23 -0
- aipt_v2/docker/builder.py +260 -0
- aipt_v2/docker/manager.py +222 -0
- aipt_v2/docker/sandbox.py +371 -0
- aipt_v2/evasion/__init__.py +58 -0
- aipt_v2/evasion/request_obfuscator.py +272 -0
- aipt_v2/evasion/tls_fingerprint.py +285 -0
- aipt_v2/evasion/ua_rotator.py +301 -0
- aipt_v2/evasion/waf_bypass.py +439 -0
- aipt_v2/execution/__init__.py +23 -0
- aipt_v2/execution/executor.py +302 -0
- aipt_v2/execution/parser.py +544 -0
- aipt_v2/execution/terminal.py +337 -0
- aipt_v2/health.py +437 -0
- aipt_v2/intelligence/__init__.py +194 -0
- aipt_v2/intelligence/adaptation.py +474 -0
- aipt_v2/intelligence/auth.py +520 -0
- aipt_v2/intelligence/chaining.py +775 -0
- aipt_v2/intelligence/correlation.py +536 -0
- aipt_v2/intelligence/cve_aipt.py +334 -0
- aipt_v2/intelligence/cve_info.py +1111 -0
- aipt_v2/intelligence/knowledge_graph.py +590 -0
- aipt_v2/intelligence/learning.py +626 -0
- aipt_v2/intelligence/llm_analyzer.py +502 -0
- aipt_v2/intelligence/llm_tool_selector.py +518 -0
- aipt_v2/intelligence/payload_generator.py +562 -0
- aipt_v2/intelligence/rag.py +239 -0
- aipt_v2/intelligence/scope.py +442 -0
- aipt_v2/intelligence/searchers/__init__.py +5 -0
- aipt_v2/intelligence/searchers/exploitdb_searcher.py +523 -0
- aipt_v2/intelligence/searchers/github_searcher.py +467 -0
- aipt_v2/intelligence/searchers/google_searcher.py +281 -0
- aipt_v2/intelligence/tools.json +443 -0
- aipt_v2/intelligence/triage.py +670 -0
- aipt_v2/interactive_shell.py +559 -0
- aipt_v2/interface/__init__.py +5 -0
- aipt_v2/interface/cli.py +230 -0
- aipt_v2/interface/main.py +501 -0
- aipt_v2/interface/tui.py +1276 -0
- aipt_v2/interface/utils.py +583 -0
- aipt_v2/llm/__init__.py +39 -0
- aipt_v2/llm/config.py +26 -0
- aipt_v2/llm/llm.py +514 -0
- aipt_v2/llm/memory.py +214 -0
- aipt_v2/llm/request_queue.py +89 -0
- aipt_v2/llm/utils.py +89 -0
- aipt_v2/local_tool_installer.py +1467 -0
- aipt_v2/models/__init__.py +15 -0
- aipt_v2/models/findings.py +295 -0
- aipt_v2/models/phase_result.py +224 -0
- aipt_v2/models/scan_config.py +207 -0
- aipt_v2/monitoring/grafana/dashboards/aipt-dashboard.json +355 -0
- aipt_v2/monitoring/grafana/dashboards/default.yml +17 -0
- aipt_v2/monitoring/grafana/datasources/prometheus.yml +17 -0
- aipt_v2/monitoring/prometheus.yml +60 -0
- aipt_v2/orchestration/__init__.py +52 -0
- aipt_v2/orchestration/pipeline.py +398 -0
- aipt_v2/orchestration/progress.py +300 -0
- aipt_v2/orchestration/scheduler.py +296 -0
- aipt_v2/orchestrator.py +2427 -0
- aipt_v2/payloads/__init__.py +27 -0
- aipt_v2/payloads/cmdi.py +150 -0
- aipt_v2/payloads/sqli.py +263 -0
- aipt_v2/payloads/ssrf.py +204 -0
- aipt_v2/payloads/templates.py +222 -0
- aipt_v2/payloads/traversal.py +166 -0
- aipt_v2/payloads/xss.py +204 -0
- aipt_v2/prompts/__init__.py +60 -0
- aipt_v2/proxy/__init__.py +29 -0
- aipt_v2/proxy/history.py +352 -0
- aipt_v2/proxy/interceptor.py +452 -0
- aipt_v2/recon/__init__.py +44 -0
- aipt_v2/recon/dns.py +241 -0
- aipt_v2/recon/osint.py +367 -0
- aipt_v2/recon/subdomain.py +372 -0
- aipt_v2/recon/tech_detect.py +311 -0
- aipt_v2/reports/__init__.py +17 -0
- aipt_v2/reports/generator.py +313 -0
- aipt_v2/reports/html_report.py +378 -0
- aipt_v2/runtime/__init__.py +53 -0
- aipt_v2/runtime/base.py +30 -0
- aipt_v2/runtime/docker.py +401 -0
- aipt_v2/runtime/local.py +346 -0
- aipt_v2/runtime/tool_server.py +205 -0
- aipt_v2/runtime/vps.py +830 -0
- aipt_v2/scanners/__init__.py +28 -0
- aipt_v2/scanners/base.py +273 -0
- aipt_v2/scanners/nikto.py +244 -0
- aipt_v2/scanners/nmap.py +402 -0
- aipt_v2/scanners/nuclei.py +273 -0
- aipt_v2/scanners/web.py +454 -0
- aipt_v2/scripts/security_audit.py +366 -0
- aipt_v2/setup_wizard.py +941 -0
- aipt_v2/skills/__init__.py +80 -0
- aipt_v2/skills/agents/__init__.py +14 -0
- aipt_v2/skills/agents/api_tester.py +706 -0
- aipt_v2/skills/agents/base.py +477 -0
- aipt_v2/skills/agents/code_review.py +459 -0
- aipt_v2/skills/agents/security_agent.py +336 -0
- aipt_v2/skills/agents/web_pentest.py +818 -0
- aipt_v2/skills/prompts/__init__.py +647 -0
- aipt_v2/system_detector.py +539 -0
- aipt_v2/telemetry/__init__.py +7 -0
- aipt_v2/telemetry/tracer.py +347 -0
- aipt_v2/terminal/__init__.py +28 -0
- aipt_v2/terminal/executor.py +400 -0
- aipt_v2/terminal/sandbox.py +350 -0
- aipt_v2/tools/__init__.py +44 -0
- aipt_v2/tools/active_directory/__init__.py +78 -0
- aipt_v2/tools/active_directory/ad_config.py +238 -0
- aipt_v2/tools/active_directory/bloodhound_wrapper.py +447 -0
- aipt_v2/tools/active_directory/kerberos_attacks.py +430 -0
- aipt_v2/tools/active_directory/ldap_enum.py +533 -0
- aipt_v2/tools/active_directory/smb_attacks.py +505 -0
- aipt_v2/tools/agents_graph/__init__.py +19 -0
- aipt_v2/tools/agents_graph/agents_graph_actions.py +69 -0
- aipt_v2/tools/api_security/__init__.py +76 -0
- aipt_v2/tools/api_security/api_discovery.py +608 -0
- aipt_v2/tools/api_security/graphql_scanner.py +622 -0
- aipt_v2/tools/api_security/jwt_analyzer.py +577 -0
- aipt_v2/tools/api_security/openapi_fuzzer.py +761 -0
- aipt_v2/tools/browser/__init__.py +5 -0
- aipt_v2/tools/browser/browser_actions.py +238 -0
- aipt_v2/tools/browser/browser_instance.py +535 -0
- aipt_v2/tools/browser/tab_manager.py +344 -0
- aipt_v2/tools/cloud/__init__.py +70 -0
- aipt_v2/tools/cloud/cloud_config.py +273 -0
- aipt_v2/tools/cloud/cloud_scanner.py +639 -0
- aipt_v2/tools/cloud/prowler_tool.py +571 -0
- aipt_v2/tools/cloud/scoutsuite_tool.py +359 -0
- aipt_v2/tools/executor.py +307 -0
- aipt_v2/tools/parser.py +408 -0
- aipt_v2/tools/proxy/__init__.py +5 -0
- aipt_v2/tools/proxy/proxy_actions.py +103 -0
- aipt_v2/tools/proxy/proxy_manager.py +789 -0
- aipt_v2/tools/registry.py +196 -0
- aipt_v2/tools/scanners/__init__.py +343 -0
- aipt_v2/tools/scanners/acunetix_tool.py +712 -0
- aipt_v2/tools/scanners/burp_tool.py +631 -0
- aipt_v2/tools/scanners/config.py +156 -0
- aipt_v2/tools/scanners/nessus_tool.py +588 -0
- aipt_v2/tools/scanners/zap_tool.py +612 -0
- aipt_v2/tools/terminal/__init__.py +5 -0
- aipt_v2/tools/terminal/terminal_actions.py +37 -0
- aipt_v2/tools/terminal/terminal_manager.py +153 -0
- aipt_v2/tools/terminal/terminal_session.py +449 -0
- aipt_v2/tools/tool_processing.py +108 -0
- aipt_v2/utils/__init__.py +17 -0
- aipt_v2/utils/logging.py +202 -0
- aipt_v2/utils/model_manager.py +187 -0
- aipt_v2/utils/searchers/__init__.py +269 -0
- aipt_v2/verify_install.py +793 -0
- aiptx-2.0.7.dist-info/METADATA +345 -0
- aiptx-2.0.7.dist-info/RECORD +187 -0
- aiptx-2.0.7.dist-info/WHEEL +5 -0
- aiptx-2.0.7.dist-info/entry_points.txt +7 -0
- aiptx-2.0.7.dist-info/licenses/LICENSE +21 -0
- aiptx-2.0.7.dist-info/top_level.txt +1 -0
aipt_v2/llm/llm.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from fnmatch import fnmatch
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import litellm
|
|
12
|
+
from jinja2 import (
|
|
13
|
+
Environment,
|
|
14
|
+
FileSystemLoader,
|
|
15
|
+
select_autoescape,
|
|
16
|
+
)
|
|
17
|
+
from litellm import ModelResponse, completion_cost
|
|
18
|
+
from litellm.utils import supports_prompt_caching, supports_vision
|
|
19
|
+
|
|
20
|
+
from aipt_v2.llm.config import LLMConfig
|
|
21
|
+
from aipt_v2.llm.memory import MemoryCompressor
|
|
22
|
+
from aipt_v2.llm.request_queue import get_global_queue
|
|
23
|
+
from aipt_v2.llm.utils import _truncate_to_first_function, parse_tool_invocations
|
|
24
|
+
from aipt_v2.prompts import load_prompt_modules, get_tools_prompt
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
litellm.drop_params = True
|
|
30
|
+
litellm.modify_params = True
|
|
31
|
+
|
|
32
|
+
_LLM_API_KEY = os.getenv("LLM_API_KEY")
|
|
33
|
+
_LLM_API_BASE = (
|
|
34
|
+
os.getenv("LLM_API_BASE")
|
|
35
|
+
or os.getenv("OPENAI_API_BASE")
|
|
36
|
+
or os.getenv("LITELLM_BASE_URL")
|
|
37
|
+
or os.getenv("OLLAMA_API_BASE")
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LLMRequestFailedError(Exception):
|
|
42
|
+
def __init__(self, message: str, details: str | None = None):
|
|
43
|
+
super().__init__(message)
|
|
44
|
+
self.message = message
|
|
45
|
+
self.details = details
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
|
|
49
|
+
"o1*",
|
|
50
|
+
"grok-4-0709",
|
|
51
|
+
"grok-code-fast-1",
|
|
52
|
+
"deepseek-r1-0528*",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
REASONING_EFFORT_PATTERNS: list[str] = [
|
|
56
|
+
"o1-2024-12-17",
|
|
57
|
+
"o1",
|
|
58
|
+
"o3",
|
|
59
|
+
"o3-2025-04-16",
|
|
60
|
+
"o3-mini-2025-01-31",
|
|
61
|
+
"o3-mini",
|
|
62
|
+
"o4-mini",
|
|
63
|
+
"o4-mini-2025-04-16",
|
|
64
|
+
"gemini-2.5-flash",
|
|
65
|
+
"gemini-2.5-pro",
|
|
66
|
+
"gpt-5*",
|
|
67
|
+
"deepseek-r1-0528*",
|
|
68
|
+
"claude-sonnet-4-5*",
|
|
69
|
+
"claude-haiku-4-5*",
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def normalize_model_name(model: str) -> str:
|
|
74
|
+
raw = (model or "").strip().lower()
|
|
75
|
+
if "/" in raw:
|
|
76
|
+
name = raw.split("/")[-1]
|
|
77
|
+
if ":" in name:
|
|
78
|
+
name = name.split(":", 1)[0]
|
|
79
|
+
else:
|
|
80
|
+
name = raw
|
|
81
|
+
if name.endswith("-gguf"):
|
|
82
|
+
name = name[: -len("-gguf")]
|
|
83
|
+
return name
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def model_matches(model: str, patterns: list[str]) -> bool:
|
|
87
|
+
raw = (model or "").strip().lower()
|
|
88
|
+
name = normalize_model_name(model)
|
|
89
|
+
for pat in patterns:
|
|
90
|
+
pat_l = pat.lower()
|
|
91
|
+
if "/" in pat_l:
|
|
92
|
+
if fnmatch(raw, pat_l):
|
|
93
|
+
return True
|
|
94
|
+
elif fnmatch(name, pat_l):
|
|
95
|
+
return True
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class StepRole(str, Enum):
|
|
100
|
+
AGENT = "agent"
|
|
101
|
+
USER = "user"
|
|
102
|
+
SYSTEM = "system"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass
|
|
106
|
+
class LLMResponse:
|
|
107
|
+
content: str
|
|
108
|
+
tool_invocations: list[dict[str, Any]] | None = None
|
|
109
|
+
scan_id: str | None = None
|
|
110
|
+
step_number: int = 1
|
|
111
|
+
role: StepRole = StepRole.AGENT
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass
|
|
115
|
+
class RequestStats:
|
|
116
|
+
input_tokens: int = 0
|
|
117
|
+
output_tokens: int = 0
|
|
118
|
+
cached_tokens: int = 0
|
|
119
|
+
cache_creation_tokens: int = 0
|
|
120
|
+
cost: float = 0.0
|
|
121
|
+
requests: int = 0
|
|
122
|
+
failed_requests: int = 0
|
|
123
|
+
|
|
124
|
+
def to_dict(self) -> dict[str, int | float]:
|
|
125
|
+
return {
|
|
126
|
+
"input_tokens": self.input_tokens,
|
|
127
|
+
"output_tokens": self.output_tokens,
|
|
128
|
+
"cached_tokens": self.cached_tokens,
|
|
129
|
+
"cache_creation_tokens": self.cache_creation_tokens,
|
|
130
|
+
"cost": round(self.cost, 4),
|
|
131
|
+
"requests": self.requests,
|
|
132
|
+
"failed_requests": self.failed_requests,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class LLM:
|
|
137
|
+
def __init__(
|
|
138
|
+
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
|
|
139
|
+
):
|
|
140
|
+
self.config = config
|
|
141
|
+
self.agent_name = agent_name
|
|
142
|
+
self.agent_id = agent_id
|
|
143
|
+
self._total_stats = RequestStats()
|
|
144
|
+
self._last_request_stats = RequestStats()
|
|
145
|
+
|
|
146
|
+
self.memory_compressor = MemoryCompressor(
|
|
147
|
+
model_name=self.config.model_name,
|
|
148
|
+
timeout=self.config.timeout,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if agent_name:
|
|
152
|
+
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
|
|
153
|
+
prompts_dir = Path(__file__).parent.parent / "prompts"
|
|
154
|
+
|
|
155
|
+
loader = FileSystemLoader([prompt_dir, prompts_dir])
|
|
156
|
+
self.jinja_env = Environment(
|
|
157
|
+
loader=loader,
|
|
158
|
+
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
prompt_module_content = load_prompt_modules(
|
|
163
|
+
self.config.prompt_modules or [], self.jinja_env
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def get_module(name: str) -> str:
|
|
167
|
+
return prompt_module_content.get(name, "")
|
|
168
|
+
|
|
169
|
+
self.jinja_env.globals["get_module"] = get_module
|
|
170
|
+
|
|
171
|
+
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
|
|
172
|
+
get_tools_prompt=get_tools_prompt,
|
|
173
|
+
loaded_module_names=list(prompt_module_content.keys()),
|
|
174
|
+
**prompt_module_content,
|
|
175
|
+
)
|
|
176
|
+
except (FileNotFoundError, OSError, ValueError) as e:
|
|
177
|
+
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
|
|
178
|
+
self.system_prompt = "You are a helpful AI assistant."
|
|
179
|
+
else:
|
|
180
|
+
self.system_prompt = "You are a helpful AI assistant."
|
|
181
|
+
|
|
182
|
+
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
|
|
183
|
+
if agent_name:
|
|
184
|
+
self.agent_name = agent_name
|
|
185
|
+
if agent_id:
|
|
186
|
+
self.agent_id = agent_id
|
|
187
|
+
|
|
188
|
+
def _build_identity_message(self) -> dict[str, Any] | None:
|
|
189
|
+
if not (self.agent_name and str(self.agent_name).strip()):
|
|
190
|
+
return None
|
|
191
|
+
identity_name = self.agent_name
|
|
192
|
+
identity_id = self.agent_id
|
|
193
|
+
content = (
|
|
194
|
+
"\n\n"
|
|
195
|
+
"<agent_identity>\n"
|
|
196
|
+
"<meta>Internal metadata: do not echo or reference; "
|
|
197
|
+
"not part of history or tool calls.</meta>\n"
|
|
198
|
+
"<note>You are now assuming the role of this agent. "
|
|
199
|
+
"Act strictly as this agent and maintain self-identity for this step. "
|
|
200
|
+
"Now go answer the next needed step!</note>\n"
|
|
201
|
+
f"<agent_name>{identity_name}</agent_name>\n"
|
|
202
|
+
f"<agent_id>{identity_id}</agent_id>\n"
|
|
203
|
+
"</agent_identity>\n\n"
|
|
204
|
+
)
|
|
205
|
+
return {"role": "user", "content": content}
|
|
206
|
+
|
|
207
|
+
def _add_cache_control_to_content(
|
|
208
|
+
self, content: str | list[dict[str, Any]]
|
|
209
|
+
) -> str | list[dict[str, Any]]:
|
|
210
|
+
if isinstance(content, str):
|
|
211
|
+
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
|
212
|
+
if isinstance(content, list) and content:
|
|
213
|
+
last_item = content[-1]
|
|
214
|
+
if isinstance(last_item, dict) and last_item.get("type") == "text":
|
|
215
|
+
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
|
|
216
|
+
return content
|
|
217
|
+
|
|
218
|
+
def _is_anthropic_model(self) -> bool:
|
|
219
|
+
if not self.config.model_name:
|
|
220
|
+
return False
|
|
221
|
+
model_lower = self.config.model_name.lower()
|
|
222
|
+
return any(provider in model_lower for provider in ["anthropic/", "claude"])
|
|
223
|
+
|
|
224
|
+
def _calculate_cache_interval(self, total_messages: int) -> int:
|
|
225
|
+
if total_messages <= 1:
|
|
226
|
+
return 10
|
|
227
|
+
|
|
228
|
+
max_cached_messages = 3
|
|
229
|
+
non_system_messages = total_messages - 1
|
|
230
|
+
|
|
231
|
+
interval = 10
|
|
232
|
+
while non_system_messages // interval > max_cached_messages:
|
|
233
|
+
interval += 10
|
|
234
|
+
|
|
235
|
+
return interval
|
|
236
|
+
|
|
237
|
+
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
238
|
+
if (
|
|
239
|
+
not self.config.enable_prompt_caching
|
|
240
|
+
or not supports_prompt_caching(self.config.model_name)
|
|
241
|
+
or not messages
|
|
242
|
+
):
|
|
243
|
+
return messages
|
|
244
|
+
|
|
245
|
+
if not self._is_anthropic_model():
|
|
246
|
+
return messages
|
|
247
|
+
|
|
248
|
+
cached_messages = list(messages)
|
|
249
|
+
|
|
250
|
+
if cached_messages and cached_messages[0].get("role") == "system":
|
|
251
|
+
system_message = cached_messages[0].copy()
|
|
252
|
+
system_message["content"] = self._add_cache_control_to_content(
|
|
253
|
+
system_message["content"]
|
|
254
|
+
)
|
|
255
|
+
cached_messages[0] = system_message
|
|
256
|
+
|
|
257
|
+
total_messages = len(cached_messages)
|
|
258
|
+
if total_messages > 1:
|
|
259
|
+
interval = self._calculate_cache_interval(total_messages)
|
|
260
|
+
|
|
261
|
+
cached_count = 0
|
|
262
|
+
for i in range(interval, total_messages, interval):
|
|
263
|
+
if cached_count >= 3:
|
|
264
|
+
break
|
|
265
|
+
|
|
266
|
+
if i < len(cached_messages):
|
|
267
|
+
message = cached_messages[i].copy()
|
|
268
|
+
message["content"] = self._add_cache_control_to_content(message["content"])
|
|
269
|
+
cached_messages[i] = message
|
|
270
|
+
cached_count += 1
|
|
271
|
+
|
|
272
|
+
return cached_messages
|
|
273
|
+
|
|
274
|
+
async def generate( # noqa: PLR0912, PLR0915
|
|
275
|
+
self,
|
|
276
|
+
conversation_history: list[dict[str, Any]],
|
|
277
|
+
scan_id: str | None = None,
|
|
278
|
+
step_number: int = 1,
|
|
279
|
+
) -> LLMResponse:
|
|
280
|
+
messages = [{"role": "system", "content": self.system_prompt}]
|
|
281
|
+
|
|
282
|
+
identity_message = self._build_identity_message()
|
|
283
|
+
if identity_message:
|
|
284
|
+
messages.append(identity_message)
|
|
285
|
+
|
|
286
|
+
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
|
|
287
|
+
|
|
288
|
+
conversation_history.clear()
|
|
289
|
+
conversation_history.extend(compressed_history)
|
|
290
|
+
messages.extend(compressed_history)
|
|
291
|
+
|
|
292
|
+
cached_messages = self._prepare_cached_messages(messages)
|
|
293
|
+
|
|
294
|
+
try:
|
|
295
|
+
response = await self._make_request(cached_messages)
|
|
296
|
+
self._update_usage_stats(response)
|
|
297
|
+
|
|
298
|
+
content = ""
|
|
299
|
+
if (
|
|
300
|
+
response.choices
|
|
301
|
+
and hasattr(response.choices[0], "message")
|
|
302
|
+
and response.choices[0].message
|
|
303
|
+
):
|
|
304
|
+
content = getattr(response.choices[0].message, "content", "") or ""
|
|
305
|
+
|
|
306
|
+
content = _truncate_to_first_function(content)
|
|
307
|
+
|
|
308
|
+
if "</function>" in content:
|
|
309
|
+
function_end_index = content.find("</function>") + len("</function>")
|
|
310
|
+
content = content[:function_end_index]
|
|
311
|
+
|
|
312
|
+
tool_invocations = parse_tool_invocations(content)
|
|
313
|
+
|
|
314
|
+
return LLMResponse(
|
|
315
|
+
scan_id=scan_id,
|
|
316
|
+
step_number=step_number,
|
|
317
|
+
role=StepRole.AGENT,
|
|
318
|
+
content=content,
|
|
319
|
+
tool_invocations=tool_invocations if tool_invocations else None,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
except litellm.RateLimitError as e:
|
|
323
|
+
raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e
|
|
324
|
+
except litellm.AuthenticationError as e:
|
|
325
|
+
raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e
|
|
326
|
+
except litellm.NotFoundError as e:
|
|
327
|
+
raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e
|
|
328
|
+
except litellm.ContextWindowExceededError as e:
|
|
329
|
+
raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e
|
|
330
|
+
except litellm.ContentPolicyViolationError as e:
|
|
331
|
+
raise LLMRequestFailedError(
|
|
332
|
+
"LLM request failed: Content policy violation", str(e)
|
|
333
|
+
) from e
|
|
334
|
+
except litellm.ServiceUnavailableError as e:
|
|
335
|
+
raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e
|
|
336
|
+
except litellm.Timeout as e:
|
|
337
|
+
raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e
|
|
338
|
+
except litellm.UnprocessableEntityError as e:
|
|
339
|
+
raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e
|
|
340
|
+
except litellm.InternalServerError as e:
|
|
341
|
+
raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e
|
|
342
|
+
except litellm.APIConnectionError as e:
|
|
343
|
+
raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e
|
|
344
|
+
except litellm.UnsupportedParamsError as e:
|
|
345
|
+
raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e
|
|
346
|
+
except litellm.BudgetExceededError as e:
|
|
347
|
+
raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e
|
|
348
|
+
except litellm.APIResponseValidationError as e:
|
|
349
|
+
raise LLMRequestFailedError(
|
|
350
|
+
"LLM request failed: Response validation error", str(e)
|
|
351
|
+
) from e
|
|
352
|
+
except litellm.JSONSchemaValidationError as e:
|
|
353
|
+
raise LLMRequestFailedError(
|
|
354
|
+
"LLM request failed: JSON schema validation error", str(e)
|
|
355
|
+
) from e
|
|
356
|
+
except litellm.InvalidRequestError as e:
|
|
357
|
+
raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e
|
|
358
|
+
except litellm.BadRequestError as e:
|
|
359
|
+
raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e
|
|
360
|
+
except litellm.APIError as e:
|
|
361
|
+
raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e
|
|
362
|
+
except litellm.OpenAIError as e:
|
|
363
|
+
raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e
|
|
364
|
+
except Exception as e:
|
|
365
|
+
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
|
366
|
+
|
|
367
|
+
@property
|
|
368
|
+
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
|
369
|
+
return {
|
|
370
|
+
"total": self._total_stats.to_dict(),
|
|
371
|
+
"last_request": self._last_request_stats.to_dict(),
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
def get_cache_config(self) -> dict[str, bool]:
|
|
375
|
+
return {
|
|
376
|
+
"enabled": self.config.enable_prompt_caching,
|
|
377
|
+
"supported": supports_prompt_caching(self.config.model_name),
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
def _should_include_stop_param(self) -> bool:
|
|
381
|
+
if not self.config.model_name:
|
|
382
|
+
return True
|
|
383
|
+
|
|
384
|
+
return not model_matches(self.config.model_name, SUPPORTS_STOP_WORDS_FALSE_PATTERNS)
|
|
385
|
+
|
|
386
|
+
def _should_include_reasoning_effort(self) -> bool:
|
|
387
|
+
if not self.config.model_name:
|
|
388
|
+
return False
|
|
389
|
+
|
|
390
|
+
return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
|
|
391
|
+
|
|
392
|
+
def _model_supports_vision(self) -> bool:
|
|
393
|
+
if not self.config.model_name:
|
|
394
|
+
return False
|
|
395
|
+
try:
|
|
396
|
+
return supports_vision(model=self.config.model_name)
|
|
397
|
+
except Exception: # noqa: BLE001
|
|
398
|
+
return False
|
|
399
|
+
|
|
400
|
+
def _filter_images_from_messages(
|
|
401
|
+
self, messages: list[dict[str, Any]]
|
|
402
|
+
) -> list[dict[str, Any]]:
|
|
403
|
+
filtered_messages = []
|
|
404
|
+
for msg in messages:
|
|
405
|
+
content = msg.get("content")
|
|
406
|
+
if isinstance(content, list):
|
|
407
|
+
filtered_content = []
|
|
408
|
+
for item in content:
|
|
409
|
+
if isinstance(item, dict):
|
|
410
|
+
if item.get("type") == "image_url":
|
|
411
|
+
filtered_content.append({
|
|
412
|
+
"type": "text",
|
|
413
|
+
"text": "[Screenshot removed - model does not support vision. "
|
|
414
|
+
"Use view_source or execute_js to interact with the page instead.]",
|
|
415
|
+
})
|
|
416
|
+
else:
|
|
417
|
+
filtered_content.append(item)
|
|
418
|
+
else:
|
|
419
|
+
filtered_content.append(item)
|
|
420
|
+
if filtered_content:
|
|
421
|
+
text_parts = [
|
|
422
|
+
item.get("text", "") if isinstance(item, dict) else str(item)
|
|
423
|
+
for item in filtered_content
|
|
424
|
+
]
|
|
425
|
+
if all(isinstance(item, dict) and item.get("type") == "text" for item in filtered_content):
|
|
426
|
+
msg = {**msg, "content": "\n".join(text_parts)}
|
|
427
|
+
else:
|
|
428
|
+
msg = {**msg, "content": filtered_content}
|
|
429
|
+
else:
|
|
430
|
+
msg = {**msg, "content": ""}
|
|
431
|
+
filtered_messages.append(msg)
|
|
432
|
+
return filtered_messages
|
|
433
|
+
|
|
434
|
+
async def _make_request(
|
|
435
|
+
self,
|
|
436
|
+
messages: list[dict[str, Any]],
|
|
437
|
+
) -> ModelResponse:
|
|
438
|
+
if not self._model_supports_vision():
|
|
439
|
+
messages = self._filter_images_from_messages(messages)
|
|
440
|
+
|
|
441
|
+
completion_args: dict[str, Any] = {
|
|
442
|
+
"model": self.config.model_name,
|
|
443
|
+
"messages": messages,
|
|
444
|
+
"timeout": self.config.timeout,
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
if _LLM_API_KEY:
|
|
448
|
+
completion_args["api_key"] = _LLM_API_KEY
|
|
449
|
+
if _LLM_API_BASE:
|
|
450
|
+
completion_args["api_base"] = _LLM_API_BASE
|
|
451
|
+
|
|
452
|
+
if self._should_include_stop_param():
|
|
453
|
+
completion_args["stop"] = ["</function>"]
|
|
454
|
+
|
|
455
|
+
if self._should_include_reasoning_effort():
|
|
456
|
+
completion_args["reasoning_effort"] = "high"
|
|
457
|
+
|
|
458
|
+
queue = get_global_queue()
|
|
459
|
+
response = await queue.make_request(completion_args)
|
|
460
|
+
|
|
461
|
+
self._total_stats.requests += 1
|
|
462
|
+
self._last_request_stats = RequestStats(requests=1)
|
|
463
|
+
|
|
464
|
+
return response
|
|
465
|
+
|
|
466
|
+
def _update_usage_stats(self, response: ModelResponse) -> None:
|
|
467
|
+
try:
|
|
468
|
+
if hasattr(response, "usage") and response.usage:
|
|
469
|
+
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
|
470
|
+
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
|
471
|
+
|
|
472
|
+
cached_tokens = 0
|
|
473
|
+
cache_creation_tokens = 0
|
|
474
|
+
|
|
475
|
+
if hasattr(response.usage, "prompt_tokens_details"):
|
|
476
|
+
prompt_details = response.usage.prompt_tokens_details
|
|
477
|
+
if hasattr(prompt_details, "cached_tokens"):
|
|
478
|
+
cached_tokens = prompt_details.cached_tokens or 0
|
|
479
|
+
|
|
480
|
+
if hasattr(response.usage, "cache_creation_input_tokens"):
|
|
481
|
+
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
|
|
482
|
+
|
|
483
|
+
else:
|
|
484
|
+
input_tokens = 0
|
|
485
|
+
output_tokens = 0
|
|
486
|
+
cached_tokens = 0
|
|
487
|
+
cache_creation_tokens = 0
|
|
488
|
+
|
|
489
|
+
try:
|
|
490
|
+
cost = completion_cost(response) or 0.0
|
|
491
|
+
except Exception as e: # noqa: BLE001
|
|
492
|
+
logger.warning(f"Failed to calculate cost: {e}")
|
|
493
|
+
cost = 0.0
|
|
494
|
+
|
|
495
|
+
self._total_stats.input_tokens += input_tokens
|
|
496
|
+
self._total_stats.output_tokens += output_tokens
|
|
497
|
+
self._total_stats.cached_tokens += cached_tokens
|
|
498
|
+
self._total_stats.cache_creation_tokens += cache_creation_tokens
|
|
499
|
+
self._total_stats.cost += cost
|
|
500
|
+
|
|
501
|
+
self._last_request_stats.input_tokens = input_tokens
|
|
502
|
+
self._last_request_stats.output_tokens = output_tokens
|
|
503
|
+
self._last_request_stats.cached_tokens = cached_tokens
|
|
504
|
+
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
|
|
505
|
+
self._last_request_stats.cost = cost
|
|
506
|
+
|
|
507
|
+
if cached_tokens > 0:
|
|
508
|
+
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
|
|
509
|
+
if cache_creation_tokens > 0:
|
|
510
|
+
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
|
|
511
|
+
|
|
512
|
+
logger.info(f"Usage stats: {self.usage_stats}")
|
|
513
|
+
except Exception as e: # noqa: BLE001
|
|
514
|
+
logger.warning(f"Failed to update usage stats: {e}")
|