aiptx 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiptx might be problematic. Click here for more details.
- aipt_v2/__init__.py +110 -0
- aipt_v2/__main__.py +24 -0
- aipt_v2/agents/AIPTxAgent/__init__.py +10 -0
- aipt_v2/agents/AIPTxAgent/aiptx_agent.py +211 -0
- aipt_v2/agents/__init__.py +24 -0
- aipt_v2/agents/base.py +520 -0
- aipt_v2/agents/ptt.py +406 -0
- aipt_v2/agents/state.py +168 -0
- aipt_v2/app.py +960 -0
- aipt_v2/browser/__init__.py +31 -0
- aipt_v2/browser/automation.py +458 -0
- aipt_v2/browser/crawler.py +453 -0
- aipt_v2/cli.py +321 -0
- aipt_v2/compliance/__init__.py +71 -0
- aipt_v2/compliance/compliance_report.py +449 -0
- aipt_v2/compliance/framework_mapper.py +424 -0
- aipt_v2/compliance/nist_mapping.py +345 -0
- aipt_v2/compliance/owasp_mapping.py +330 -0
- aipt_v2/compliance/pci_mapping.py +297 -0
- aipt_v2/config.py +288 -0
- aipt_v2/core/__init__.py +43 -0
- aipt_v2/core/agent.py +630 -0
- aipt_v2/core/llm.py +395 -0
- aipt_v2/core/memory.py +305 -0
- aipt_v2/core/ptt.py +329 -0
- aipt_v2/database/__init__.py +14 -0
- aipt_v2/database/models.py +232 -0
- aipt_v2/database/repository.py +384 -0
- aipt_v2/docker/__init__.py +23 -0
- aipt_v2/docker/builder.py +260 -0
- aipt_v2/docker/manager.py +222 -0
- aipt_v2/docker/sandbox.py +371 -0
- aipt_v2/evasion/__init__.py +58 -0
- aipt_v2/evasion/request_obfuscator.py +272 -0
- aipt_v2/evasion/tls_fingerprint.py +285 -0
- aipt_v2/evasion/ua_rotator.py +301 -0
- aipt_v2/evasion/waf_bypass.py +439 -0
- aipt_v2/execution/__init__.py +23 -0
- aipt_v2/execution/executor.py +302 -0
- aipt_v2/execution/parser.py +544 -0
- aipt_v2/execution/terminal.py +337 -0
- aipt_v2/health.py +437 -0
- aipt_v2/intelligence/__init__.py +85 -0
- aipt_v2/intelligence/auth.py +520 -0
- aipt_v2/intelligence/chaining.py +775 -0
- aipt_v2/intelligence/cve_aipt.py +334 -0
- aipt_v2/intelligence/cve_info.py +1111 -0
- aipt_v2/intelligence/rag.py +239 -0
- aipt_v2/intelligence/scope.py +442 -0
- aipt_v2/intelligence/searchers/__init__.py +5 -0
- aipt_v2/intelligence/searchers/exploitdb_searcher.py +523 -0
- aipt_v2/intelligence/searchers/github_searcher.py +467 -0
- aipt_v2/intelligence/searchers/google_searcher.py +281 -0
- aipt_v2/intelligence/tools.json +443 -0
- aipt_v2/intelligence/triage.py +670 -0
- aipt_v2/interface/__init__.py +5 -0
- aipt_v2/interface/cli.py +230 -0
- aipt_v2/interface/main.py +501 -0
- aipt_v2/interface/tui.py +1276 -0
- aipt_v2/interface/utils.py +583 -0
- aipt_v2/llm/__init__.py +39 -0
- aipt_v2/llm/config.py +26 -0
- aipt_v2/llm/llm.py +514 -0
- aipt_v2/llm/memory.py +214 -0
- aipt_v2/llm/request_queue.py +89 -0
- aipt_v2/llm/utils.py +89 -0
- aipt_v2/models/__init__.py +15 -0
- aipt_v2/models/findings.py +295 -0
- aipt_v2/models/phase_result.py +224 -0
- aipt_v2/models/scan_config.py +207 -0
- aipt_v2/monitoring/grafana/dashboards/aipt-dashboard.json +355 -0
- aipt_v2/monitoring/grafana/dashboards/default.yml +17 -0
- aipt_v2/monitoring/grafana/datasources/prometheus.yml +17 -0
- aipt_v2/monitoring/prometheus.yml +60 -0
- aipt_v2/orchestration/__init__.py +52 -0
- aipt_v2/orchestration/pipeline.py +398 -0
- aipt_v2/orchestration/progress.py +300 -0
- aipt_v2/orchestration/scheduler.py +296 -0
- aipt_v2/orchestrator.py +2284 -0
- aipt_v2/payloads/__init__.py +27 -0
- aipt_v2/payloads/cmdi.py +150 -0
- aipt_v2/payloads/sqli.py +263 -0
- aipt_v2/payloads/ssrf.py +204 -0
- aipt_v2/payloads/templates.py +222 -0
- aipt_v2/payloads/traversal.py +166 -0
- aipt_v2/payloads/xss.py +204 -0
- aipt_v2/prompts/__init__.py +60 -0
- aipt_v2/proxy/__init__.py +29 -0
- aipt_v2/proxy/history.py +352 -0
- aipt_v2/proxy/interceptor.py +452 -0
- aipt_v2/recon/__init__.py +44 -0
- aipt_v2/recon/dns.py +241 -0
- aipt_v2/recon/osint.py +367 -0
- aipt_v2/recon/subdomain.py +372 -0
- aipt_v2/recon/tech_detect.py +311 -0
- aipt_v2/reports/__init__.py +17 -0
- aipt_v2/reports/generator.py +313 -0
- aipt_v2/reports/html_report.py +378 -0
- aipt_v2/runtime/__init__.py +44 -0
- aipt_v2/runtime/base.py +30 -0
- aipt_v2/runtime/docker.py +401 -0
- aipt_v2/runtime/local.py +346 -0
- aipt_v2/runtime/tool_server.py +205 -0
- aipt_v2/scanners/__init__.py +28 -0
- aipt_v2/scanners/base.py +273 -0
- aipt_v2/scanners/nikto.py +244 -0
- aipt_v2/scanners/nmap.py +402 -0
- aipt_v2/scanners/nuclei.py +273 -0
- aipt_v2/scanners/web.py +454 -0
- aipt_v2/scripts/security_audit.py +366 -0
- aipt_v2/telemetry/__init__.py +7 -0
- aipt_v2/telemetry/tracer.py +347 -0
- aipt_v2/terminal/__init__.py +28 -0
- aipt_v2/terminal/executor.py +400 -0
- aipt_v2/terminal/sandbox.py +350 -0
- aipt_v2/tools/__init__.py +44 -0
- aipt_v2/tools/active_directory/__init__.py +78 -0
- aipt_v2/tools/active_directory/ad_config.py +238 -0
- aipt_v2/tools/active_directory/bloodhound_wrapper.py +447 -0
- aipt_v2/tools/active_directory/kerberos_attacks.py +430 -0
- aipt_v2/tools/active_directory/ldap_enum.py +533 -0
- aipt_v2/tools/active_directory/smb_attacks.py +505 -0
- aipt_v2/tools/agents_graph/__init__.py +19 -0
- aipt_v2/tools/agents_graph/agents_graph_actions.py +69 -0
- aipt_v2/tools/api_security/__init__.py +76 -0
- aipt_v2/tools/api_security/api_discovery.py +608 -0
- aipt_v2/tools/api_security/graphql_scanner.py +622 -0
- aipt_v2/tools/api_security/jwt_analyzer.py +577 -0
- aipt_v2/tools/api_security/openapi_fuzzer.py +761 -0
- aipt_v2/tools/browser/__init__.py +5 -0
- aipt_v2/tools/browser/browser_actions.py +238 -0
- aipt_v2/tools/browser/browser_instance.py +535 -0
- aipt_v2/tools/browser/tab_manager.py +344 -0
- aipt_v2/tools/cloud/__init__.py +70 -0
- aipt_v2/tools/cloud/cloud_config.py +273 -0
- aipt_v2/tools/cloud/cloud_scanner.py +639 -0
- aipt_v2/tools/cloud/prowler_tool.py +571 -0
- aipt_v2/tools/cloud/scoutsuite_tool.py +359 -0
- aipt_v2/tools/executor.py +307 -0
- aipt_v2/tools/parser.py +408 -0
- aipt_v2/tools/proxy/__init__.py +5 -0
- aipt_v2/tools/proxy/proxy_actions.py +103 -0
- aipt_v2/tools/proxy/proxy_manager.py +789 -0
- aipt_v2/tools/registry.py +196 -0
- aipt_v2/tools/scanners/__init__.py +343 -0
- aipt_v2/tools/scanners/acunetix_tool.py +712 -0
- aipt_v2/tools/scanners/burp_tool.py +631 -0
- aipt_v2/tools/scanners/config.py +156 -0
- aipt_v2/tools/scanners/nessus_tool.py +588 -0
- aipt_v2/tools/scanners/zap_tool.py +612 -0
- aipt_v2/tools/terminal/__init__.py +5 -0
- aipt_v2/tools/terminal/terminal_actions.py +37 -0
- aipt_v2/tools/terminal/terminal_manager.py +153 -0
- aipt_v2/tools/terminal/terminal_session.py +449 -0
- aipt_v2/tools/tool_processing.py +108 -0
- aipt_v2/utils/__init__.py +17 -0
- aipt_v2/utils/logging.py +201 -0
- aipt_v2/utils/model_manager.py +187 -0
- aipt_v2/utils/searchers/__init__.py +269 -0
- aiptx-2.0.2.dist-info/METADATA +324 -0
- aiptx-2.0.2.dist-info/RECORD +165 -0
- aiptx-2.0.2.dist-info/WHEEL +5 -0
- aiptx-2.0.2.dist-info/entry_points.txt +7 -0
- aiptx-2.0.2.dist-info/licenses/LICENSE +21 -0
- aiptx-2.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AIPT Tool Processing - Execute tool invocations from LLM responses
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def process_tool_invocations(
|
|
12
|
+
actions: list[dict[str, Any]],
|
|
13
|
+
conversation_history: list[dict[str, Any]],
|
|
14
|
+
state: Any,
|
|
15
|
+
) -> bool:
|
|
16
|
+
"""
|
|
17
|
+
Process tool invocations from LLM response.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
actions: List of tool invocation dicts with 'name' and 'arguments'
|
|
21
|
+
conversation_history: Mutable conversation history
|
|
22
|
+
state: Agent state object
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
True if agent should finish, False otherwise
|
|
26
|
+
"""
|
|
27
|
+
for action in actions:
|
|
28
|
+
tool_name = action.get("name", "")
|
|
29
|
+
tool_args = action.get("arguments", {})
|
|
30
|
+
|
|
31
|
+
logger.info(f"Executing tool: {tool_name}")
|
|
32
|
+
|
|
33
|
+
# Check for finish tools
|
|
34
|
+
if tool_name in ["finish_scan", "agent_finish"]:
|
|
35
|
+
result = tool_args.get("result", "Task completed")
|
|
36
|
+
conversation_history.append({
|
|
37
|
+
"role": "user",
|
|
38
|
+
"content": f"Tool {tool_name} executed. Result: {result}",
|
|
39
|
+
})
|
|
40
|
+
return True
|
|
41
|
+
|
|
42
|
+
# Execute the tool
|
|
43
|
+
try:
|
|
44
|
+
result = await _execute_tool(tool_name, tool_args, state)
|
|
45
|
+
conversation_history.append({
|
|
46
|
+
"role": "user",
|
|
47
|
+
"content": f"Tool {tool_name} result:\n{result}",
|
|
48
|
+
})
|
|
49
|
+
except Exception as e:
|
|
50
|
+
error_msg = f"Tool {tool_name} failed: {str(e)}"
|
|
51
|
+
logger.error(error_msg)
|
|
52
|
+
conversation_history.append({
|
|
53
|
+
"role": "user",
|
|
54
|
+
"content": error_msg,
|
|
55
|
+
})
|
|
56
|
+
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
async def _execute_tool(name: str, args: dict[str, Any], state: Any) -> str:
|
|
61
|
+
"""Execute a single tool and return result"""
|
|
62
|
+
# Import tool executors lazily
|
|
63
|
+
if name == "execute_command":
|
|
64
|
+
return await _execute_command(args, state)
|
|
65
|
+
elif name == "browser_navigate":
|
|
66
|
+
return await _browser_navigate(args, state)
|
|
67
|
+
elif name == "browser_screenshot":
|
|
68
|
+
return await _browser_screenshot(args, state)
|
|
69
|
+
else:
|
|
70
|
+
return f"Tool '{name}' executed with args: {args}"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
async def _execute_command(args: dict[str, Any], state: Any) -> str:
|
|
74
|
+
"""Execute a shell command in the sandbox"""
|
|
75
|
+
import asyncio
|
|
76
|
+
|
|
77
|
+
command = args.get("command", "")
|
|
78
|
+
timeout = args.get("timeout", 60)
|
|
79
|
+
|
|
80
|
+
# Use subprocess for now (Docker integration later)
|
|
81
|
+
try:
|
|
82
|
+
proc = await asyncio.create_subprocess_shell(
|
|
83
|
+
command,
|
|
84
|
+
stdout=asyncio.subprocess.PIPE,
|
|
85
|
+
stderr=asyncio.subprocess.PIPE,
|
|
86
|
+
)
|
|
87
|
+
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
|
88
|
+
output = stdout.decode() if stdout else ""
|
|
89
|
+
errors = stderr.decode() if stderr else ""
|
|
90
|
+
return output + errors if errors else output
|
|
91
|
+
except asyncio.TimeoutError:
|
|
92
|
+
return f"Command timed out after {timeout} seconds"
|
|
93
|
+
except Exception as e:
|
|
94
|
+
return f"Command failed: {str(e)}"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
async def _browser_navigate(args: dict[str, Any], state: Any) -> str:
|
|
98
|
+
"""Navigate browser to URL"""
|
|
99
|
+
url = args.get("url", "")
|
|
100
|
+
return f"Navigated to: {url}"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
async def _browser_screenshot(args: dict[str, Any], state: Any) -> str:
|
|
104
|
+
"""Take browser screenshot"""
|
|
105
|
+
return "Screenshot taken"
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
__all__ = ["process_tool_invocations"]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AIPT v2 Utilities Module
|
|
3
|
+
========================
|
|
4
|
+
|
|
5
|
+
Provides common utilities used across the framework:
|
|
6
|
+
- Structured logging with secret redaction
|
|
7
|
+
- Model management wrappers
|
|
8
|
+
- Searcher utilities
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from .logging import logger, setup_logging, get_logger
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"logger",
|
|
15
|
+
"setup_logging",
|
|
16
|
+
"get_logger",
|
|
17
|
+
]
|
aipt_v2/utils/logging.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Structured Logging Configuration for AIPT v2
|
|
3
|
+
=============================================
|
|
4
|
+
|
|
5
|
+
Provides:
|
|
6
|
+
- Structured logging via structlog
|
|
7
|
+
- Automatic secret redaction
|
|
8
|
+
- JSON format for production
|
|
9
|
+
- Console format for development
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import sys
|
|
15
|
+
import re
|
|
16
|
+
from typing import Any, Optional
|
|
17
|
+
from functools import lru_cache
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
import structlog
|
|
21
|
+
STRUCTLOG_AVAILABLE = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
STRUCTLOG_AVAILABLE = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Patterns for secret redaction
|
|
27
|
+
SECRET_PATTERNS = [
|
|
28
|
+
r"api[_-]?key",
|
|
29
|
+
r"apikey",
|
|
30
|
+
r"token",
|
|
31
|
+
r"secret",
|
|
32
|
+
r"password",
|
|
33
|
+
r"credential",
|
|
34
|
+
r"auth",
|
|
35
|
+
r"bearer",
|
|
36
|
+
r"sk-[a-zA-Z0-9]+",
|
|
37
|
+
r"pk-[a-zA-Z0-9]+",
|
|
38
|
+
r"access[_-]?key",
|
|
39
|
+
r"private[_-]?key",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
SECRET_REGEX = re.compile("|".join(SECRET_PATTERNS), re.IGNORECASE)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _should_redact(key: str) -> bool:
|
|
46
|
+
"""Check if a key should be redacted."""
|
|
47
|
+
return bool(SECRET_REGEX.search(key))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _redact_value(value: str) -> str:
|
|
51
|
+
"""Redact a sensitive value, keeping first/last chars for debugging."""
|
|
52
|
+
if len(value) <= 8:
|
|
53
|
+
return "[REDACTED]"
|
|
54
|
+
return f"{value[:4]}...{value[-4:]}"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _redact_processor(logger: Any, method_name: str, event_dict: dict) -> dict:
|
|
58
|
+
"""Structlog processor to redact sensitive information."""
|
|
59
|
+
for key, value in list(event_dict.items()):
|
|
60
|
+
if isinstance(value, str):
|
|
61
|
+
if _should_redact(key):
|
|
62
|
+
event_dict[key] = "[REDACTED]"
|
|
63
|
+
elif len(value) > 20 and SECRET_REGEX.search(value):
|
|
64
|
+
event_dict[key] = _redact_value(value)
|
|
65
|
+
return event_dict
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def setup_logging(
|
|
69
|
+
level: str = "INFO",
|
|
70
|
+
json_format: bool = False,
|
|
71
|
+
redact_secrets: bool = True,
|
|
72
|
+
) -> Any:
|
|
73
|
+
"""
|
|
74
|
+
Configure structured logging for the application.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
78
|
+
json_format: Use JSON format (for production)
|
|
79
|
+
redact_secrets: Automatically redact sensitive values
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Configured logger instance
|
|
83
|
+
"""
|
|
84
|
+
log_level = getattr(logging, level.upper(), logging.INFO)
|
|
85
|
+
|
|
86
|
+
if STRUCTLOG_AVAILABLE:
|
|
87
|
+
processors = [
|
|
88
|
+
structlog.stdlib.filter_by_level,
|
|
89
|
+
structlog.stdlib.add_logger_name,
|
|
90
|
+
structlog.stdlib.add_log_level,
|
|
91
|
+
structlog.processors.TimeStamper(fmt="iso"),
|
|
92
|
+
structlog.processors.StackInfoRenderer(),
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
if redact_secrets:
|
|
96
|
+
processors.append(_redact_processor)
|
|
97
|
+
|
|
98
|
+
if json_format:
|
|
99
|
+
processors.append(structlog.processors.JSONRenderer())
|
|
100
|
+
else:
|
|
101
|
+
processors.append(structlog.dev.ConsoleRenderer(colors=True))
|
|
102
|
+
|
|
103
|
+
structlog.configure(
|
|
104
|
+
processors=processors,
|
|
105
|
+
wrapper_class=structlog.stdlib.BoundLogger,
|
|
106
|
+
context_class=dict,
|
|
107
|
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
|
108
|
+
cache_logger_on_first_use=True,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
logging.basicConfig(
|
|
112
|
+
format="%(message)s",
|
|
113
|
+
stream=sys.stdout,
|
|
114
|
+
level=log_level,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return structlog.get_logger()
|
|
118
|
+
else:
|
|
119
|
+
# Fallback to standard logging
|
|
120
|
+
logging.basicConfig(
|
|
121
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
122
|
+
stream=sys.stdout,
|
|
123
|
+
level=log_level,
|
|
124
|
+
)
|
|
125
|
+
return logging.getLogger("aipt_v2")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@lru_cache(maxsize=1)
|
|
129
|
+
def get_logger() -> Any:
|
|
130
|
+
"""Get or create the global logger instance."""
|
|
131
|
+
log_level = os.getenv("AIPT_LOG_LEVEL", "INFO")
|
|
132
|
+
json_format = os.getenv("AIPT_LOG_FORMAT", "console").lower() == "json"
|
|
133
|
+
redact = os.getenv("AIPT_REDACT_SECRETS", "true").lower() == "true"
|
|
134
|
+
|
|
135
|
+
return setup_logging(
|
|
136
|
+
level=log_level,
|
|
137
|
+
json_format=json_format,
|
|
138
|
+
redact_secrets=redact,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# Global logger instance
|
|
143
|
+
logger = get_logger()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class LoggerAdapter:
|
|
147
|
+
"""
|
|
148
|
+
Adapter for consistent logging interface.
|
|
149
|
+
|
|
150
|
+
Provides methods that work whether structlog is available or not.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def __init__(self, logger_instance: Any):
|
|
154
|
+
self._logger = logger_instance
|
|
155
|
+
self._is_structlog = STRUCTLOG_AVAILABLE
|
|
156
|
+
|
|
157
|
+
def _log(self, level: str, msg: str, **kwargs):
|
|
158
|
+
"""Internal log method."""
|
|
159
|
+
if self._is_structlog:
|
|
160
|
+
getattr(self._logger, level)(msg, **kwargs)
|
|
161
|
+
else:
|
|
162
|
+
extra = " ".join(f"{k}={v}" for k, v in kwargs.items())
|
|
163
|
+
full_msg = f"{msg} {extra}" if extra else msg
|
|
164
|
+
getattr(self._logger, level)(full_msg)
|
|
165
|
+
|
|
166
|
+
def debug(self, msg: str, **kwargs):
|
|
167
|
+
self._log("debug", msg, **kwargs)
|
|
168
|
+
|
|
169
|
+
def info(self, msg: str, **kwargs):
|
|
170
|
+
self._log("info", msg, **kwargs)
|
|
171
|
+
|
|
172
|
+
def warning(self, msg: str, **kwargs):
|
|
173
|
+
self._log("warning", msg, **kwargs)
|
|
174
|
+
|
|
175
|
+
def error(self, msg: str, exc_info: bool = False, **kwargs):
|
|
176
|
+
if self._is_structlog:
|
|
177
|
+
if exc_info:
|
|
178
|
+
kwargs["exc_info"] = True
|
|
179
|
+
self._logger.error(msg, **kwargs)
|
|
180
|
+
else:
|
|
181
|
+
extra = " ".join(f"{k}={v}" for k, v in kwargs.items())
|
|
182
|
+
full_msg = f"{msg} {extra}" if extra else msg
|
|
183
|
+
self._logger.error(full_msg, exc_info=exc_info)
|
|
184
|
+
|
|
185
|
+
def critical(self, msg: str, **kwargs):
|
|
186
|
+
self._log("critical", msg, **kwargs)
|
|
187
|
+
|
|
188
|
+
def exception(self, msg: str, **kwargs):
|
|
189
|
+
"""Log exception with traceback."""
|
|
190
|
+
self.error(msg, exc_info=True, **kwargs)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# Export a wrapped logger for consistent interface
|
|
194
|
+
def create_logger(name: Optional[str] = None) -> LoggerAdapter:
|
|
195
|
+
"""Create a named logger instance."""
|
|
196
|
+
if STRUCTLOG_AVAILABLE:
|
|
197
|
+
base_logger = structlog.get_logger(name) if name else get_logger()
|
|
198
|
+
else:
|
|
199
|
+
base_logger = logging.getLogger(name or "aipt_v2")
|
|
200
|
+
|
|
201
|
+
return LoggerAdapter(base_logger)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Manager for AIPT v2
|
|
3
|
+
=========================
|
|
4
|
+
|
|
5
|
+
Provides a unified interface for LLM model access.
|
|
6
|
+
This is a compatibility layer for intelligence modules that
|
|
7
|
+
reference utils.model_manager.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
from typing import Any, Optional, Dict, List
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
from aipt_v2.utils.logging import logger
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ModelConfig:
|
|
19
|
+
"""Configuration for model instances."""
|
|
20
|
+
model_name: str = "gpt-4"
|
|
21
|
+
temperature: float = 0.7
|
|
22
|
+
max_tokens: int = 4096
|
|
23
|
+
timeout: int = 120
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ModelWrapper:
|
|
27
|
+
"""
|
|
28
|
+
Wrapper around litellm for consistent model access.
|
|
29
|
+
|
|
30
|
+
Provides both sync and async completion methods.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, config: ModelConfig):
|
|
34
|
+
self.config = config
|
|
35
|
+
self._litellm = None
|
|
36
|
+
|
|
37
|
+
def _get_litellm(self):
|
|
38
|
+
"""Lazy load litellm."""
|
|
39
|
+
if self._litellm is None:
|
|
40
|
+
try:
|
|
41
|
+
import litellm
|
|
42
|
+
self._litellm = litellm
|
|
43
|
+
except ImportError:
|
|
44
|
+
raise ImportError(
|
|
45
|
+
"litellm is required for model_manager. "
|
|
46
|
+
"Install with: pip install litellm"
|
|
47
|
+
)
|
|
48
|
+
return self._litellm
|
|
49
|
+
|
|
50
|
+
def complete(
|
|
51
|
+
self,
|
|
52
|
+
prompt: str,
|
|
53
|
+
system_prompt: Optional[str] = None,
|
|
54
|
+
**kwargs
|
|
55
|
+
) -> str:
|
|
56
|
+
"""
|
|
57
|
+
Synchronous completion.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
prompt: User prompt
|
|
61
|
+
system_prompt: Optional system prompt
|
|
62
|
+
**kwargs: Additional arguments passed to litellm
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Model response text
|
|
66
|
+
"""
|
|
67
|
+
litellm = self._get_litellm()
|
|
68
|
+
|
|
69
|
+
messages: List[Dict[str, str]] = []
|
|
70
|
+
if system_prompt:
|
|
71
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
72
|
+
messages.append({"role": "user", "content": prompt})
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
response = litellm.completion(
|
|
76
|
+
model=self.config.model_name,
|
|
77
|
+
messages=messages,
|
|
78
|
+
temperature=kwargs.get("temperature", self.config.temperature),
|
|
79
|
+
max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
|
|
80
|
+
timeout=kwargs.get("timeout", self.config.timeout),
|
|
81
|
+
)
|
|
82
|
+
return response.choices[0].message.content
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.error("Model completion failed", model=self.config.model_name, error=str(e))
|
|
85
|
+
raise
|
|
86
|
+
|
|
87
|
+
async def acomplete(
|
|
88
|
+
self,
|
|
89
|
+
prompt: str,
|
|
90
|
+
system_prompt: Optional[str] = None,
|
|
91
|
+
**kwargs
|
|
92
|
+
) -> str:
|
|
93
|
+
"""
|
|
94
|
+
Asynchronous completion.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
prompt: User prompt
|
|
98
|
+
system_prompt: Optional system prompt
|
|
99
|
+
**kwargs: Additional arguments passed to litellm
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Model response text
|
|
103
|
+
"""
|
|
104
|
+
litellm = self._get_litellm()
|
|
105
|
+
|
|
106
|
+
messages: List[Dict[str, str]] = []
|
|
107
|
+
if system_prompt:
|
|
108
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
109
|
+
messages.append({"role": "user", "content": prompt})
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
response = await litellm.acompletion(
|
|
113
|
+
model=self.config.model_name,
|
|
114
|
+
messages=messages,
|
|
115
|
+
temperature=kwargs.get("temperature", self.config.temperature),
|
|
116
|
+
max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
|
|
117
|
+
timeout=kwargs.get("timeout", self.config.timeout),
|
|
118
|
+
)
|
|
119
|
+
return response.choices[0].message.content
|
|
120
|
+
except Exception as e:
|
|
121
|
+
logger.error("Async model completion failed", model=self.config.model_name, error=str(e))
|
|
122
|
+
raise
|
|
123
|
+
|
|
124
|
+
def embed(self, text: str) -> List[float]:
|
|
125
|
+
"""
|
|
126
|
+
Get embedding for text.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
text: Text to embed
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Embedding vector
|
|
133
|
+
"""
|
|
134
|
+
litellm = self._get_litellm()
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
response = litellm.embedding(
|
|
138
|
+
model="text-embedding-ada-002",
|
|
139
|
+
input=text,
|
|
140
|
+
)
|
|
141
|
+
return response.data[0]["embedding"]
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error("Embedding failed", error=str(e))
|
|
144
|
+
raise
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# Cache for model instances
|
|
148
|
+
_model_cache: Dict[str, ModelWrapper] = {}
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_model(
|
|
152
|
+
model_name: Optional[str] = None,
|
|
153
|
+
**kwargs
|
|
154
|
+
) -> ModelWrapper:
|
|
155
|
+
"""
|
|
156
|
+
Get or create a model instance.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
model_name: Model identifier (default from env or gpt-4)
|
|
160
|
+
**kwargs: Additional config options
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
ModelWrapper instance
|
|
164
|
+
"""
|
|
165
|
+
if model_name is None:
|
|
166
|
+
model_name = os.getenv("AIPT_LLM_MODEL", "gpt-4")
|
|
167
|
+
|
|
168
|
+
cache_key = f"{model_name}:{hash(frozenset(kwargs.items()))}"
|
|
169
|
+
|
|
170
|
+
if cache_key not in _model_cache:
|
|
171
|
+
config = ModelConfig(
|
|
172
|
+
model_name=model_name,
|
|
173
|
+
temperature=kwargs.get("temperature", 0.7),
|
|
174
|
+
max_tokens=kwargs.get("max_tokens", 4096),
|
|
175
|
+
timeout=kwargs.get("timeout", 120),
|
|
176
|
+
)
|
|
177
|
+
_model_cache[cache_key] = ModelWrapper(config)
|
|
178
|
+
logger.info("Created model instance", model=model_name)
|
|
179
|
+
|
|
180
|
+
return _model_cache[cache_key]
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def clear_model_cache():
|
|
184
|
+
"""Clear the model cache."""
|
|
185
|
+
global _model_cache
|
|
186
|
+
_model_cache = {}
|
|
187
|
+
logger.info("Model cache cleared")
|