aiptx 2.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aipt_v2/__init__.py +110 -0
- aipt_v2/__main__.py +24 -0
- aipt_v2/agents/AIPTxAgent/__init__.py +10 -0
- aipt_v2/agents/AIPTxAgent/aiptx_agent.py +211 -0
- aipt_v2/agents/__init__.py +46 -0
- aipt_v2/agents/base.py +520 -0
- aipt_v2/agents/exploit_agent.py +688 -0
- aipt_v2/agents/ptt.py +406 -0
- aipt_v2/agents/state.py +168 -0
- aipt_v2/app.py +957 -0
- aipt_v2/browser/__init__.py +31 -0
- aipt_v2/browser/automation.py +458 -0
- aipt_v2/browser/crawler.py +453 -0
- aipt_v2/cli.py +2933 -0
- aipt_v2/compliance/__init__.py +71 -0
- aipt_v2/compliance/compliance_report.py +449 -0
- aipt_v2/compliance/framework_mapper.py +424 -0
- aipt_v2/compliance/nist_mapping.py +345 -0
- aipt_v2/compliance/owasp_mapping.py +330 -0
- aipt_v2/compliance/pci_mapping.py +297 -0
- aipt_v2/config.py +341 -0
- aipt_v2/core/__init__.py +43 -0
- aipt_v2/core/agent.py +630 -0
- aipt_v2/core/llm.py +395 -0
- aipt_v2/core/memory.py +305 -0
- aipt_v2/core/ptt.py +329 -0
- aipt_v2/database/__init__.py +14 -0
- aipt_v2/database/models.py +232 -0
- aipt_v2/database/repository.py +384 -0
- aipt_v2/docker/__init__.py +23 -0
- aipt_v2/docker/builder.py +260 -0
- aipt_v2/docker/manager.py +222 -0
- aipt_v2/docker/sandbox.py +371 -0
- aipt_v2/evasion/__init__.py +58 -0
- aipt_v2/evasion/request_obfuscator.py +272 -0
- aipt_v2/evasion/tls_fingerprint.py +285 -0
- aipt_v2/evasion/ua_rotator.py +301 -0
- aipt_v2/evasion/waf_bypass.py +439 -0
- aipt_v2/execution/__init__.py +23 -0
- aipt_v2/execution/executor.py +302 -0
- aipt_v2/execution/parser.py +544 -0
- aipt_v2/execution/terminal.py +337 -0
- aipt_v2/health.py +437 -0
- aipt_v2/intelligence/__init__.py +194 -0
- aipt_v2/intelligence/adaptation.py +474 -0
- aipt_v2/intelligence/auth.py +520 -0
- aipt_v2/intelligence/chaining.py +775 -0
- aipt_v2/intelligence/correlation.py +536 -0
- aipt_v2/intelligence/cve_aipt.py +334 -0
- aipt_v2/intelligence/cve_info.py +1111 -0
- aipt_v2/intelligence/knowledge_graph.py +590 -0
- aipt_v2/intelligence/learning.py +626 -0
- aipt_v2/intelligence/llm_analyzer.py +502 -0
- aipt_v2/intelligence/llm_tool_selector.py +518 -0
- aipt_v2/intelligence/payload_generator.py +562 -0
- aipt_v2/intelligence/rag.py +239 -0
- aipt_v2/intelligence/scope.py +442 -0
- aipt_v2/intelligence/searchers/__init__.py +5 -0
- aipt_v2/intelligence/searchers/exploitdb_searcher.py +523 -0
- aipt_v2/intelligence/searchers/github_searcher.py +467 -0
- aipt_v2/intelligence/searchers/google_searcher.py +281 -0
- aipt_v2/intelligence/tools.json +443 -0
- aipt_v2/intelligence/triage.py +670 -0
- aipt_v2/interactive_shell.py +559 -0
- aipt_v2/interface/__init__.py +5 -0
- aipt_v2/interface/cli.py +230 -0
- aipt_v2/interface/main.py +501 -0
- aipt_v2/interface/tui.py +1276 -0
- aipt_v2/interface/utils.py +583 -0
- aipt_v2/llm/__init__.py +39 -0
- aipt_v2/llm/config.py +26 -0
- aipt_v2/llm/llm.py +514 -0
- aipt_v2/llm/memory.py +214 -0
- aipt_v2/llm/request_queue.py +89 -0
- aipt_v2/llm/utils.py +89 -0
- aipt_v2/local_tool_installer.py +1467 -0
- aipt_v2/models/__init__.py +15 -0
- aipt_v2/models/findings.py +295 -0
- aipt_v2/models/phase_result.py +224 -0
- aipt_v2/models/scan_config.py +207 -0
- aipt_v2/monitoring/grafana/dashboards/aipt-dashboard.json +355 -0
- aipt_v2/monitoring/grafana/dashboards/default.yml +17 -0
- aipt_v2/monitoring/grafana/datasources/prometheus.yml +17 -0
- aipt_v2/monitoring/prometheus.yml +60 -0
- aipt_v2/orchestration/__init__.py +52 -0
- aipt_v2/orchestration/pipeline.py +398 -0
- aipt_v2/orchestration/progress.py +300 -0
- aipt_v2/orchestration/scheduler.py +296 -0
- aipt_v2/orchestrator.py +2427 -0
- aipt_v2/payloads/__init__.py +27 -0
- aipt_v2/payloads/cmdi.py +150 -0
- aipt_v2/payloads/sqli.py +263 -0
- aipt_v2/payloads/ssrf.py +204 -0
- aipt_v2/payloads/templates.py +222 -0
- aipt_v2/payloads/traversal.py +166 -0
- aipt_v2/payloads/xss.py +204 -0
- aipt_v2/prompts/__init__.py +60 -0
- aipt_v2/proxy/__init__.py +29 -0
- aipt_v2/proxy/history.py +352 -0
- aipt_v2/proxy/interceptor.py +452 -0
- aipt_v2/recon/__init__.py +44 -0
- aipt_v2/recon/dns.py +241 -0
- aipt_v2/recon/osint.py +367 -0
- aipt_v2/recon/subdomain.py +372 -0
- aipt_v2/recon/tech_detect.py +311 -0
- aipt_v2/reports/__init__.py +17 -0
- aipt_v2/reports/generator.py +313 -0
- aipt_v2/reports/html_report.py +378 -0
- aipt_v2/runtime/__init__.py +53 -0
- aipt_v2/runtime/base.py +30 -0
- aipt_v2/runtime/docker.py +401 -0
- aipt_v2/runtime/local.py +346 -0
- aipt_v2/runtime/tool_server.py +205 -0
- aipt_v2/runtime/vps.py +830 -0
- aipt_v2/scanners/__init__.py +28 -0
- aipt_v2/scanners/base.py +273 -0
- aipt_v2/scanners/nikto.py +244 -0
- aipt_v2/scanners/nmap.py +402 -0
- aipt_v2/scanners/nuclei.py +273 -0
- aipt_v2/scanners/web.py +454 -0
- aipt_v2/scripts/security_audit.py +366 -0
- aipt_v2/setup_wizard.py +941 -0
- aipt_v2/skills/__init__.py +80 -0
- aipt_v2/skills/agents/__init__.py +14 -0
- aipt_v2/skills/agents/api_tester.py +706 -0
- aipt_v2/skills/agents/base.py +477 -0
- aipt_v2/skills/agents/code_review.py +459 -0
- aipt_v2/skills/agents/security_agent.py +336 -0
- aipt_v2/skills/agents/web_pentest.py +818 -0
- aipt_v2/skills/prompts/__init__.py +647 -0
- aipt_v2/system_detector.py +539 -0
- aipt_v2/telemetry/__init__.py +7 -0
- aipt_v2/telemetry/tracer.py +347 -0
- aipt_v2/terminal/__init__.py +28 -0
- aipt_v2/terminal/executor.py +400 -0
- aipt_v2/terminal/sandbox.py +350 -0
- aipt_v2/tools/__init__.py +44 -0
- aipt_v2/tools/active_directory/__init__.py +78 -0
- aipt_v2/tools/active_directory/ad_config.py +238 -0
- aipt_v2/tools/active_directory/bloodhound_wrapper.py +447 -0
- aipt_v2/tools/active_directory/kerberos_attacks.py +430 -0
- aipt_v2/tools/active_directory/ldap_enum.py +533 -0
- aipt_v2/tools/active_directory/smb_attacks.py +505 -0
- aipt_v2/tools/agents_graph/__init__.py +19 -0
- aipt_v2/tools/agents_graph/agents_graph_actions.py +69 -0
- aipt_v2/tools/api_security/__init__.py +76 -0
- aipt_v2/tools/api_security/api_discovery.py +608 -0
- aipt_v2/tools/api_security/graphql_scanner.py +622 -0
- aipt_v2/tools/api_security/jwt_analyzer.py +577 -0
- aipt_v2/tools/api_security/openapi_fuzzer.py +761 -0
- aipt_v2/tools/browser/__init__.py +5 -0
- aipt_v2/tools/browser/browser_actions.py +238 -0
- aipt_v2/tools/browser/browser_instance.py +535 -0
- aipt_v2/tools/browser/tab_manager.py +344 -0
- aipt_v2/tools/cloud/__init__.py +70 -0
- aipt_v2/tools/cloud/cloud_config.py +273 -0
- aipt_v2/tools/cloud/cloud_scanner.py +639 -0
- aipt_v2/tools/cloud/prowler_tool.py +571 -0
- aipt_v2/tools/cloud/scoutsuite_tool.py +359 -0
- aipt_v2/tools/executor.py +307 -0
- aipt_v2/tools/parser.py +408 -0
- aipt_v2/tools/proxy/__init__.py +5 -0
- aipt_v2/tools/proxy/proxy_actions.py +103 -0
- aipt_v2/tools/proxy/proxy_manager.py +789 -0
- aipt_v2/tools/registry.py +196 -0
- aipt_v2/tools/scanners/__init__.py +343 -0
- aipt_v2/tools/scanners/acunetix_tool.py +712 -0
- aipt_v2/tools/scanners/burp_tool.py +631 -0
- aipt_v2/tools/scanners/config.py +156 -0
- aipt_v2/tools/scanners/nessus_tool.py +588 -0
- aipt_v2/tools/scanners/zap_tool.py +612 -0
- aipt_v2/tools/terminal/__init__.py +5 -0
- aipt_v2/tools/terminal/terminal_actions.py +37 -0
- aipt_v2/tools/terminal/terminal_manager.py +153 -0
- aipt_v2/tools/terminal/terminal_session.py +449 -0
- aipt_v2/tools/tool_processing.py +108 -0
- aipt_v2/utils/__init__.py +17 -0
- aipt_v2/utils/logging.py +202 -0
- aipt_v2/utils/model_manager.py +187 -0
- aipt_v2/utils/searchers/__init__.py +269 -0
- aipt_v2/verify_install.py +793 -0
- aiptx-2.0.7.dist-info/METADATA +345 -0
- aiptx-2.0.7.dist-info/RECORD +187 -0
- aiptx-2.0.7.dist-info/WHEEL +5 -0
- aiptx-2.0.7.dist-info/entry_points.txt +7 -0
- aiptx-2.0.7.dist-info/licenses/LICENSE +21 -0
- aiptx-2.0.7.dist-info/top_level.txt +1 -0
aipt_v2/llm/memory.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import litellm
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
MAX_TOTAL_TOKENS = 100_000
|
|
14
|
+
MIN_RECENT_MESSAGES = 15
|
|
15
|
+
|
|
16
|
+
SUMMARY_PROMPT_TEMPLATE = """You are an agent performing context
|
|
17
|
+
condensation for a security agent. Your job is to compress scan data while preserving
|
|
18
|
+
ALL operationally critical information for continuing the security assessment.
|
|
19
|
+
|
|
20
|
+
CRITICAL ELEMENTS TO PRESERVE:
|
|
21
|
+
- Discovered vulnerabilities and potential attack vectors
|
|
22
|
+
- Scan results and tool outputs (compressed but maintaining key findings)
|
|
23
|
+
- Access credentials, tokens, or authentication details found
|
|
24
|
+
- System architecture insights and potential weak points
|
|
25
|
+
- Progress made in the assessment
|
|
26
|
+
- Failed attempts and dead ends (to avoid duplication)
|
|
27
|
+
- Any decisions made about the testing approach
|
|
28
|
+
|
|
29
|
+
COMPRESSION GUIDELINES:
|
|
30
|
+
- Preserve exact technical details (URLs, paths, parameters, payloads)
|
|
31
|
+
- Summarize verbose tool outputs while keeping critical findings
|
|
32
|
+
- Maintain version numbers, specific technologies identified
|
|
33
|
+
- Keep exact error messages that might indicate vulnerabilities
|
|
34
|
+
- Compress repetitive or similar findings into consolidated form
|
|
35
|
+
|
|
36
|
+
Remember: Another security agent will use this summary to continue the assessment.
|
|
37
|
+
They must be able to pick up exactly where you left off without losing any
|
|
38
|
+
operational advantage or context needed to find vulnerabilities.
|
|
39
|
+
|
|
40
|
+
CONVERSATION SEGMENT TO SUMMARIZE:
|
|
41
|
+
{conversation}
|
|
42
|
+
|
|
43
|
+
Provide a technically precise summary that preserves all operational security context while
|
|
44
|
+
keeping the summary concise and to the point."""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _count_tokens(text: str, model: str) -> int:
|
|
48
|
+
try:
|
|
49
|
+
count = litellm.token_counter(model=model, text=text)
|
|
50
|
+
return int(count)
|
|
51
|
+
except Exception:
|
|
52
|
+
logger.exception("Failed to count tokens")
|
|
53
|
+
return len(text) // 4 # Rough estimate
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _get_message_tokens(msg: dict[str, Any], model: str) -> int:
|
|
57
|
+
content = msg.get("content", "")
|
|
58
|
+
if isinstance(content, str):
|
|
59
|
+
return _count_tokens(content, model)
|
|
60
|
+
if isinstance(content, list):
|
|
61
|
+
return sum(
|
|
62
|
+
_count_tokens(item.get("text", ""), model)
|
|
63
|
+
for item in content
|
|
64
|
+
if isinstance(item, dict) and item.get("type") == "text"
|
|
65
|
+
)
|
|
66
|
+
return 0
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _extract_message_text(msg: dict[str, Any]) -> str:
|
|
70
|
+
content = msg.get("content", "")
|
|
71
|
+
if isinstance(content, str):
|
|
72
|
+
return content
|
|
73
|
+
|
|
74
|
+
if isinstance(content, list):
|
|
75
|
+
parts = []
|
|
76
|
+
for item in content:
|
|
77
|
+
if isinstance(item, dict):
|
|
78
|
+
if item.get("type") == "text":
|
|
79
|
+
parts.append(item.get("text", ""))
|
|
80
|
+
elif item.get("type") == "image_url":
|
|
81
|
+
parts.append("[IMAGE]")
|
|
82
|
+
return " ".join(parts)
|
|
83
|
+
|
|
84
|
+
return str(content)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _summarize_messages(
|
|
88
|
+
messages: list[dict[str, Any]],
|
|
89
|
+
model: str,
|
|
90
|
+
timeout: int = 600,
|
|
91
|
+
) -> dict[str, Any]:
|
|
92
|
+
if not messages:
|
|
93
|
+
empty_summary = "<context_summary message_count='0'>{text}</context_summary>"
|
|
94
|
+
return {
|
|
95
|
+
"role": "assistant",
|
|
96
|
+
"content": empty_summary.format(text="No messages to summarize"),
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
formatted = []
|
|
100
|
+
for msg in messages:
|
|
101
|
+
role = msg.get("role", "unknown")
|
|
102
|
+
text = _extract_message_text(msg)
|
|
103
|
+
formatted.append(f"{role}: {text}")
|
|
104
|
+
|
|
105
|
+
conversation = "\n".join(formatted)
|
|
106
|
+
prompt = SUMMARY_PROMPT_TEMPLATE.format(conversation=conversation)
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
completion_args = {
|
|
110
|
+
"model": model,
|
|
111
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
112
|
+
"timeout": timeout,
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
response = litellm.completion(**completion_args)
|
|
116
|
+
summary = response.choices[0].message.content or ""
|
|
117
|
+
if not summary.strip():
|
|
118
|
+
return messages[0]
|
|
119
|
+
summary_msg = "<context_summary message_count='{count}'>{text}</context_summary>"
|
|
120
|
+
return {
|
|
121
|
+
"role": "assistant",
|
|
122
|
+
"content": summary_msg.format(count=len(messages), text=summary),
|
|
123
|
+
}
|
|
124
|
+
except Exception:
|
|
125
|
+
logger.exception("Failed to summarize messages")
|
|
126
|
+
return messages[0]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _handle_images(messages: list[dict[str, Any]], max_images: int) -> None:
|
|
130
|
+
image_count = 0
|
|
131
|
+
for msg in reversed(messages):
|
|
132
|
+
content = msg.get("content", [])
|
|
133
|
+
if isinstance(content, list):
|
|
134
|
+
for item in content:
|
|
135
|
+
if isinstance(item, dict) and item.get("type") == "image_url":
|
|
136
|
+
if image_count >= max_images:
|
|
137
|
+
item.update(
|
|
138
|
+
{
|
|
139
|
+
"type": "text",
|
|
140
|
+
"text": "[Previously attached image removed to preserve context]",
|
|
141
|
+
}
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
image_count += 1
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class MemoryCompressor:
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
max_images: int = 3,
|
|
151
|
+
model_name: str | None = None,
|
|
152
|
+
timeout: int = 600,
|
|
153
|
+
):
|
|
154
|
+
self.max_images = max_images
|
|
155
|
+
self.model_name = model_name or os.getenv("AIPT_LLM", "openai/gpt-5")
|
|
156
|
+
self.timeout = timeout
|
|
157
|
+
|
|
158
|
+
if not self.model_name:
|
|
159
|
+
raise ValueError("AIPT_LLM environment variable must be set and not empty")
|
|
160
|
+
|
|
161
|
+
def compress_history(
|
|
162
|
+
self,
|
|
163
|
+
messages: list[dict[str, Any]],
|
|
164
|
+
) -> list[dict[str, Any]]:
|
|
165
|
+
"""Compress conversation history to stay within token limits.
|
|
166
|
+
|
|
167
|
+
Strategy:
|
|
168
|
+
1. Handle image limits first
|
|
169
|
+
2. Keep all system messages
|
|
170
|
+
3. Keep minimum recent messages
|
|
171
|
+
4. Summarize older messages when total tokens exceed limit
|
|
172
|
+
|
|
173
|
+
The compression preserves:
|
|
174
|
+
- All system messages unchanged
|
|
175
|
+
- Most recent messages intact
|
|
176
|
+
- Critical security context in summaries
|
|
177
|
+
- Recent images for visual context
|
|
178
|
+
- Technical details and findings
|
|
179
|
+
"""
|
|
180
|
+
if not messages:
|
|
181
|
+
return messages
|
|
182
|
+
|
|
183
|
+
_handle_images(messages, self.max_images)
|
|
184
|
+
|
|
185
|
+
system_msgs = []
|
|
186
|
+
regular_msgs = []
|
|
187
|
+
for msg in messages:
|
|
188
|
+
if msg.get("role") == "system":
|
|
189
|
+
system_msgs.append(msg)
|
|
190
|
+
else:
|
|
191
|
+
regular_msgs.append(msg)
|
|
192
|
+
|
|
193
|
+
recent_msgs = regular_msgs[-MIN_RECENT_MESSAGES:]
|
|
194
|
+
old_msgs = regular_msgs[:-MIN_RECENT_MESSAGES]
|
|
195
|
+
|
|
196
|
+
# Type assertion since we ensure model_name is not None in __init__
|
|
197
|
+
model_name: str = self.model_name # type: ignore[assignment]
|
|
198
|
+
|
|
199
|
+
total_tokens = sum(
|
|
200
|
+
_get_message_tokens(msg, model_name) for msg in system_msgs + regular_msgs
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
if total_tokens <= MAX_TOTAL_TOKENS * 0.9:
|
|
204
|
+
return messages
|
|
205
|
+
|
|
206
|
+
compressed = []
|
|
207
|
+
chunk_size = 10
|
|
208
|
+
for i in range(0, len(old_msgs), chunk_size):
|
|
209
|
+
chunk = old_msgs[i : i + chunk_size]
|
|
210
|
+
summary = _summarize_messages(chunk, model_name, self.timeout)
|
|
211
|
+
if summary:
|
|
212
|
+
compressed.append(summary)
|
|
213
|
+
|
|
214
|
+
return system_msgs + compressed + recent_msgs
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import litellm
|
|
11
|
+
from litellm import ModelResponse, completion
|
|
12
|
+
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def should_retry_exception(exception: Exception) -> bool:
|
|
19
|
+
status_code = None
|
|
20
|
+
|
|
21
|
+
if hasattr(exception, "status_code"):
|
|
22
|
+
status_code = exception.status_code
|
|
23
|
+
elif hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
|
24
|
+
status_code = exception.response.status_code
|
|
25
|
+
|
|
26
|
+
if status_code is not None:
|
|
27
|
+
return bool(litellm._should_retry(status_code))
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LLMRequestQueue:
|
|
32
|
+
def __init__(self, max_concurrent: int = 1, delay_between_requests: float = 4.0):
|
|
33
|
+
rate_limit_delay = os.getenv("LLM_RATE_LIMIT_DELAY")
|
|
34
|
+
if rate_limit_delay:
|
|
35
|
+
delay_between_requests = float(rate_limit_delay)
|
|
36
|
+
|
|
37
|
+
rate_limit_concurrent = os.getenv("LLM_RATE_LIMIT_CONCURRENT")
|
|
38
|
+
if rate_limit_concurrent:
|
|
39
|
+
max_concurrent = int(rate_limit_concurrent)
|
|
40
|
+
|
|
41
|
+
self.max_concurrent = max_concurrent
|
|
42
|
+
self.delay_between_requests = delay_between_requests
|
|
43
|
+
self._semaphore = threading.BoundedSemaphore(max_concurrent)
|
|
44
|
+
self._last_request_time = 0.0
|
|
45
|
+
self._lock = threading.Lock()
|
|
46
|
+
|
|
47
|
+
async def make_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
|
48
|
+
try:
|
|
49
|
+
while not self._semaphore.acquire(timeout=0.2):
|
|
50
|
+
await asyncio.sleep(0.1)
|
|
51
|
+
|
|
52
|
+
with self._lock:
|
|
53
|
+
now = time.time()
|
|
54
|
+
time_since_last = now - self._last_request_time
|
|
55
|
+
sleep_needed = max(0, self.delay_between_requests - time_since_last)
|
|
56
|
+
self._last_request_time = now + sleep_needed
|
|
57
|
+
|
|
58
|
+
if sleep_needed > 0:
|
|
59
|
+
await asyncio.sleep(sleep_needed)
|
|
60
|
+
|
|
61
|
+
return await self._reliable_request(completion_args)
|
|
62
|
+
finally:
|
|
63
|
+
self._semaphore.release()
|
|
64
|
+
|
|
65
|
+
@retry( # type: ignore[misc]
|
|
66
|
+
stop=stop_after_attempt(3),
|
|
67
|
+
wait=wait_exponential(multiplier=8, min=8, max=64),
|
|
68
|
+
retry=retry_if_exception(should_retry_exception),
|
|
69
|
+
reraise=True,
|
|
70
|
+
)
|
|
71
|
+
async def _reliable_request(self, completion_args: dict[str, Any]) -> ModelResponse:
|
|
72
|
+
response = completion(**completion_args, stream=False)
|
|
73
|
+
if isinstance(response, ModelResponse):
|
|
74
|
+
return response
|
|
75
|
+
self._raise_unexpected_response()
|
|
76
|
+
raise RuntimeError("Unreachable code")
|
|
77
|
+
|
|
78
|
+
def _raise_unexpected_response(self) -> None:
|
|
79
|
+
raise RuntimeError("Unexpected response type")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
_global_queue: LLMRequestQueue | None = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_global_queue() -> LLMRequestQueue:
|
|
86
|
+
global _global_queue # noqa: PLW0603
|
|
87
|
+
if _global_queue is None:
|
|
88
|
+
_global_queue = LLMRequestQueue()
|
|
89
|
+
return _global_queue
|
aipt_v2/llm/utils.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import html
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _truncate_to_first_function(content: str) -> str:
|
|
9
|
+
if not content:
|
|
10
|
+
return content
|
|
11
|
+
|
|
12
|
+
function_starts = [match.start() for match in re.finditer(r"<function=", content)]
|
|
13
|
+
|
|
14
|
+
if len(function_starts) >= 2:
|
|
15
|
+
second_function_start = function_starts[1]
|
|
16
|
+
|
|
17
|
+
return content[:second_function_start].rstrip()
|
|
18
|
+
|
|
19
|
+
return content
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def parse_tool_invocations(content: str) -> list[dict[str, Any]] | None:
|
|
23
|
+
content = _fix_stopword(content)
|
|
24
|
+
|
|
25
|
+
tool_invocations: list[dict[str, Any]] = []
|
|
26
|
+
|
|
27
|
+
fn_regex_pattern = r"<function=([^>]+)>\n?(.*?)</function>"
|
|
28
|
+
fn_param_regex_pattern = r"<parameter=([^>]+)>(.*?)</parameter>"
|
|
29
|
+
|
|
30
|
+
fn_matches = re.finditer(fn_regex_pattern, content, re.DOTALL)
|
|
31
|
+
|
|
32
|
+
for fn_match in fn_matches:
|
|
33
|
+
fn_name = fn_match.group(1)
|
|
34
|
+
fn_body = fn_match.group(2)
|
|
35
|
+
|
|
36
|
+
param_matches = re.finditer(fn_param_regex_pattern, fn_body, re.DOTALL)
|
|
37
|
+
|
|
38
|
+
args = {}
|
|
39
|
+
for param_match in param_matches:
|
|
40
|
+
param_name = param_match.group(1)
|
|
41
|
+
param_value = param_match.group(2).strip()
|
|
42
|
+
|
|
43
|
+
param_value = html.unescape(param_value)
|
|
44
|
+
args[param_name] = param_value
|
|
45
|
+
|
|
46
|
+
tool_invocations.append({"toolName": fn_name, "args": args})
|
|
47
|
+
|
|
48
|
+
return tool_invocations if tool_invocations else None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _fix_stopword(content: str) -> str:
|
|
52
|
+
if "<function=" in content and content.count("<function=") == 1:
|
|
53
|
+
if content.endswith("</"):
|
|
54
|
+
content = content.rstrip() + "function>"
|
|
55
|
+
elif not content.rstrip().endswith("</function>"):
|
|
56
|
+
content = content + "\n</function>"
|
|
57
|
+
return content
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def format_tool_call(tool_name: str, args: dict[str, Any]) -> str:
|
|
61
|
+
xml_parts = [f"<function={tool_name}>"]
|
|
62
|
+
|
|
63
|
+
for key, value in args.items():
|
|
64
|
+
xml_parts.append(f"<parameter={key}>{value}</parameter>")
|
|
65
|
+
|
|
66
|
+
xml_parts.append("</function>")
|
|
67
|
+
|
|
68
|
+
return "\n".join(xml_parts)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def clean_content(content: str) -> str:
|
|
72
|
+
if not content:
|
|
73
|
+
return ""
|
|
74
|
+
|
|
75
|
+
content = _fix_stopword(content)
|
|
76
|
+
|
|
77
|
+
tool_pattern = r"<function=[^>]+>.*?</function>"
|
|
78
|
+
cleaned = re.sub(tool_pattern, "", content, flags=re.DOTALL)
|
|
79
|
+
|
|
80
|
+
hidden_xml_patterns = [
|
|
81
|
+
r"<inter_agent_message>.*?</inter_agent_message>",
|
|
82
|
+
r"<agent_completion_report>.*?</agent_completion_report>",
|
|
83
|
+
]
|
|
84
|
+
for pattern in hidden_xml_patterns:
|
|
85
|
+
cleaned = re.sub(pattern, "", cleaned, flags=re.DOTALL | re.IGNORECASE)
|
|
86
|
+
|
|
87
|
+
cleaned = re.sub(r"\n\s*\n", "\n\n", cleaned)
|
|
88
|
+
|
|
89
|
+
return cleaned.strip()
|