strix-agent 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. strix/__init__.py +0 -0
  2. strix/agents/StrixAgent/__init__.py +4 -0
  3. strix/agents/StrixAgent/strix_agent.py +89 -0
  4. strix/agents/StrixAgent/system_prompt.jinja +404 -0
  5. strix/agents/__init__.py +10 -0
  6. strix/agents/base_agent.py +518 -0
  7. strix/agents/state.py +163 -0
  8. strix/interface/__init__.py +4 -0
  9. strix/interface/assets/tui_styles.tcss +694 -0
  10. strix/interface/cli.py +230 -0
  11. strix/interface/main.py +500 -0
  12. strix/interface/tool_components/__init__.py +39 -0
  13. strix/interface/tool_components/agents_graph_renderer.py +123 -0
  14. strix/interface/tool_components/base_renderer.py +62 -0
  15. strix/interface/tool_components/browser_renderer.py +120 -0
  16. strix/interface/tool_components/file_edit_renderer.py +99 -0
  17. strix/interface/tool_components/finish_renderer.py +31 -0
  18. strix/interface/tool_components/notes_renderer.py +108 -0
  19. strix/interface/tool_components/proxy_renderer.py +255 -0
  20. strix/interface/tool_components/python_renderer.py +34 -0
  21. strix/interface/tool_components/registry.py +72 -0
  22. strix/interface/tool_components/reporting_renderer.py +53 -0
  23. strix/interface/tool_components/scan_info_renderer.py +64 -0
  24. strix/interface/tool_components/terminal_renderer.py +131 -0
  25. strix/interface/tool_components/thinking_renderer.py +29 -0
  26. strix/interface/tool_components/user_message_renderer.py +43 -0
  27. strix/interface/tool_components/web_search_renderer.py +28 -0
  28. strix/interface/tui.py +1274 -0
  29. strix/interface/utils.py +559 -0
  30. strix/llm/__init__.py +15 -0
  31. strix/llm/config.py +20 -0
  32. strix/llm/llm.py +465 -0
  33. strix/llm/memory_compressor.py +212 -0
  34. strix/llm/request_queue.py +87 -0
  35. strix/llm/utils.py +87 -0
  36. strix/prompts/README.md +64 -0
  37. strix/prompts/__init__.py +109 -0
  38. strix/prompts/cloud/.gitkeep +0 -0
  39. strix/prompts/coordination/root_agent.jinja +41 -0
  40. strix/prompts/custom/.gitkeep +0 -0
  41. strix/prompts/frameworks/fastapi.jinja +142 -0
  42. strix/prompts/frameworks/nextjs.jinja +126 -0
  43. strix/prompts/protocols/graphql.jinja +215 -0
  44. strix/prompts/reconnaissance/.gitkeep +0 -0
  45. strix/prompts/technologies/firebase_firestore.jinja +177 -0
  46. strix/prompts/technologies/supabase.jinja +189 -0
  47. strix/prompts/vulnerabilities/authentication_jwt.jinja +147 -0
  48. strix/prompts/vulnerabilities/broken_function_level_authorization.jinja +146 -0
  49. strix/prompts/vulnerabilities/business_logic.jinja +171 -0
  50. strix/prompts/vulnerabilities/csrf.jinja +174 -0
  51. strix/prompts/vulnerabilities/idor.jinja +195 -0
  52. strix/prompts/vulnerabilities/information_disclosure.jinja +222 -0
  53. strix/prompts/vulnerabilities/insecure_file_uploads.jinja +188 -0
  54. strix/prompts/vulnerabilities/mass_assignment.jinja +141 -0
  55. strix/prompts/vulnerabilities/open_redirect.jinja +177 -0
  56. strix/prompts/vulnerabilities/path_traversal_lfi_rfi.jinja +142 -0
  57. strix/prompts/vulnerabilities/race_conditions.jinja +164 -0
  58. strix/prompts/vulnerabilities/rce.jinja +154 -0
  59. strix/prompts/vulnerabilities/sql_injection.jinja +151 -0
  60. strix/prompts/vulnerabilities/ssrf.jinja +135 -0
  61. strix/prompts/vulnerabilities/subdomain_takeover.jinja +155 -0
  62. strix/prompts/vulnerabilities/xss.jinja +169 -0
  63. strix/prompts/vulnerabilities/xxe.jinja +184 -0
  64. strix/runtime/__init__.py +19 -0
  65. strix/runtime/docker_runtime.py +399 -0
  66. strix/runtime/runtime.py +29 -0
  67. strix/runtime/tool_server.py +205 -0
  68. strix/telemetry/__init__.py +4 -0
  69. strix/telemetry/tracer.py +337 -0
  70. strix/tools/__init__.py +64 -0
  71. strix/tools/agents_graph/__init__.py +16 -0
  72. strix/tools/agents_graph/agents_graph_actions.py +621 -0
  73. strix/tools/agents_graph/agents_graph_actions_schema.xml +226 -0
  74. strix/tools/argument_parser.py +121 -0
  75. strix/tools/browser/__init__.py +4 -0
  76. strix/tools/browser/browser_actions.py +236 -0
  77. strix/tools/browser/browser_actions_schema.xml +183 -0
  78. strix/tools/browser/browser_instance.py +533 -0
  79. strix/tools/browser/tab_manager.py +342 -0
  80. strix/tools/executor.py +305 -0
  81. strix/tools/file_edit/__init__.py +4 -0
  82. strix/tools/file_edit/file_edit_actions.py +141 -0
  83. strix/tools/file_edit/file_edit_actions_schema.xml +128 -0
  84. strix/tools/finish/__init__.py +4 -0
  85. strix/tools/finish/finish_actions.py +174 -0
  86. strix/tools/finish/finish_actions_schema.xml +45 -0
  87. strix/tools/notes/__init__.py +14 -0
  88. strix/tools/notes/notes_actions.py +191 -0
  89. strix/tools/notes/notes_actions_schema.xml +150 -0
  90. strix/tools/proxy/__init__.py +20 -0
  91. strix/tools/proxy/proxy_actions.py +101 -0
  92. strix/tools/proxy/proxy_actions_schema.xml +267 -0
  93. strix/tools/proxy/proxy_manager.py +785 -0
  94. strix/tools/python/__init__.py +4 -0
  95. strix/tools/python/python_actions.py +47 -0
  96. strix/tools/python/python_actions_schema.xml +131 -0
  97. strix/tools/python/python_instance.py +172 -0
  98. strix/tools/python/python_manager.py +131 -0
  99. strix/tools/registry.py +196 -0
  100. strix/tools/reporting/__init__.py +6 -0
  101. strix/tools/reporting/reporting_actions.py +63 -0
  102. strix/tools/reporting/reporting_actions_schema.xml +30 -0
  103. strix/tools/terminal/__init__.py +4 -0
  104. strix/tools/terminal/terminal_actions.py +35 -0
  105. strix/tools/terminal/terminal_actions_schema.xml +146 -0
  106. strix/tools/terminal/terminal_manager.py +151 -0
  107. strix/tools/terminal/terminal_session.py +447 -0
  108. strix/tools/thinking/__init__.py +4 -0
  109. strix/tools/thinking/thinking_actions.py +18 -0
  110. strix/tools/thinking/thinking_actions_schema.xml +52 -0
  111. strix/tools/web_search/__init__.py +4 -0
  112. strix/tools/web_search/web_search_actions.py +80 -0
  113. strix/tools/web_search/web_search_actions_schema.xml +83 -0
  114. strix_agent-0.4.0.dist-info/LICENSE +201 -0
  115. strix_agent-0.4.0.dist-info/METADATA +282 -0
  116. strix_agent-0.4.0.dist-info/RECORD +118 -0
  117. strix_agent-0.4.0.dist-info/WHEEL +4 -0
  118. strix_agent-0.4.0.dist-info/entry_points.txt +3 -0
strix/llm/llm.py ADDED
@@ -0,0 +1,465 @@
1
+ import logging
2
+ import os
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from fnmatch import fnmatch
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import litellm
10
+ from jinja2 import (
11
+ Environment,
12
+ FileSystemLoader,
13
+ select_autoescape,
14
+ )
15
+ from litellm import ModelResponse, completion_cost
16
+ from litellm.utils import supports_prompt_caching
17
+
18
+ from strix.llm.config import LLMConfig
19
+ from strix.llm.memory_compressor import MemoryCompressor
20
+ from strix.llm.request_queue import get_global_queue
21
+ from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
22
+ from strix.prompts import load_prompt_modules
23
+ from strix.tools import get_tools_prompt
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ api_key = os.getenv("LLM_API_KEY")
29
+ if api_key:
30
+ litellm.api_key = api_key
31
+
32
+ api_base = (
33
+ os.getenv("LLM_API_BASE")
34
+ or os.getenv("OPENAI_API_BASE")
35
+ or os.getenv("LITELLM_BASE_URL")
36
+ or os.getenv("OLLAMA_API_BASE")
37
+ )
38
+ if api_base:
39
+ litellm.api_base = api_base
40
+
41
+
42
+ class LLMRequestFailedError(Exception):
43
+ def __init__(self, message: str, details: str | None = None):
44
+ super().__init__(message)
45
+ self.message = message
46
+ self.details = details
47
+
48
+
49
+ SUPPORTS_STOP_WORDS_FALSE_PATTERNS: list[str] = [
50
+ "o1*",
51
+ "grok-4-0709",
52
+ "grok-code-fast-1",
53
+ "deepseek-r1-0528*",
54
+ ]
55
+
56
+ REASONING_EFFORT_PATTERNS: list[str] = [
57
+ "o1-2024-12-17",
58
+ "o1",
59
+ "o3",
60
+ "o3-2025-04-16",
61
+ "o3-mini-2025-01-31",
62
+ "o3-mini",
63
+ "o4-mini",
64
+ "o4-mini-2025-04-16",
65
+ "gemini-2.5-flash",
66
+ "gemini-2.5-pro",
67
+ "gpt-5*",
68
+ "deepseek-r1-0528*",
69
+ "claude-sonnet-4-5*",
70
+ "claude-haiku-4-5*",
71
+ ]
72
+
73
+
74
+ def normalize_model_name(model: str) -> str:
75
+ raw = (model or "").strip().lower()
76
+ if "/" in raw:
77
+ name = raw.split("/")[-1]
78
+ if ":" in name:
79
+ name = name.split(":", 1)[0]
80
+ else:
81
+ name = raw
82
+ if name.endswith("-gguf"):
83
+ name = name[: -len("-gguf")]
84
+ return name
85
+
86
+
87
+ def model_matches(model: str, patterns: list[str]) -> bool:
88
+ raw = (model or "").strip().lower()
89
+ name = normalize_model_name(model)
90
+ for pat in patterns:
91
+ pat_l = pat.lower()
92
+ if "/" in pat_l:
93
+ if fnmatch(raw, pat_l):
94
+ return True
95
+ elif fnmatch(name, pat_l):
96
+ return True
97
+ return False
98
+
99
+
100
+ class StepRole(str, Enum):
101
+ AGENT = "agent"
102
+ USER = "user"
103
+ SYSTEM = "system"
104
+
105
+
106
+ @dataclass
107
+ class LLMResponse:
108
+ content: str
109
+ tool_invocations: list[dict[str, Any]] | None = None
110
+ scan_id: str | None = None
111
+ step_number: int = 1
112
+ role: StepRole = StepRole.AGENT
113
+
114
+
115
+ @dataclass
116
+ class RequestStats:
117
+ input_tokens: int = 0
118
+ output_tokens: int = 0
119
+ cached_tokens: int = 0
120
+ cache_creation_tokens: int = 0
121
+ cost: float = 0.0
122
+ requests: int = 0
123
+ failed_requests: int = 0
124
+
125
+ def to_dict(self) -> dict[str, int | float]:
126
+ return {
127
+ "input_tokens": self.input_tokens,
128
+ "output_tokens": self.output_tokens,
129
+ "cached_tokens": self.cached_tokens,
130
+ "cache_creation_tokens": self.cache_creation_tokens,
131
+ "cost": round(self.cost, 4),
132
+ "requests": self.requests,
133
+ "failed_requests": self.failed_requests,
134
+ }
135
+
136
+
137
+ class LLM:
138
+ def __init__(
139
+ self, config: LLMConfig, agent_name: str | None = None, agent_id: str | None = None
140
+ ):
141
+ self.config = config
142
+ self.agent_name = agent_name
143
+ self.agent_id = agent_id
144
+ self._total_stats = RequestStats()
145
+ self._last_request_stats = RequestStats()
146
+
147
+ self.memory_compressor = MemoryCompressor(
148
+ model_name=self.config.model_name,
149
+ timeout=self.config.timeout,
150
+ )
151
+
152
+ if agent_name:
153
+ prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
154
+ prompts_dir = Path(__file__).parent.parent / "prompts"
155
+
156
+ loader = FileSystemLoader([prompt_dir, prompts_dir])
157
+ self.jinja_env = Environment(
158
+ loader=loader,
159
+ autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
160
+ )
161
+
162
+ try:
163
+ prompt_module_content = load_prompt_modules(
164
+ self.config.prompt_modules or [], self.jinja_env
165
+ )
166
+
167
+ def get_module(name: str) -> str:
168
+ return prompt_module_content.get(name, "")
169
+
170
+ self.jinja_env.globals["get_module"] = get_module
171
+
172
+ self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
173
+ get_tools_prompt=get_tools_prompt,
174
+ loaded_module_names=list(prompt_module_content.keys()),
175
+ **prompt_module_content,
176
+ )
177
+ except (FileNotFoundError, OSError, ValueError) as e:
178
+ logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
179
+ self.system_prompt = "You are a helpful AI assistant."
180
+ else:
181
+ self.system_prompt = "You are a helpful AI assistant."
182
+
183
+ def set_agent_identity(self, agent_name: str | None, agent_id: str | None) -> None:
184
+ if agent_name:
185
+ self.agent_name = agent_name
186
+ if agent_id:
187
+ self.agent_id = agent_id
188
+
189
+ def _build_identity_message(self) -> dict[str, Any] | None:
190
+ if not (self.agent_name and str(self.agent_name).strip()):
191
+ return None
192
+ identity_name = self.agent_name
193
+ identity_id = self.agent_id
194
+ content = (
195
+ "\n\n"
196
+ "<agent_identity>\n"
197
+ "<meta>Internal metadata: do not echo or reference; "
198
+ "not part of history or tool calls.</meta>\n"
199
+ "<note>You are now assuming the role of this agent. "
200
+ "Act strictly as this agent and maintain self-identity for this step. "
201
+ "Now go answer the next needed step!</note>\n"
202
+ f"<agent_name>{identity_name}</agent_name>\n"
203
+ f"<agent_id>{identity_id}</agent_id>\n"
204
+ "</agent_identity>\n\n"
205
+ )
206
+ return {"role": "user", "content": content}
207
+
208
+ def _add_cache_control_to_content(
209
+ self, content: str | list[dict[str, Any]]
210
+ ) -> str | list[dict[str, Any]]:
211
+ if isinstance(content, str):
212
+ return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
213
+ if isinstance(content, list) and content:
214
+ last_item = content[-1]
215
+ if isinstance(last_item, dict) and last_item.get("type") == "text":
216
+ return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
217
+ return content
218
+
219
+ def _is_anthropic_model(self) -> bool:
220
+ if not self.config.model_name:
221
+ return False
222
+ model_lower = self.config.model_name.lower()
223
+ return any(provider in model_lower for provider in ["anthropic/", "claude"])
224
+
225
+ def _calculate_cache_interval(self, total_messages: int) -> int:
226
+ if total_messages <= 1:
227
+ return 10
228
+
229
+ max_cached_messages = 3
230
+ non_system_messages = total_messages - 1
231
+
232
+ interval = 10
233
+ while non_system_messages // interval > max_cached_messages:
234
+ interval += 10
235
+
236
+ return interval
237
+
238
+ def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
239
+ if (
240
+ not self.config.enable_prompt_caching
241
+ or not supports_prompt_caching(self.config.model_name)
242
+ or not messages
243
+ ):
244
+ return messages
245
+
246
+ if not self._is_anthropic_model():
247
+ return messages
248
+
249
+ cached_messages = list(messages)
250
+
251
+ if cached_messages and cached_messages[0].get("role") == "system":
252
+ system_message = cached_messages[0].copy()
253
+ system_message["content"] = self._add_cache_control_to_content(
254
+ system_message["content"]
255
+ )
256
+ cached_messages[0] = system_message
257
+
258
+ total_messages = len(cached_messages)
259
+ if total_messages > 1:
260
+ interval = self._calculate_cache_interval(total_messages)
261
+
262
+ cached_count = 0
263
+ for i in range(interval, total_messages, interval):
264
+ if cached_count >= 3:
265
+ break
266
+
267
+ if i < len(cached_messages):
268
+ message = cached_messages[i].copy()
269
+ message["content"] = self._add_cache_control_to_content(message["content"])
270
+ cached_messages[i] = message
271
+ cached_count += 1
272
+
273
+ return cached_messages
274
+
275
+ async def generate( # noqa: PLR0912, PLR0915
276
+ self,
277
+ conversation_history: list[dict[str, Any]],
278
+ scan_id: str | None = None,
279
+ step_number: int = 1,
280
+ ) -> LLMResponse:
281
+ messages = [{"role": "system", "content": self.system_prompt}]
282
+
283
+ identity_message = self._build_identity_message()
284
+ if identity_message:
285
+ messages.append(identity_message)
286
+
287
+ compressed_history = list(self.memory_compressor.compress_history(conversation_history))
288
+
289
+ conversation_history.clear()
290
+ conversation_history.extend(compressed_history)
291
+ messages.extend(compressed_history)
292
+
293
+ cached_messages = self._prepare_cached_messages(messages)
294
+
295
+ try:
296
+ response = await self._make_request(cached_messages)
297
+ self._update_usage_stats(response)
298
+
299
+ content = ""
300
+ if (
301
+ response.choices
302
+ and hasattr(response.choices[0], "message")
303
+ and response.choices[0].message
304
+ ):
305
+ content = getattr(response.choices[0].message, "content", "") or ""
306
+
307
+ content = _truncate_to_first_function(content)
308
+
309
+ if "</function>" in content:
310
+ function_end_index = content.find("</function>") + len("</function>")
311
+ content = content[:function_end_index]
312
+
313
+ tool_invocations = parse_tool_invocations(content)
314
+
315
+ return LLMResponse(
316
+ scan_id=scan_id,
317
+ step_number=step_number,
318
+ role=StepRole.AGENT,
319
+ content=content,
320
+ tool_invocations=tool_invocations if tool_invocations else None,
321
+ )
322
+
323
+ except litellm.RateLimitError as e:
324
+ raise LLMRequestFailedError("LLM request failed: Rate limit exceeded", str(e)) from e
325
+ except litellm.AuthenticationError as e:
326
+ raise LLMRequestFailedError("LLM request failed: Invalid API key", str(e)) from e
327
+ except litellm.NotFoundError as e:
328
+ raise LLMRequestFailedError("LLM request failed: Model not found", str(e)) from e
329
+ except litellm.ContextWindowExceededError as e:
330
+ raise LLMRequestFailedError("LLM request failed: Context too long", str(e)) from e
331
+ except litellm.ContentPolicyViolationError as e:
332
+ raise LLMRequestFailedError(
333
+ "LLM request failed: Content policy violation", str(e)
334
+ ) from e
335
+ except litellm.ServiceUnavailableError as e:
336
+ raise LLMRequestFailedError("LLM request failed: Service unavailable", str(e)) from e
337
+ except litellm.Timeout as e:
338
+ raise LLMRequestFailedError("LLM request failed: Request timed out", str(e)) from e
339
+ except litellm.UnprocessableEntityError as e:
340
+ raise LLMRequestFailedError("LLM request failed: Unprocessable entity", str(e)) from e
341
+ except litellm.InternalServerError as e:
342
+ raise LLMRequestFailedError("LLM request failed: Internal server error", str(e)) from e
343
+ except litellm.APIConnectionError as e:
344
+ raise LLMRequestFailedError("LLM request failed: Connection error", str(e)) from e
345
+ except litellm.UnsupportedParamsError as e:
346
+ raise LLMRequestFailedError("LLM request failed: Unsupported parameters", str(e)) from e
347
+ except litellm.BudgetExceededError as e:
348
+ raise LLMRequestFailedError("LLM request failed: Budget exceeded", str(e)) from e
349
+ except litellm.APIResponseValidationError as e:
350
+ raise LLMRequestFailedError(
351
+ "LLM request failed: Response validation error", str(e)
352
+ ) from e
353
+ except litellm.JSONSchemaValidationError as e:
354
+ raise LLMRequestFailedError(
355
+ "LLM request failed: JSON schema validation error", str(e)
356
+ ) from e
357
+ except litellm.InvalidRequestError as e:
358
+ raise LLMRequestFailedError("LLM request failed: Invalid request", str(e)) from e
359
+ except litellm.BadRequestError as e:
360
+ raise LLMRequestFailedError("LLM request failed: Bad request", str(e)) from e
361
+ except litellm.APIError as e:
362
+ raise LLMRequestFailedError("LLM request failed: API error", str(e)) from e
363
+ except litellm.OpenAIError as e:
364
+ raise LLMRequestFailedError("LLM request failed: OpenAI error", str(e)) from e
365
+ except Exception as e:
366
+ raise LLMRequestFailedError(f"LLM request failed: {type(e).__name__}", str(e)) from e
367
+
368
+ @property
369
+ def usage_stats(self) -> dict[str, dict[str, int | float]]:
370
+ return {
371
+ "total": self._total_stats.to_dict(),
372
+ "last_request": self._last_request_stats.to_dict(),
373
+ }
374
+
375
+ def get_cache_config(self) -> dict[str, bool]:
376
+ return {
377
+ "enabled": self.config.enable_prompt_caching,
378
+ "supported": supports_prompt_caching(self.config.model_name),
379
+ }
380
+
381
+ def _should_include_stop_param(self) -> bool:
382
+ if not self.config.model_name:
383
+ return True
384
+
385
+ return not model_matches(self.config.model_name, SUPPORTS_STOP_WORDS_FALSE_PATTERNS)
386
+
387
+ def _should_include_reasoning_effort(self) -> bool:
388
+ if not self.config.model_name:
389
+ return False
390
+
391
+ return model_matches(self.config.model_name, REASONING_EFFORT_PATTERNS)
392
+
393
+ async def _make_request(
394
+ self,
395
+ messages: list[dict[str, Any]],
396
+ ) -> ModelResponse:
397
+ completion_args: dict[str, Any] = {
398
+ "model": self.config.model_name,
399
+ "messages": messages,
400
+ "timeout": self.config.timeout,
401
+ }
402
+
403
+ if self._should_include_stop_param():
404
+ completion_args["stop"] = ["</function>"]
405
+
406
+ if self._should_include_reasoning_effort():
407
+ completion_args["reasoning_effort"] = "high"
408
+
409
+ queue = get_global_queue()
410
+ response = await queue.make_request(completion_args)
411
+
412
+ self._total_stats.requests += 1
413
+ self._last_request_stats = RequestStats(requests=1)
414
+
415
+ return response
416
+
417
+ def _update_usage_stats(self, response: ModelResponse) -> None:
418
+ try:
419
+ if hasattr(response, "usage") and response.usage:
420
+ input_tokens = getattr(response.usage, "prompt_tokens", 0)
421
+ output_tokens = getattr(response.usage, "completion_tokens", 0)
422
+
423
+ cached_tokens = 0
424
+ cache_creation_tokens = 0
425
+
426
+ if hasattr(response.usage, "prompt_tokens_details"):
427
+ prompt_details = response.usage.prompt_tokens_details
428
+ if hasattr(prompt_details, "cached_tokens"):
429
+ cached_tokens = prompt_details.cached_tokens or 0
430
+
431
+ if hasattr(response.usage, "cache_creation_input_tokens"):
432
+ cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
433
+
434
+ else:
435
+ input_tokens = 0
436
+ output_tokens = 0
437
+ cached_tokens = 0
438
+ cache_creation_tokens = 0
439
+
440
+ try:
441
+ cost = completion_cost(response) or 0.0
442
+ except Exception as e: # noqa: BLE001
443
+ logger.warning(f"Failed to calculate cost: {e}")
444
+ cost = 0.0
445
+
446
+ self._total_stats.input_tokens += input_tokens
447
+ self._total_stats.output_tokens += output_tokens
448
+ self._total_stats.cached_tokens += cached_tokens
449
+ self._total_stats.cache_creation_tokens += cache_creation_tokens
450
+ self._total_stats.cost += cost
451
+
452
+ self._last_request_stats.input_tokens = input_tokens
453
+ self._last_request_stats.output_tokens = output_tokens
454
+ self._last_request_stats.cached_tokens = cached_tokens
455
+ self._last_request_stats.cache_creation_tokens = cache_creation_tokens
456
+ self._last_request_stats.cost = cost
457
+
458
+ if cached_tokens > 0:
459
+ logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
460
+ if cache_creation_tokens > 0:
461
+ logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
462
+
463
+ logger.info(f"Usage stats: {self.usage_stats}")
464
+ except Exception as e: # noqa: BLE001
465
+ logger.warning(f"Failed to update usage stats: {e}")
@@ -0,0 +1,212 @@
1
+ import logging
2
+ import os
3
+ from typing import Any
4
+
5
+ import litellm
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ MAX_TOTAL_TOKENS = 100_000
12
+ MIN_RECENT_MESSAGES = 15
13
+
14
+ SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context
15
+ condensation for a security agent. Your job is to compress scan data while preserving
16
+ ALL operationally critical information for continuing the security assessment.
17
+
18
+ CRITICAL ELEMENTS TO PRESERVE:
19
+ - Discovered vulnerabilities and potential attack vectors
20
+ - Scan results and tool outputs (compressed but maintaining key findings)
21
+ - Access credentials, tokens, or authentication details found
22
+ - System architecture insights and potential weak points
23
+ - Progress made in the assessment
24
+ - Failed attempts and dead ends (to avoid duplication)
25
+ - Any decisions made about the testing approach
26
+
27
+ COMPRESSION GUIDELINES:
28
+ - Preserve exact technical details (URLs, paths, parameters, payloads)
29
+ - Summarize verbose tool outputs while keeping critical findings
30
+ - Maintain version numbers, specific technologies identified
31
+ - Keep exact error messages that might indicate vulnerabilities
32
+ - Compress repetitive or similar findings into consolidated form
33
+
34
+ Remember: Another security agent will use this summary to continue the assessment.
35
+ They must be able to pick up exactly where you left off without losing any
36
+ operational advantage or context needed to find vulnerabilities.
37
+
38
+ CONVERSATION SEGMENT TO SUMMARIZE:
39
+ {conversation}
40
+
41
+ Provide a technically precise summary that preserves all operational security context while
42
+ keeping the summary concise and to the point."""
43
+
44
+
45
+ def _count_tokens(text: str, model: str) -> int:
46
+ try:
47
+ count = litellm.token_counter(model=model, text=text)
48
+ return int(count)
49
+ except Exception:
50
+ logger.exception("Failed to count tokens")
51
+ return len(text) // 4 # Rough estimate
52
+
53
+
54
+ def _get_message_tokens(msg: dict[str, Any], model: str) -> int:
55
+ content = msg.get("content", "")
56
+ if isinstance(content, str):
57
+ return _count_tokens(content, model)
58
+ if isinstance(content, list):
59
+ return sum(
60
+ _count_tokens(item.get("text", ""), model)
61
+ for item in content
62
+ if isinstance(item, dict) and item.get("type") == "text"
63
+ )
64
+ return 0
65
+
66
+
67
+ def _extract_message_text(msg: dict[str, Any]) -> str:
68
+ content = msg.get("content", "")
69
+ if isinstance(content, str):
70
+ return content
71
+
72
+ if isinstance(content, list):
73
+ parts = []
74
+ for item in content:
75
+ if isinstance(item, dict):
76
+ if item.get("type") == "text":
77
+ parts.append(item.get("text", ""))
78
+ elif item.get("type") == "image_url":
79
+ parts.append("[IMAGE]")
80
+ return " ".join(parts)
81
+
82
+ return str(content)
83
+
84
+
85
+ def _summarize_messages(
86
+ messages: list[dict[str, Any]],
87
+ model: str,
88
+ timeout: int = 600,
89
+ ) -> dict[str, Any]:
90
+ if not messages:
91
+ empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
92
+ return {
93
+ "role": "assistant",
94
+ "content": empty_summary.format(text="No messages to summarize"),
95
+ }
96
+
97
+ formatted = []
98
+ for msg in messages:
99
+ role = msg.get("role", "unknown")
100
+ text = _extract_message_text(msg)
101
+ formatted.append(f"{role}: {text}")
102
+
103
+ conversation = "\n".join(formatted)
104
+ prompt = SUMMARY_PROMPT_TEMPLATE.format(conversation=conversation)
105
+
106
+ try:
107
+ completion_args = {
108
+ "model": model,
109
+ "messages": [{"role": "user", "content": prompt}],
110
+ "timeout": timeout,
111
+ }
112
+
113
+ response = litellm.completion(**completion_args)
114
+ summary = response.choices[0].message.content or ""
115
+ if not summary.strip():
116
+ return messages[0]
117
+ summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
118
+ return {
119
+ "role": "assistant",
120
+ "content": summary_msg.format(count=len(messages), text=summary),
121
+ }
122
+ except Exception:
123
+ logger.exception("Failed to summarize messages")
124
+ return messages[0]
125
+
126
+
127
+ def _handle_images(messages: list[dict[str, Any]], max_images: int) -> None:
128
+ image_count = 0
129
+ for msg in reversed(messages):
130
+ content = msg.get("content", [])
131
+ if isinstance(content, list):
132
+ for item in content:
133
+ if isinstance(item, dict) and item.get("type") == "image_url":
134
+ if image_count >= max_images:
135
+ item.update(
136
+ {
137
+ "type": "text",
138
+ "text": "[Previously attached image removed to preserve context]",
139
+ }
140
+ )
141
+ else:
142
+ image_count += 1
143
+
144
+
145
+ class MemoryCompressor:
146
+ def __init__(
147
+ self,
148
+ max_images: int = 3,
149
+ model_name: str | None = None,
150
+ timeout: int = 600,
151
+ ):
152
+ self.max_images = max_images
153
+ self.model_name = model_name or os.getenv("STRIX_LLM", "openai/gpt-5")
154
+ self.timeout = timeout
155
+
156
+ if not self.model_name:
157
+ raise ValueError("STRIX_LLM environment variable must be set and not empty")
158
+
159
+ def compress_history(
160
+ self,
161
+ messages: list[dict[str, Any]],
162
+ ) -> list[dict[str, Any]]:
163
+ """Compress conversation history to stay within token limits.
164
+
165
+ Strategy:
166
+ 1. Handle image limits first
167
+ 2. Keep all system messages
168
+ 3. Keep minimum recent messages
169
+ 4. Summarize older messages when total tokens exceed limit
170
+
171
+ The compression preserves:
172
+ - All system messages unchanged
173
+ - Most recent messages intact
174
+ - Critical security context in summaries
175
+ - Recent images for visual context
176
+ - Technical details and findings
177
+ """
178
+ if not messages:
179
+ return messages
180
+
181
+ _handle_images(messages, self.max_images)
182
+
183
+ system_msgs = []
184
+ regular_msgs = []
185
+ for msg in messages:
186
+ if msg.get("role") == "system":
187
+ system_msgs.append(msg)
188
+ else:
189
+ regular_msgs.append(msg)
190
+
191
+ recent_msgs = regular_msgs[-MIN_RECENT_MESSAGES:]
192
+ old_msgs = regular_msgs[:-MIN_RECENT_MESSAGES]
193
+
194
+ # Type assertion since we ensure model_name is not None in __init__
195
+ model_name: str = self.model_name # type: ignore[assignment]
196
+
197
+ total_tokens = sum(
198
+ _get_message_tokens(msg, model_name) for msg in system_msgs + regular_msgs
199
+ )
200
+
201
+ if total_tokens <= MAX_TOTAL_TOKENS * 0.9:
202
+ return messages
203
+
204
+ compressed = []
205
+ chunk_size = 10
206
+ for i in range(0, len(old_msgs), chunk_size):
207
+ chunk = old_msgs[i : i + chunk_size]
208
+ summary = _summarize_messages(chunk, model_name, self.timeout)
209
+ if summary:
210
+ compressed.append(summary)
211
+
212
+ return system_msgs + compressed + recent_msgs