aiptx 2.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. aipt_v2/__init__.py +110 -0
  2. aipt_v2/__main__.py +24 -0
  3. aipt_v2/agents/AIPTxAgent/__init__.py +10 -0
  4. aipt_v2/agents/AIPTxAgent/aiptx_agent.py +211 -0
  5. aipt_v2/agents/__init__.py +46 -0
  6. aipt_v2/agents/base.py +520 -0
  7. aipt_v2/agents/exploit_agent.py +688 -0
  8. aipt_v2/agents/ptt.py +406 -0
  9. aipt_v2/agents/state.py +168 -0
  10. aipt_v2/app.py +957 -0
  11. aipt_v2/browser/__init__.py +31 -0
  12. aipt_v2/browser/automation.py +458 -0
  13. aipt_v2/browser/crawler.py +453 -0
  14. aipt_v2/cli.py +2933 -0
  15. aipt_v2/compliance/__init__.py +71 -0
  16. aipt_v2/compliance/compliance_report.py +449 -0
  17. aipt_v2/compliance/framework_mapper.py +424 -0
  18. aipt_v2/compliance/nist_mapping.py +345 -0
  19. aipt_v2/compliance/owasp_mapping.py +330 -0
  20. aipt_v2/compliance/pci_mapping.py +297 -0
  21. aipt_v2/config.py +341 -0
  22. aipt_v2/core/__init__.py +43 -0
  23. aipt_v2/core/agent.py +630 -0
  24. aipt_v2/core/llm.py +395 -0
  25. aipt_v2/core/memory.py +305 -0
  26. aipt_v2/core/ptt.py +329 -0
  27. aipt_v2/database/__init__.py +14 -0
  28. aipt_v2/database/models.py +232 -0
  29. aipt_v2/database/repository.py +384 -0
  30. aipt_v2/docker/__init__.py +23 -0
  31. aipt_v2/docker/builder.py +260 -0
  32. aipt_v2/docker/manager.py +222 -0
  33. aipt_v2/docker/sandbox.py +371 -0
  34. aipt_v2/evasion/__init__.py +58 -0
  35. aipt_v2/evasion/request_obfuscator.py +272 -0
  36. aipt_v2/evasion/tls_fingerprint.py +285 -0
  37. aipt_v2/evasion/ua_rotator.py +301 -0
  38. aipt_v2/evasion/waf_bypass.py +439 -0
  39. aipt_v2/execution/__init__.py +23 -0
  40. aipt_v2/execution/executor.py +302 -0
  41. aipt_v2/execution/parser.py +544 -0
  42. aipt_v2/execution/terminal.py +337 -0
  43. aipt_v2/health.py +437 -0
  44. aipt_v2/intelligence/__init__.py +194 -0
  45. aipt_v2/intelligence/adaptation.py +474 -0
  46. aipt_v2/intelligence/auth.py +520 -0
  47. aipt_v2/intelligence/chaining.py +775 -0
  48. aipt_v2/intelligence/correlation.py +536 -0
  49. aipt_v2/intelligence/cve_aipt.py +334 -0
  50. aipt_v2/intelligence/cve_info.py +1111 -0
  51. aipt_v2/intelligence/knowledge_graph.py +590 -0
  52. aipt_v2/intelligence/learning.py +626 -0
  53. aipt_v2/intelligence/llm_analyzer.py +502 -0
  54. aipt_v2/intelligence/llm_tool_selector.py +518 -0
  55. aipt_v2/intelligence/payload_generator.py +562 -0
  56. aipt_v2/intelligence/rag.py +239 -0
  57. aipt_v2/intelligence/scope.py +442 -0
  58. aipt_v2/intelligence/searchers/__init__.py +5 -0
  59. aipt_v2/intelligence/searchers/exploitdb_searcher.py +523 -0
  60. aipt_v2/intelligence/searchers/github_searcher.py +467 -0
  61. aipt_v2/intelligence/searchers/google_searcher.py +281 -0
  62. aipt_v2/intelligence/tools.json +443 -0
  63. aipt_v2/intelligence/triage.py +670 -0
  64. aipt_v2/interactive_shell.py +559 -0
  65. aipt_v2/interface/__init__.py +5 -0
  66. aipt_v2/interface/cli.py +230 -0
  67. aipt_v2/interface/main.py +501 -0
  68. aipt_v2/interface/tui.py +1276 -0
  69. aipt_v2/interface/utils.py +583 -0
  70. aipt_v2/llm/__init__.py +39 -0
  71. aipt_v2/llm/config.py +26 -0
  72. aipt_v2/llm/llm.py +514 -0
  73. aipt_v2/llm/memory.py +214 -0
  74. aipt_v2/llm/request_queue.py +89 -0
  75. aipt_v2/llm/utils.py +89 -0
  76. aipt_v2/local_tool_installer.py +1467 -0
  77. aipt_v2/models/__init__.py +15 -0
  78. aipt_v2/models/findings.py +295 -0
  79. aipt_v2/models/phase_result.py +224 -0
  80. aipt_v2/models/scan_config.py +207 -0
  81. aipt_v2/monitoring/grafana/dashboards/aipt-dashboard.json +355 -0
  82. aipt_v2/monitoring/grafana/dashboards/default.yml +17 -0
  83. aipt_v2/monitoring/grafana/datasources/prometheus.yml +17 -0
  84. aipt_v2/monitoring/prometheus.yml +60 -0
  85. aipt_v2/orchestration/__init__.py +52 -0
  86. aipt_v2/orchestration/pipeline.py +398 -0
  87. aipt_v2/orchestration/progress.py +300 -0
  88. aipt_v2/orchestration/scheduler.py +296 -0
  89. aipt_v2/orchestrator.py +2427 -0
  90. aipt_v2/payloads/__init__.py +27 -0
  91. aipt_v2/payloads/cmdi.py +150 -0
  92. aipt_v2/payloads/sqli.py +263 -0
  93. aipt_v2/payloads/ssrf.py +204 -0
  94. aipt_v2/payloads/templates.py +222 -0
  95. aipt_v2/payloads/traversal.py +166 -0
  96. aipt_v2/payloads/xss.py +204 -0
  97. aipt_v2/prompts/__init__.py +60 -0
  98. aipt_v2/proxy/__init__.py +29 -0
  99. aipt_v2/proxy/history.py +352 -0
  100. aipt_v2/proxy/interceptor.py +452 -0
  101. aipt_v2/recon/__init__.py +44 -0
  102. aipt_v2/recon/dns.py +241 -0
  103. aipt_v2/recon/osint.py +367 -0
  104. aipt_v2/recon/subdomain.py +372 -0
  105. aipt_v2/recon/tech_detect.py +311 -0
  106. aipt_v2/reports/__init__.py +17 -0
  107. aipt_v2/reports/generator.py +313 -0
  108. aipt_v2/reports/html_report.py +378 -0
  109. aipt_v2/runtime/__init__.py +53 -0
  110. aipt_v2/runtime/base.py +30 -0
  111. aipt_v2/runtime/docker.py +401 -0
  112. aipt_v2/runtime/local.py +346 -0
  113. aipt_v2/runtime/tool_server.py +205 -0
  114. aipt_v2/runtime/vps.py +830 -0
  115. aipt_v2/scanners/__init__.py +28 -0
  116. aipt_v2/scanners/base.py +273 -0
  117. aipt_v2/scanners/nikto.py +244 -0
  118. aipt_v2/scanners/nmap.py +402 -0
  119. aipt_v2/scanners/nuclei.py +273 -0
  120. aipt_v2/scanners/web.py +454 -0
  121. aipt_v2/scripts/security_audit.py +366 -0
  122. aipt_v2/setup_wizard.py +941 -0
  123. aipt_v2/skills/__init__.py +80 -0
  124. aipt_v2/skills/agents/__init__.py +14 -0
  125. aipt_v2/skills/agents/api_tester.py +706 -0
  126. aipt_v2/skills/agents/base.py +477 -0
  127. aipt_v2/skills/agents/code_review.py +459 -0
  128. aipt_v2/skills/agents/security_agent.py +336 -0
  129. aipt_v2/skills/agents/web_pentest.py +818 -0
  130. aipt_v2/skills/prompts/__init__.py +647 -0
  131. aipt_v2/system_detector.py +539 -0
  132. aipt_v2/telemetry/__init__.py +7 -0
  133. aipt_v2/telemetry/tracer.py +347 -0
  134. aipt_v2/terminal/__init__.py +28 -0
  135. aipt_v2/terminal/executor.py +400 -0
  136. aipt_v2/terminal/sandbox.py +350 -0
  137. aipt_v2/tools/__init__.py +44 -0
  138. aipt_v2/tools/active_directory/__init__.py +78 -0
  139. aipt_v2/tools/active_directory/ad_config.py +238 -0
  140. aipt_v2/tools/active_directory/bloodhound_wrapper.py +447 -0
  141. aipt_v2/tools/active_directory/kerberos_attacks.py +430 -0
  142. aipt_v2/tools/active_directory/ldap_enum.py +533 -0
  143. aipt_v2/tools/active_directory/smb_attacks.py +505 -0
  144. aipt_v2/tools/agents_graph/__init__.py +19 -0
  145. aipt_v2/tools/agents_graph/agents_graph_actions.py +69 -0
  146. aipt_v2/tools/api_security/__init__.py +76 -0
  147. aipt_v2/tools/api_security/api_discovery.py +608 -0
  148. aipt_v2/tools/api_security/graphql_scanner.py +622 -0
  149. aipt_v2/tools/api_security/jwt_analyzer.py +577 -0
  150. aipt_v2/tools/api_security/openapi_fuzzer.py +761 -0
  151. aipt_v2/tools/browser/__init__.py +5 -0
  152. aipt_v2/tools/browser/browser_actions.py +238 -0
  153. aipt_v2/tools/browser/browser_instance.py +535 -0
  154. aipt_v2/tools/browser/tab_manager.py +344 -0
  155. aipt_v2/tools/cloud/__init__.py +70 -0
  156. aipt_v2/tools/cloud/cloud_config.py +273 -0
  157. aipt_v2/tools/cloud/cloud_scanner.py +639 -0
  158. aipt_v2/tools/cloud/prowler_tool.py +571 -0
  159. aipt_v2/tools/cloud/scoutsuite_tool.py +359 -0
  160. aipt_v2/tools/executor.py +307 -0
  161. aipt_v2/tools/parser.py +408 -0
  162. aipt_v2/tools/proxy/__init__.py +5 -0
  163. aipt_v2/tools/proxy/proxy_actions.py +103 -0
  164. aipt_v2/tools/proxy/proxy_manager.py +789 -0
  165. aipt_v2/tools/registry.py +196 -0
  166. aipt_v2/tools/scanners/__init__.py +343 -0
  167. aipt_v2/tools/scanners/acunetix_tool.py +712 -0
  168. aipt_v2/tools/scanners/burp_tool.py +631 -0
  169. aipt_v2/tools/scanners/config.py +156 -0
  170. aipt_v2/tools/scanners/nessus_tool.py +588 -0
  171. aipt_v2/tools/scanners/zap_tool.py +612 -0
  172. aipt_v2/tools/terminal/__init__.py +5 -0
  173. aipt_v2/tools/terminal/terminal_actions.py +37 -0
  174. aipt_v2/tools/terminal/terminal_manager.py +153 -0
  175. aipt_v2/tools/terminal/terminal_session.py +449 -0
  176. aipt_v2/tools/tool_processing.py +108 -0
  177. aipt_v2/utils/__init__.py +17 -0
  178. aipt_v2/utils/logging.py +202 -0
  179. aipt_v2/utils/model_manager.py +187 -0
  180. aipt_v2/utils/searchers/__init__.py +269 -0
  181. aipt_v2/verify_install.py +793 -0
  182. aiptx-2.0.7.dist-info/METADATA +345 -0
  183. aiptx-2.0.7.dist-info/RECORD +187 -0
  184. aiptx-2.0.7.dist-info/WHEEL +5 -0
  185. aiptx-2.0.7.dist-info/entry_points.txt +7 -0
  186. aiptx-2.0.7.dist-info/licenses/LICENSE +21 -0
  187. aiptx-2.0.7.dist-info/top_level.txt +1 -0
aipt_v2/llm/llm.py ADDED
@@ -0,0 +1,514 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ from fnmatch import fnmatch
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import litellm
12
+ from jinja2 import (
13
+ Environment,
14
+ FileSystemLoader,
15
+ select_autoescape,
16
+ )
17
+ from litellm import ModelResponse, completion_cost
18
+ from litellm.utils import supports_prompt_caching, supports_vision
19
+
20
+ from aipt_v2.llm.config import LLMConfig
21
+ from aipt_v2.llm.memory import MemoryCompressor
22
+ from aipt_v2.llm.request_queue import get_global_queue
23
+ from aipt_v2.llm.utils import _truncate_to_first_function, parse_tool_invocations
24
+ from aipt_v2.prompts import load_prompt_modules, get_tools_prompt
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ litellm.drop_params = True
30
+ litellm.modify_params = True
31
+
32
+ _LLM_API_KEY = os.getenv("LLM_API_KEY")
33
+ _LLM_API_BASE = (
34
+ os.getenv("LLM_API_BASE")
35
+ or os.getenv("OPENAI_API_BASE")
36
+ or os.getenv("LITELLM_BASE_URL")
37
+ or os.getenv("OLLAMA_API_BASE")
38
+ )
39
+
40
+
41
+ class LLMRequestFailedError(Exception):
42
+ def __init__(self, message: str, details: str | None = None):
43
+ super().__init__(message)
44
+ self.message = message
45
+ self.details = details
46
+
47
+
48
+ SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
49
+ "o1*",
50
+ "grok-4-0709",
51
+ "grok-code-fast-1",
52
+ "deepseek-r1-0528*",
53
+ ]
54
+
55
+ REASONING_EFFORT_PATTERNS: list[str] = [
56
+ "o1-2024-12-17",
57
+ "o1",
58
+ "o3",
59
+ "o3-2025-04-16",
60
+ "o3-mini-2025-01-31",
61
+ "o3-mini",
62
+ "o4-mini",
63
+ "o4-mini-2025-04-16",
64
+ "gemini-2.5-flash",
65
+ "gemini-2.5-pro",
66
+ "gpt-5*",
67
+ "deepseek-r1-0528*",
68
+ "claude-sonnet-4-5*",
69
+ "claude-haiku-4-5*",
70
+ ]
71
+
72
+
73
+ def normalize_model_name(model: str) -> str:
74
+ raw = (model or "").strip().lower()
75
+ if "/" in raw:
76
+ name = raw.split("/")[-1]
77
+ if ":" in name:
78
+ name = name.split(":", 1)[0]
79
+ else:
80
+ name = raw
81
+ if name.endswith("-gguf"):
82
+ name = name[: -len("-gguf")]
83
+ return name
84
+
85
+
86
+ def model_matches(model: str, patterns: list[str]) -> bool:
87
+ raw = (model or "").strip().lower()
88
+ name = normalize_model_name(model)
89
+ for pat in patterns:
90
+ pat_l = pat.lower()
91
+ if "/" in pat_l:
92
+ if fnmatch(raw, pat_l):
93
+ return True
94
+ elif fnmatch(name, pat_l):
95
+ return True
96
+ return False
97
+
98
+
99
+ class StepRole(str, Enum):
100
+ AGENT = "agent"
101
+ USER = "user"
102
+ SYSTEM = "system"
103
+
104
+
105
+ @dataclass
106
+ class LLMResponse:
107
+ content: str
108
+ tool_invocations: list[dict[str, Any]] | None = None
109
+ scan_id: str | None = None
110
+ step_number: int = 1
111
+ role: StepRole = StepRole.AGENT
112
+
113
+
114
+ @dataclass
115
+ class RequestStats:
116
+ input_tokens: int = 0
117
+ output_tokens: int = 0
118
+ cached_tokens: int = 0
119
+ cache_creation_tokens: int = 0
120
+ cost: float = 0.0
121
+ requests: int = 0
122
+ failed_requests: int = 0
123
+
124
+ def to_dict(self) -> dict[str, int | float]:
125
+ return {
126
+ "input_tokens": self.input_tokens,
127
+ "output_tokens": self.output_tokens,
128
+ "cached_tokens": self.cached_tokens,
129
+ "cache_creation_tokens": self.cache_creation_tokens,
130
+ "cost": round(self.cost, 4),
131
+ "requests": self.requests,
132
+ "failed_requests": self.failed_requests,
133
+ }
134
+
135
+
136
+ class LLM:
137
+ def __init__(
138
+ self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
139
+ ):
140
+ self.config = config
141
+ self.agent_name = agent_name
142
+ self.agent_id = agent_id
143
+ self._total_stats = RequestStats()
144
+ self._last_request_stats = RequestStats()
145
+
146
+ self.memory_compressor = MemoryCompressor(
147
+ model_name=self.config.model_name,
148
+ timeout=self.config.timeout,
149
+ )
150
+
151
+ if agent_name:
152
+ prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
153
+ prompts_dir = Path(__file__).parent.parent / "prompts"
154
+
155
+ loader = FileSystemLoader([prompt_dir, prompts_dir])
156
+ self.jinja_env = Environment(
157
+ loader=loader,
158
+ autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
159
+ )
160
+
161
+ try:
162
+ prompt_module_content = load_prompt_modules(
163
+ self.config.prompt_modules or [], self.jinja_env
164
+ )
165
+
166
+ def get_module(name: str) -> str:
167
+ return prompt_module_content.get(name, "")
168
+
169
+ self.jinja_env.globals["get_module"] = get_module
170
+
171
+ self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
172
+ get_tools_prompt=get_tools_prompt,
173
+ loaded_module_names=list(prompt_module_content.keys()),
174
+ **prompt_module_content,
175
+ )
176
+ except (FileNotFoundError, OSError, ValueError) as e:
177
+ logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
178
+ self.system_prompt = "You are a helpful AI assistant."
179
+ else:
180
+ self.system_prompt = "You are a helpful AI assistant."
181
+
182
+ def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
183
+ if agent_name:
184
+ self.agent_name = agent_name
185
+ if agent_id:
186
+ self.agent_id = agent_id
187
+
188
+ def _build_identity_message(self) -> dict[str, Any] | None:
189
+ if not (self.agent_name and str(self.agent_name).strip()):
190
+ return None
191
+ identity_name = self.agent_name
192
+ identity_id = self.agent_id
193
+ content = (
194
+ "\n\n"
195
+ "<agent_identity>\n"
196
+ "<meta>Internal metadata: do not echo or reference; "
197
+ "not part of history or tool calls.</meta>\n"
198
+ "<note>You are now assuming the role of this agent. "
199
+ "Act strictly as this agent and maintain self-identity for this step. "
200
+ "Now go answer the next needed step!</note>\n"
201
+ f"<agent_name>{identity_name}</agent_name>\n"
202
+ f"<agent_id>{identity_id}</agent_id>\n"
203
+ "</agent_identity>\n\n"
204
+ )
205
+ return {"role": "user", "content": content}
206
+
207
+ def _add_cache_control_to_content(
208
+ self, content: str | list[dict[str, Any]]
209
+ ) -> str | list[dict[str, Any]]:
210
+ if isinstance(content, str):
211
+ return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
212
+ if isinstance(content, list) and content:
213
+ last_item = content[-1]
214
+ if isinstance(last_item, dict) and last_item.get("type") == "text":
215
+ return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
216
+ return content
217
+
218
+ def _is_anthropic_model(self) -> bool:
219
+ if not self.config.model_name:
220
+ return False
221
+ model_lower = self.config.model_name.lower()
222
+ return any(provider in model_lower for provider in ["anthropic/", "claude"])
223
+
224
+ def _calculate_cache_interval(self, total_messages: int) -> int:
225
+ if total_messages <= 1:
226
+ return 10
227
+
228
+ max_cached_messages = 3
229
+ non_system_messages = total_messages - 1
230
+
231
+ interval = 10
232
+ while non_system_messages // interval > max_cached_messages:
233
+ interval += 10
234
+
235
+ return interval
236
+
237
+ def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
238
+ if (
239
+ not self.config.enable_prompt_caching
240
+ or not supports_prompt_caching(self.config.model_name)
241
+ or not messages
242
+ ):
243
+ return messages
244
+
245
+ if not self._is_anthropic_model():
246
+ return messages
247
+
248
+ cached_messages = list(messages)
249
+
250
+ if cached_messages and cached_messages[0].get("role") == "system":
251
+ system_message = cached_messages[0].copy()
252
+ system_message["content"] = self._add_cache_control_to_content(
253
+ system_message["content"]
254
+ )
255
+ cached_messages[0] = system_message
256
+
257
+ total_messages = len(cached_messages)
258
+ if total_messages > 1:
259
+ interval = self._calculate_cache_interval(total_messages)
260
+
261
+ cached_count = 0
262
+ for i in range(interval, total_messages, interval):
263
+ if cached_count >= 3:
264
+ break
265
+
266
+ if i < len(cached_messages):
267
+ message = cached_messages[i].copy()
268
+ message["content"] = self._add_cache_control_to_content(message["content"])
269
+ cached_messages[i] = message
270
+ cached_count += 1
271
+
272
+ return cached_messages
273
+
274
+ async def generate( # noqa: PLR0912, PLR0915
275
+ self,
276
+ conversation_history: list[dict[str, Any]],
277
+ scan_id: str | None = None,
278
+ step_number: int = 1,
279
+ ) -> LLMResponse:
280
+ messages = [{"role": "system", "content": self.system_prompt}]
281
+
282
+ identity_message = self._build_identity_message()
283
+ if identity_message:
284
+ messages.append(identity_message)
285
+
286
+ compressed_history = list(self.memory_compressor.compress_history(conversation_history))
287
+
288
+ conversation_history.clear()
289
+ conversation_history.extend(compressed_history)
290
+ messages.extend(compressed_history)
291
+
292
+ cached_messages = self._prepare_cached_messages(messages)
293
+
294
+ try:
295
+ response = await self._make_request(cached_messages)
296
+ self._update_usage_stats(response)
297
+
298
+ content = ""
299
+ if (
300
+ response.choices
301
+ and hasattr(response.choices[0], "message")
302
+ and response.choices[0].message
303
+ ):
304
+ content = getattr(response.choices[0].message, "content", "") or ""
305
+
306
+ content = _truncate_to_first_function(content)
307
+
308
+ if "</function>" in content:
309
+ function_end_index = content.find("</function>") + len("</function>")
310
+ content = content[:function_end_index]
311
+
312
+ tool_invocations = parse_tool_invocations(content)
313
+
314
+ return LLMResponse(
315
+ scan_id=scan_id,
316
+ step_number=step_number,
317
+ role=StepRole.AGENT,
318
+ content=content,
319
+ tool_invocations=tool_invocations if tool_invocations else None,
320
+ )
321
+
322
+ except litellm.RateLimitError as e:
323
+ raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e
324
+ except litellm.AuthenticationError as e:
325
+ raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e
326
+ except litellm.NotFoundError as e:
327
+ raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e
328
+ except litellm.ContextWindowExceededError as e:
329
+ raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e
330
+ except litellm.ContentPolicyViolationError as e:
331
+ raise LLMRequestFailedError(
332
+ "LLM request failed: Content policy violation", str(e)
333
+ ) from e
334
+ except litellm.ServiceUnavailableError as e:
335
+ raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e
336
+ except litellm.Timeout as e:
337
+ raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e
338
+ except litellm.UnprocessableEntityError as e:
339
+ raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e
340
+ except litellm.InternalServerError as e:
341
+ raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e
342
+ except litellm.APIConnectionError as e:
343
+ raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e
344
+ except litellm.UnsupportedParamsError as e:
345
+ raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e
346
+ except litellm.BudgetExceededError as e:
347
+ raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e
348
+ except litellm.APIResponseValidationError as e:
349
+ raise LLMRequestFailedError(
350
+ "LLM request failed: Response validation error", str(e)
351
+ ) from e
352
+ except litellm.JSONSchemaValidationError as e:
353
+ raise LLMRequestFailedError(
354
+ "LLM request failed: JSON schema validation error", str(e)
355
+ ) from e
356
+ except litellm.InvalidRequestError as e:
357
+ raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e
358
+ except litellm.BadRequestError as e:
359
+ raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e
360
+ except litellm.APIError as e:
361
+ raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e
362
+ except litellm.OpenAIError as e:
363
+ raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e
364
+ except Exception as e:
365
+ raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
366
+
367
+ @property
368
+ def usage_stats(self) -> dict[str, dict[str, int | float]]:
369
+ return {
370
+ "total": self._total_stats.to_dict(),
371
+ "last_request": self._last_request_stats.to_dict(),
372
+ }
373
+
374
+ def get_cache_config(self) -> dict[str, bool]:
375
+ return {
376
+ "enabled": self.config.enable_prompt_caching,
377
+ "supported": supports_prompt_caching(self.config.model_name),
378
+ }
379
+
380
+ def _should_include_stop_param(self) -> bool:
381
+ if not self.config.model_name:
382
+ return True
383
+
384
+ return not model_matches(self.config.model_name, SUPPORTS_STOP_WORDS_FALSE_PATTERNS)
385
+
386
+ def _should_include_reasoning_effort(self) -> bool:
387
+ if not self.config.model_name:
388
+ return False
389
+
390
+ return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
391
+
392
+ def _model_supports_vision(self) -> bool:
393
+ if not self.config.model_name:
394
+ return False
395
+ try:
396
+ return supports_vision(model=self.config.model_name)
397
+ except Exception: # noqa: BLE001
398
+ return False
399
+
400
+ def _filter_images_from_messages(
401
+ self, messages: list[dict[str, Any]]
402
+ ) -> list[dict[str, Any]]:
403
+ filtered_messages = []
404
+ for msg in messages:
405
+ content = msg.get("content")
406
+ if isinstance(content, list):
407
+ filtered_content = []
408
+ for item in content:
409
+ if isinstance(item, dict):
410
+ if item.get("type") == "image_url":
411
+ filtered_content.append({
412
+ "type": "text",
413
+ "text": "[Screenshot removed - model does not support vision. "
414
+ "Use view_source or execute_js to interact with the page instead.]",
415
+ })
416
+ else:
417
+ filtered_content.append(item)
418
+ else:
419
+ filtered_content.append(item)
420
+ if filtered_content:
421
+ text_parts = [
422
+ item.get("text", "") if isinstance(item, dict) else str(item)
423
+ for item in filtered_content
424
+ ]
425
+ if all(isinstance(item, dict) and item.get("type") == "text" for item in filtered_content):
426
+ msg = {**msg, "content": "\n".join(text_parts)}
427
+ else:
428
+ msg = {**msg, "content": filtered_content}
429
+ else:
430
+ msg = {**msg, "content": ""}
431
+ filtered_messages.append(msg)
432
+ return filtered_messages
433
+
434
+ async def _make_request(
435
+ self,
436
+ messages: list[dict[str, Any]],
437
+ ) -> ModelResponse:
438
+ if not self._model_supports_vision():
439
+ messages = self._filter_images_from_messages(messages)
440
+
441
+ completion_args: dict[str, Any] = {
442
+ "model": self.config.model_name,
443
+ "messages": messages,
444
+ "timeout": self.config.timeout,
445
+ }
446
+
447
+ if _LLM_API_KEY:
448
+ completion_args["api_key"] = _LLM_API_KEY
449
+ if _LLM_API_BASE:
450
+ completion_args["api_base"] = _LLM_API_BASE
451
+
452
+ if self._should_include_stop_param():
453
+ completion_args["stop"] = ["</function>"]
454
+
455
+ if self._should_include_reasoning_effort():
456
+ completion_args["reasoning_effort"] = "high"
457
+
458
+ queue = get_global_queue()
459
+ response = await queue.make_request(completion_args)
460
+
461
+ self._total_stats.requests += 1
462
+ self._last_request_stats = RequestStats(requests=1)
463
+
464
+ return response
465
+
466
+ def _update_usage_stats(self, response: ModelResponse) -> None:
467
+ try:
468
+ if hasattr(response, "usage") and response.usage:
469
+ input_tokens = getattr(response.usage, "prompt_tokens", 0)
470
+ output_tokens = getattr(response.usage, "completion_tokens", 0)
471
+
472
+ cached_tokens = 0
473
+ cache_creation_tokens = 0
474
+
475
+ if hasattr(response.usage, "prompt_tokens_details"):
476
+ prompt_details = response.usage.prompt_tokens_details
477
+ if hasattr(prompt_details, "cached_tokens"):
478
+ cached_tokens = prompt_details.cached_tokens or 0
479
+
480
+ if hasattr(response.usage, "cache_creation_input_tokens"):
481
+ cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
482
+
483
+ else:
484
+ input_tokens = 0
485
+ output_tokens = 0
486
+ cached_tokens = 0
487
+ cache_creation_tokens = 0
488
+
489
+ try:
490
+ cost = completion_cost(response) or 0.0
491
+ except Exception as e: # noqa: BLE001
492
+ logger.warning(f"Failed to calculate cost: {e}")
493
+ cost = 0.0
494
+
495
+ self._total_stats.input_tokens += input_tokens
496
+ self._total_stats.output_tokens += output_tokens
497
+ self._total_stats.cached_tokens += cached_tokens
498
+ self._total_stats.cache_creation_tokens += cache_creation_tokens
499
+ self._total_stats.cost += cost
500
+
501
+ self._last_request_stats.input_tokens = input_tokens
502
+ self._last_request_stats.output_tokens = output_tokens
503
+ self._last_request_stats.cached_tokens = cached_tokens
504
+ self._last_request_stats.cache_creation_tokens = cache_creation_tokens
505
+ self._last_request_stats.cost = cost
506
+
507
+ if cached_tokens > 0:
508
+ logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
509
+ if cache_creation_tokens > 0:
510
+ logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
511
+
512
+ logger.info(f"Usage stats: {self.usage_stats}")
513
+ except Exception as e: # noqa: BLE001
514
+ logger.warning(f"Failed to update usage stats: {e}")