strix-agent 0.4.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. strix/agents/StrixAgent/strix_agent.py +3 -3
  2. strix/agents/StrixAgent/system_prompt.jinja +30 -26
  3. strix/agents/base_agent.py +159 -75
  4. strix/agents/state.py +5 -2
  5. strix/config/__init__.py +12 -0
  6. strix/config/config.py +172 -0
  7. strix/interface/assets/tui_styles.tcss +195 -230
  8. strix/interface/cli.py +16 -41
  9. strix/interface/main.py +151 -74
  10. strix/interface/streaming_parser.py +119 -0
  11. strix/interface/tool_components/__init__.py +4 -0
  12. strix/interface/tool_components/agent_message_renderer.py +190 -0
  13. strix/interface/tool_components/agents_graph_renderer.py +54 -38
  14. strix/interface/tool_components/base_renderer.py +68 -36
  15. strix/interface/tool_components/browser_renderer.py +106 -91
  16. strix/interface/tool_components/file_edit_renderer.py +117 -36
  17. strix/interface/tool_components/finish_renderer.py +43 -10
  18. strix/interface/tool_components/notes_renderer.py +63 -38
  19. strix/interface/tool_components/proxy_renderer.py +133 -92
  20. strix/interface/tool_components/python_renderer.py +121 -8
  21. strix/interface/tool_components/registry.py +19 -12
  22. strix/interface/tool_components/reporting_renderer.py +196 -28
  23. strix/interface/tool_components/scan_info_renderer.py +22 -19
  24. strix/interface/tool_components/terminal_renderer.py +270 -90
  25. strix/interface/tool_components/thinking_renderer.py +8 -6
  26. strix/interface/tool_components/todo_renderer.py +225 -0
  27. strix/interface/tool_components/user_message_renderer.py +26 -19
  28. strix/interface/tool_components/web_search_renderer.py +7 -6
  29. strix/interface/tui.py +907 -262
  30. strix/interface/utils.py +236 -4
  31. strix/llm/__init__.py +6 -2
  32. strix/llm/config.py +8 -5
  33. strix/llm/dedupe.py +217 -0
  34. strix/llm/llm.py +209 -356
  35. strix/llm/memory_compressor.py +6 -5
  36. strix/llm/utils.py +17 -8
  37. strix/runtime/__init__.py +12 -3
  38. strix/runtime/docker_runtime.py +121 -202
  39. strix/runtime/tool_server.py +55 -95
  40. strix/skills/README.md +64 -0
  41. strix/skills/__init__.py +110 -0
  42. strix/{prompts → skills}/frameworks/nextjs.jinja +26 -0
  43. strix/skills/scan_modes/deep.jinja +145 -0
  44. strix/skills/scan_modes/quick.jinja +63 -0
  45. strix/skills/scan_modes/standard.jinja +91 -0
  46. strix/telemetry/README.md +38 -0
  47. strix/telemetry/__init__.py +7 -1
  48. strix/telemetry/posthog.py +137 -0
  49. strix/telemetry/tracer.py +194 -54
  50. strix/tools/__init__.py +11 -4
  51. strix/tools/agents_graph/agents_graph_actions.py +20 -21
  52. strix/tools/agents_graph/agents_graph_actions_schema.xml +8 -8
  53. strix/tools/browser/browser_actions.py +10 -6
  54. strix/tools/browser/browser_actions_schema.xml +6 -1
  55. strix/tools/browser/browser_instance.py +96 -48
  56. strix/tools/browser/tab_manager.py +121 -102
  57. strix/tools/context.py +12 -0
  58. strix/tools/executor.py +63 -4
  59. strix/tools/file_edit/file_edit_actions.py +6 -3
  60. strix/tools/file_edit/file_edit_actions_schema.xml +45 -3
  61. strix/tools/finish/finish_actions.py +80 -105
  62. strix/tools/finish/finish_actions_schema.xml +121 -14
  63. strix/tools/notes/notes_actions.py +6 -33
  64. strix/tools/notes/notes_actions_schema.xml +50 -46
  65. strix/tools/proxy/proxy_actions.py +14 -2
  66. strix/tools/proxy/proxy_actions_schema.xml +0 -1
  67. strix/tools/proxy/proxy_manager.py +28 -16
  68. strix/tools/python/python_actions.py +2 -2
  69. strix/tools/python/python_actions_schema.xml +9 -1
  70. strix/tools/python/python_instance.py +39 -37
  71. strix/tools/python/python_manager.py +43 -31
  72. strix/tools/registry.py +73 -12
  73. strix/tools/reporting/reporting_actions.py +218 -31
  74. strix/tools/reporting/reporting_actions_schema.xml +256 -8
  75. strix/tools/terminal/terminal_actions.py +2 -2
  76. strix/tools/terminal/terminal_actions_schema.xml +6 -0
  77. strix/tools/terminal/terminal_manager.py +41 -30
  78. strix/tools/thinking/thinking_actions_schema.xml +27 -25
  79. strix/tools/todo/__init__.py +18 -0
  80. strix/tools/todo/todo_actions.py +568 -0
  81. strix/tools/todo/todo_actions_schema.xml +225 -0
  82. strix/utils/__init__.py +0 -0
  83. strix/utils/resource_paths.py +13 -0
  84. {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info}/METADATA +90 -65
  85. strix_agent-0.6.2.dist-info/RECORD +134 -0
  86. {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info}/WHEEL +1 -1
  87. strix/llm/request_queue.py +0 -87
  88. strix/prompts/README.md +0 -64
  89. strix/prompts/__init__.py +0 -109
  90. strix_agent-0.4.0.dist-info/RECORD +0 -118
  91. /strix/{prompts → skills}/cloud/.gitkeep +0 -0
  92. /strix/{prompts → skills}/coordination/root_agent.jinja +0 -0
  93. /strix/{prompts → skills}/custom/.gitkeep +0 -0
  94. /strix/{prompts → skills}/frameworks/fastapi.jinja +0 -0
  95. /strix/{prompts → skills}/protocols/graphql.jinja +0 -0
  96. /strix/{prompts → skills}/reconnaissance/.gitkeep +0 -0
  97. /strix/{prompts → skills}/technologies/firebase_firestore.jinja +0 -0
  98. /strix/{prompts → skills}/technologies/supabase.jinja +0 -0
  99. /strix/{prompts → skills}/vulnerabilities/authentication_jwt.jinja +0 -0
  100. /strix/{prompts → skills}/vulnerabilities/broken_function_level_authorization.jinja +0 -0
  101. /strix/{prompts → skills}/vulnerabilities/business_logic.jinja +0 -0
  102. /strix/{prompts → skills}/vulnerabilities/csrf.jinja +0 -0
  103. /strix/{prompts → skills}/vulnerabilities/idor.jinja +0 -0
  104. /strix/{prompts → skills}/vulnerabilities/information_disclosure.jinja +0 -0
  105. /strix/{prompts → skills}/vulnerabilities/insecure_file_uploads.jinja +0 -0
  106. /strix/{prompts → skills}/vulnerabilities/mass_assignment.jinja +0 -0
  107. /strix/{prompts → skills}/vulnerabilities/open_redirect.jinja +0 -0
  108. /strix/{prompts → skills}/vulnerabilities/path_traversal_lfi_rfi.jinja +0 -0
  109. /strix/{prompts → skills}/vulnerabilities/race_conditions.jinja +0 -0
  110. /strix/{prompts → skills}/vulnerabilities/rce.jinja +0 -0
  111. /strix/{prompts → skills}/vulnerabilities/sql_injection.jinja +0 -0
  112. /strix/{prompts → skills}/vulnerabilities/ssrf.jinja +0 -0
  113. /strix/{prompts → skills}/vulnerabilities/subdomain_takeover.jinja +0 -0
  114. /strix/{prompts → skills}/vulnerabilities/xss.jinja +0 -0
  115. /strix/{prompts → skills}/vulnerabilities/xxe.jinja +0 -0
  116. {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info}/entry_points.txt +0 -0
  117. {strix_agent-0.4.0.dist-info → strix_agent-0.6.2.dist-info/licenses}/LICENSE +0 -0
strix/llm/llm.py CHANGED
@@ -1,42 +1,28 @@
1
- import logging
2
- import os
1
+ import asyncio
2
+ from collections.abc import AsyncIterator
3
3
  from dataclasses import dataclass
4
- from enum import Enum
5
- from fnmatch import fnmatch
6
- from pathlib import Path
7
4
  from typing import Any
8
5
 
9
6
  import litellm
10
- from jinja2 import (
11
- Environment,
12
- FileSystemLoader,
13
- select_autoescape,
14
- )
15
- from litellm import ModelResponse, completion_cost
16
- from litellm.utils import supports_prompt_caching
7
+ from jinja2 import Environment, FileSystemLoader, select_autoescape
8
+ from litellm import acompletion, completion_cost, stream_chunk_builder, supports_reasoning
9
+ from litellm.utils import supports_prompt_caching, supports_vision
17
10
 
11
+ from strix.config import Config
18
12
  from strix.llm.config import LLMConfig
19
13
  from strix.llm.memory_compressor import MemoryCompressor
20
- from strix.llm.request_queue import get_global_queue
21
- from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
22
- from strix.prompts import load_prompt_modules
14
+ from strix.llm.utils import (
15
+ _truncate_to_first_function,
16
+ fix_incomplete_tool_call,
17
+ parse_tool_invocations,
18
+ )
19
+ from strix.skills import load_skills
23
20
  from strix.tools import get_tools_prompt
21
+ from strix.utils.resource_paths import get_strix_resource_path
24
22
 
25
23
 
26
- logger = logging.getLogger(__name__)
27
-
28
- api_key = os.getenv("LLM_API_KEY")
29
- if api_key:
30
- litellm.api_key = api_key
31
-
32
- api_base = (
33
- os.getenv("LLM_API_BASE")
34
- or os.getenv("OPENAI_API_BASE")
35
- or os.getenv("LITELLM_BASE_URL")
36
- or os.getenv("OLLAMA_API_BASE")
37
- )
38
- if api_base:
39
- litellm.api_base = api_base
24
+ litellm.drop_params = True
25
+ litellm.modify_params = True
40
26
 
41
27
 
42
28
  class LLMRequestFailedError(Exception):
@@ -46,70 +32,11 @@ class LLMRequestFailedError(Exception):
46
32
  self.details = details
47
33
 
48
34
 
49
- SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
50
- "o1*",
51
- "grok-4-0709",
52
- "grok-code-fast-1",
53
- "deepseek-r1-0528*",
54
- ]
55
-
56
- REASONING_EFFORT_PATTERNS: list[str] = [
57
- "o1-2024-12-17",
58
- "o1",
59
- "o3",
60
- "o3-2025-04-16",
61
- "o3-mini-2025-01-31",
62
- "o3-mini",
63
- "o4-mini",
64
- "o4-mini-2025-04-16",
65
- "gemini-2.5-flash",
66
- "gemini-2.5-pro",
67
- "gpt-5*",
68
- "deepseek-r1-0528*",
69
- "claude-sonnet-4-5*",
70
- "claude-haiku-4-5*",
71
- ]
72
-
73
-
74
- def normalize_model_name(model: str) -> str:
75
- raw = (model or "").strip().lower()
76
- if "/" in raw:
77
- name = raw.split("/")[-1]
78
- if ":" in name:
79
- name = name.split(":", 1)[0]
80
- else:
81
- name = raw
82
- if name.endswith("-gguf"):
83
- name = name[: -len("-gguf")]
84
- return name
85
-
86
-
87
- def model_matches(model: str, patterns: list[str]) -> bool:
88
- raw = (model or "").strip().lower()
89
- name = normalize_model_name(model)
90
- for pat in patterns:
91
- pat_l = pat.lower()
92
- if "/" in pat_l:
93
- if fnmatch(raw, pat_l):
94
- return True
95
- elif fnmatch(name, pat_l):
96
- return True
97
- return False
98
-
99
-
100
- class StepRole(str, Enum):
101
- AGENT = "agent"
102
- USER = "user"
103
- SYSTEM = "system"
104
-
105
-
106
35
  @dataclass
107
36
  class LLMResponse:
108
37
  content: str
109
38
  tool_invocations: list[dict[str, Any]] | None = None
110
- scan_id: str | None = None
111
- step_number: int = 1
112
- role: StepRole = StepRole.AGENT
39
+ thinking_blocks: list[dict[str, Any]] | None = None
113
40
 
114
41
 
115
42
  @dataclass
@@ -117,68 +44,63 @@ class RequestStats:
117
44
  input_tokens: int = 0
118
45
  output_tokens: int = 0
119
46
  cached_tokens: int = 0
120
- cache_creation_tokens: int = 0
121
47
  cost: float = 0.0
122
48
  requests: int = 0
123
- failed_requests: int = 0
124
49
 
125
50
  def to_dict(self) -> dict[str, int | float]:
126
51
  return {
127
52
  "input_tokens": self.input_tokens,
128
53
  "output_tokens": self.output_tokens,
129
54
  "cached_tokens": self.cached_tokens,
130
- "cache_creation_tokens": self.cache_creation_tokens,
131
55
  "cost": round(self.cost, 4),
132
56
  "requests": self.requests,
133
- "failed_requests": self.failed_requests,
134
57
  }
135
58
 
136
59
 
137
60
  class LLM:
138
- def __init__(
139
- self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
140
- ):
61
+ def __init__(self, config: LLMConfig, agent_name: str | None = None):
141
62
  self.config = config
142
63
  self.agent_name = agent_name
143
- self.agent_id = agent_id
64
+ self.agent_id: str | None = None
144
65
  self._total_stats = RequestStats()
145
- self._last_request_stats = RequestStats()
66
+ self.memory_compressor = MemoryCompressor(model_name=config.model_name)
67
+ self.system_prompt = self._load_system_prompt(agent_name)
68
+
69
+ reasoning = Config.get("strix_reasoning_effort")
70
+ if reasoning:
71
+ self._reasoning_effort = reasoning
72
+ elif config.scan_mode == "quick":
73
+ self._reasoning_effort = "medium"
74
+ else:
75
+ self._reasoning_effort = "high"
146
76
 
147
- self.memory_compressor = MemoryCompressor(
148
- model_name=self.config.model_name,
149
- timeout=self.config.timeout,
150
- )
77
+ def _load_system_prompt(self, agent_name: str | None) -> str:
78
+ if not agent_name:
79
+ return ""
151
80
 
152
- if agent_name:
153
- prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
154
- prompts_dir = Path(__file__).parent.parent / "prompts"
155
-
156
- loader = FileSystemLoader([prompt_dir, prompts_dir])
157
- self.jinja_env = Environment(
158
- loader=loader,
81
+ try:
82
+ prompt_dir = get_strix_resource_path("agents", agent_name)
83
+ skills_dir = get_strix_resource_path("skills")
84
+ env = Environment(
85
+ loader=FileSystemLoader([prompt_dir, skills_dir]),
159
86
  autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
160
87
  )
161
88
 
162
- try:
163
- prompt_module_content = load_prompt_modules(
164
- self.config.prompt_modules or [], self.jinja_env
165
- )
166
-
167
- def get_module(name: str) -> str:
168
- return prompt_module_content.get(name, "")
169
-
170
- self.jinja_env.globals["get_module"] = get_module
171
-
172
- self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
173
- get_tools_prompt=get_tools_prompt,
174
- loaded_module_names=list(prompt_module_content.keys()),
175
- **prompt_module_content,
176
- )
177
- except (FileNotFoundError, OSError, ValueError) as e:
178
- logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
179
- self.system_prompt = "You are a helpful AI assistant."
180
- else:
181
- self.system_prompt = "You are a helpful AI assistant."
89
+ skills_to_load = [
90
+ *list(self.config.skills or []),
91
+ f"scan_modes/{self.config.scan_mode}",
92
+ ]
93
+ skill_content = load_skills(skills_to_load, env)
94
+ env.globals["get_skill"] = lambda name: skill_content.get(name, "")
95
+
96
+ result = env.get_template("system_prompt.jinja").render(
97
+ get_tools_prompt=get_tools_prompt,
98
+ loaded_skill_names=list(skill_content.keys()),
99
+ **skill_content,
100
+ )
101
+ return str(result)
102
+ except Exception: # noqa: BLE001
103
+ return ""
182
104
 
183
105
  def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
184
106
  if agent_name:
@@ -186,280 +108,211 @@ class LLM:
186
108
  if agent_id:
187
109
  self.agent_id = agent_id
188
110
 
189
- def _build_identity_message(self) -> dict[str, Any] | None:
190
- if not (self.agent_name and str(self.agent_name).strip()):
191
- return None
192
- identity_name = self.agent_name
193
- identity_id = self.agent_id
194
- content = (
195
- "\n\n"
196
- "<agent_identity>\n"
197
- "<meta>Internal metadata: do not echo or reference; "
198
- "not part of history or tool calls.</meta>\n"
199
- "<note>You are now assuming the role of this agent. "
200
- "Act strictly as this agent and maintain self-identity for this step. "
201
- "Now go answer the next needed step!</note>\n"
202
- f"<agent_name>{identity_name}</agent_name>\n"
203
- f"<agent_id>{identity_id}</agent_id>\n"
204
- "</agent_identity>\n\n"
205
- )
206
- return {"role": "user", "content": content}
207
-
208
- def _add_cache_control_to_content(
209
- self, content: str | list[dict[str, Any]]
210
- ) -> str | list[dict[str, Any]]:
211
- if isinstance(content, str):
212
- return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
213
- if isinstance(content, list) and content:
214
- last_item = content[-1]
215
- if isinstance(last_item, dict) and last_item.get("type") == "text":
216
- return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
217
- return content
218
-
219
- def _is_anthropic_model(self) -> bool:
220
- if not self.config.model_name:
221
- return False
222
- model_lower = self.config.model_name.lower()
223
- return any(provider in model_lower for provider in ["anthropic/", "claude"])
224
-
225
- def _calculate_cache_interval(self, total_messages: int) -> int:
226
- if total_messages <= 1:
227
- return 10
228
-
229
- max_cached_messages = 3
230
- non_system_messages = total_messages - 1
231
-
232
- interval = 10
233
- while non_system_messages // interval > max_cached_messages:
234
- interval += 10
235
-
236
- return interval
237
-
238
- def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
239
- if (
240
- not self.config.enable_prompt_caching
241
- or not supports_prompt_caching(self.config.model_name)
242
- or not messages
243
- ):
244
- return messages
245
-
246
- if not self._is_anthropic_model():
247
- return messages
248
-
249
- cached_messages = list(messages)
111
+ async def generate(
112
+ self, conversation_history: list[dict[str, Any]]
113
+ ) -> AsyncIterator[LLMResponse]:
114
+ messages = self._prepare_messages(conversation_history)
115
+ max_retries = int(Config.get("strix_llm_max_retries") or "5")
250
116
 
251
- if cached_messages and cached_messages[0].get("role") == "system":
252
- system_message = cached_messages[0].copy()
253
- system_message["content"] = self._add_cache_control_to_content(
254
- system_message["content"]
255
- )
256
- cached_messages[0] = system_message
117
+ for attempt in range(max_retries + 1):
118
+ try:
119
+ async for response in self._stream(messages):
120
+ yield response
121
+ return # noqa: TRY300
122
+ except Exception as e: # noqa: BLE001
123
+ if attempt >= max_retries or not self._should_retry(e):
124
+ self._raise_error(e)
125
+ wait = min(10, 2 * (2**attempt))
126
+ await asyncio.sleep(wait)
257
127
 
258
- total_messages = len(cached_messages)
259
- if total_messages > 1:
260
- interval = self._calculate_cache_interval(total_messages)
128
+ async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]:
129
+ accumulated = ""
130
+ chunks: list[Any] = []
261
131
 
262
- cached_count = 0
263
- for i in range(interval, total_messages, interval):
264
- if cached_count >= 3:
132
+ self._total_stats.requests += 1
133
+ response = await acompletion(**self._build_completion_args(messages), stream=True)
134
+
135
+ async for chunk in response:
136
+ chunks.append(chunk)
137
+ delta = self._get_chunk_content(chunk)
138
+ if delta:
139
+ accumulated += delta
140
+ if "</function>" in accumulated:
141
+ accumulated = accumulated[
142
+ : accumulated.find("</function>") + len("</function>")
143
+ ]
144
+ yield LLMResponse(content=accumulated)
265
145
  break
146
+ yield LLMResponse(content=accumulated)
266
147
 
267
- if i < len(cached_messages):
268
- message = cached_messages[i].copy()
269
- message["content"] = self._add_cache_control_to_content(message["content"])
270
- cached_messages[i] = message
271
- cached_count += 1
148
+ if chunks:
149
+ self._update_usage_stats(stream_chunk_builder(chunks))
272
150
 
273
- return cached_messages
151
+ accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated))
152
+ yield LLMResponse(
153
+ content=accumulated,
154
+ tool_invocations=parse_tool_invocations(accumulated),
155
+ thinking_blocks=self._extract_thinking(chunks),
156
+ )
274
157
 
275
- async def generate( # noqa: PLR0912, PLR0915
276
- self,
277
- conversation_history: list[dict[str, Any]],
278
- scan_id: str | None = None,
279
- step_number: int = 1,
280
- ) -> LLMResponse:
158
+ def _prepare_messages(self, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]]:
281
159
  messages = [{"role": "system", "content": self.system_prompt}]
282
160
 
283
- identity_message = self._build_identity_message()
284
- if identity_message:
285
- messages.append(identity_message)
286
-
287
- compressed_history = list(self.memory_compressor.compress_history(conversation_history))
288
-
289
- conversation_history.clear()
290
- conversation_history.extend(compressed_history)
291
- messages.extend(compressed_history)
292
-
293
- cached_messages = self._prepare_cached_messages(messages)
294
-
295
- try:
296
- response = await self._make_request(cached_messages)
297
- self._update_usage_stats(response)
298
-
299
- content = ""
300
- if (
301
- response.choices
302
- and hasattr(response.choices[0], "message")
303
- and response.choices[0].message
304
- ):
305
- content = getattr(response.choices[0].message, "content", "") or ""
306
-
307
- content = _truncate_to_first_function(content)
308
-
309
- if "</function>" in content:
310
- function_end_index = content.find("</function>") + len("</function>")
311
- content = content[:function_end_index]
312
-
313
- tool_invocations = parse_tool_invocations(content)
314
-
315
- return LLMResponse(
316
- scan_id=scan_id,
317
- step_number=step_number,
318
- role=StepRole.AGENT,
319
- content=content,
320
- tool_invocations=tool_invocations if tool_invocations else None,
161
+ if self.agent_name:
162
+ messages.append(
163
+ {
164
+ "role": "user",
165
+ "content": (
166
+ f"\n\n<agent_identity>\n"
167
+ f"<meta>Internal metadata: do not echo or reference.</meta>\n"
168
+ f"<agent_name>{self.agent_name}</agent_name>\n"
169
+ f"<agent_id>{self.agent_id}</agent_id>\n"
170
+ f"</agent_identity>\n\n"
171
+ ),
172
+ }
321
173
  )
322
174
 
323
- except litellm.RateLimitError as e:
324
- raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e
325
- except litellm.AuthenticationError as e:
326
- raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e
327
- except litellm.NotFoundError as e:
328
- raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e
329
- except litellm.ContextWindowExceededError as e:
330
- raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e
331
- except litellm.ContentPolicyViolationError as e:
332
- raise LLMRequestFailedError(
333
- "LLM request failed: Content policy violation", str(e)
334
- ) from e
335
- except litellm.ServiceUnavailableError as e:
336
- raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e
337
- except litellm.Timeout as e:
338
- raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e
339
- except litellm.UnprocessableEntityError as e:
340
- raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e
341
- except litellm.InternalServerError as e:
342
- raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e
343
- except litellm.APIConnectionError as e:
344
- raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e
345
- except litellm.UnsupportedParamsError as e:
346
- raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e
347
- except litellm.BudgetExceededError as e:
348
- raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e
349
- except litellm.APIResponseValidationError as e:
350
- raise LLMRequestFailedError(
351
- "LLM request failed: Response validation error", str(e)
352
- ) from e
353
- except litellm.JSONSchemaValidationError as e:
354
- raise LLMRequestFailedError(
355
- "LLM request failed: JSON schema validation error", str(e)
356
- ) from e
357
- except litellm.InvalidRequestError as e:
358
- raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e
359
- except litellm.BadRequestError as e:
360
- raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e
361
- except litellm.APIError as e:
362
- raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e
363
- except litellm.OpenAIError as e:
364
- raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e
365
- except Exception as e:
366
- raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
367
-
368
- @property
369
- def usage_stats(self) -> dict[str, dict[str, int | float]]:
370
- return {
371
- "total": self._total_stats.to_dict(),
372
- "last_request": self._last_request_stats.to_dict(),
373
- }
175
+ compressed = list(self.memory_compressor.compress_history(conversation_history))
176
+ conversation_history.clear()
177
+ conversation_history.extend(compressed)
178
+ messages.extend(compressed)
374
179
 
375
- def get_cache_config(self) -> dict[str, bool]:
376
- return {
377
- "enabled": self.config.enable_prompt_caching,
378
- "supported": supports_prompt_caching(self.config.model_name),
379
- }
180
+ if self._is_anthropic() and self.config.enable_prompt_caching:
181
+ messages = self._add_cache_control(messages)
380
182
 
381
- def _should_include_stop_param(self) -> bool:
382
- if not self.config.model_name:
383
- return True
183
+ return messages
384
184
 
385
- return not model_matches(self.config.model_name, SUPPORTS_STOP_WORDS_FALSE_PATTERNS)
185
+ def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
186
+ if not self._supports_vision():
187
+ messages = self._strip_images(messages)
386
188
 
387
- def _should_include_reasoning_effort(self) -> bool:
388
- if not self.config.model_name:
389
- return False
390
-
391
- return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
392
-
393
- async def _make_request(
394
- self,
395
- messages: list[dict[str, Any]],
396
- ) -> ModelResponse:
397
- completion_args: dict[str, Any] = {
189
+ args: dict[str, Any] = {
398
190
  "model": self.config.model_name,
399
191
  "messages": messages,
400
192
  "timeout": self.config.timeout,
193
+ "stream_options": {"include_usage": True},
401
194
  }
402
195
 
403
- if self._should_include_stop_param():
404
- completion_args["stop"] = ["</function>"]
405
-
406
- if self._should_include_reasoning_effort():
407
- completion_args["reasoning_effort"] = "high"
196
+ if api_key := Config.get("llm_api_key"):
197
+ args["api_key"] = api_key
198
+ if api_base := (
199
+ Config.get("llm_api_base")
200
+ or Config.get("openai_api_base")
201
+ or Config.get("litellm_base_url")
202
+ or Config.get("ollama_api_base")
203
+ ):
204
+ args["api_base"] = api_base
205
+ if self._supports_reasoning():
206
+ args["reasoning_effort"] = self._reasoning_effort
408
207
 
409
- queue = get_global_queue()
410
- response = await queue.make_request(completion_args)
208
+ return args
411
209
 
412
- self._total_stats.requests += 1
413
- self._last_request_stats = RequestStats(requests=1)
210
+ def _get_chunk_content(self, chunk: Any) -> str:
211
+ if chunk.choices and hasattr(chunk.choices[0], "delta"):
212
+ return getattr(chunk.choices[0].delta, "content", "") or ""
213
+ return ""
414
214
 
415
- return response
416
-
417
- def _update_usage_stats(self, response: ModelResponse) -> None:
215
+ def _extract_thinking(self, chunks: list[Any]) -> list[dict[str, Any]] | None:
216
+ if not chunks or not self._supports_reasoning():
217
+ return None
218
+ try:
219
+ resp = stream_chunk_builder(chunks)
220
+ if resp.choices and hasattr(resp.choices[0].message, "thinking_blocks"):
221
+ blocks: list[dict[str, Any]] = resp.choices[0].message.thinking_blocks
222
+ return blocks
223
+ except Exception: # noqa: BLE001, S110 # nosec B110
224
+ pass
225
+ return None
226
+
227
+ def _update_usage_stats(self, response: Any) -> None:
418
228
  try:
419
229
  if hasattr(response, "usage") and response.usage:
420
230
  input_tokens = getattr(response.usage, "prompt_tokens", 0)
421
231
  output_tokens = getattr(response.usage, "completion_tokens", 0)
422
232
 
423
233
  cached_tokens = 0
424
- cache_creation_tokens = 0
425
-
426
234
  if hasattr(response.usage, "prompt_tokens_details"):
427
235
  prompt_details = response.usage.prompt_tokens_details
428
236
  if hasattr(prompt_details, "cached_tokens"):
429
237
  cached_tokens = prompt_details.cached_tokens or 0
430
238
 
431
- if hasattr(response.usage, "cache_creation_input_tokens"):
432
- cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
433
-
434
239
  else:
435
240
  input_tokens = 0
436
241
  output_tokens = 0
437
242
  cached_tokens = 0
438
- cache_creation_tokens = 0
439
243
 
440
244
  try:
441
245
  cost = completion_cost(response) or 0.0
442
- except Exception as e: # noqa: BLE001
443
- logger.warning(f"Failed to calculate cost: {e}")
246
+ except Exception: # noqa: BLE001
444
247
  cost = 0.0
445
248
 
446
249
  self._total_stats.input_tokens += input_tokens
447
250
  self._total_stats.output_tokens += output_tokens
448
251
  self._total_stats.cached_tokens += cached_tokens
449
- self._total_stats.cache_creation_tokens += cache_creation_tokens
450
252
  self._total_stats.cost += cost
451
253
 
452
- self._last_request_stats.input_tokens = input_tokens
453
- self._last_request_stats.output_tokens = output_tokens
454
- self._last_request_stats.cached_tokens = cached_tokens
455
- self._last_request_stats.cache_creation_tokens = cache_creation_tokens
456
- self._last_request_stats.cost = cost
254
+ except Exception: # noqa: BLE001, S110 # nosec B110
255
+ pass
256
+
257
+ def _should_retry(self, e: Exception) -> bool:
258
+ code = getattr(e, "status_code", None) or getattr(
259
+ getattr(e, "response", None), "status_code", None
260
+ )
261
+ return code is None or litellm._should_retry(code)
262
+
263
+ def _raise_error(self, e: Exception) -> None:
264
+ from strix.telemetry import posthog
265
+
266
+ posthog.error("llm_error", type(e).__name__)
267
+ raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
268
+
269
+ def _is_anthropic(self) -> bool:
270
+ if not self.config.model_name:
271
+ return False
272
+ return any(p in self.config.model_name.lower() for p in ["anthropic/", "claude"])
273
+
274
+ def _supports_vision(self) -> bool:
275
+ try:
276
+ return bool(supports_vision(model=self.config.model_name))
277
+ except Exception: # noqa: BLE001
278
+ return False
457
279
 
458
- if cached_tokens > 0:
459
- logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
460
- if cache_creation_tokens > 0:
461
- logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
280
+ def _supports_reasoning(self) -> bool:
281
+ try:
282
+ return bool(supports_reasoning(model=self.config.model_name))
283
+ except Exception: # noqa: BLE001
284
+ return False
285
+
286
+ def _strip_images(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
287
+ result = []
288
+ for msg in messages:
289
+ content = msg.get("content")
290
+ if isinstance(content, list):
291
+ text_parts = []
292
+ for item in content:
293
+ if isinstance(item, dict) and item.get("type") == "text":
294
+ text_parts.append(item.get("text", ""))
295
+ elif isinstance(item, dict) and item.get("type") == "image_url":
296
+ text_parts.append("[Image removed - model doesn't support vision]")
297
+ result.append({**msg, "content": "\n".join(text_parts)})
298
+ else:
299
+ result.append(msg)
300
+ return result
301
+
302
+ def _add_cache_control(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
303
+ if not messages or not supports_prompt_caching(self.config.model_name):
304
+ return messages
462
305
 
463
- logger.info(f"Usage stats: {self.usage_stats}")
464
- except Exception as e: # noqa: BLE001
465
- logger.warning(f"Failed to update usage stats: {e}")
306
+ result = list(messages)
307
+
308
+ if result[0].get("role") == "system":
309
+ content = result[0]["content"]
310
+ result[0] = {
311
+ **result[0],
312
+ "content": [
313
+ {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
314
+ ]
315
+ if isinstance(content, str)
316
+ else content,
317
+ }
318
+ return result