strix-agent 0.4.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strix/agents/StrixAgent/strix_agent.py +3 -3
- strix/agents/StrixAgent/system_prompt.jinja +30 -26
- strix/agents/base_agent.py +159 -75
- strix/agents/state.py +5 -2
- strix/config/__init__.py +12 -0
- strix/config/config.py +172 -0
- strix/interface/assets/tui_styles.tcss +195 -230
- strix/interface/cli.py +16 -41
- strix/interface/main.py +151 -74
- strix/interface/streaming_parser.py +119 -0
- strix/interface/tool_components/__init__.py +4 -0
- strix/interface/tool_components/agent_message_renderer.py +190 -0
- strix/interface/tool_components/agents_graph_renderer.py +54 -38
- strix/interface/tool_components/base_renderer.py +68 -36
- strix/interface/tool_components/browser_renderer.py +106 -91
- strix/interface/tool_components/file_edit_renderer.py +117 -36
- strix/interface/tool_components/finish_renderer.py +43 -10
- strix/interface/tool_components/notes_renderer.py +63 -38
- strix/interface/tool_components/proxy_renderer.py +133 -92
- strix/interface/tool_components/python_renderer.py +121 -8
- strix/interface/tool_components/registry.py +19 -12
- strix/interface/tool_components/reporting_renderer.py +196 -28
- strix/interface/tool_components/scan_info_renderer.py +22 -19
- strix/interface/tool_components/terminal_renderer.py +270 -90
- strix/interface/tool_components/thinking_renderer.py +8 -6
- strix/interface/tool_components/todo_renderer.py +225 -0
- strix/interface/tool_components/user_message_renderer.py +26 -19
- strix/interface/tool_components/web_search_renderer.py +7 -6
- strix/interface/tui.py +907 -262
- strix/interface/utils.py +236 -4
- strix/llm/__init__.py +6 -2
- strix/llm/config.py +8 -5
- strix/llm/dedupe.py +217 -0
- strix/llm/llm.py +209 -356
- strix/llm/memory_compressor.py +6 -5
- strix/llm/utils.py +17 -8
- strix/runtime/__init__.py +12 -3
- strix/runtime/docker_runtime.py +121 -202
- strix/runtime/tool_server.py +55 -95
- strix/skills/README.md +64 -0
- strix/skills/__init__.py +110 -0
- strix/{prompts → skills}/frameworks/nextjs.jinja +26 -0
- strix/skills/scan_modes/deep.jinja +145 -0
- strix/skills/scan_modes/quick.jinja +63 -0
- strix/skills/scan_modes/standard.jinja +91 -0
- strix/telemetry/README.md +38 -0
- strix/telemetry/__init__.py +7 -1
- strix/telemetry/posthog.py +137 -0
- strix/telemetry/tracer.py +194 -54
- strix/tools/__init__.py +11 -4
- strix/tools/agents_graph/agents_graph_actions.py +20 -21
- strix/tools/agents_graph/agents_graph_actions_schema.xml +8 -8
- strix/tools/browser/browser_actions.py +10 -6
- strix/tools/browser/browser_actions_schema.xml +6 -1
- strix/tools/browser/browser_instance.py +96 -48
- strix/tools/browser/tab_manager.py +121 -102
- strix/tools/context.py +12 -0
- strix/tools/executor.py +63 -4
- strix/tools/file_edit/file_edit_actions.py +6 -3
- strix/tools/file_edit/file_edit_actions_schema.xml +45 -3
- strix/tools/finish/finish_actions.py +80 -105
- strix/tools/finish/finish_actions_schema.xml +121 -14
- strix/tools/notes/notes_actions.py +6 -33
- strix/tools/notes/notes_actions_schema.xml +50 -46
- strix/tools/proxy/proxy_actions.py +14 -2
- strix/tools/proxy/proxy_actions_schema.xml +0 -1
- strix/tools/proxy/proxy_manager.py +28 -16
- strix/tools/python/python_actions.py +2 -2
- strix/tools/python/python_actions_schema.xml +9 -1
- strix/tools/python/python_instance.py +39 -37
- strix/tools/python/python_manager.py +43 -31
- strix/tools/registry.py +73 -12
- strix/tools/reporting/reporting_actions.py +218 -31
- strix/tools/reporting/reporting_actions_schema.xml +256 -8
- strix/tools/terminal/terminal_actions.py +2 -2
- strix/tools/terminal/terminal_actions_schema.xml +6 -0
- strix/tools/terminal/terminal_manager.py +41 -30
- strix/tools/thinking/thinking_actions_schema.xml +27 -25
- strix/tools/todo/__init__.py +18 -0
- strix/tools/todo/todo_actions.py +568 -0
- strix/tools/todo/todo_actions_schema.xml +225 -0
- strix/utils/__init__.py +0 -0
- strix/utils/resource_paths.py +13 -0
- {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info}/METADATA +90 -65
- strix_agent-0.6.2.dist-info/RECORD +134 -0
- {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info}/WHEEL +1 -1
- strix/llm/request_queue.py +0 -87
- strix/prompts/README.md +0 -64
- strix/prompts/__init__.py +0 -109
- strix_agent-0.4.0.dist-info/RECORD +0 -118
- /strix/{prompts → skills}/cloud/.gitkeep +0 -0
- /strix/{prompts → skills}/coordination/root_agent.jinja +0 -0
- /strix/{prompts → skills}/custom/.gitkeep +0 -0
- /strix/{prompts → skills}/frameworks/fastapi.jinja +0 -0
- /strix/{prompts → skills}/protocols/graphql.jinja +0 -0
- /strix/{prompts → skills}/reconnaissance/.gitkeep +0 -0
- /strix/{prompts → skills}/technologies/firebase_firestore.jinja +0 -0
- /strix/{prompts → skills}/technologies/supabase.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/authentication_jwt.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/broken_function_level_authorization.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/business_logic.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/csrf.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/idor.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/information_disclosure.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/insecure_file_uploads.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/mass_assignment.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/open_redirect.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/path_traversal_lfi_rfi.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/race_conditions.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/rce.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/sql_injection.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/ssrf.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/subdomain_takeover.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/xss.jinja +0 -0
- /strix/{prompts → skills}/vulnerabilities/xxe.jinja +0 -0
- {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info}/entry_points.txt +0 -0
- {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info/licenses}/LICENSE +0 -0
strix/llm/llm.py
CHANGED
|
@@ -1,42 +1,28 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import AsyncIterator
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from enum import Enum
|
|
5
|
-
from fnmatch import fnmatch
|
|
6
|
-
from pathlib import Path
|
|
7
4
|
from typing import Any
|
|
8
5
|
|
|
9
6
|
import litellm
|
|
10
|
-
from jinja2 import
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
select_autoescape,
|
|
14
|
-
)
|
|
15
|
-
from litellm import ModelResponse, completion_cost
|
|
16
|
-
from litellm.utils import supports_prompt_caching
|
|
7
|
+
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
|
8
|
+
from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
|
|
9
|
+
from litellm.utils import supports_prompt_caching, supports_vision
|
|
17
10
|
|
|
11
|
+
from strix.config import Config
|
|
18
12
|
from strix.llm.config import LLMConfig
|
|
19
13
|
from strix.llm.memory_compressor import MemoryCompressor
|
|
20
|
-
from strix.llm.
|
|
21
|
-
|
|
22
|
-
|
|
14
|
+
from strix.llm.utils import (
|
|
15
|
+
_truncate_to_first_function,
|
|
16
|
+
fix_incomplete_tool_call,
|
|
17
|
+
parse_tool_invocations,
|
|
18
|
+
)
|
|
19
|
+
from strix.skills import load_skills
|
|
23
20
|
from strix.tools import get_tools_prompt
|
|
21
|
+
from strix.utils.resource_paths import get_strix_resource_path
|
|
24
22
|
|
|
25
23
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
api_key = os.getenv("LLM_API_KEY")
|
|
29
|
-
if api_key:
|
|
30
|
-
litellm.api_key = api_key
|
|
31
|
-
|
|
32
|
-
api_base = (
|
|
33
|
-
os.getenv("LLM_API_BASE")
|
|
34
|
-
or os.getenv("OPENAI_API_BASE")
|
|
35
|
-
or os.getenv("LITELLM_BASE_URL")
|
|
36
|
-
or os.getenv("OLLAMA_API_BASE")
|
|
37
|
-
)
|
|
38
|
-
if api_base:
|
|
39
|
-
litellm.api_base = api_base
|
|
24
|
+
litellm.drop_params = True
|
|
25
|
+
litellm.modify_params = True
|
|
40
26
|
|
|
41
27
|
|
|
42
28
|
class LLMRequestFailedError(Exception):
|
|
@@ -46,70 +32,11 @@ class LLMRequestFailedError(Exception):
|
|
|
46
32
|
self.details = details
|
|
47
33
|
|
|
48
34
|
|
|
49
|
-
SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
|
|
50
|
-
"o1*",
|
|
51
|
-
"grok-4-0709",
|
|
52
|
-
"grok-code-fast-1",
|
|
53
|
-
"deepseek-r1-0528*",
|
|
54
|
-
]
|
|
55
|
-
|
|
56
|
-
REASONING_EFFORT_PATTERNS: list[str] = [
|
|
57
|
-
"o1-2024-12-17",
|
|
58
|
-
"o1",
|
|
59
|
-
"o3",
|
|
60
|
-
"o3-2025-04-16",
|
|
61
|
-
"o3-mini-2025-01-31",
|
|
62
|
-
"o3-mini",
|
|
63
|
-
"o4-mini",
|
|
64
|
-
"o4-mini-2025-04-16",
|
|
65
|
-
"gemini-2.5-flash",
|
|
66
|
-
"gemini-2.5-pro",
|
|
67
|
-
"gpt-5*",
|
|
68
|
-
"deepseek-r1-0528*",
|
|
69
|
-
"claude-sonnet-4-5*",
|
|
70
|
-
"claude-haiku-4-5*",
|
|
71
|
-
]
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def normalize_model_name(model: str) -> str:
|
|
75
|
-
raw = (model or "").strip().lower()
|
|
76
|
-
if "/" in raw:
|
|
77
|
-
name = raw.split("/")[-1]
|
|
78
|
-
if ":" in name:
|
|
79
|
-
name = name.split(":", 1)[0]
|
|
80
|
-
else:
|
|
81
|
-
name = raw
|
|
82
|
-
if name.endswith("-gguf"):
|
|
83
|
-
name = name[: -len("-gguf")]
|
|
84
|
-
return name
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def model_matches(model: str, patterns: list[str]) -> bool:
|
|
88
|
-
raw = (model or "").strip().lower()
|
|
89
|
-
name = normalize_model_name(model)
|
|
90
|
-
for pat in patterns:
|
|
91
|
-
pat_l = pat.lower()
|
|
92
|
-
if "/" in pat_l:
|
|
93
|
-
if fnmatch(raw, pat_l):
|
|
94
|
-
return True
|
|
95
|
-
elif fnmatch(name, pat_l):
|
|
96
|
-
return True
|
|
97
|
-
return False
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
class StepRole(str, Enum):
|
|
101
|
-
AGENT = "agent"
|
|
102
|
-
USER = "user"
|
|
103
|
-
SYSTEM = "system"
|
|
104
|
-
|
|
105
|
-
|
|
106
35
|
@dataclass
|
|
107
36
|
class LLMResponse:
|
|
108
37
|
content: str
|
|
109
38
|
tool_invocations: list[dict[str, Any]] | None = None
|
|
110
|
-
|
|
111
|
-
step_number: int = 1
|
|
112
|
-
role: StepRole = StepRole.AGENT
|
|
39
|
+
thinking_blocks: list[dict[str, Any]] | None = None
|
|
113
40
|
|
|
114
41
|
|
|
115
42
|
@dataclass
|
|
@@ -117,68 +44,63 @@ class RequestStats:
|
|
|
117
44
|
input_tokens: int = 0
|
|
118
45
|
output_tokens: int = 0
|
|
119
46
|
cached_tokens: int = 0
|
|
120
|
-
cache_creation_tokens: int = 0
|
|
121
47
|
cost: float = 0.0
|
|
122
48
|
requests: int = 0
|
|
123
|
-
failed_requests: int = 0
|
|
124
49
|
|
|
125
50
|
def to_dict(self) -> dict[str, int | float]:
|
|
126
51
|
return {
|
|
127
52
|
"input_tokens": self.input_tokens,
|
|
128
53
|
"output_tokens": self.output_tokens,
|
|
129
54
|
"cached_tokens": self.cached_tokens,
|
|
130
|
-
"cache_creation_tokens": self.cache_creation_tokens,
|
|
131
55
|
"cost": round(self.cost, 4),
|
|
132
56
|
"requests": self.requests,
|
|
133
|
-
"failed_requests": self.failed_requests,
|
|
134
57
|
}
|
|
135
58
|
|
|
136
59
|
|
|
137
60
|
class LLM:
|
|
138
|
-
def __init__(
|
|
139
|
-
self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
|
|
140
|
-
):
|
|
61
|
+
def __init__(self, config: LLMConfig, agent_name: str | None = None):
|
|
141
62
|
self.config = config
|
|
142
63
|
self.agent_name = agent_name
|
|
143
|
-
self.agent_id =
|
|
64
|
+
self.agent_id: str | None = None
|
|
144
65
|
self._total_stats = RequestStats()
|
|
145
|
-
self.
|
|
66
|
+
self.memory_compressor = MemoryCompressor(model_name=config.model_name)
|
|
67
|
+
self.system_prompt = self._load_system_prompt(agent_name)
|
|
68
|
+
|
|
69
|
+
reasoning = Config.get("strix_reasoning_effort")
|
|
70
|
+
if reasoning:
|
|
71
|
+
self._reasoning_effort = reasoning
|
|
72
|
+
elif config.scan_mode == "quick":
|
|
73
|
+
self._reasoning_effort = "medium"
|
|
74
|
+
else:
|
|
75
|
+
self._reasoning_effort = "high"
|
|
146
76
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
)
|
|
77
|
+
def _load_system_prompt(self, agent_name: str | None) -> str:
|
|
78
|
+
if not agent_name:
|
|
79
|
+
return ""
|
|
151
80
|
|
|
152
|
-
|
|
153
|
-
prompt_dir =
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
self.jinja_env = Environment(
|
|
158
|
-
loader=loader,
|
|
81
|
+
try:
|
|
82
|
+
prompt_dir = get_strix_resource_path("agents", agent_name)
|
|
83
|
+
skills_dir = get_strix_resource_path("skills")
|
|
84
|
+
env = Environment(
|
|
85
|
+
loader=FileSystemLoader([prompt_dir, skills_dir]),
|
|
159
86
|
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
|
160
87
|
)
|
|
161
88
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
except (FileNotFoundError, OSError, ValueError) as e:
|
|
178
|
-
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
|
|
179
|
-
self.system_prompt = "You are a helpful AI assistant."
|
|
180
|
-
else:
|
|
181
|
-
self.system_prompt = "You are a helpful AI assistant."
|
|
89
|
+
skills_to_load = [
|
|
90
|
+
*list(self.config.skills or []),
|
|
91
|
+
f"scan_modes/{self.config.scan_mode}",
|
|
92
|
+
]
|
|
93
|
+
skill_content = load_skills(skills_to_load, env)
|
|
94
|
+
env.globals["get_skill"] = lambda name: skill_content.get(name, "")
|
|
95
|
+
|
|
96
|
+
result = env.get_template("system_prompt.jinja").render(
|
|
97
|
+
get_tools_prompt=get_tools_prompt,
|
|
98
|
+
loaded_skill_names=list(skill_content.keys()),
|
|
99
|
+
**skill_content,
|
|
100
|
+
)
|
|
101
|
+
return str(result)
|
|
102
|
+
except Exception: # noqa: BLE001
|
|
103
|
+
return ""
|
|
182
104
|
|
|
183
105
|
def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
|
|
184
106
|
if agent_name:
|
|
@@ -186,280 +108,211 @@ class LLM:
|
|
|
186
108
|
if agent_id:
|
|
187
109
|
self.agent_id = agent_id
|
|
188
110
|
|
|
189
|
-
def
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
content = (
|
|
195
|
-
"\n\n"
|
|
196
|
-
"<agent_identity>\n"
|
|
197
|
-
"<meta>Internal metadata: do not echo or reference; "
|
|
198
|
-
"not part of history or tool calls.</meta>\n"
|
|
199
|
-
"<note>You are now assuming the role of this agent. "
|
|
200
|
-
"Act strictly as this agent and maintain self-identity for this step. "
|
|
201
|
-
"Now go answer the next needed step!</note>\n"
|
|
202
|
-
f"<agent_name>{identity_name}</agent_name>\n"
|
|
203
|
-
f"<agent_id>{identity_id}</agent_id>\n"
|
|
204
|
-
"</agent_identity>\n\n"
|
|
205
|
-
)
|
|
206
|
-
return {"role": "user", "content": content}
|
|
207
|
-
|
|
208
|
-
def _add_cache_control_to_content(
|
|
209
|
-
self, content: str | list[dict[str, Any]]
|
|
210
|
-
) -> str | list[dict[str, Any]]:
|
|
211
|
-
if isinstance(content, str):
|
|
212
|
-
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
|
213
|
-
if isinstance(content, list) and content:
|
|
214
|
-
last_item = content[-1]
|
|
215
|
-
if isinstance(last_item, dict) and last_item.get("type") == "text":
|
|
216
|
-
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
|
|
217
|
-
return content
|
|
218
|
-
|
|
219
|
-
def _is_anthropic_model(self) -> bool:
|
|
220
|
-
if not self.config.model_name:
|
|
221
|
-
return False
|
|
222
|
-
model_lower = self.config.model_name.lower()
|
|
223
|
-
return any(provider in model_lower for provider in ["anthropic/", "claude"])
|
|
224
|
-
|
|
225
|
-
def _calculate_cache_interval(self, total_messages: int) -> int:
|
|
226
|
-
if total_messages <= 1:
|
|
227
|
-
return 10
|
|
228
|
-
|
|
229
|
-
max_cached_messages = 3
|
|
230
|
-
non_system_messages = total_messages - 1
|
|
231
|
-
|
|
232
|
-
interval = 10
|
|
233
|
-
while non_system_messages // interval > max_cached_messages:
|
|
234
|
-
interval += 10
|
|
235
|
-
|
|
236
|
-
return interval
|
|
237
|
-
|
|
238
|
-
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
239
|
-
if (
|
|
240
|
-
not self.config.enable_prompt_caching
|
|
241
|
-
or not supports_prompt_caching(self.config.model_name)
|
|
242
|
-
or not messages
|
|
243
|
-
):
|
|
244
|
-
return messages
|
|
245
|
-
|
|
246
|
-
if not self._is_anthropic_model():
|
|
247
|
-
return messages
|
|
248
|
-
|
|
249
|
-
cached_messages = list(messages)
|
|
111
|
+
async def generate(
|
|
112
|
+
self, conversation_history: list[dict[str, Any]]
|
|
113
|
+
) -> AsyncIterator[LLMResponse]:
|
|
114
|
+
messages = self._prepare_messages(conversation_history)
|
|
115
|
+
max_retries = int(Config.get("strix_llm_max_retries") or "5")
|
|
250
116
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
117
|
+
for attempt in range(max_retries + 1):
|
|
118
|
+
try:
|
|
119
|
+
async for response in self._stream(messages):
|
|
120
|
+
yield response
|
|
121
|
+
return # noqa: TRY300
|
|
122
|
+
except Exception as e: # noqa: BLE001
|
|
123
|
+
if attempt >= max_retries or not self._should_retry(e):
|
|
124
|
+
self._raise_error(e)
|
|
125
|
+
wait = min(10, 2 * (2**attempt))
|
|
126
|
+
await asyncio.sleep(wait)
|
|
257
127
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
128
|
+
async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
|
|
129
|
+
accumulated = ""
|
|
130
|
+
chunks: list[Any] = []
|
|
261
131
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
132
|
+
self._total_stats.requests += 1
|
|
133
|
+
response = await acompletion(**self._build_completion_args(messages), stream=True)
|
|
134
|
+
|
|
135
|
+
async for chunk in response:
|
|
136
|
+
chunks.append(chunk)
|
|
137
|
+
delta = self._get_chunk_content(chunk)
|
|
138
|
+
if delta:
|
|
139
|
+
accumulated += delta
|
|
140
|
+
if "</function>" in accumulated:
|
|
141
|
+
accumulated = accumulated[
|
|
142
|
+
: accumulated.find("</function>") + len("</function>")
|
|
143
|
+
]
|
|
144
|
+
yield LLMResponse(content=accumulated)
|
|
265
145
|
break
|
|
146
|
+
yield LLMResponse(content=accumulated)
|
|
266
147
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
message["content"] = self._add_cache_control_to_content(message["content"])
|
|
270
|
-
cached_messages[i] = message
|
|
271
|
-
cached_count += 1
|
|
148
|
+
if chunks:
|
|
149
|
+
self._update_usage_stats(stream_chunk_builder(chunks))
|
|
272
150
|
|
|
273
|
-
|
|
151
|
+
accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
|
|
152
|
+
yield LLMResponse(
|
|
153
|
+
content=accumulated,
|
|
154
|
+
tool_invocations=parse_tool_invocations(accumulated),
|
|
155
|
+
thinking_blocks=self._extract_thinking(chunks),
|
|
156
|
+
)
|
|
274
157
|
|
|
275
|
-
|
|
276
|
-
self,
|
|
277
|
-
conversation_history: list[dict[str, Any]],
|
|
278
|
-
scan_id: str | None = None,
|
|
279
|
-
step_number: int = 1,
|
|
280
|
-
) -> LLMResponse:
|
|
158
|
+
def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
281
159
|
messages = [{"role": "system", "content": self.system_prompt}]
|
|
282
160
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
try:
|
|
296
|
-
response = await self._make_request(cached_messages)
|
|
297
|
-
self._update_usage_stats(response)
|
|
298
|
-
|
|
299
|
-
content = ""
|
|
300
|
-
if (
|
|
301
|
-
response.choices
|
|
302
|
-
and hasattr(response.choices[0], "message")
|
|
303
|
-
and response.choices[0].message
|
|
304
|
-
):
|
|
305
|
-
content = getattr(response.choices[0].message, "content", "") or ""
|
|
306
|
-
|
|
307
|
-
content = _truncate_to_first_function(content)
|
|
308
|
-
|
|
309
|
-
if "</function>" in content:
|
|
310
|
-
function_end_index = content.find("</function>") + len("</function>")
|
|
311
|
-
content = content[:function_end_index]
|
|
312
|
-
|
|
313
|
-
tool_invocations = parse_tool_invocations(content)
|
|
314
|
-
|
|
315
|
-
return LLMResponse(
|
|
316
|
-
scan_id=scan_id,
|
|
317
|
-
step_number=step_number,
|
|
318
|
-
role=StepRole.AGENT,
|
|
319
|
-
content=content,
|
|
320
|
-
tool_invocations=tool_invocations if tool_invocations else None,
|
|
161
|
+
if self.agent_name:
|
|
162
|
+
messages.append(
|
|
163
|
+
{
|
|
164
|
+
"role": "user",
|
|
165
|
+
"content": (
|
|
166
|
+
f"\n\n<agent_identity>\n"
|
|
167
|
+
f"<meta>Internal metadata: do not echo or reference.</meta>\n"
|
|
168
|
+
f"<agent_name>{self.agent_name}</agent_name>\n"
|
|
169
|
+
f"<agent_id>{self.agent_id}</agent_id>\n"
|
|
170
|
+
f"</agent_identity>\n\n"
|
|
171
|
+
),
|
|
172
|
+
}
|
|
321
173
|
)
|
|
322
174
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
except litellm.NotFoundError as e:
|
|
328
|
-
raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e
|
|
329
|
-
except litellm.ContextWindowExceededError as e:
|
|
330
|
-
raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e
|
|
331
|
-
except litellm.ContentPolicyViolationError as e:
|
|
332
|
-
raise LLMRequestFailedError(
|
|
333
|
-
"LLM request failed: Content policy violation", str(e)
|
|
334
|
-
) from e
|
|
335
|
-
except litellm.ServiceUnavailableError as e:
|
|
336
|
-
raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e
|
|
337
|
-
except litellm.Timeout as e:
|
|
338
|
-
raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e
|
|
339
|
-
except litellm.UnprocessableEntityError as e:
|
|
340
|
-
raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e
|
|
341
|
-
except litellm.InternalServerError as e:
|
|
342
|
-
raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e
|
|
343
|
-
except litellm.APIConnectionError as e:
|
|
344
|
-
raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e
|
|
345
|
-
except litellm.UnsupportedParamsError as e:
|
|
346
|
-
raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e
|
|
347
|
-
except litellm.BudgetExceededError as e:
|
|
348
|
-
raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e
|
|
349
|
-
except litellm.APIResponseValidationError as e:
|
|
350
|
-
raise LLMRequestFailedError(
|
|
351
|
-
"LLM request failed: Response validation error", str(e)
|
|
352
|
-
) from e
|
|
353
|
-
except litellm.JSONSchemaValidationError as e:
|
|
354
|
-
raise LLMRequestFailedError(
|
|
355
|
-
"LLM request failed: JSON schema validation error", str(e)
|
|
356
|
-
) from e
|
|
357
|
-
except litellm.InvalidRequestError as e:
|
|
358
|
-
raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e
|
|
359
|
-
except litellm.BadRequestError as e:
|
|
360
|
-
raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e
|
|
361
|
-
except litellm.APIError as e:
|
|
362
|
-
raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e
|
|
363
|
-
except litellm.OpenAIError as e:
|
|
364
|
-
raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e
|
|
365
|
-
except Exception as e:
|
|
366
|
-
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
|
367
|
-
|
|
368
|
-
@property
|
|
369
|
-
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
|
370
|
-
return {
|
|
371
|
-
"total": self._total_stats.to_dict(),
|
|
372
|
-
"last_request": self._last_request_stats.to_dict(),
|
|
373
|
-
}
|
|
175
|
+
compressed = list(self.memory_compressor.compress_history(conversation_history))
|
|
176
|
+
conversation_history.clear()
|
|
177
|
+
conversation_history.extend(compressed)
|
|
178
|
+
messages.extend(compressed)
|
|
374
179
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
"enabled": self.config.enable_prompt_caching,
|
|
378
|
-
"supported": supports_prompt_caching(self.config.model_name),
|
|
379
|
-
}
|
|
180
|
+
if self._is_anthropic() and self.config.enable_prompt_caching:
|
|
181
|
+
messages = self._add_cache_control(messages)
|
|
380
182
|
|
|
381
|
-
|
|
382
|
-
if not self.config.model_name:
|
|
383
|
-
return True
|
|
183
|
+
return messages
|
|
384
184
|
|
|
385
|
-
|
|
185
|
+
def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
186
|
+
if not self._supports_vision():
|
|
187
|
+
messages = self._strip_images(messages)
|
|
386
188
|
|
|
387
|
-
|
|
388
|
-
if not self.config.model_name:
|
|
389
|
-
return False
|
|
390
|
-
|
|
391
|
-
return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
|
|
392
|
-
|
|
393
|
-
async def _make_request(
|
|
394
|
-
self,
|
|
395
|
-
messages: list[dict[str, Any]],
|
|
396
|
-
) -> ModelResponse:
|
|
397
|
-
completion_args: dict[str, Any] = {
|
|
189
|
+
args: dict[str, Any] = {
|
|
398
190
|
"model": self.config.model_name,
|
|
399
191
|
"messages": messages,
|
|
400
192
|
"timeout": self.config.timeout,
|
|
193
|
+
"stream_options": {"include_usage": True},
|
|
401
194
|
}
|
|
402
195
|
|
|
403
|
-
if
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
196
|
+
if api_key := Config.get("llm_api_key"):
|
|
197
|
+
args["api_key"] = api_key
|
|
198
|
+
if api_base := (
|
|
199
|
+
Config.get("llm_api_base")
|
|
200
|
+
or Config.get("openai_api_base")
|
|
201
|
+
or Config.get("litellm_base_url")
|
|
202
|
+
or Config.get("ollama_api_base")
|
|
203
|
+
):
|
|
204
|
+
args["api_base"] = api_base
|
|
205
|
+
if self._supports_reasoning():
|
|
206
|
+
args["reasoning_effort"] = self._reasoning_effort
|
|
408
207
|
|
|
409
|
-
|
|
410
|
-
response = await queue.make_request(completion_args)
|
|
208
|
+
return args
|
|
411
209
|
|
|
412
|
-
|
|
413
|
-
|
|
210
|
+
def _get_chunk_content(self, chunk: Any) -> str:
|
|
211
|
+
if chunk.choices and hasattr(chunk.choices[0], "delta"):
|
|
212
|
+
return getattr(chunk.choices[0].delta, "content", "") or ""
|
|
213
|
+
return ""
|
|
414
214
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
215
|
+
def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
|
|
216
|
+
if not chunks or not self._supports_reasoning():
|
|
217
|
+
return None
|
|
218
|
+
try:
|
|
219
|
+
resp = stream_chunk_builder(chunks)
|
|
220
|
+
if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
|
|
221
|
+
blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
|
|
222
|
+
return blocks
|
|
223
|
+
except Exception: # noqa: BLE001, S110 # nosec B110
|
|
224
|
+
pass
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
def _update_usage_stats(self, response: Any) -> None:
|
|
418
228
|
try:
|
|
419
229
|
if hasattr(response, "usage") and response.usage:
|
|
420
230
|
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
|
421
231
|
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
|
422
232
|
|
|
423
233
|
cached_tokens = 0
|
|
424
|
-
cache_creation_tokens = 0
|
|
425
|
-
|
|
426
234
|
if hasattr(response.usage, "prompt_tokens_details"):
|
|
427
235
|
prompt_details = response.usage.prompt_tokens_details
|
|
428
236
|
if hasattr(prompt_details, "cached_tokens"):
|
|
429
237
|
cached_tokens = prompt_details.cached_tokens or 0
|
|
430
238
|
|
|
431
|
-
if hasattr(response.usage, "cache_creation_input_tokens"):
|
|
432
|
-
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
|
|
433
|
-
|
|
434
239
|
else:
|
|
435
240
|
input_tokens = 0
|
|
436
241
|
output_tokens = 0
|
|
437
242
|
cached_tokens = 0
|
|
438
|
-
cache_creation_tokens = 0
|
|
439
243
|
|
|
440
244
|
try:
|
|
441
245
|
cost = completion_cost(response) or 0.0
|
|
442
|
-
except Exception
|
|
443
|
-
logger.warning(f"Failed to calculate cost: {e}")
|
|
246
|
+
except Exception: # noqa: BLE001
|
|
444
247
|
cost = 0.0
|
|
445
248
|
|
|
446
249
|
self._total_stats.input_tokens += input_tokens
|
|
447
250
|
self._total_stats.output_tokens += output_tokens
|
|
448
251
|
self._total_stats.cached_tokens += cached_tokens
|
|
449
|
-
self._total_stats.cache_creation_tokens += cache_creation_tokens
|
|
450
252
|
self._total_stats.cost += cost
|
|
451
253
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
254
|
+
except Exception: # noqa: BLE001, S110 # nosec B110
|
|
255
|
+
pass
|
|
256
|
+
|
|
257
|
+
def _should_retry(self, e: Exception) -> bool:
|
|
258
|
+
code = getattr(e, "status_code", None) or getattr(
|
|
259
|
+
getattr(e, "response", None), "status_code", None
|
|
260
|
+
)
|
|
261
|
+
return code is None or litellm._should_retry(code)
|
|
262
|
+
|
|
263
|
+
def _raise_error(self, e: Exception) -> None:
|
|
264
|
+
from strix.telemetry import posthog
|
|
265
|
+
|
|
266
|
+
posthog.error("llm_error", type(e).__name__)
|
|
267
|
+
raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
|
|
268
|
+
|
|
269
|
+
def _is_anthropic(self) -> bool:
|
|
270
|
+
if not self.config.model_name:
|
|
271
|
+
return False
|
|
272
|
+
return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
|
|
273
|
+
|
|
274
|
+
def _supports_vision(self) -> bool:
|
|
275
|
+
try:
|
|
276
|
+
return bool(supports_vision(model=self.config.model_name))
|
|
277
|
+
except Exception: # noqa: BLE001
|
|
278
|
+
return False
|
|
457
279
|
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
280
|
+
def _supports_reasoning(self) -> bool:
|
|
281
|
+
try:
|
|
282
|
+
return bool(supports_reasoning(model=self.config.model_name))
|
|
283
|
+
except Exception: # noqa: BLE001
|
|
284
|
+
return False
|
|
285
|
+
|
|
286
|
+
def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
287
|
+
result = []
|
|
288
|
+
for msg in messages:
|
|
289
|
+
content = msg.get("content")
|
|
290
|
+
if isinstance(content, list):
|
|
291
|
+
text_parts = []
|
|
292
|
+
for item in content:
|
|
293
|
+
if isinstance(item, dict) and item.get("type") == "text":
|
|
294
|
+
text_parts.append(item.get("text", ""))
|
|
295
|
+
elif isinstance(item, dict) and item.get("type") == "image_url":
|
|
296
|
+
text_parts.append("[Image removed - model doesn't support vision]")
|
|
297
|
+
result.append({**msg, "content": "\n".join(text_parts)})
|
|
298
|
+
else:
|
|
299
|
+
result.append(msg)
|
|
300
|
+
return result
|
|
301
|
+
|
|
302
|
+
def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
303
|
+
if not messages or not supports_prompt_caching(self.config.model_name):
|
|
304
|
+
return messages
|
|
462
305
|
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
306
|
+
result = list(messages)
|
|
307
|
+
|
|
308
|
+
if result[0].get("role") == "system":
|
|
309
|
+
content = result[0]["content"]
|
|
310
|
+
result[0] = {
|
|
311
|
+
**result[0],
|
|
312
|
+
"content": [
|
|
313
|
+
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
|
314
|
+
]
|
|
315
|
+
if isinstance(content, str)
|
|
316
|
+
else content,
|
|
317
|
+
}
|
|
318
|
+
return result
|