strix-agent 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strix/__init__.py +0 -0
- strix/agents/StrixAgent/__init__.py +4 -0
- strix/agents/StrixAgent/strix_agent.py +89 -0
- strix/agents/StrixAgent/system_prompt.jinja +404 -0
- strix/agents/__init__.py +10 -0
- strix/agents/base_agent.py +518 -0
- strix/agents/state.py +163 -0
- strix/interface/__init__.py +4 -0
- strix/interface/assets/tui_styles.tcss +694 -0
- strix/interface/cli.py +230 -0
- strix/interface/main.py +500 -0
- strix/interface/tool_components/__init__.py +39 -0
- strix/interface/tool_components/agents_graph_renderer.py +123 -0
- strix/interface/tool_components/base_renderer.py +62 -0
- strix/interface/tool_components/browser_renderer.py +120 -0
- strix/interface/tool_components/file_edit_renderer.py +99 -0
- strix/interface/tool_components/finish_renderer.py +31 -0
- strix/interface/tool_components/notes_renderer.py +108 -0
- strix/interface/tool_components/proxy_renderer.py +255 -0
- strix/interface/tool_components/python_renderer.py +34 -0
- strix/interface/tool_components/registry.py +72 -0
- strix/interface/tool_components/reporting_renderer.py +53 -0
- strix/interface/tool_components/scan_info_renderer.py +64 -0
- strix/interface/tool_components/terminal_renderer.py +131 -0
- strix/interface/tool_components/thinking_renderer.py +29 -0
- strix/interface/tool_components/user_message_renderer.py +43 -0
- strix/interface/tool_components/web_search_renderer.py +28 -0
- strix/interface/tui.py +1274 -0
- strix/interface/utils.py +559 -0
- strix/llm/__init__.py +15 -0
- strix/llm/config.py +20 -0
- strix/llm/llm.py +465 -0
- strix/llm/memory_compressor.py +212 -0
- strix/llm/request_queue.py +87 -0
- strix/llm/utils.py +87 -0
- strix/prompts/README.md +64 -0
- strix/prompts/__init__.py +109 -0
- strix/prompts/cloud/.gitkeep +0 -0
- strix/prompts/coordination/root_agent.jinja +41 -0
- strix/prompts/custom/.gitkeep +0 -0
- strix/prompts/frameworks/fastapi.jinja +142 -0
- strix/prompts/frameworks/nextjs.jinja +126 -0
- strix/prompts/protocols/graphql.jinja +215 -0
- strix/prompts/reconnaissance/.gitkeep +0 -0
- strix/prompts/technologies/firebase_firestore.jinja +177 -0
- strix/prompts/technologies/supabase.jinja +189 -0
- strix/prompts/vulnerabilities/authentication_jwt.jinja +147 -0
- strix/prompts/vulnerabilities/broken_function_level_authorization.jinja +146 -0
- strix/prompts/vulnerabilities/business_logic.jinja +171 -0
- strix/prompts/vulnerabilities/csrf.jinja +174 -0
- strix/prompts/vulnerabilities/idor.jinja +195 -0
- strix/prompts/vulnerabilities/information_disclosure.jinja +222 -0
- strix/prompts/vulnerabilities/insecure_file_uploads.jinja +188 -0
- strix/prompts/vulnerabilities/mass_assignment.jinja +141 -0
- strix/prompts/vulnerabilities/open_redirect.jinja +177 -0
- strix/prompts/vulnerabilities/path_traversal_lfi_rfi.jinja +142 -0
- strix/prompts/vulnerabilities/race_conditions.jinja +164 -0
- strix/prompts/vulnerabilities/rce.jinja +154 -0
- strix/prompts/vulnerabilities/sql_injection.jinja +151 -0
- strix/prompts/vulnerabilities/ssrf.jinja +135 -0
- strix/prompts/vulnerabilities/subdomain_takeover.jinja +155 -0
- strix/prompts/vulnerabilities/xss.jinja +169 -0
- strix/prompts/vulnerabilities/xxe.jinja +184 -0
- strix/runtime/__init__.py +19 -0
- strix/runtime/docker_runtime.py +399 -0
- strix/runtime/runtime.py +29 -0
- strix/runtime/tool_server.py +205 -0
- strix/telemetry/__init__.py +4 -0
- strix/telemetry/tracer.py +337 -0
- strix/tools/__init__.py +64 -0
- strix/tools/agents_graph/__init__.py +16 -0
- strix/tools/agents_graph/agents_graph_actions.py +621 -0
- strix/tools/agents_graph/agents_graph_actions_schema.xml +226 -0
- strix/tools/argument_parser.py +121 -0
- strix/tools/browser/__init__.py +4 -0
- strix/tools/browser/browser_actions.py +236 -0
- strix/tools/browser/browser_actions_schema.xml +183 -0
- strix/tools/browser/browser_instance.py +533 -0
- strix/tools/browser/tab_manager.py +342 -0
- strix/tools/executor.py +305 -0
- strix/tools/file_edit/__init__.py +4 -0
- strix/tools/file_edit/file_edit_actions.py +141 -0
- strix/tools/file_edit/file_edit_actions_schema.xml +128 -0
- strix/tools/finish/__init__.py +4 -0
- strix/tools/finish/finish_actions.py +174 -0
- strix/tools/finish/finish_actions_schema.xml +45 -0
- strix/tools/notes/__init__.py +14 -0
- strix/tools/notes/notes_actions.py +191 -0
- strix/tools/notes/notes_actions_schema.xml +150 -0
- strix/tools/proxy/__init__.py +20 -0
- strix/tools/proxy/proxy_actions.py +101 -0
- strix/tools/proxy/proxy_actions_schema.xml +267 -0
- strix/tools/proxy/proxy_manager.py +785 -0
- strix/tools/python/__init__.py +4 -0
- strix/tools/python/python_actions.py +47 -0
- strix/tools/python/python_actions_schema.xml +131 -0
- strix/tools/python/python_instance.py +172 -0
- strix/tools/python/python_manager.py +131 -0
- strix/tools/registry.py +196 -0
- strix/tools/reporting/__init__.py +6 -0
- strix/tools/reporting/reporting_actions.py +63 -0
- strix/tools/reporting/reporting_actions_schema.xml +30 -0
- strix/tools/terminal/__init__.py +4 -0
- strix/tools/terminal/terminal_actions.py +35 -0
- strix/tools/terminal/terminal_actions_schema.xml +146 -0
- strix/tools/terminal/terminal_manager.py +151 -0
- strix/tools/terminal/terminal_session.py +447 -0
- strix/tools/thinking/__init__.py +4 -0
- strix/tools/thinking/thinking_actions.py +18 -0
- strix/tools/thinking/thinking_actions_schema.xml +52 -0
- strix/tools/web_search/__init__.py +4 -0
- strix/tools/web_search/web_search_actions.py +80 -0
- strix/tools/web_search/web_search_actions_schema.xml +83 -0
- strix_agent-0.4.0.dist-info/LICENSE +201 -0
- strix_agent-0.4.0.dist-info/METADATA +282 -0
- strix_agent-0.4.0.dist-info/RECORD +118 -0
- strix_agent-0.4.0.dist-info/WHEEL +4 -0
- strix_agent-0.4.0.dist-info/entry_points.txt +3 -0
strix/llm/llm.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from fnmatch import fnmatch
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import litellm
|
|
10
|
+
from jinja2 import (
|
|
11
|
+
Environment,
|
|
12
|
+
FileSystemLoader,
|
|
13
|
+
select_autoescape,
|
|
14
|
+
)
|
|
15
|
+
from litellm import ModelResponse, completion_cost
|
|
16
|
+
from litellm.utils import supports_prompt_caching
|
|
17
|
+
|
|
18
|
+
from strix.llm.config import LLMConfig
|
|
19
|
+
from strix.llm.memory_compressor import MemoryCompressor
|
|
20
|
+
from strix.llm.request_queue import get_global_queue
|
|
21
|
+
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
|
|
22
|
+
from strix.prompts import load_prompt_modules
|
|
23
|
+
from strix.tools import get_tools_prompt
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
api_key = os.getenv("LLM_API_KEY")
|
|
29
|
+
if api_key:
|
|
30
|
+
litellm.api_key = api_key
|
|
31
|
+
|
|
32
|
+
api_base = (
|
|
33
|
+
os.getenv("LLM_API_BASE")
|
|
34
|
+
or os.getenv("OPENAI_API_BASE")
|
|
35
|
+
or os.getenv("LITELLM_BASE_URL")
|
|
36
|
+
or os.getenv("OLLAMA_API_BASE")
|
|
37
|
+
)
|
|
38
|
+
if api_base:
|
|
39
|
+
litellm.api_base = api_base
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LLMRequestFailedError(Exception):
|
|
43
|
+
def __init__(self, message: str, details: str | None = None):
|
|
44
|
+
super().__init__(message)
|
|
45
|
+
self.message = message
|
|
46
|
+
self.details = details
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
|
|
50
|
+
"o1*",
|
|
51
|
+
"grok-4-0709",
|
|
52
|
+
"grok-code-fast-1",
|
|
53
|
+
"deepseek-r1-0528*",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
REASONING_EFFORT_PATTERNS: list[str] = [
|
|
57
|
+
"o1-2024-12-17",
|
|
58
|
+
"o1",
|
|
59
|
+
"o3",
|
|
60
|
+
"o3-2025-04-16",
|
|
61
|
+
"o3-mini-2025-01-31",
|
|
62
|
+
"o3-mini",
|
|
63
|
+
"o4-mini",
|
|
64
|
+
"o4-mini-2025-04-16",
|
|
65
|
+
"gemini-2.5-flash",
|
|
66
|
+
"gemini-2.5-pro",
|
|
67
|
+
"gpt-5*",
|
|
68
|
+
"deepseek-r1-0528*",
|
|
69
|
+
"claude-sonnet-4-5*",
|
|
70
|
+
"claude-haiku-4-5*",
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def normalize_model_name(model: str) -> str:
|
|
75
|
+
raw = (model or "").strip().lower()
|
|
76
|
+
if "/" in raw:
|
|
77
|
+
name = raw.split("/")[-1]
|
|
78
|
+
if ":" in name:
|
|
79
|
+
name = name.split(":", 1)[0]
|
|
80
|
+
else:
|
|
81
|
+
name = raw
|
|
82
|
+
if name.endswith("-gguf"):
|
|
83
|
+
name = name[: -len("-gguf")]
|
|
84
|
+
return name
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def model_matches(model: str, patterns: list[str]) -> bool:
|
|
88
|
+
raw = (model or "").strip().lower()
|
|
89
|
+
name = normalize_model_name(model)
|
|
90
|
+
for pat in patterns:
|
|
91
|
+
pat_l = pat.lower()
|
|
92
|
+
if "/" in pat_l:
|
|
93
|
+
if fnmatch(raw, pat_l):
|
|
94
|
+
return True
|
|
95
|
+
elif fnmatch(name, pat_l):
|
|
96
|
+
return True
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class StepRole(str, Enum):
|
|
101
|
+
AGENT = "agent"
|
|
102
|
+
USER = "user"
|
|
103
|
+
SYSTEM = "system"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class LLMResponse:
|
|
108
|
+
content: str
|
|
109
|
+
tool_invocations: list[dict[str, Any]] | None = None
|
|
110
|
+
scan_id: str | None = None
|
|
111
|
+
step_number: int = 1
|
|
112
|
+
role: StepRole = StepRole.AGENT
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class RequestStats:
|
|
117
|
+
input_tokens: int = 0
|
|
118
|
+
output_tokens: int = 0
|
|
119
|
+
cached_tokens: int = 0
|
|
120
|
+
cache_creation_tokens: int = 0
|
|
121
|
+
cost: float = 0.0
|
|
122
|
+
requests: int = 0
|
|
123
|
+
failed_requests: int = 0
|
|
124
|
+
|
|
125
|
+
def to_dict(self) -> dict[str, int | float]:
|
|
126
|
+
return {
|
|
127
|
+
"input_tokens": self.input_tokens,
|
|
128
|
+
"output_tokens": self.output_tokens,
|
|
129
|
+
"cached_tokens": self.cached_tokens,
|
|
130
|
+
"cache_creation_tokens": self.cache_creation_tokens,
|
|
131
|
+
"cost": round(self.cost, 4),
|
|
132
|
+
"requests": self.requests,
|
|
133
|
+
"failed_requests": self.failed_requests,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class LLM:
|
|
138
|
+
def __init__(
|
|
139
|
+
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
|
|
140
|
+
):
|
|
141
|
+
self.config = config
|
|
142
|
+
self.agent_name = agent_name
|
|
143
|
+
self.agent_id = agent_id
|
|
144
|
+
self._total_stats = RequestStats()
|
|
145
|
+
self._last_request_stats = RequestStats()
|
|
146
|
+
|
|
147
|
+
self.memory_compressor = MemoryCompressor(
|
|
148
|
+
model_name=self.config.model_name,
|
|
149
|
+
timeout=self.config.timeout,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if agent_name:
|
|
153
|
+
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
|
|
154
|
+
prompts_dir = Path(__file__).parent.parent / "prompts"
|
|
155
|
+
|
|
156
|
+
loader = FileSystemLoader([prompt_dir, prompts_dir])
|
|
157
|
+
self.jinja_env = Environment(
|
|
158
|
+
loader=loader,
|
|
159
|
+
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
prompt_module_content = load_prompt_modules(
|
|
164
|
+
self.config.prompt_modules or [], self.jinja_env
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def get_module(name: str) -> str:
|
|
168
|
+
return prompt_module_content.get(name, "")
|
|
169
|
+
|
|
170
|
+
self.jinja_env.globals["get_module"] = get_module
|
|
171
|
+
|
|
172
|
+
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
|
|
173
|
+
get_tools_prompt=get_tools_prompt,
|
|
174
|
+
loaded_module_names=list(prompt_module_content.keys()),
|
|
175
|
+
**prompt_module_content,
|
|
176
|
+
)
|
|
177
|
+
except (FileNotFoundError, OSError, ValueError) as e:
|
|
178
|
+
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
|
|
179
|
+
self.system_prompt = "You are a helpful AI assistant."
|
|
180
|
+
else:
|
|
181
|
+
self.system_prompt = "You are a helpful AI assistant."
|
|
182
|
+
|
|
183
|
+
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
|
|
184
|
+
if agent_name:
|
|
185
|
+
self.agent_name = agent_name
|
|
186
|
+
if agent_id:
|
|
187
|
+
self.agent_id = agent_id
|
|
188
|
+
|
|
189
|
+
def _build_identity_message(self) -> dict[str, Any] | None:
|
|
190
|
+
if not (self.agent_name and str(self.agent_name).strip()):
|
|
191
|
+
return None
|
|
192
|
+
identity_name = self.agent_name
|
|
193
|
+
identity_id = self.agent_id
|
|
194
|
+
content = (
|
|
195
|
+
"\n\n"
|
|
196
|
+
"<agent_identity>\n"
|
|
197
|
+
"<meta>Internal metadata: do not echo or reference; "
|
|
198
|
+
"not part of history or tool calls.</meta>\n"
|
|
199
|
+
"<note>You are now assuming the role of this agent. "
|
|
200
|
+
"Act strictly as this agent and maintain self-identity for this step. "
|
|
201
|
+
"Now go answer the next needed step!</note>\n"
|
|
202
|
+
f"<agent_name>{identity_name}</agent_name>\n"
|
|
203
|
+
f"<agent_id>{identity_id}</agent_id>\n"
|
|
204
|
+
"</agent_identity>\n\n"
|
|
205
|
+
)
|
|
206
|
+
return {"role": "user", "content": content}
|
|
207
|
+
|
|
208
|
+
def _add_cache_control_to_content(
|
|
209
|
+
self, content: str | list[dict[str, Any]]
|
|
210
|
+
) -> str | list[dict[str, Any]]:
|
|
211
|
+
if isinstance(content, str):
|
|
212
|
+
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
|
213
|
+
if isinstance(content, list) and content:
|
|
214
|
+
last_item = content[-1]
|
|
215
|
+
if isinstance(last_item, dict) and last_item.get("type") == "text":
|
|
216
|
+
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
|
|
217
|
+
return content
|
|
218
|
+
|
|
219
|
+
def _is_anthropic_model(self) -> bool:
|
|
220
|
+
if not self.config.model_name:
|
|
221
|
+
return False
|
|
222
|
+
model_lower = self.config.model_name.lower()
|
|
223
|
+
return any(provider in model_lower for provider in ["anthropic/", "claude"])
|
|
224
|
+
|
|
225
|
+
def _calculate_cache_interval(self, total_messages: int) -> int:
|
|
226
|
+
if total_messages <= 1:
|
|
227
|
+
return 10
|
|
228
|
+
|
|
229
|
+
max_cached_messages = 3
|
|
230
|
+
non_system_messages = total_messages - 1
|
|
231
|
+
|
|
232
|
+
interval = 10
|
|
233
|
+
while non_system_messages // interval > max_cached_messages:
|
|
234
|
+
interval += 10
|
|
235
|
+
|
|
236
|
+
return interval
|
|
237
|
+
|
|
238
|
+
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
239
|
+
if (
|
|
240
|
+
not self.config.enable_prompt_caching
|
|
241
|
+
or not supports_prompt_caching(self.config.model_name)
|
|
242
|
+
or not messages
|
|
243
|
+
):
|
|
244
|
+
return messages
|
|
245
|
+
|
|
246
|
+
if not self._is_anthropic_model():
|
|
247
|
+
return messages
|
|
248
|
+
|
|
249
|
+
cached_messages = list(messages)
|
|
250
|
+
|
|
251
|
+
if cached_messages and cached_messages[0].get("role") == "system":
|
|
252
|
+
system_message = cached_messages[0].copy()
|
|
253
|
+
system_message["content"] = self._add_cache_control_to_content(
|
|
254
|
+
system_message["content"]
|
|
255
|
+
)
|
|
256
|
+
cached_messages[0] = system_message
|
|
257
|
+
|
|
258
|
+
total_messages = len(cached_messages)
|
|
259
|
+
if total_messages > 1:
|
|
260
|
+
interval = self._calculate_cache_interval(total_messages)
|
|
261
|
+
|
|
262
|
+
cached_count = 0
|
|
263
|
+
for i in range(interval, total_messages, interval):
|
|
264
|
+
if cached_count >= 3:
|
|
265
|
+
break
|
|
266
|
+
|
|
267
|
+
if i < len(cached_messages):
|
|
268
|
+
message = cached_messages[i].copy()
|
|
269
|
+
message["content"] = self._add_cache_control_to_content(message["content"])
|
|
270
|
+
cached_messages[i] = message
|
|
271
|
+
cached_count += 1
|
|
272
|
+
|
|
273
|
+
return cached_messages
|
|
274
|
+
|
|
275
|
+
async def generate( # noqa: PLR0912, PLR0915
|
|
276
|
+
self,
|
|
277
|
+
conversation_history: list[dict[str, Any]],
|
|
278
|
+
scan_id: str | None = None,
|
|
279
|
+
step_number: int = 1,
|
|
280
|
+
) -> LLMResponse:
|
|
281
|
+
messages = [{"role": "system", "content": self.system_prompt}]
|
|
282
|
+
|
|
283
|
+
identity_message = self._build_identity_message()
|
|
284
|
+
if identity_message:
|
|
285
|
+
messages.append(identity_message)
|
|
286
|
+
|
|
287
|
+
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
|
|
288
|
+
|
|
289
|
+
conversation_history.clear()
|
|
290
|
+
conversation_history.extend(compressed_history)
|
|
291
|
+
messages.extend(compressed_history)
|
|
292
|
+
|
|
293
|
+
cached_messages = self._prepare_cached_messages(messages)
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
response = await self._make_request(cached_messages)
|
|
297
|
+
self._update_usage_stats(response)
|
|
298
|
+
|
|
299
|
+
content = ""
|
|
300
|
+
if (
|
|
301
|
+
response.choices
|
|
302
|
+
and hasattr(response.choices[0], "message")
|
|
303
|
+
and response.choices[0].message
|
|
304
|
+
):
|
|
305
|
+
content = getattr(response.choices[0].message, "content", "") or ""
|
|
306
|
+
|
|
307
|
+
content = _truncate_to_first_function(content)
|
|
308
|
+
|
|
309
|
+
if "</function>" in content:
|
|
310
|
+
function_end_index = content.find("</function>") + len("</function>")
|
|
311
|
+
content = content[:function_end_index]
|
|
312
|
+
|
|
313
|
+
tool_invocations = parse_tool_invocations(content)
|
|
314
|
+
|
|
315
|
+
return LLMResponse(
|
|
316
|
+
scan_id=scan_id,
|
|
317
|
+
step_number=step_number,
|
|
318
|
+
role=StepRole.AGENT,
|
|
319
|
+
content=content,
|
|
320
|
+
tool_invocations=tool_invocations if tool_invocations else None,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
except litellm.RateLimitError as e:
|
|
324
|
+
raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e
|
|
325
|
+
except litellm.AuthenticationError as e:
|
|
326
|
+
raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e
|
|
327
|
+
except litellm.NotFoundError as e:
|
|
328
|
+
raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e
|
|
329
|
+
except litellm.ContextWindowExceededError as e:
|
|
330
|
+
raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e
|
|
331
|
+
except litellm.ContentPolicyViolationError as e:
|
|
332
|
+
raise LLMRequestFailedError(
|
|
333
|
+
"LLM request failed: Content policy violation", str(e)
|
|
334
|
+
) from e
|
|
335
|
+
except litellm.ServiceUnavailableError as e:
|
|
336
|
+
raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e
|
|
337
|
+
except litellm.Timeout as e:
|
|
338
|
+
raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e
|
|
339
|
+
except litellm.UnprocessableEntityError as e:
|
|
340
|
+
raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e
|
|
341
|
+
except litellm.InternalServerError as e:
|
|
342
|
+
raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e
|
|
343
|
+
except litellm.APIConnectionError as e:
|
|
344
|
+
raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e
|
|
345
|
+
except litellm.UnsupportedParamsError as e:
|
|
346
|
+
raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e
|
|
347
|
+
except litellm.BudgetExceededError as e:
|
|
348
|
+
raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e
|
|
349
|
+
except litellm.APIResponseValidationError as e:
|
|
350
|
+
raise LLMRequestFailedError(
|
|
351
|
+
"LLM request failed: Response validation error", str(e)
|
|
352
|
+
) from e
|
|
353
|
+
except litellm.JSONSchemaValidationError as e:
|
|
354
|
+
raise LLMRequestFailedError(
|
|
355
|
+
"LLM request failed: JSON schema validation error", str(e)
|
|
356
|
+
) from e
|
|
357
|
+
except litellm.InvalidRequestError as e:
|
|
358
|
+
raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e
|
|
359
|
+
except litellm.BadRequestError as e:
|
|
360
|
+
raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e
|
|
361
|
+
except litellm.APIError as e:
|
|
362
|
+
raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e
|
|
363
|
+
except litellm.OpenAIError as e:
|
|
364
|
+
raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e
|
|
365
|
+
except Exception as e:
|
|
366
|
+
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
|
367
|
+
|
|
368
|
+
@property
|
|
369
|
+
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
|
370
|
+
return {
|
|
371
|
+
"total": self._total_stats.to_dict(),
|
|
372
|
+
"last_request": self._last_request_stats.to_dict(),
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
def get_cache_config(self) -> dict[str, bool]:
|
|
376
|
+
return {
|
|
377
|
+
"enabled": self.config.enable_prompt_caching,
|
|
378
|
+
"supported": supports_prompt_caching(self.config.model_name),
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
def _should_include_stop_param(self) -> bool:
|
|
382
|
+
if not self.config.model_name:
|
|
383
|
+
return True
|
|
384
|
+
|
|
385
|
+
return not model_matches(self.config.model_name, SUPPORTS_STOP_WORDS_FALSE_PATTERNS)
|
|
386
|
+
|
|
387
|
+
def _should_include_reasoning_effort(self) -> bool:
|
|
388
|
+
if not self.config.model_name:
|
|
389
|
+
return False
|
|
390
|
+
|
|
391
|
+
return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
|
|
392
|
+
|
|
393
|
+
async def _make_request(
|
|
394
|
+
self,
|
|
395
|
+
messages: list[dict[str, Any]],
|
|
396
|
+
) -> ModelResponse:
|
|
397
|
+
completion_args: dict[str, Any] = {
|
|
398
|
+
"model": self.config.model_name,
|
|
399
|
+
"messages": messages,
|
|
400
|
+
"timeout": self.config.timeout,
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
if self._should_include_stop_param():
|
|
404
|
+
completion_args["stop"] = ["</function>"]
|
|
405
|
+
|
|
406
|
+
if self._should_include_reasoning_effort():
|
|
407
|
+
completion_args["reasoning_effort"] = "high"
|
|
408
|
+
|
|
409
|
+
queue = get_global_queue()
|
|
410
|
+
response = await queue.make_request(completion_args)
|
|
411
|
+
|
|
412
|
+
self._total_stats.requests += 1
|
|
413
|
+
self._last_request_stats = RequestStats(requests=1)
|
|
414
|
+
|
|
415
|
+
return response
|
|
416
|
+
|
|
417
|
+
def _update_usage_stats(self, response: ModelResponse) -> None:
|
|
418
|
+
try:
|
|
419
|
+
if hasattr(response, "usage") and response.usage:
|
|
420
|
+
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
|
421
|
+
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
|
422
|
+
|
|
423
|
+
cached_tokens = 0
|
|
424
|
+
cache_creation_tokens = 0
|
|
425
|
+
|
|
426
|
+
if hasattr(response.usage, "prompt_tokens_details"):
|
|
427
|
+
prompt_details = response.usage.prompt_tokens_details
|
|
428
|
+
if hasattr(prompt_details, "cached_tokens"):
|
|
429
|
+
cached_tokens = prompt_details.cached_tokens or 0
|
|
430
|
+
|
|
431
|
+
if hasattr(response.usage, "cache_creation_input_tokens"):
|
|
432
|
+
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
|
|
433
|
+
|
|
434
|
+
else:
|
|
435
|
+
input_tokens = 0
|
|
436
|
+
output_tokens = 0
|
|
437
|
+
cached_tokens = 0
|
|
438
|
+
cache_creation_tokens = 0
|
|
439
|
+
|
|
440
|
+
try:
|
|
441
|
+
cost = completion_cost(response) or 0.0
|
|
442
|
+
except Exception as e: # noqa: BLE001
|
|
443
|
+
logger.warning(f"Failed to calculate cost: {e}")
|
|
444
|
+
cost = 0.0
|
|
445
|
+
|
|
446
|
+
self._total_stats.input_tokens += input_tokens
|
|
447
|
+
self._total_stats.output_tokens += output_tokens
|
|
448
|
+
self._total_stats.cached_tokens += cached_tokens
|
|
449
|
+
self._total_stats.cache_creation_tokens += cache_creation_tokens
|
|
450
|
+
self._total_stats.cost += cost
|
|
451
|
+
|
|
452
|
+
self._last_request_stats.input_tokens = input_tokens
|
|
453
|
+
self._last_request_stats.output_tokens = output_tokens
|
|
454
|
+
self._last_request_stats.cached_tokens = cached_tokens
|
|
455
|
+
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
|
|
456
|
+
self._last_request_stats.cost = cost
|
|
457
|
+
|
|
458
|
+
if cached_tokens > 0:
|
|
459
|
+
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
|
|
460
|
+
if cache_creation_tokens > 0:
|
|
461
|
+
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
|
|
462
|
+
|
|
463
|
+
logger.info(f"Usage stats: {self.usage_stats}")
|
|
464
|
+
except Exception as e: # noqa: BLE001
|
|
465
|
+
logger.warning(f"Failed to update usage stats: {e}")
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import litellm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
MAX_TOTAL_TOKENS = 100_000
|
|
12
|
+
MIN_RECENT_MESSAGES = 15
|
|
13
|
+
|
|
14
|
+
SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context
|
|
15
|
+
condensation for a security agent. Your job is to compress scan data while preserving
|
|
16
|
+
ALL operationally critical information for continuing the security assessment.
|
|
17
|
+
|
|
18
|
+
CRITICAL ELEMENTS TO PRESERVE:
|
|
19
|
+
- Discovered vulnerabilities and potential attack vectors
|
|
20
|
+
- Scan results and tool outputs (compressed but maintaining key findings)
|
|
21
|
+
- Access credentials, tokens, or authentication details found
|
|
22
|
+
- System architecture insights and potential weak points
|
|
23
|
+
- Progress made in the assessment
|
|
24
|
+
- Failed attempts and dead ends (to avoid duplication)
|
|
25
|
+
- Any decisions made about the testing approach
|
|
26
|
+
|
|
27
|
+
COMPRESSION GUIDELINES:
|
|
28
|
+
- Preserve exact technical details (URLs, paths, parameters, payloads)
|
|
29
|
+
- Summarize verbose tool outputs while keeping critical findings
|
|
30
|
+
- Maintain version numbers, specific technologies identified
|
|
31
|
+
- Keep exact error messages that might indicate vulnerabilities
|
|
32
|
+
- Compress repetitive or similar findings into consolidated form
|
|
33
|
+
|
|
34
|
+
Remember: Another security agent will use this summary to continue the assessment.
|
|
35
|
+
They must be able to pick up exactly where you left off without losing any
|
|
36
|
+
operational advantage or context needed to find vulnerabilities.
|
|
37
|
+
|
|
38
|
+
CONVERSATION SEGMENT TO SUMMARIZE:
|
|
39
|
+
{conversation}
|
|
40
|
+
|
|
41
|
+
Provide a technically precise summary that preserves all operational security context while
|
|
42
|
+
keeping the summary concise and to the point."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _count_tokens(text: str, model: str) -> int:
|
|
46
|
+
try:
|
|
47
|
+
count = litellm.token_counter(model=model, text=text)
|
|
48
|
+
return int(count)
|
|
49
|
+
except Exception:
|
|
50
|
+
logger.exception("Failed to count tokens")
|
|
51
|
+
return len(text) // 4 # Rough estimate
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _get_message_tokens(msg: dict[str, Any], model: str) -> int:
|
|
55
|
+
content = msg.get("content", "")
|
|
56
|
+
if isinstance(content, str):
|
|
57
|
+
return _count_tokens(content, model)
|
|
58
|
+
if isinstance(content, list):
|
|
59
|
+
return sum(
|
|
60
|
+
_count_tokens(item.get("text", ""), model)
|
|
61
|
+
for item in content
|
|
62
|
+
if isinstance(item, dict) and item.get("type") == "text"
|
|
63
|
+
)
|
|
64
|
+
return 0
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _extract_message_text(msg: dict[str, Any]) -> str:
|
|
68
|
+
content = msg.get("content", "")
|
|
69
|
+
if isinstance(content, str):
|
|
70
|
+
return content
|
|
71
|
+
|
|
72
|
+
if isinstance(content, list):
|
|
73
|
+
parts = []
|
|
74
|
+
for item in content:
|
|
75
|
+
if isinstance(item, dict):
|
|
76
|
+
if item.get("type") == "text":
|
|
77
|
+
parts.append(item.get("text", ""))
|
|
78
|
+
elif item.get("type") == "image_url":
|
|
79
|
+
parts.append("[IMAGE]")
|
|
80
|
+
return " ".join(parts)
|
|
81
|
+
|
|
82
|
+
return str(content)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _summarize_messages(
|
|
86
|
+
messages: list[dict[str, Any]],
|
|
87
|
+
model: str,
|
|
88
|
+
timeout: int = 600,
|
|
89
|
+
) -> dict[str, Any]:
|
|
90
|
+
if not messages:
|
|
91
|
+
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
|
|
92
|
+
return {
|
|
93
|
+
"role": "assistant",
|
|
94
|
+
"content": empty_summary.format(text="No messages to summarize"),
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
formatted = []
|
|
98
|
+
for msg in messages:
|
|
99
|
+
role = msg.get("role", "unknown")
|
|
100
|
+
text = _extract_message_text(msg)
|
|
101
|
+
formatted.append(f"{role}: {text}")
|
|
102
|
+
|
|
103
|
+
conversation = "\n".join(formatted)
|
|
104
|
+
prompt = SUMMARY_PROMPT_TEMPLATE.format(conversation=conversation)
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
completion_args = {
|
|
108
|
+
"model": model,
|
|
109
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
110
|
+
"timeout": timeout,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
response = litellm.completion(**completion_args)
|
|
114
|
+
summary = response.choices[0].message.content or ""
|
|
115
|
+
if not summary.strip():
|
|
116
|
+
return messages[0]
|
|
117
|
+
summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
|
|
118
|
+
return {
|
|
119
|
+
"role": "assistant",
|
|
120
|
+
"content": summary_msg.format(count=len(messages), text=summary),
|
|
121
|
+
}
|
|
122
|
+
except Exception:
|
|
123
|
+
logger.exception("Failed to summarize messages")
|
|
124
|
+
return messages[0]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _handle_images(messages: list[dict[str, Any]], max_images: int) -> None:
|
|
128
|
+
image_count = 0
|
|
129
|
+
for msg in reversed(messages):
|
|
130
|
+
content = msg.get("content", [])
|
|
131
|
+
if isinstance(content, list):
|
|
132
|
+
for item in content:
|
|
133
|
+
if isinstance(item, dict) and item.get("type") == "image_url":
|
|
134
|
+
if image_count >= max_images:
|
|
135
|
+
item.update(
|
|
136
|
+
{
|
|
137
|
+
"type": "text",
|
|
138
|
+
"text": "[Previously attached image removed to preserve context]",
|
|
139
|
+
}
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
image_count += 1
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class MemoryCompressor:
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
max_images: int = 3,
|
|
149
|
+
model_name: str | None = None,
|
|
150
|
+
timeout: int = 600,
|
|
151
|
+
):
|
|
152
|
+
self.max_images = max_images
|
|
153
|
+
self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5")
|
|
154
|
+
self.timeout = timeout
|
|
155
|
+
|
|
156
|
+
if not self.model_name:
|
|
157
|
+
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
|
158
|
+
|
|
159
|
+
def compress_history(
|
|
160
|
+
self,
|
|
161
|
+
messages: list[dict[str, Any]],
|
|
162
|
+
) -> list[dict[str, Any]]:
|
|
163
|
+
"""Compress conversation history to stay within token limits.
|
|
164
|
+
|
|
165
|
+
Strategy:
|
|
166
|
+
1. Handle image limits first
|
|
167
|
+
2. Keep all system messages
|
|
168
|
+
3. Keep minimum recent messages
|
|
169
|
+
4. Summarize older messages when total tokens exceed limit
|
|
170
|
+
|
|
171
|
+
The compression preserves:
|
|
172
|
+
- All system messages unchanged
|
|
173
|
+
- Most recent messages intact
|
|
174
|
+
- Critical security context in summaries
|
|
175
|
+
- Recent images for visual context
|
|
176
|
+
- Technical details and findings
|
|
177
|
+
"""
|
|
178
|
+
if not messages:
|
|
179
|
+
return messages
|
|
180
|
+
|
|
181
|
+
_handle_images(messages, self.max_images)
|
|
182
|
+
|
|
183
|
+
system_msgs = []
|
|
184
|
+
regular_msgs = []
|
|
185
|
+
for msg in messages:
|
|
186
|
+
if msg.get("role") == "system":
|
|
187
|
+
system_msgs.append(msg)
|
|
188
|
+
else:
|
|
189
|
+
regular_msgs.append(msg)
|
|
190
|
+
|
|
191
|
+
recent_msgs = regular_msgs[-MIN_RECENT_MESSAGES:]
|
|
192
|
+
old_msgs = regular_msgs[:-MIN_RECENT_MESSAGES]
|
|
193
|
+
|
|
194
|
+
# Type assertion since we ensure model_name is not None in __init__
|
|
195
|
+
model_name: str = self.model_name # type: ignore[assignment]
|
|
196
|
+
|
|
197
|
+
total_tokens = sum(
|
|
198
|
+
_get_message_tokens(msg, model_name) for msg in system_msgs + regular_msgs
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if total_tokens <= MAX_TOTAL_TOKENS * 0.9:
|
|
202
|
+
return messages
|
|
203
|
+
|
|
204
|
+
compressed = []
|
|
205
|
+
chunk_size = 10
|
|
206
|
+
for i in range(0, len(old_msgs), chunk_size):
|
|
207
|
+
chunk = old_msgs[i : i + chunk_size]
|
|
208
|
+
summary = _summarize_messages(chunk, model_name, self.timeout)
|
|
209
|
+
if summary:
|
|
210
|
+
compressed.append(summary)
|
|
211
|
+
|
|
212
|
+
return system_msgs + compressed + recent_msgs
|