strix-agent 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. strix/__init__.py +0 -0
  2. strix/agents/StrixAgent/__init__.py +4 -0
  3. strix/agents/StrixAgent/strix_agent.py +60 -0
  4. strix/agents/StrixAgent/system_prompt.jinja +504 -0
  5. strix/agents/__init__.py +10 -0
  6. strix/agents/base_agent.py +394 -0
  7. strix/agents/state.py +139 -0
  8. strix/cli/__init__.py +4 -0
  9. strix/cli/app.py +1124 -0
  10. strix/cli/assets/cli.tcss +680 -0
  11. strix/cli/main.py +542 -0
  12. strix/cli/tool_components/__init__.py +39 -0
  13. strix/cli/tool_components/agents_graph_renderer.py +129 -0
  14. strix/cli/tool_components/base_renderer.py +61 -0
  15. strix/cli/tool_components/browser_renderer.py +107 -0
  16. strix/cli/tool_components/file_edit_renderer.py +95 -0
  17. strix/cli/tool_components/finish_renderer.py +32 -0
  18. strix/cli/tool_components/notes_renderer.py +108 -0
  19. strix/cli/tool_components/proxy_renderer.py +255 -0
  20. strix/cli/tool_components/python_renderer.py +34 -0
  21. strix/cli/tool_components/registry.py +72 -0
  22. strix/cli/tool_components/reporting_renderer.py +53 -0
  23. strix/cli/tool_components/scan_info_renderer.py +58 -0
  24. strix/cli/tool_components/terminal_renderer.py +99 -0
  25. strix/cli/tool_components/thinking_renderer.py +29 -0
  26. strix/cli/tool_components/user_message_renderer.py +43 -0
  27. strix/cli/tool_components/web_search_renderer.py +28 -0
  28. strix/cli/tracer.py +308 -0
  29. strix/llm/__init__.py +14 -0
  30. strix/llm/config.py +19 -0
  31. strix/llm/llm.py +310 -0
  32. strix/llm/memory_compressor.py +206 -0
  33. strix/llm/request_queue.py +63 -0
  34. strix/llm/utils.py +84 -0
  35. strix/prompts/__init__.py +113 -0
  36. strix/prompts/coordination/root_agent.jinja +41 -0
  37. strix/prompts/vulnerabilities/authentication_jwt.jinja +129 -0
  38. strix/prompts/vulnerabilities/business_logic.jinja +143 -0
  39. strix/prompts/vulnerabilities/csrf.jinja +168 -0
  40. strix/prompts/vulnerabilities/idor.jinja +164 -0
  41. strix/prompts/vulnerabilities/race_conditions.jinja +194 -0
  42. strix/prompts/vulnerabilities/rce.jinja +222 -0
  43. strix/prompts/vulnerabilities/sql_injection.jinja +216 -0
  44. strix/prompts/vulnerabilities/ssrf.jinja +168 -0
  45. strix/prompts/vulnerabilities/xss.jinja +221 -0
  46. strix/prompts/vulnerabilities/xxe.jinja +276 -0
  47. strix/runtime/__init__.py +19 -0
  48. strix/runtime/docker_runtime.py +298 -0
  49. strix/runtime/runtime.py +25 -0
  50. strix/runtime/tool_server.py +97 -0
  51. strix/tools/__init__.py +64 -0
  52. strix/tools/agents_graph/__init__.py +16 -0
  53. strix/tools/agents_graph/agents_graph_actions.py +610 -0
  54. strix/tools/agents_graph/agents_graph_actions_schema.xml +223 -0
  55. strix/tools/argument_parser.py +120 -0
  56. strix/tools/browser/__init__.py +4 -0
  57. strix/tools/browser/browser_actions.py +236 -0
  58. strix/tools/browser/browser_actions_schema.xml +183 -0
  59. strix/tools/browser/browser_instance.py +533 -0
  60. strix/tools/browser/tab_manager.py +342 -0
  61. strix/tools/executor.py +302 -0
  62. strix/tools/file_edit/__init__.py +4 -0
  63. strix/tools/file_edit/file_edit_actions.py +141 -0
  64. strix/tools/file_edit/file_edit_actions_schema.xml +128 -0
  65. strix/tools/finish/__init__.py +4 -0
  66. strix/tools/finish/finish_actions.py +167 -0
  67. strix/tools/finish/finish_actions_schema.xml +45 -0
  68. strix/tools/notes/__init__.py +14 -0
  69. strix/tools/notes/notes_actions.py +191 -0
  70. strix/tools/notes/notes_actions_schema.xml +150 -0
  71. strix/tools/proxy/__init__.py +20 -0
  72. strix/tools/proxy/proxy_actions.py +101 -0
  73. strix/tools/proxy/proxy_actions_schema.xml +267 -0
  74. strix/tools/proxy/proxy_manager.py +785 -0
  75. strix/tools/python/__init__.py +4 -0
  76. strix/tools/python/python_actions.py +47 -0
  77. strix/tools/python/python_actions_schema.xml +131 -0
  78. strix/tools/python/python_instance.py +172 -0
  79. strix/tools/python/python_manager.py +131 -0
  80. strix/tools/registry.py +196 -0
  81. strix/tools/reporting/__init__.py +6 -0
  82. strix/tools/reporting/reporting_actions.py +63 -0
  83. strix/tools/reporting/reporting_actions_schema.xml +30 -0
  84. strix/tools/terminal/__init__.py +4 -0
  85. strix/tools/terminal/terminal_actions.py +53 -0
  86. strix/tools/terminal/terminal_actions_schema.xml +114 -0
  87. strix/tools/terminal/terminal_instance.py +231 -0
  88. strix/tools/terminal/terminal_manager.py +191 -0
  89. strix/tools/thinking/__init__.py +4 -0
  90. strix/tools/thinking/thinking_actions.py +18 -0
  91. strix/tools/thinking/thinking_actions_schema.xml +52 -0
  92. strix/tools/web_search/__init__.py +4 -0
  93. strix/tools/web_search/web_search_actions.py +80 -0
  94. strix/tools/web_search/web_search_actions_schema.xml +83 -0
  95. strix_agent-0.1.1.dist-info/LICENSE +201 -0
  96. strix_agent-0.1.1.dist-info/METADATA +200 -0
  97. strix_agent-0.1.1.dist-info/RECORD +99 -0
  98. strix_agent-0.1.1.dist-info/WHEEL +4 -0
  99. strix_agent-0.1.1.dist-info/entry_points.txt +3 -0
strix/cli/tracer.py ADDED
@@ -0,0 +1,308 @@
1
+ import logging
2
+ from datetime import UTC, datetime
3
+ from pathlib import Path
4
+ from typing import Any, Optional
5
+ from uuid import uuid4
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ _global_tracer: Optional["Tracer"] = None
11
+
12
+
13
+ def get_global_tracer() -> Optional["Tracer"]:
14
+ return _global_tracer
15
+
16
+
17
+ def set_global_tracer(tracer: "Tracer") -> None:
18
+ global _global_tracer # noqa: PLW0603
19
+ _global_tracer = tracer
20
+
21
+
22
+ class Tracer:
23
+ def __init__(self, run_name: str | None = None):
24
+ self.run_name = run_name
25
+ self.run_id = run_name or f"run-{uuid4().hex[:8]}"
26
+ self.start_time = datetime.now(UTC).isoformat()
27
+ self.end_time: str | None = None
28
+
29
+ self.agents: dict[str, dict[str, Any]] = {}
30
+ self.tool_executions: dict[int, dict[str, Any]] = {}
31
+ self.chat_messages: list[dict[str, Any]] = []
32
+
33
+ self.vulnerability_reports: list[dict[str, Any]] = []
34
+ self.final_scan_result: str | None = None
35
+
36
+ self.scan_results: dict[str, Any] | None = None
37
+ self.scan_config: dict[str, Any] | None = None
38
+ self.run_metadata: dict[str, Any] = {
39
+ "run_id": self.run_id,
40
+ "run_name": self.run_name,
41
+ "start_time": self.start_time,
42
+ "end_time": None,
43
+ "target": None,
44
+ "scan_type": None,
45
+ "status": "running",
46
+ }
47
+ self._run_dir: Path | None = None
48
+ self._next_execution_id = 1
49
+ self._next_message_id = 1
50
+
51
+ def set_run_name(self, run_name: str) -> None:
52
+ self.run_name = run_name
53
+ self.run_id = run_name
54
+
55
+ def get_run_dir(self) -> Path:
56
+ if self._run_dir is None:
57
+ workspace_root = Path(__file__).parent.parent.parent
58
+ runs_dir = workspace_root / "agent_runs"
59
+ runs_dir.mkdir(exist_ok=True)
60
+
61
+ run_dir_name = self.run_name if self.run_name else self.run_id
62
+ self._run_dir = runs_dir / run_dir_name
63
+ self._run_dir.mkdir(exist_ok=True)
64
+
65
+ return self._run_dir
66
+
67
+ def add_vulnerability_report(
68
+ self,
69
+ title: str,
70
+ content: str,
71
+ severity: str,
72
+ ) -> str:
73
+ report_id = f"vuln-{len(self.vulnerability_reports) + 1:04d}"
74
+
75
+ report = {
76
+ "id": report_id,
77
+ "title": title.strip(),
78
+ "content": content.strip(),
79
+ "severity": severity.lower().strip(),
80
+ "timestamp": datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S UTC"),
81
+ }
82
+
83
+ self.vulnerability_reports.append(report)
84
+ logger.info(f"Added vulnerability report: {report_id} - {title}")
85
+ return report_id
86
+
87
+ def set_final_scan_result(
88
+ self,
89
+ content: str,
90
+ success: bool = True,
91
+ ) -> None:
92
+ self.final_scan_result = content.strip()
93
+
94
+ self.scan_results = {
95
+ "scan_completed": True,
96
+ "content": content,
97
+ "success": success,
98
+ }
99
+
100
+ logger.info(f"Set final scan result: success={success}")
101
+
102
+ def log_agent_creation(
103
+ self, agent_id: str, name: str, task: str, parent_id: str | None = None
104
+ ) -> None:
105
+ agent_data: dict[str, Any] = {
106
+ "id": agent_id,
107
+ "name": name,
108
+ "task": task,
109
+ "status": "running",
110
+ "parent_id": parent_id,
111
+ "created_at": datetime.now(UTC).isoformat(),
112
+ "updated_at": datetime.now(UTC).isoformat(),
113
+ "tool_executions": [],
114
+ }
115
+
116
+ self.agents[agent_id] = agent_data
117
+
118
+ def log_chat_message(
119
+ self,
120
+ content: str,
121
+ role: str,
122
+ agent_id: str | None = None,
123
+ metadata: dict[str, Any] | None = None,
124
+ ) -> int:
125
+ message_id = self._next_message_id
126
+ self._next_message_id += 1
127
+
128
+ message_data = {
129
+ "message_id": message_id,
130
+ "content": content,
131
+ "role": role,
132
+ "agent_id": agent_id,
133
+ "timestamp": datetime.now(UTC).isoformat(),
134
+ "metadata": metadata or {},
135
+ }
136
+
137
+ self.chat_messages.append(message_data)
138
+ return message_id
139
+
140
+ def log_tool_execution_start(self, agent_id: str, tool_name: str, args: dict[str, Any]) -> int:
141
+ execution_id = self._next_execution_id
142
+ self._next_execution_id += 1
143
+
144
+ now = datetime.now(UTC).isoformat()
145
+ execution_data = {
146
+ "execution_id": execution_id,
147
+ "agent_id": agent_id,
148
+ "tool_name": tool_name,
149
+ "args": args,
150
+ "status": "running",
151
+ "result": None,
152
+ "timestamp": now,
153
+ "started_at": now,
154
+ "completed_at": None,
155
+ }
156
+
157
+ self.tool_executions[execution_id] = execution_data
158
+
159
+ if agent_id in self.agents:
160
+ self.agents[agent_id]["tool_executions"].append(execution_id)
161
+
162
+ return execution_id
163
+
164
+ def update_tool_execution(
165
+ self, execution_id: int, status: str, result: Any | None = None
166
+ ) -> None:
167
+ if execution_id in self.tool_executions:
168
+ self.tool_executions[execution_id]["status"] = status
169
+ self.tool_executions[execution_id]["result"] = result
170
+ self.tool_executions[execution_id]["completed_at"] = datetime.now(UTC).isoformat()
171
+
172
+ def update_agent_status(self, agent_id: str, status: str) -> None:
173
+ if agent_id in self.agents:
174
+ self.agents[agent_id]["status"] = status
175
+ self.agents[agent_id]["updated_at"] = datetime.now(UTC).isoformat()
176
+
177
+ def set_scan_config(self, config: dict[str, Any]) -> None:
178
+ self.scan_config = config
179
+ self.run_metadata.update(
180
+ {
181
+ "target": config.get("target", {}),
182
+ "scan_type": config.get("scan_type", "general"),
183
+ "user_instructions": config.get("user_instructions", ""),
184
+ "max_iterations": config.get("max_iterations", 200),
185
+ }
186
+ )
187
+
188
+ def save_run_data(self) -> None:
189
+ try:
190
+ run_dir = self.get_run_dir()
191
+ self.end_time = datetime.now(UTC).isoformat()
192
+
193
+ if self.final_scan_result:
194
+ scan_report_file = run_dir / "scan_report.md"
195
+ with scan_report_file.open("w", encoding="utf-8") as f:
196
+ f.write("# Security Scan Report\n\n")
197
+ f.write(
198
+ f"**Generated:** {datetime.now(UTC).strftime('%Y-%m-%d %H:%M:%S UTC')}\n\n"
199
+ )
200
+ f.write(f"{self.final_scan_result}\n")
201
+ logger.info(f"Saved final scan report to: {scan_report_file}")
202
+
203
+ if self.vulnerability_reports:
204
+ vuln_dir = run_dir / "vulnerabilities"
205
+ vuln_dir.mkdir(exist_ok=True)
206
+
207
+ severity_order = {"critical": 0, "high": 1, "medium": 2, "low": 3, "info": 4}
208
+ sorted_reports = sorted(
209
+ self.vulnerability_reports,
210
+ key=lambda x: (severity_order.get(x["severity"], 5), x["timestamp"]),
211
+ )
212
+
213
+ for report in sorted_reports:
214
+ vuln_file = vuln_dir / f"{report['id']}.md"
215
+ with vuln_file.open("w", encoding="utf-8") as f:
216
+ f.write(f"# {report['title']}\n\n")
217
+ f.write(f"**ID:** {report['id']}\n")
218
+ f.write(f"**Severity:** {report['severity'].upper()}\n")
219
+ f.write(f"**Found:** {report['timestamp']}\n\n")
220
+ f.write("## Description\n\n")
221
+ f.write(f"{report['content']}\n")
222
+
223
+ vuln_csv_file = run_dir / "vulnerabilities.csv"
224
+ with vuln_csv_file.open("w", encoding="utf-8", newline="") as f:
225
+ import csv
226
+
227
+ fieldnames = ["id", "title", "severity", "timestamp", "file"]
228
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
229
+ writer.writeheader()
230
+
231
+ for report in sorted_reports:
232
+ writer.writerow(
233
+ {
234
+ "id": report["id"],
235
+ "title": report["title"],
236
+ "severity": report["severity"].upper(),
237
+ "timestamp": report["timestamp"],
238
+ "file": f"vulnerabilities/{report['id']}.md",
239
+ }
240
+ )
241
+
242
+ logger.info(
243
+ f"Saved {len(self.vulnerability_reports)} vulnerability reports to: {vuln_dir}"
244
+ )
245
+ logger.info(f"Saved vulnerability index to: {vuln_csv_file}")
246
+
247
+ logger.info(f"📊 Essential scan data saved to: {run_dir}")
248
+
249
+ except (OSError, RuntimeError):
250
+ logger.exception("Failed to save scan data")
251
+
252
+ def _calculate_duration(self) -> float:
253
+ try:
254
+ start = datetime.fromisoformat(self.start_time.replace("Z", "+00:00"))
255
+ if self.end_time:
256
+ end = datetime.fromisoformat(self.end_time.replace("Z", "+00:00"))
257
+ return (end - start).total_seconds()
258
+ except (ValueError, TypeError):
259
+ pass
260
+ return 0.0
261
+
262
+ def get_agent_tools(self, agent_id: str) -> list[dict[str, Any]]:
263
+ return [
264
+ exec_data
265
+ for exec_data in self.tool_executions.values()
266
+ if exec_data.get("agent_id") == agent_id
267
+ ]
268
+
269
+ def get_real_tool_count(self) -> int:
270
+ return sum(
271
+ 1
272
+ for exec_data in self.tool_executions.values()
273
+ if exec_data.get("tool_name") not in ["scan_start_info", "subagent_start_info"]
274
+ )
275
+
276
+ def get_total_llm_stats(self) -> dict[str, Any]:
277
+ from strix.tools.agents_graph.agents_graph_actions import _agent_instances
278
+
279
+ total_stats = {
280
+ "input_tokens": 0,
281
+ "output_tokens": 0,
282
+ "cached_tokens": 0,
283
+ "cache_creation_tokens": 0,
284
+ "cost": 0.0,
285
+ "requests": 0,
286
+ "failed_requests": 0,
287
+ }
288
+
289
+ for agent_instance in _agent_instances.values():
290
+ if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"):
291
+ agent_stats = agent_instance.llm._total_stats
292
+ total_stats["input_tokens"] += agent_stats.input_tokens
293
+ total_stats["output_tokens"] += agent_stats.output_tokens
294
+ total_stats["cached_tokens"] += agent_stats.cached_tokens
295
+ total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
296
+ total_stats["cost"] += agent_stats.cost
297
+ total_stats["requests"] += agent_stats.requests
298
+ total_stats["failed_requests"] += agent_stats.failed_requests
299
+
300
+ total_stats["cost"] = round(total_stats["cost"], 4)
301
+
302
+ return {
303
+ "total": total_stats,
304
+ "total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"],
305
+ }
306
+
307
+ def cleanup(self) -> None:
308
+ self.save_run_data()
strix/llm/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ import litellm
2
+
3
+ from .config import LLMConfig
4
+ from .llm import LLM
5
+
6
+
7
+ __all__ = [
8
+ "LLM",
9
+ "LLMConfig",
10
+ ]
11
+
12
+ litellm.drop_params = True
13
+ litellm.suppress_debug_info = True
14
+ litellm.set_verbose = False
strix/llm/config.py ADDED
@@ -0,0 +1,19 @@
1
+ import os
2
+
3
+
4
+ class LLMConfig:
5
+ def __init__(
6
+ self,
7
+ model_name: str | None = None,
8
+ temperature: float = 0,
9
+ enable_prompt_caching: bool = True,
10
+ prompt_modules: list[str] | None = None,
11
+ ):
12
+ self.model_name = model_name or os.getenv("STRIX_LLM", "anthropic/claude-sonnet-4-20250514")
13
+
14
+ if not self.model_name:
15
+ raise ValueError("STRIX_LLM environment variable must be set and not empty")
16
+
17
+ self.temperature = max(0.0, min(1.0, temperature))
18
+ self.enable_prompt_caching = enable_prompt_caching
19
+ self.prompt_modules = prompt_modules or []
strix/llm/llm.py ADDED
@@ -0,0 +1,310 @@
1
+ import logging
2
+ import os
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import litellm
9
+ from jinja2 import (
10
+ Environment,
11
+ FileSystemLoader,
12
+ select_autoescape,
13
+ )
14
+ from litellm import ModelResponse, completion_cost
15
+ from litellm.utils import supports_prompt_caching
16
+
17
+ from strix.llm.config import LLMConfig
18
+ from strix.llm.memory_compressor import MemoryCompressor
19
+ from strix.llm.request_queue import get_global_queue
20
+ from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
21
+ from strix.prompts import load_prompt_modules
22
+ from strix.tools import get_tools_prompt
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ api_key = os.getenv("LLM_API_KEY")
28
+ if api_key:
29
+ litellm.api_key = api_key
30
+
31
+
32
+ class StepRole(str, Enum):
33
+ AGENT = "agent"
34
+ USER = "user"
35
+ SYSTEM = "system"
36
+
37
+
38
+ @dataclass
39
+ class LLMResponse:
40
+ content: str
41
+ tool_invocations: list[dict[str, Any]] | None = None
42
+ scan_id: str | None = None
43
+ step_number: int = 1
44
+ role: StepRole = StepRole.AGENT
45
+
46
+
47
+ @dataclass
48
+ class RequestStats:
49
+ input_tokens: int = 0
50
+ output_tokens: int = 0
51
+ cached_tokens: int = 0
52
+ cache_creation_tokens: int = 0
53
+ cost: float = 0.0
54
+ requests: int = 0
55
+ failed_requests: int = 0
56
+
57
+ def to_dict(self) -> dict[str, int | float]:
58
+ return {
59
+ "input_tokens": self.input_tokens,
60
+ "output_tokens": self.output_tokens,
61
+ "cached_tokens": self.cached_tokens,
62
+ "cache_creation_tokens": self.cache_creation_tokens,
63
+ "cost": round(self.cost, 4),
64
+ "requests": self.requests,
65
+ "failed_requests": self.failed_requests,
66
+ }
67
+
68
+
69
+ class LLM:
70
+ def __init__(self, config: LLMConfig, agent_name: str | None = None):
71
+ self.config = config
72
+ self.agent_name = agent_name
73
+ self._total_stats = RequestStats()
74
+ self._last_request_stats = RequestStats()
75
+
76
+ self.memory_compressor = MemoryCompressor()
77
+
78
+ if agent_name:
79
+ prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
80
+ prompts_dir = Path(__file__).parent.parent / "prompts"
81
+
82
+ loader = FileSystemLoader([prompt_dir, prompts_dir])
83
+ self.jinja_env = Environment(
84
+ loader=loader,
85
+ autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
86
+ )
87
+
88
+ try:
89
+ prompt_module_content = load_prompt_modules(
90
+ self.config.prompt_modules or [], self.jinja_env
91
+ )
92
+
93
+ def get_module(name: str) -> str:
94
+ return prompt_module_content.get(name, "")
95
+
96
+ self.jinja_env.globals["get_module"] = get_module
97
+
98
+ self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
99
+ get_tools_prompt=get_tools_prompt,
100
+ loaded_module_names=list(prompt_module_content.keys()),
101
+ **prompt_module_content,
102
+ )
103
+ except (FileNotFoundError, OSError, ValueError) as e:
104
+ logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
105
+ self.system_prompt = "You are a helpful AI assistant."
106
+ else:
107
+ self.system_prompt = "You are a helpful AI assistant."
108
+
109
+ def _add_cache_control_to_content(
110
+ self, content: str | list[dict[str, Any]]
111
+ ) -> str | list[dict[str, Any]]:
112
+ if isinstance(content, str):
113
+ return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
114
+ if isinstance(content, list) and content:
115
+ last_item = content[-1]
116
+ if isinstance(last_item, dict) and last_item.get("type") == "text":
117
+ return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
118
+ return content
119
+
120
+ def _is_anthropic_model(self) -> bool:
121
+ if not self.config.model_name:
122
+ return False
123
+ model_lower = self.config.model_name.lower()
124
+ return any(provider in model_lower for provider in ["anthropic/", "claude"])
125
+
126
+ def _calculate_cache_interval(self, total_messages: int) -> int:
127
+ if total_messages <= 1:
128
+ return 10
129
+
130
+ max_cached_messages = 3
131
+ non_system_messages = total_messages - 1
132
+
133
+ interval = 10
134
+ while non_system_messages // interval > max_cached_messages:
135
+ interval += 10
136
+
137
+ return interval
138
+
139
+ def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
140
+ if (
141
+ not self.config.enable_prompt_caching
142
+ or not supports_prompt_caching(self.config.model_name)
143
+ or not messages
144
+ ):
145
+ return messages
146
+
147
+ if not self._is_anthropic_model():
148
+ return messages
149
+
150
+ cached_messages = list(messages)
151
+
152
+ if cached_messages and cached_messages[0].get("role") == "system":
153
+ system_message = cached_messages[0].copy()
154
+ system_message["content"] = self._add_cache_control_to_content(
155
+ system_message["content"]
156
+ )
157
+ cached_messages[0] = system_message
158
+
159
+ total_messages = len(cached_messages)
160
+ if total_messages > 1:
161
+ interval = self._calculate_cache_interval(total_messages)
162
+
163
+ cached_count = 0
164
+ for i in range(interval, total_messages, interval):
165
+ if cached_count >= 3:
166
+ break
167
+
168
+ if i < len(cached_messages):
169
+ message = cached_messages[i].copy()
170
+ message["content"] = self._add_cache_control_to_content(message["content"])
171
+ cached_messages[i] = message
172
+ cached_count += 1
173
+
174
+ return cached_messages
175
+
176
+ async def generate(
177
+ self,
178
+ conversation_history: list[dict[str, Any]],
179
+ scan_id: str | None = None,
180
+ step_number: int = 1,
181
+ ) -> LLMResponse:
182
+ messages = [{"role": "system", "content": self.system_prompt}]
183
+
184
+ compressed_history = list(self.memory_compressor.compress_history(conversation_history))
185
+
186
+ conversation_history.clear()
187
+ conversation_history.extend(compressed_history)
188
+ messages.extend(compressed_history)
189
+
190
+ cached_messages = self._prepare_cached_messages(messages)
191
+
192
+ try:
193
+ response = await self._make_request(cached_messages)
194
+ self._update_usage_stats(response)
195
+
196
+ content = ""
197
+ if (
198
+ response.choices
199
+ and hasattr(response.choices[0], "message")
200
+ and response.choices[0].message
201
+ ):
202
+ content = getattr(response.choices[0].message, "content", "") or ""
203
+
204
+ content = _truncate_to_first_function(content)
205
+
206
+ if "</function>" in content:
207
+ function_end_index = content.find("</function>") + len("</function>")
208
+ content = content[:function_end_index]
209
+
210
+ tool_invocations = parse_tool_invocations(content)
211
+
212
+ return LLMResponse(
213
+ scan_id=scan_id,
214
+ step_number=step_number,
215
+ role=StepRole.AGENT,
216
+ content=content,
217
+ tool_invocations=tool_invocations if tool_invocations else None,
218
+ )
219
+
220
+ except (ValueError, TypeError, RuntimeError):
221
+ logger.exception("Error in LLM generation")
222
+ return LLMResponse(
223
+ scan_id=scan_id,
224
+ step_number=step_number,
225
+ role=StepRole.AGENT,
226
+ content="An error occurred while generating the response",
227
+ tool_invocations=None,
228
+ )
229
+
230
+ @property
231
+ def usage_stats(self) -> dict[str, dict[str, int | float]]:
232
+ return {
233
+ "total": self._total_stats.to_dict(),
234
+ "last_request": self._last_request_stats.to_dict(),
235
+ }
236
+
237
+ def get_cache_config(self) -> dict[str, bool]:
238
+ return {
239
+ "enabled": self.config.enable_prompt_caching,
240
+ "supported": supports_prompt_caching(self.config.model_name),
241
+ }
242
+
243
+ async def _make_request(
244
+ self,
245
+ messages: list[dict[str, Any]],
246
+ ) -> ModelResponse:
247
+ completion_args = {
248
+ "model": self.config.model_name,
249
+ "messages": messages,
250
+ "temperature": self.config.temperature,
251
+ "stop": ["</function>"],
252
+ }
253
+
254
+ queue = get_global_queue()
255
+ response = await queue.make_request(completion_args)
256
+
257
+ self._total_stats.requests += 1
258
+ self._last_request_stats = RequestStats(requests=1)
259
+
260
+ return response
261
+
262
+ def _update_usage_stats(self, response: ModelResponse) -> None:
263
+ try:
264
+ if hasattr(response, "usage") and response.usage:
265
+ input_tokens = getattr(response.usage, "prompt_tokens", 0)
266
+ output_tokens = getattr(response.usage, "completion_tokens", 0)
267
+
268
+ cached_tokens = 0
269
+ cache_creation_tokens = 0
270
+
271
+ if hasattr(response.usage, "prompt_tokens_details"):
272
+ prompt_details = response.usage.prompt_tokens_details
273
+ if hasattr(prompt_details, "cached_tokens"):
274
+ cached_tokens = prompt_details.cached_tokens or 0
275
+
276
+ if hasattr(response.usage, "cache_creation_input_tokens"):
277
+ cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
278
+
279
+ else:
280
+ input_tokens = 0
281
+ output_tokens = 0
282
+ cached_tokens = 0
283
+ cache_creation_tokens = 0
284
+
285
+ try:
286
+ cost = completion_cost(response) or 0.0
287
+ except (ValueError, TypeError, RuntimeError) as e:
288
+ logger.warning(f"Failed to calculate cost: {e}")
289
+ cost = 0.0
290
+
291
+ self._total_stats.input_tokens += input_tokens
292
+ self._total_stats.output_tokens += output_tokens
293
+ self._total_stats.cached_tokens += cached_tokens
294
+ self._total_stats.cache_creation_tokens += cache_creation_tokens
295
+ self._total_stats.cost += cost
296
+
297
+ self._last_request_stats.input_tokens = input_tokens
298
+ self._last_request_stats.output_tokens = output_tokens
299
+ self._last_request_stats.cached_tokens = cached_tokens
300
+ self._last_request_stats.cache_creation_tokens = cache_creation_tokens
301
+ self._last_request_stats.cost = cost
302
+
303
+ if cached_tokens > 0:
304
+ logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
305
+ if cache_creation_tokens > 0:
306
+ logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
307
+
308
+ logger.info(f"Usage stats: {self.usage_stats}")
309
+ except (AttributeError, TypeError, ValueError) as e:
310
+ logger.warning(f"Failed to update usage stats: {e}")