strix-agent 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strix/__init__.py +0 -0
- strix/agents/StrixAgent/__init__.py +4 -0
- strix/agents/StrixAgent/strix_agent.py +60 -0
- strix/agents/StrixAgent/system_prompt.jinja +504 -0
- strix/agents/__init__.py +10 -0
- strix/agents/base_agent.py +394 -0
- strix/agents/state.py +139 -0
- strix/cli/__init__.py +4 -0
- strix/cli/app.py +1124 -0
- strix/cli/assets/cli.tcss +680 -0
- strix/cli/main.py +542 -0
- strix/cli/tool_components/__init__.py +39 -0
- strix/cli/tool_components/agents_graph_renderer.py +129 -0
- strix/cli/tool_components/base_renderer.py +61 -0
- strix/cli/tool_components/browser_renderer.py +107 -0
- strix/cli/tool_components/file_edit_renderer.py +95 -0
- strix/cli/tool_components/finish_renderer.py +32 -0
- strix/cli/tool_components/notes_renderer.py +108 -0
- strix/cli/tool_components/proxy_renderer.py +255 -0
- strix/cli/tool_components/python_renderer.py +34 -0
- strix/cli/tool_components/registry.py +72 -0
- strix/cli/tool_components/reporting_renderer.py +53 -0
- strix/cli/tool_components/scan_info_renderer.py +58 -0
- strix/cli/tool_components/terminal_renderer.py +99 -0
- strix/cli/tool_components/thinking_renderer.py +29 -0
- strix/cli/tool_components/user_message_renderer.py +43 -0
- strix/cli/tool_components/web_search_renderer.py +28 -0
- strix/cli/tracer.py +308 -0
- strix/llm/__init__.py +14 -0
- strix/llm/config.py +19 -0
- strix/llm/llm.py +310 -0
- strix/llm/memory_compressor.py +206 -0
- strix/llm/request_queue.py +63 -0
- strix/llm/utils.py +84 -0
- strix/prompts/__init__.py +113 -0
- strix/prompts/coordination/root_agent.jinja +41 -0
- strix/prompts/vulnerabilities/authentication_jwt.jinja +129 -0
- strix/prompts/vulnerabilities/business_logic.jinja +143 -0
- strix/prompts/vulnerabilities/csrf.jinja +168 -0
- strix/prompts/vulnerabilities/idor.jinja +164 -0
- strix/prompts/vulnerabilities/race_conditions.jinja +194 -0
- strix/prompts/vulnerabilities/rce.jinja +222 -0
- strix/prompts/vulnerabilities/sql_injection.jinja +216 -0
- strix/prompts/vulnerabilities/ssrf.jinja +168 -0
- strix/prompts/vulnerabilities/xss.jinja +221 -0
- strix/prompts/vulnerabilities/xxe.jinja +276 -0
- strix/runtime/__init__.py +19 -0
- strix/runtime/docker_runtime.py +298 -0
- strix/runtime/runtime.py +25 -0
- strix/runtime/tool_server.py +97 -0
- strix/tools/__init__.py +64 -0
- strix/tools/agents_graph/__init__.py +16 -0
- strix/tools/agents_graph/agents_graph_actions.py +610 -0
- strix/tools/agents_graph/agents_graph_actions_schema.xml +223 -0
- strix/tools/argument_parser.py +120 -0
- strix/tools/browser/__init__.py +4 -0
- strix/tools/browser/browser_actions.py +236 -0
- strix/tools/browser/browser_actions_schema.xml +183 -0
- strix/tools/browser/browser_instance.py +533 -0
- strix/tools/browser/tab_manager.py +342 -0
- strix/tools/executor.py +302 -0
- strix/tools/file_edit/__init__.py +4 -0
- strix/tools/file_edit/file_edit_actions.py +141 -0
- strix/tools/file_edit/file_edit_actions_schema.xml +128 -0
- strix/tools/finish/__init__.py +4 -0
- strix/tools/finish/finish_actions.py +167 -0
- strix/tools/finish/finish_actions_schema.xml +45 -0
- strix/tools/notes/__init__.py +14 -0
- strix/tools/notes/notes_actions.py +191 -0
- strix/tools/notes/notes_actions_schema.xml +150 -0
- strix/tools/proxy/__init__.py +20 -0
- strix/tools/proxy/proxy_actions.py +101 -0
- strix/tools/proxy/proxy_actions_schema.xml +267 -0
- strix/tools/proxy/proxy_manager.py +785 -0
- strix/tools/python/__init__.py +4 -0
- strix/tools/python/python_actions.py +47 -0
- strix/tools/python/python_actions_schema.xml +131 -0
- strix/tools/python/python_instance.py +172 -0
- strix/tools/python/python_manager.py +131 -0
- strix/tools/registry.py +196 -0
- strix/tools/reporting/__init__.py +6 -0
- strix/tools/reporting/reporting_actions.py +63 -0
- strix/tools/reporting/reporting_actions_schema.xml +30 -0
- strix/tools/terminal/__init__.py +4 -0
- strix/tools/terminal/terminal_actions.py +53 -0
- strix/tools/terminal/terminal_actions_schema.xml +114 -0
- strix/tools/terminal/terminal_instance.py +231 -0
- strix/tools/terminal/terminal_manager.py +191 -0
- strix/tools/thinking/__init__.py +4 -0
- strix/tools/thinking/thinking_actions.py +18 -0
- strix/tools/thinking/thinking_actions_schema.xml +52 -0
- strix/tools/web_search/__init__.py +4 -0
- strix/tools/web_search/web_search_actions.py +80 -0
- strix/tools/web_search/web_search_actions_schema.xml +83 -0
- strix_agent-0.1.1.dist-info/LICENSE +201 -0
- strix_agent-0.1.1.dist-info/METADATA +200 -0
- strix_agent-0.1.1.dist-info/RECORD +99 -0
- strix_agent-0.1.1.dist-info/WHEEL +4 -0
- strix_agent-0.1.1.dist-info/entry_points.txt +3 -0
strix/cli/tracer.py
ADDED
@@ -0,0 +1,308 @@
|
|
1
|
+
import logging
|
2
|
+
from datetime import UTC, datetime
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Any, Optional
|
5
|
+
from uuid import uuid4
|
6
|
+
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
_global_tracer: Optional["Tracer"] = None
|
11
|
+
|
12
|
+
|
13
|
+
def get_global_tracer() -> Optional["Tracer"]:
|
14
|
+
return _global_tracer
|
15
|
+
|
16
|
+
|
17
|
+
def set_global_tracer(tracer: "Tracer") -> None:
|
18
|
+
global _global_tracer # noqa: PLW0603
|
19
|
+
_global_tracer = tracer
|
20
|
+
|
21
|
+
|
22
|
+
class Tracer:
|
23
|
+
def __init__(self, run_name: str | None = None):
|
24
|
+
self.run_name = run_name
|
25
|
+
self.run_id = run_name or f"run-{uuid4().hex[:8]}"
|
26
|
+
self.start_time = datetime.now(UTC).isoformat()
|
27
|
+
self.end_time: str | None = None
|
28
|
+
|
29
|
+
self.agents: dict[str, dict[str, Any]] = {}
|
30
|
+
self.tool_executions: dict[int, dict[str, Any]] = {}
|
31
|
+
self.chat_messages: list[dict[str, Any]] = []
|
32
|
+
|
33
|
+
self.vulnerability_reports: list[dict[str, Any]] = []
|
34
|
+
self.final_scan_result: str | None = None
|
35
|
+
|
36
|
+
self.scan_results: dict[str, Any] | None = None
|
37
|
+
self.scan_config: dict[str, Any] | None = None
|
38
|
+
self.run_metadata: dict[str, Any] = {
|
39
|
+
"run_id": self.run_id,
|
40
|
+
"run_name": self.run_name,
|
41
|
+
"start_time": self.start_time,
|
42
|
+
"end_time": None,
|
43
|
+
"target": None,
|
44
|
+
"scan_type": None,
|
45
|
+
"status": "running",
|
46
|
+
}
|
47
|
+
self._run_dir: Path | None = None
|
48
|
+
self._next_execution_id = 1
|
49
|
+
self._next_message_id = 1
|
50
|
+
|
51
|
+
def set_run_name(self, run_name: str) -> None:
|
52
|
+
self.run_name = run_name
|
53
|
+
self.run_id = run_name
|
54
|
+
|
55
|
+
def get_run_dir(self) -> Path:
|
56
|
+
if self._run_dir is None:
|
57
|
+
workspace_root = Path(__file__).parent.parent.parent
|
58
|
+
runs_dir = workspace_root / "agent_runs"
|
59
|
+
runs_dir.mkdir(exist_ok=True)
|
60
|
+
|
61
|
+
run_dir_name = self.run_name if self.run_name else self.run_id
|
62
|
+
self._run_dir = runs_dir / run_dir_name
|
63
|
+
self._run_dir.mkdir(exist_ok=True)
|
64
|
+
|
65
|
+
return self._run_dir
|
66
|
+
|
67
|
+
def add_vulnerability_report(
|
68
|
+
self,
|
69
|
+
title: str,
|
70
|
+
content: str,
|
71
|
+
severity: str,
|
72
|
+
) -> str:
|
73
|
+
report_id = f"vuln-{len(self.vulnerability_reports) + 1:04d}"
|
74
|
+
|
75
|
+
report = {
|
76
|
+
"id": report_id,
|
77
|
+
"title": title.strip(),
|
78
|
+
"content": content.strip(),
|
79
|
+
"severity": severity.lower().strip(),
|
80
|
+
"timestamp": datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S UTC"),
|
81
|
+
}
|
82
|
+
|
83
|
+
self.vulnerability_reports.append(report)
|
84
|
+
logger.info(f"Added vulnerability report: {report_id} - {title}")
|
85
|
+
return report_id
|
86
|
+
|
87
|
+
def set_final_scan_result(
|
88
|
+
self,
|
89
|
+
content: str,
|
90
|
+
success: bool = True,
|
91
|
+
) -> None:
|
92
|
+
self.final_scan_result = content.strip()
|
93
|
+
|
94
|
+
self.scan_results = {
|
95
|
+
"scan_completed": True,
|
96
|
+
"content": content,
|
97
|
+
"success": success,
|
98
|
+
}
|
99
|
+
|
100
|
+
logger.info(f"Set final scan result: success={success}")
|
101
|
+
|
102
|
+
def log_agent_creation(
|
103
|
+
self, agent_id: str, name: str, task: str, parent_id: str | None = None
|
104
|
+
) -> None:
|
105
|
+
agent_data: dict[str, Any] = {
|
106
|
+
"id": agent_id,
|
107
|
+
"name": name,
|
108
|
+
"task": task,
|
109
|
+
"status": "running",
|
110
|
+
"parent_id": parent_id,
|
111
|
+
"created_at": datetime.now(UTC).isoformat(),
|
112
|
+
"updated_at": datetime.now(UTC).isoformat(),
|
113
|
+
"tool_executions": [],
|
114
|
+
}
|
115
|
+
|
116
|
+
self.agents[agent_id] = agent_data
|
117
|
+
|
118
|
+
def log_chat_message(
|
119
|
+
self,
|
120
|
+
content: str,
|
121
|
+
role: str,
|
122
|
+
agent_id: str | None = None,
|
123
|
+
metadata: dict[str, Any] | None = None,
|
124
|
+
) -> int:
|
125
|
+
message_id = self._next_message_id
|
126
|
+
self._next_message_id += 1
|
127
|
+
|
128
|
+
message_data = {
|
129
|
+
"message_id": message_id,
|
130
|
+
"content": content,
|
131
|
+
"role": role,
|
132
|
+
"agent_id": agent_id,
|
133
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
134
|
+
"metadata": metadata or {},
|
135
|
+
}
|
136
|
+
|
137
|
+
self.chat_messages.append(message_data)
|
138
|
+
return message_id
|
139
|
+
|
140
|
+
def log_tool_execution_start(self, agent_id: str, tool_name: str, args: dict[str, Any]) -> int:
|
141
|
+
execution_id = self._next_execution_id
|
142
|
+
self._next_execution_id += 1
|
143
|
+
|
144
|
+
now = datetime.now(UTC).isoformat()
|
145
|
+
execution_data = {
|
146
|
+
"execution_id": execution_id,
|
147
|
+
"agent_id": agent_id,
|
148
|
+
"tool_name": tool_name,
|
149
|
+
"args": args,
|
150
|
+
"status": "running",
|
151
|
+
"result": None,
|
152
|
+
"timestamp": now,
|
153
|
+
"started_at": now,
|
154
|
+
"completed_at": None,
|
155
|
+
}
|
156
|
+
|
157
|
+
self.tool_executions[execution_id] = execution_data
|
158
|
+
|
159
|
+
if agent_id in self.agents:
|
160
|
+
self.agents[agent_id]["tool_executions"].append(execution_id)
|
161
|
+
|
162
|
+
return execution_id
|
163
|
+
|
164
|
+
def update_tool_execution(
|
165
|
+
self, execution_id: int, status: str, result: Any | None = None
|
166
|
+
) -> None:
|
167
|
+
if execution_id in self.tool_executions:
|
168
|
+
self.tool_executions[execution_id]["status"] = status
|
169
|
+
self.tool_executions[execution_id]["result"] = result
|
170
|
+
self.tool_executions[execution_id]["completed_at"] = datetime.now(UTC).isoformat()
|
171
|
+
|
172
|
+
def update_agent_status(self, agent_id: str, status: str) -> None:
|
173
|
+
if agent_id in self.agents:
|
174
|
+
self.agents[agent_id]["status"] = status
|
175
|
+
self.agents[agent_id]["updated_at"] = datetime.now(UTC).isoformat()
|
176
|
+
|
177
|
+
def set_scan_config(self, config: dict[str, Any]) -> None:
|
178
|
+
self.scan_config = config
|
179
|
+
self.run_metadata.update(
|
180
|
+
{
|
181
|
+
"target": config.get("target", {}),
|
182
|
+
"scan_type": config.get("scan_type", "general"),
|
183
|
+
"user_instructions": config.get("user_instructions", ""),
|
184
|
+
"max_iterations": config.get("max_iterations", 200),
|
185
|
+
}
|
186
|
+
)
|
187
|
+
|
188
|
+
def save_run_data(self) -> None:
|
189
|
+
try:
|
190
|
+
run_dir = self.get_run_dir()
|
191
|
+
self.end_time = datetime.now(UTC).isoformat()
|
192
|
+
|
193
|
+
if self.final_scan_result:
|
194
|
+
scan_report_file = run_dir / "scan_report.md"
|
195
|
+
with scan_report_file.open("w", encoding="utf-8") as f:
|
196
|
+
f.write("# Security Scan Report\n\n")
|
197
|
+
f.write(
|
198
|
+
f"**Generated:** {datetime.now(UTC).strftime('%Y-%m-%d %H:%M:%S UTC')}\n\n"
|
199
|
+
)
|
200
|
+
f.write(f"{self.final_scan_result}\n")
|
201
|
+
logger.info(f"Saved final scan report to: {scan_report_file}")
|
202
|
+
|
203
|
+
if self.vulnerability_reports:
|
204
|
+
vuln_dir = run_dir / "vulnerabilities"
|
205
|
+
vuln_dir.mkdir(exist_ok=True)
|
206
|
+
|
207
|
+
severity_order = {"critical": 0, "high": 1, "medium": 2, "low": 3, "info": 4}
|
208
|
+
sorted_reports = sorted(
|
209
|
+
self.vulnerability_reports,
|
210
|
+
key=lambda x: (severity_order.get(x["severity"], 5), x["timestamp"]),
|
211
|
+
)
|
212
|
+
|
213
|
+
for report in sorted_reports:
|
214
|
+
vuln_file = vuln_dir / f"{report['id']}.md"
|
215
|
+
with vuln_file.open("w", encoding="utf-8") as f:
|
216
|
+
f.write(f"# {report['title']}\n\n")
|
217
|
+
f.write(f"**ID:** {report['id']}\n")
|
218
|
+
f.write(f"**Severity:** {report['severity'].upper()}\n")
|
219
|
+
f.write(f"**Found:** {report['timestamp']}\n\n")
|
220
|
+
f.write("## Description\n\n")
|
221
|
+
f.write(f"{report['content']}\n")
|
222
|
+
|
223
|
+
vuln_csv_file = run_dir / "vulnerabilities.csv"
|
224
|
+
with vuln_csv_file.open("w", encoding="utf-8", newline="") as f:
|
225
|
+
import csv
|
226
|
+
|
227
|
+
fieldnames = ["id", "title", "severity", "timestamp", "file"]
|
228
|
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
229
|
+
writer.writeheader()
|
230
|
+
|
231
|
+
for report in sorted_reports:
|
232
|
+
writer.writerow(
|
233
|
+
{
|
234
|
+
"id": report["id"],
|
235
|
+
"title": report["title"],
|
236
|
+
"severity": report["severity"].upper(),
|
237
|
+
"timestamp": report["timestamp"],
|
238
|
+
"file": f"vulnerabilities/{report['id']}.md",
|
239
|
+
}
|
240
|
+
)
|
241
|
+
|
242
|
+
logger.info(
|
243
|
+
f"Saved {len(self.vulnerability_reports)} vulnerability reports to: {vuln_dir}"
|
244
|
+
)
|
245
|
+
logger.info(f"Saved vulnerability index to: {vuln_csv_file}")
|
246
|
+
|
247
|
+
logger.info(f"📊 Essential scan data saved to: {run_dir}")
|
248
|
+
|
249
|
+
except (OSError, RuntimeError):
|
250
|
+
logger.exception("Failed to save scan data")
|
251
|
+
|
252
|
+
def _calculate_duration(self) -> float:
|
253
|
+
try:
|
254
|
+
start = datetime.fromisoformat(self.start_time.replace("Z", "+00:00"))
|
255
|
+
if self.end_time:
|
256
|
+
end = datetime.fromisoformat(self.end_time.replace("Z", "+00:00"))
|
257
|
+
return (end - start).total_seconds()
|
258
|
+
except (ValueError, TypeError):
|
259
|
+
pass
|
260
|
+
return 0.0
|
261
|
+
|
262
|
+
def get_agent_tools(self, agent_id: str) -> list[dict[str, Any]]:
|
263
|
+
return [
|
264
|
+
exec_data
|
265
|
+
for exec_data in self.tool_executions.values()
|
266
|
+
if exec_data.get("agent_id") == agent_id
|
267
|
+
]
|
268
|
+
|
269
|
+
def get_real_tool_count(self) -> int:
|
270
|
+
return sum(
|
271
|
+
1
|
272
|
+
for exec_data in self.tool_executions.values()
|
273
|
+
if exec_data.get("tool_name") not in ["scan_start_info", "subagent_start_info"]
|
274
|
+
)
|
275
|
+
|
276
|
+
def get_total_llm_stats(self) -> dict[str, Any]:
|
277
|
+
from strix.tools.agents_graph.agents_graph_actions import _agent_instances
|
278
|
+
|
279
|
+
total_stats = {
|
280
|
+
"input_tokens": 0,
|
281
|
+
"output_tokens": 0,
|
282
|
+
"cached_tokens": 0,
|
283
|
+
"cache_creation_tokens": 0,
|
284
|
+
"cost": 0.0,
|
285
|
+
"requests": 0,
|
286
|
+
"failed_requests": 0,
|
287
|
+
}
|
288
|
+
|
289
|
+
for agent_instance in _agent_instances.values():
|
290
|
+
if hasattr(agent_instance, "llm") and hasattr(agent_instance.llm, "_total_stats"):
|
291
|
+
agent_stats = agent_instance.llm._total_stats
|
292
|
+
total_stats["input_tokens"] += agent_stats.input_tokens
|
293
|
+
total_stats["output_tokens"] += agent_stats.output_tokens
|
294
|
+
total_stats["cached_tokens"] += agent_stats.cached_tokens
|
295
|
+
total_stats["cache_creation_tokens"] += agent_stats.cache_creation_tokens
|
296
|
+
total_stats["cost"] += agent_stats.cost
|
297
|
+
total_stats["requests"] += agent_stats.requests
|
298
|
+
total_stats["failed_requests"] += agent_stats.failed_requests
|
299
|
+
|
300
|
+
total_stats["cost"] = round(total_stats["cost"], 4)
|
301
|
+
|
302
|
+
return {
|
303
|
+
"total": total_stats,
|
304
|
+
"total_tokens": total_stats["input_tokens"] + total_stats["output_tokens"],
|
305
|
+
}
|
306
|
+
|
307
|
+
def cleanup(self) -> None:
|
308
|
+
self.save_run_data()
|
strix/llm/__init__.py
ADDED
strix/llm/config.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
|
4
|
+
class LLMConfig:
|
5
|
+
def __init__(
|
6
|
+
self,
|
7
|
+
model_name: str | None = None,
|
8
|
+
temperature: float = 0,
|
9
|
+
enable_prompt_caching: bool = True,
|
10
|
+
prompt_modules: list[str] | None = None,
|
11
|
+
):
|
12
|
+
self.model_name = model_name or os.getenv("STRIX_LLM", "anthropic/claude-sonnet-4-20250514")
|
13
|
+
|
14
|
+
if not self.model_name:
|
15
|
+
raise ValueError("STRIX_LLM environment variable must be set and not empty")
|
16
|
+
|
17
|
+
self.temperature = max(0.0, min(1.0, temperature))
|
18
|
+
self.enable_prompt_caching = enable_prompt_caching
|
19
|
+
self.prompt_modules = prompt_modules or []
|
strix/llm/llm.py
ADDED
@@ -0,0 +1,310 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from enum import Enum
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import litellm
|
9
|
+
from jinja2 import (
|
10
|
+
Environment,
|
11
|
+
FileSystemLoader,
|
12
|
+
select_autoescape,
|
13
|
+
)
|
14
|
+
from litellm import ModelResponse, completion_cost
|
15
|
+
from litellm.utils import supports_prompt_caching
|
16
|
+
|
17
|
+
from strix.llm.config import LLMConfig
|
18
|
+
from strix.llm.memory_compressor import MemoryCompressor
|
19
|
+
from strix.llm.request_queue import get_global_queue
|
20
|
+
from strix.llm.utils import _truncate_to_first_function, parse_tool_invocations
|
21
|
+
from strix.prompts import load_prompt_modules
|
22
|
+
from strix.tools import get_tools_prompt
|
23
|
+
|
24
|
+
|
25
|
+
logger = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
api_key = os.getenv("LLM_API_KEY")
|
28
|
+
if api_key:
|
29
|
+
litellm.api_key = api_key
|
30
|
+
|
31
|
+
|
32
|
+
class StepRole(str, Enum):
|
33
|
+
AGENT = "agent"
|
34
|
+
USER = "user"
|
35
|
+
SYSTEM = "system"
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class LLMResponse:
|
40
|
+
content: str
|
41
|
+
tool_invocations: list[dict[str, Any]] | None = None
|
42
|
+
scan_id: str | None = None
|
43
|
+
step_number: int = 1
|
44
|
+
role: StepRole = StepRole.AGENT
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class RequestStats:
|
49
|
+
input_tokens: int = 0
|
50
|
+
output_tokens: int = 0
|
51
|
+
cached_tokens: int = 0
|
52
|
+
cache_creation_tokens: int = 0
|
53
|
+
cost: float = 0.0
|
54
|
+
requests: int = 0
|
55
|
+
failed_requests: int = 0
|
56
|
+
|
57
|
+
def to_dict(self) -> dict[str, int | float]:
|
58
|
+
return {
|
59
|
+
"input_tokens": self.input_tokens,
|
60
|
+
"output_tokens": self.output_tokens,
|
61
|
+
"cached_tokens": self.cached_tokens,
|
62
|
+
"cache_creation_tokens": self.cache_creation_tokens,
|
63
|
+
"cost": round(self.cost, 4),
|
64
|
+
"requests": self.requests,
|
65
|
+
"failed_requests": self.failed_requests,
|
66
|
+
}
|
67
|
+
|
68
|
+
|
69
|
+
class LLM:
|
70
|
+
def __init__(self, config: LLMConfig, agent_name: str | None = None):
|
71
|
+
self.config = config
|
72
|
+
self.agent_name = agent_name
|
73
|
+
self._total_stats = RequestStats()
|
74
|
+
self._last_request_stats = RequestStats()
|
75
|
+
|
76
|
+
self.memory_compressor = MemoryCompressor()
|
77
|
+
|
78
|
+
if agent_name:
|
79
|
+
prompt_dir = Path(__file__).parent.parent / "agents" / agent_name
|
80
|
+
prompts_dir = Path(__file__).parent.parent / "prompts"
|
81
|
+
|
82
|
+
loader = FileSystemLoader([prompt_dir, prompts_dir])
|
83
|
+
self.jinja_env = Environment(
|
84
|
+
loader=loader,
|
85
|
+
autoescape=select_autoescape(enabled_extensions=(), default_for_string=False),
|
86
|
+
)
|
87
|
+
|
88
|
+
try:
|
89
|
+
prompt_module_content = load_prompt_modules(
|
90
|
+
self.config.prompt_modules or [], self.jinja_env
|
91
|
+
)
|
92
|
+
|
93
|
+
def get_module(name: str) -> str:
|
94
|
+
return prompt_module_content.get(name, "")
|
95
|
+
|
96
|
+
self.jinja_env.globals["get_module"] = get_module
|
97
|
+
|
98
|
+
self.system_prompt = self.jinja_env.get_template("system_prompt.jinja").render(
|
99
|
+
get_tools_prompt=get_tools_prompt,
|
100
|
+
loaded_module_names=list(prompt_module_content.keys()),
|
101
|
+
**prompt_module_content,
|
102
|
+
)
|
103
|
+
except (FileNotFoundError, OSError, ValueError) as e:
|
104
|
+
logger.warning(f"Failed to load system prompt for {agent_name}: {e}")
|
105
|
+
self.system_prompt = "You are a helpful AI assistant."
|
106
|
+
else:
|
107
|
+
self.system_prompt = "You are a helpful AI assistant."
|
108
|
+
|
109
|
+
def _add_cache_control_to_content(
|
110
|
+
self, content: str | list[dict[str, Any]]
|
111
|
+
) -> str | list[dict[str, Any]]:
|
112
|
+
if isinstance(content, str):
|
113
|
+
return [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
114
|
+
if isinstance(content, list) and content:
|
115
|
+
last_item = content[-1]
|
116
|
+
if isinstance(last_item, dict) and last_item.get("type") == "text":
|
117
|
+
return content[:-1] + [{**last_item, "cache_control": {"type": "ephemeral"}}]
|
118
|
+
return content
|
119
|
+
|
120
|
+
def _is_anthropic_model(self) -> bool:
|
121
|
+
if not self.config.model_name:
|
122
|
+
return False
|
123
|
+
model_lower = self.config.model_name.lower()
|
124
|
+
return any(provider in model_lower for provider in ["anthropic/", "claude"])
|
125
|
+
|
126
|
+
def _calculate_cache_interval(self, total_messages: int) -> int:
|
127
|
+
if total_messages <= 1:
|
128
|
+
return 10
|
129
|
+
|
130
|
+
max_cached_messages = 3
|
131
|
+
non_system_messages = total_messages - 1
|
132
|
+
|
133
|
+
interval = 10
|
134
|
+
while non_system_messages // interval > max_cached_messages:
|
135
|
+
interval += 10
|
136
|
+
|
137
|
+
return interval
|
138
|
+
|
139
|
+
def _prepare_cached_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
140
|
+
if (
|
141
|
+
not self.config.enable_prompt_caching
|
142
|
+
or not supports_prompt_caching(self.config.model_name)
|
143
|
+
or not messages
|
144
|
+
):
|
145
|
+
return messages
|
146
|
+
|
147
|
+
if not self._is_anthropic_model():
|
148
|
+
return messages
|
149
|
+
|
150
|
+
cached_messages = list(messages)
|
151
|
+
|
152
|
+
if cached_messages and cached_messages[0].get("role") == "system":
|
153
|
+
system_message = cached_messages[0].copy()
|
154
|
+
system_message["content"] = self._add_cache_control_to_content(
|
155
|
+
system_message["content"]
|
156
|
+
)
|
157
|
+
cached_messages[0] = system_message
|
158
|
+
|
159
|
+
total_messages = len(cached_messages)
|
160
|
+
if total_messages > 1:
|
161
|
+
interval = self._calculate_cache_interval(total_messages)
|
162
|
+
|
163
|
+
cached_count = 0
|
164
|
+
for i in range(interval, total_messages, interval):
|
165
|
+
if cached_count >= 3:
|
166
|
+
break
|
167
|
+
|
168
|
+
if i < len(cached_messages):
|
169
|
+
message = cached_messages[i].copy()
|
170
|
+
message["content"] = self._add_cache_control_to_content(message["content"])
|
171
|
+
cached_messages[i] = message
|
172
|
+
cached_count += 1
|
173
|
+
|
174
|
+
return cached_messages
|
175
|
+
|
176
|
+
async def generate(
|
177
|
+
self,
|
178
|
+
conversation_history: list[dict[str, Any]],
|
179
|
+
scan_id: str | None = None,
|
180
|
+
step_number: int = 1,
|
181
|
+
) -> LLMResponse:
|
182
|
+
messages = [{"role": "system", "content": self.system_prompt}]
|
183
|
+
|
184
|
+
compressed_history = list(self.memory_compressor.compress_history(conversation_history))
|
185
|
+
|
186
|
+
conversation_history.clear()
|
187
|
+
conversation_history.extend(compressed_history)
|
188
|
+
messages.extend(compressed_history)
|
189
|
+
|
190
|
+
cached_messages = self._prepare_cached_messages(messages)
|
191
|
+
|
192
|
+
try:
|
193
|
+
response = await self._make_request(cached_messages)
|
194
|
+
self._update_usage_stats(response)
|
195
|
+
|
196
|
+
content = ""
|
197
|
+
if (
|
198
|
+
response.choices
|
199
|
+
and hasattr(response.choices[0], "message")
|
200
|
+
and response.choices[0].message
|
201
|
+
):
|
202
|
+
content = getattr(response.choices[0].message, "content", "") or ""
|
203
|
+
|
204
|
+
content = _truncate_to_first_function(content)
|
205
|
+
|
206
|
+
if "</function>" in content:
|
207
|
+
function_end_index = content.find("</function>") + len("</function>")
|
208
|
+
content = content[:function_end_index]
|
209
|
+
|
210
|
+
tool_invocations = parse_tool_invocations(content)
|
211
|
+
|
212
|
+
return LLMResponse(
|
213
|
+
scan_id=scan_id,
|
214
|
+
step_number=step_number,
|
215
|
+
role=StepRole.AGENT,
|
216
|
+
content=content,
|
217
|
+
tool_invocations=tool_invocations if tool_invocations else None,
|
218
|
+
)
|
219
|
+
|
220
|
+
except (ValueError, TypeError, RuntimeError):
|
221
|
+
logger.exception("Error in LLM generation")
|
222
|
+
return LLMResponse(
|
223
|
+
scan_id=scan_id,
|
224
|
+
step_number=step_number,
|
225
|
+
role=StepRole.AGENT,
|
226
|
+
content="An error occurred while generating the response",
|
227
|
+
tool_invocations=None,
|
228
|
+
)
|
229
|
+
|
230
|
+
@property
|
231
|
+
def usage_stats(self) -> dict[str, dict[str, int | float]]:
|
232
|
+
return {
|
233
|
+
"total": self._total_stats.to_dict(),
|
234
|
+
"last_request": self._last_request_stats.to_dict(),
|
235
|
+
}
|
236
|
+
|
237
|
+
def get_cache_config(self) -> dict[str, bool]:
|
238
|
+
return {
|
239
|
+
"enabled": self.config.enable_prompt_caching,
|
240
|
+
"supported": supports_prompt_caching(self.config.model_name),
|
241
|
+
}
|
242
|
+
|
243
|
+
async def _make_request(
|
244
|
+
self,
|
245
|
+
messages: list[dict[str, Any]],
|
246
|
+
) -> ModelResponse:
|
247
|
+
completion_args = {
|
248
|
+
"model": self.config.model_name,
|
249
|
+
"messages": messages,
|
250
|
+
"temperature": self.config.temperature,
|
251
|
+
"stop": ["</function>"],
|
252
|
+
}
|
253
|
+
|
254
|
+
queue = get_global_queue()
|
255
|
+
response = await queue.make_request(completion_args)
|
256
|
+
|
257
|
+
self._total_stats.requests += 1
|
258
|
+
self._last_request_stats = RequestStats(requests=1)
|
259
|
+
|
260
|
+
return response
|
261
|
+
|
262
|
+
def _update_usage_stats(self, response: ModelResponse) -> None:
|
263
|
+
try:
|
264
|
+
if hasattr(response, "usage") and response.usage:
|
265
|
+
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
266
|
+
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
267
|
+
|
268
|
+
cached_tokens = 0
|
269
|
+
cache_creation_tokens = 0
|
270
|
+
|
271
|
+
if hasattr(response.usage, "prompt_tokens_details"):
|
272
|
+
prompt_details = response.usage.prompt_tokens_details
|
273
|
+
if hasattr(prompt_details, "cached_tokens"):
|
274
|
+
cached_tokens = prompt_details.cached_tokens or 0
|
275
|
+
|
276
|
+
if hasattr(response.usage, "cache_creation_input_tokens"):
|
277
|
+
cache_creation_tokens = response.usage.cache_creation_input_tokens or 0
|
278
|
+
|
279
|
+
else:
|
280
|
+
input_tokens = 0
|
281
|
+
output_tokens = 0
|
282
|
+
cached_tokens = 0
|
283
|
+
cache_creation_tokens = 0
|
284
|
+
|
285
|
+
try:
|
286
|
+
cost = completion_cost(response) or 0.0
|
287
|
+
except (ValueError, TypeError, RuntimeError) as e:
|
288
|
+
logger.warning(f"Failed to calculate cost: {e}")
|
289
|
+
cost = 0.0
|
290
|
+
|
291
|
+
self._total_stats.input_tokens += input_tokens
|
292
|
+
self._total_stats.output_tokens += output_tokens
|
293
|
+
self._total_stats.cached_tokens += cached_tokens
|
294
|
+
self._total_stats.cache_creation_tokens += cache_creation_tokens
|
295
|
+
self._total_stats.cost += cost
|
296
|
+
|
297
|
+
self._last_request_stats.input_tokens = input_tokens
|
298
|
+
self._last_request_stats.output_tokens = output_tokens
|
299
|
+
self._last_request_stats.cached_tokens = cached_tokens
|
300
|
+
self._last_request_stats.cache_creation_tokens = cache_creation_tokens
|
301
|
+
self._last_request_stats.cost = cost
|
302
|
+
|
303
|
+
if cached_tokens > 0:
|
304
|
+
logger.info(f"Cache hit: {cached_tokens} cached tokens, {input_tokens} new tokens")
|
305
|
+
if cache_creation_tokens > 0:
|
306
|
+
logger.info(f"Cache creation: {cache_creation_tokens} tokens written to cache")
|
307
|
+
|
308
|
+
logger.info(f"Usage stats: {self.usage_stats}")
|
309
|
+
except (AttributeError, TypeError, ValueError) as e:
|
310
|
+
logger.warning(f"Failed to update usage stats: {e}")
|