aiptx 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiptx might be problematic. Click here for more details.

Files changed (165) hide show
  1. aipt_v2/__init__.py +110 -0
  2. aipt_v2/__main__.py +24 -0
  3. aipt_v2/agents/AIPTxAgent/__init__.py +10 -0
  4. aipt_v2/agents/AIPTxAgent/aiptx_agent.py +211 -0
  5. aipt_v2/agents/__init__.py +24 -0
  6. aipt_v2/agents/base.py +520 -0
  7. aipt_v2/agents/ptt.py +406 -0
  8. aipt_v2/agents/state.py +168 -0
  9. aipt_v2/app.py +960 -0
  10. aipt_v2/browser/__init__.py +31 -0
  11. aipt_v2/browser/automation.py +458 -0
  12. aipt_v2/browser/crawler.py +453 -0
  13. aipt_v2/cli.py +321 -0
  14. aipt_v2/compliance/__init__.py +71 -0
  15. aipt_v2/compliance/compliance_report.py +449 -0
  16. aipt_v2/compliance/framework_mapper.py +424 -0
  17. aipt_v2/compliance/nist_mapping.py +345 -0
  18. aipt_v2/compliance/owasp_mapping.py +330 -0
  19. aipt_v2/compliance/pci_mapping.py +297 -0
  20. aipt_v2/config.py +288 -0
  21. aipt_v2/core/__init__.py +43 -0
  22. aipt_v2/core/agent.py +630 -0
  23. aipt_v2/core/llm.py +395 -0
  24. aipt_v2/core/memory.py +305 -0
  25. aipt_v2/core/ptt.py +329 -0
  26. aipt_v2/database/__init__.py +14 -0
  27. aipt_v2/database/models.py +232 -0
  28. aipt_v2/database/repository.py +384 -0
  29. aipt_v2/docker/__init__.py +23 -0
  30. aipt_v2/docker/builder.py +260 -0
  31. aipt_v2/docker/manager.py +222 -0
  32. aipt_v2/docker/sandbox.py +371 -0
  33. aipt_v2/evasion/__init__.py +58 -0
  34. aipt_v2/evasion/request_obfuscator.py +272 -0
  35. aipt_v2/evasion/tls_fingerprint.py +285 -0
  36. aipt_v2/evasion/ua_rotator.py +301 -0
  37. aipt_v2/evasion/waf_bypass.py +439 -0
  38. aipt_v2/execution/__init__.py +23 -0
  39. aipt_v2/execution/executor.py +302 -0
  40. aipt_v2/execution/parser.py +544 -0
  41. aipt_v2/execution/terminal.py +337 -0
  42. aipt_v2/health.py +437 -0
  43. aipt_v2/intelligence/__init__.py +85 -0
  44. aipt_v2/intelligence/auth.py +520 -0
  45. aipt_v2/intelligence/chaining.py +775 -0
  46. aipt_v2/intelligence/cve_aipt.py +334 -0
  47. aipt_v2/intelligence/cve_info.py +1111 -0
  48. aipt_v2/intelligence/rag.py +239 -0
  49. aipt_v2/intelligence/scope.py +442 -0
  50. aipt_v2/intelligence/searchers/__init__.py +5 -0
  51. aipt_v2/intelligence/searchers/exploitdb_searcher.py +523 -0
  52. aipt_v2/intelligence/searchers/github_searcher.py +467 -0
  53. aipt_v2/intelligence/searchers/google_searcher.py +281 -0
  54. aipt_v2/intelligence/tools.json +443 -0
  55. aipt_v2/intelligence/triage.py +670 -0
  56. aipt_v2/interface/__init__.py +5 -0
  57. aipt_v2/interface/cli.py +230 -0
  58. aipt_v2/interface/main.py +501 -0
  59. aipt_v2/interface/tui.py +1276 -0
  60. aipt_v2/interface/utils.py +583 -0
  61. aipt_v2/llm/__init__.py +39 -0
  62. aipt_v2/llm/config.py +26 -0
  63. aipt_v2/llm/llm.py +514 -0
  64. aipt_v2/llm/memory.py +214 -0
  65. aipt_v2/llm/request_queue.py +89 -0
  66. aipt_v2/llm/utils.py +89 -0
  67. aipt_v2/models/__init__.py +15 -0
  68. aipt_v2/models/findings.py +295 -0
  69. aipt_v2/models/phase_result.py +224 -0
  70. aipt_v2/models/scan_config.py +207 -0
  71. aipt_v2/monitoring/grafana/dashboards/aipt-dashboard.json +355 -0
  72. aipt_v2/monitoring/grafana/dashboards/default.yml +17 -0
  73. aipt_v2/monitoring/grafana/datasources/prometheus.yml +17 -0
  74. aipt_v2/monitoring/prometheus.yml +60 -0
  75. aipt_v2/orchestration/__init__.py +52 -0
  76. aipt_v2/orchestration/pipeline.py +398 -0
  77. aipt_v2/orchestration/progress.py +300 -0
  78. aipt_v2/orchestration/scheduler.py +296 -0
  79. aipt_v2/orchestrator.py +2284 -0
  80. aipt_v2/payloads/__init__.py +27 -0
  81. aipt_v2/payloads/cmdi.py +150 -0
  82. aipt_v2/payloads/sqli.py +263 -0
  83. aipt_v2/payloads/ssrf.py +204 -0
  84. aipt_v2/payloads/templates.py +222 -0
  85. aipt_v2/payloads/traversal.py +166 -0
  86. aipt_v2/payloads/xss.py +204 -0
  87. aipt_v2/prompts/__init__.py +60 -0
  88. aipt_v2/proxy/__init__.py +29 -0
  89. aipt_v2/proxy/history.py +352 -0
  90. aipt_v2/proxy/interceptor.py +452 -0
  91. aipt_v2/recon/__init__.py +44 -0
  92. aipt_v2/recon/dns.py +241 -0
  93. aipt_v2/recon/osint.py +367 -0
  94. aipt_v2/recon/subdomain.py +372 -0
  95. aipt_v2/recon/tech_detect.py +311 -0
  96. aipt_v2/reports/__init__.py +17 -0
  97. aipt_v2/reports/generator.py +313 -0
  98. aipt_v2/reports/html_report.py +378 -0
  99. aipt_v2/runtime/__init__.py +44 -0
  100. aipt_v2/runtime/base.py +30 -0
  101. aipt_v2/runtime/docker.py +401 -0
  102. aipt_v2/runtime/local.py +346 -0
  103. aipt_v2/runtime/tool_server.py +205 -0
  104. aipt_v2/scanners/__init__.py +28 -0
  105. aipt_v2/scanners/base.py +273 -0
  106. aipt_v2/scanners/nikto.py +244 -0
  107. aipt_v2/scanners/nmap.py +402 -0
  108. aipt_v2/scanners/nuclei.py +273 -0
  109. aipt_v2/scanners/web.py +454 -0
  110. aipt_v2/scripts/security_audit.py +366 -0
  111. aipt_v2/telemetry/__init__.py +7 -0
  112. aipt_v2/telemetry/tracer.py +347 -0
  113. aipt_v2/terminal/__init__.py +28 -0
  114. aipt_v2/terminal/executor.py +400 -0
  115. aipt_v2/terminal/sandbox.py +350 -0
  116. aipt_v2/tools/__init__.py +44 -0
  117. aipt_v2/tools/active_directory/__init__.py +78 -0
  118. aipt_v2/tools/active_directory/ad_config.py +238 -0
  119. aipt_v2/tools/active_directory/bloodhound_wrapper.py +447 -0
  120. aipt_v2/tools/active_directory/kerberos_attacks.py +430 -0
  121. aipt_v2/tools/active_directory/ldap_enum.py +533 -0
  122. aipt_v2/tools/active_directory/smb_attacks.py +505 -0
  123. aipt_v2/tools/agents_graph/__init__.py +19 -0
  124. aipt_v2/tools/agents_graph/agents_graph_actions.py +69 -0
  125. aipt_v2/tools/api_security/__init__.py +76 -0
  126. aipt_v2/tools/api_security/api_discovery.py +608 -0
  127. aipt_v2/tools/api_security/graphql_scanner.py +622 -0
  128. aipt_v2/tools/api_security/jwt_analyzer.py +577 -0
  129. aipt_v2/tools/api_security/openapi_fuzzer.py +761 -0
  130. aipt_v2/tools/browser/__init__.py +5 -0
  131. aipt_v2/tools/browser/browser_actions.py +238 -0
  132. aipt_v2/tools/browser/browser_instance.py +535 -0
  133. aipt_v2/tools/browser/tab_manager.py +344 -0
  134. aipt_v2/tools/cloud/__init__.py +70 -0
  135. aipt_v2/tools/cloud/cloud_config.py +273 -0
  136. aipt_v2/tools/cloud/cloud_scanner.py +639 -0
  137. aipt_v2/tools/cloud/prowler_tool.py +571 -0
  138. aipt_v2/tools/cloud/scoutsuite_tool.py +359 -0
  139. aipt_v2/tools/executor.py +307 -0
  140. aipt_v2/tools/parser.py +408 -0
  141. aipt_v2/tools/proxy/__init__.py +5 -0
  142. aipt_v2/tools/proxy/proxy_actions.py +103 -0
  143. aipt_v2/tools/proxy/proxy_manager.py +789 -0
  144. aipt_v2/tools/registry.py +196 -0
  145. aipt_v2/tools/scanners/__init__.py +343 -0
  146. aipt_v2/tools/scanners/acunetix_tool.py +712 -0
  147. aipt_v2/tools/scanners/burp_tool.py +631 -0
  148. aipt_v2/tools/scanners/config.py +156 -0
  149. aipt_v2/tools/scanners/nessus_tool.py +588 -0
  150. aipt_v2/tools/scanners/zap_tool.py +612 -0
  151. aipt_v2/tools/terminal/__init__.py +5 -0
  152. aipt_v2/tools/terminal/terminal_actions.py +37 -0
  153. aipt_v2/tools/terminal/terminal_manager.py +153 -0
  154. aipt_v2/tools/terminal/terminal_session.py +449 -0
  155. aipt_v2/tools/tool_processing.py +108 -0
  156. aipt_v2/utils/__init__.py +17 -0
  157. aipt_v2/utils/logging.py +201 -0
  158. aipt_v2/utils/model_manager.py +187 -0
  159. aipt_v2/utils/searchers/__init__.py +269 -0
  160. aiptx-2.0.2.dist-info/METADATA +324 -0
  161. aiptx-2.0.2.dist-info/RECORD +165 -0
  162. aiptx-2.0.2.dist-info/WHEEL +5 -0
  163. aiptx-2.0.2.dist-info/entry_points.txt +7 -0
  164. aiptx-2.0.2.dist-info/licenses/LICENSE +21 -0
  165. aiptx-2.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,108 @@
1
+ """
2
+ AIPT Tool Processing - Execute tool invocations from LLM responses
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ async def process_tool_invocations(
12
+ actions: list[dict[str, Any]],
13
+ conversation_history: list[dict[str, Any]],
14
+ state: Any,
15
+ ) -> bool:
16
+ """
17
+ Process tool invocations from LLM response.
18
+
19
+ Args:
20
+ actions: List of tool invocation dicts with 'name' and 'arguments'
21
+ conversation_history: Mutable conversation history
22
+ state: Agent state object
23
+
24
+ Returns:
25
+ True if agent should finish, False otherwise
26
+ """
27
+ for action in actions:
28
+ tool_name = action.get("name", "")
29
+ tool_args = action.get("arguments", {})
30
+
31
+ logger.info(f"Executing tool: {tool_name}")
32
+
33
+ # Check for finish tools
34
+ if tool_name in ["finish_scan", "agent_finish"]:
35
+ result = tool_args.get("result", "Task completed")
36
+ conversation_history.append({
37
+ "role": "user",
38
+ "content": f"Tool {tool_name} executed. Result: {result}",
39
+ })
40
+ return True
41
+
42
+ # Execute the tool
43
+ try:
44
+ result = await _execute_tool(tool_name, tool_args, state)
45
+ conversation_history.append({
46
+ "role": "user",
47
+ "content": f"Tool {tool_name} result:\n{result}",
48
+ })
49
+ except Exception as e:
50
+ error_msg = f"Tool {tool_name} failed: {str(e)}"
51
+ logger.error(error_msg)
52
+ conversation_history.append({
53
+ "role": "user",
54
+ "content": error_msg,
55
+ })
56
+
57
+ return False
58
+
59
+
60
+ async def _execute_tool(name: str, args: dict[str, Any], state: Any) -> str:
61
+ """Execute a single tool and return result"""
62
+ # Import tool executors lazily
63
+ if name == "execute_command":
64
+ return await _execute_command(args, state)
65
+ elif name == "browser_navigate":
66
+ return await _browser_navigate(args, state)
67
+ elif name == "browser_screenshot":
68
+ return await _browser_screenshot(args, state)
69
+ else:
70
+ return f"Tool '{name}' executed with args: {args}"
71
+
72
+
73
+ async def _execute_command(args: dict[str, Any], state: Any) -> str:
74
+ """Execute a shell command in the sandbox"""
75
+ import asyncio
76
+
77
+ command = args.get("command", "")
78
+ timeout = args.get("timeout", 60)
79
+
80
+ # Use subprocess for now (Docker integration later)
81
+ try:
82
+ proc = await asyncio.create_subprocess_shell(
83
+ command,
84
+ stdout=asyncio.subprocess.PIPE,
85
+ stderr=asyncio.subprocess.PIPE,
86
+ )
87
+ stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
88
+ output = stdout.decode() if stdout else ""
89
+ errors = stderr.decode() if stderr else ""
90
+ return output + errors if errors else output
91
+ except asyncio.TimeoutError:
92
+ return f"Command timed out after {timeout} seconds"
93
+ except Exception as e:
94
+ return f"Command failed: {str(e)}"
95
+
96
+
97
+ async def _browser_navigate(args: dict[str, Any], state: Any) -> str:
98
+ """Navigate browser to URL"""
99
+ url = args.get("url", "")
100
+ return f"Navigated to: {url}"
101
+
102
+
103
+ async def _browser_screenshot(args: dict[str, Any], state: Any) -> str:
104
+ """Take browser screenshot"""
105
+ return "Screenshot taken"
106
+
107
+
108
+ __all__ = ["process_tool_invocations"]
@@ -0,0 +1,17 @@
1
+ """
2
+ AIPT v2 Utilities Module
3
+ ========================
4
+
5
+ Provides common utilities used across the framework:
6
+ - Structured logging with secret redaction
7
+ - Model management wrappers
8
+ - Searcher utilities
9
+ """
10
+
11
+ from .logging import logger, setup_logging, get_logger
12
+
13
+ __all__ = [
14
+ "logger",
15
+ "setup_logging",
16
+ "get_logger",
17
+ ]
@@ -0,0 +1,201 @@
1
+ """
2
+ Structured Logging Configuration for AIPT v2
3
+ =============================================
4
+
5
+ Provides:
6
+ - Structured logging via structlog
7
+ - Automatic secret redaction
8
+ - JSON format for production
9
+ - Console format for development
10
+ """
11
+
12
+ import logging
13
+ import os
14
+ import sys
15
+ import re
16
+ from typing import Any, Optional
17
+ from functools import lru_cache
18
+
19
+ try:
20
+ import structlog
21
+ STRUCTLOG_AVAILABLE = True
22
+ except ImportError:
23
+ STRUCTLOG_AVAILABLE = False
24
+
25
+
26
+ # Patterns for secret redaction
27
+ SECRET_PATTERNS = [
28
+ r"api[_-]?key",
29
+ r"apikey",
30
+ r"token",
31
+ r"secret",
32
+ r"password",
33
+ r"credential",
34
+ r"auth",
35
+ r"bearer",
36
+ r"sk-[a-zA-Z0-9]+",
37
+ r"pk-[a-zA-Z0-9]+",
38
+ r"access[_-]?key",
39
+ r"private[_-]?key",
40
+ ]
41
+
42
+ SECRET_REGEX = re.compile("|".join(SECRET_PATTERNS), re.IGNORECASE)
43
+
44
+
45
+ def _should_redact(key: str) -> bool:
46
+ """Check if a key should be redacted."""
47
+ return bool(SECRET_REGEX.search(key))
48
+
49
+
50
+ def _redact_value(value: str) -> str:
51
+ """Redact a sensitive value, keeping first/last chars for debugging."""
52
+ if len(value) <= 8:
53
+ return "[REDACTED]"
54
+ return f"{value[:4]}...{value[-4:]}"
55
+
56
+
57
+ def _redact_processor(logger: Any, method_name: str, event_dict: dict) -> dict:
58
+ """Structlog processor to redact sensitive information."""
59
+ for key, value in list(event_dict.items()):
60
+ if isinstance(value, str):
61
+ if _should_redact(key):
62
+ event_dict[key] = "[REDACTED]"
63
+ elif len(value) > 20 and SECRET_REGEX.search(value):
64
+ event_dict[key] = _redact_value(value)
65
+ return event_dict
66
+
67
+
68
+ def setup_logging(
69
+ level: str = "INFO",
70
+ json_format: bool = False,
71
+ redact_secrets: bool = True,
72
+ ) -> Any:
73
+ """
74
+ Configure structured logging for the application.
75
+
76
+ Args:
77
+ level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
78
+ json_format: Use JSON format (for production)
79
+ redact_secrets: Automatically redact sensitive values
80
+
81
+ Returns:
82
+ Configured logger instance
83
+ """
84
+ log_level = getattr(logging, level.upper(), logging.INFO)
85
+
86
+ if STRUCTLOG_AVAILABLE:
87
+ processors = [
88
+ structlog.stdlib.filter_by_level,
89
+ structlog.stdlib.add_logger_name,
90
+ structlog.stdlib.add_log_level,
91
+ structlog.processors.TimeStamper(fmt="iso"),
92
+ structlog.processors.StackInfoRenderer(),
93
+ ]
94
+
95
+ if redact_secrets:
96
+ processors.append(_redact_processor)
97
+
98
+ if json_format:
99
+ processors.append(structlog.processors.JSONRenderer())
100
+ else:
101
+ processors.append(structlog.dev.ConsoleRenderer(colors=True))
102
+
103
+ structlog.configure(
104
+ processors=processors,
105
+ wrapper_class=structlog.stdlib.BoundLogger,
106
+ context_class=dict,
107
+ logger_factory=structlog.stdlib.LoggerFactory(),
108
+ cache_logger_on_first_use=True,
109
+ )
110
+
111
+ logging.basicConfig(
112
+ format="%(message)s",
113
+ stream=sys.stdout,
114
+ level=log_level,
115
+ )
116
+
117
+ return structlog.get_logger()
118
+ else:
119
+ # Fallback to standard logging
120
+ logging.basicConfig(
121
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
122
+ stream=sys.stdout,
123
+ level=log_level,
124
+ )
125
+ return logging.getLogger("aipt_v2")
126
+
127
+
128
+ @lru_cache(maxsize=1)
129
+ def get_logger() -> Any:
130
+ """Get or create the global logger instance."""
131
+ log_level = os.getenv("AIPT_LOG_LEVEL", "INFO")
132
+ json_format = os.getenv("AIPT_LOG_FORMAT", "console").lower() == "json"
133
+ redact = os.getenv("AIPT_REDACT_SECRETS", "true").lower() == "true"
134
+
135
+ return setup_logging(
136
+ level=log_level,
137
+ json_format=json_format,
138
+ redact_secrets=redact,
139
+ )
140
+
141
+
142
+ # Global logger instance
143
+ logger = get_logger()
144
+
145
+
146
+ class LoggerAdapter:
147
+ """
148
+ Adapter for consistent logging interface.
149
+
150
+ Provides methods that work whether structlog is available or not.
151
+ """
152
+
153
+ def __init__(self, logger_instance: Any):
154
+ self._logger = logger_instance
155
+ self._is_structlog = STRUCTLOG_AVAILABLE
156
+
157
+ def _log(self, level: str, msg: str, **kwargs):
158
+ """Internal log method."""
159
+ if self._is_structlog:
160
+ getattr(self._logger, level)(msg, **kwargs)
161
+ else:
162
+ extra = " ".join(f"{k}={v}" for k, v in kwargs.items())
163
+ full_msg = f"{msg} {extra}" if extra else msg
164
+ getattr(self._logger, level)(full_msg)
165
+
166
+ def debug(self, msg: str, **kwargs):
167
+ self._log("debug", msg, **kwargs)
168
+
169
+ def info(self, msg: str, **kwargs):
170
+ self._log("info", msg, **kwargs)
171
+
172
+ def warning(self, msg: str, **kwargs):
173
+ self._log("warning", msg, **kwargs)
174
+
175
+ def error(self, msg: str, exc_info: bool = False, **kwargs):
176
+ if self._is_structlog:
177
+ if exc_info:
178
+ kwargs["exc_info"] = True
179
+ self._logger.error(msg, **kwargs)
180
+ else:
181
+ extra = " ".join(f"{k}={v}" for k, v in kwargs.items())
182
+ full_msg = f"{msg} {extra}" if extra else msg
183
+ self._logger.error(full_msg, exc_info=exc_info)
184
+
185
+ def critical(self, msg: str, **kwargs):
186
+ self._log("critical", msg, **kwargs)
187
+
188
+ def exception(self, msg: str, **kwargs):
189
+ """Log exception with traceback."""
190
+ self.error(msg, exc_info=True, **kwargs)
191
+
192
+
193
+ # Export a wrapped logger for consistent interface
194
+ def create_logger(name: Optional[str] = None) -> LoggerAdapter:
195
+ """Create a named logger instance."""
196
+ if STRUCTLOG_AVAILABLE:
197
+ base_logger = structlog.get_logger(name) if name else get_logger()
198
+ else:
199
+ base_logger = logging.getLogger(name or "aipt_v2")
200
+
201
+ return LoggerAdapter(base_logger)
@@ -0,0 +1,187 @@
1
+ """
2
+ Model Manager for AIPT v2
3
+ =========================
4
+
5
+ Provides a unified interface for LLM model access.
6
+ This is a compatibility layer for intelligence modules that
7
+ reference utils.model_manager.
8
+ """
9
+
10
+ import os
11
+ from typing import Any, Optional, Dict, List
12
+ from dataclasses import dataclass
13
+
14
+ from aipt_v2.utils.logging import logger
15
+
16
+
17
+ @dataclass
18
+ class ModelConfig:
19
+ """Configuration for model instances."""
20
+ model_name: str = "gpt-4"
21
+ temperature: float = 0.7
22
+ max_tokens: int = 4096
23
+ timeout: int = 120
24
+
25
+
26
+ class ModelWrapper:
27
+ """
28
+ Wrapper around litellm for consistent model access.
29
+
30
+ Provides both sync and async completion methods.
31
+ """
32
+
33
+ def __init__(self, config: ModelConfig):
34
+ self.config = config
35
+ self._litellm = None
36
+
37
+ def _get_litellm(self):
38
+ """Lazy load litellm."""
39
+ if self._litellm is None:
40
+ try:
41
+ import litellm
42
+ self._litellm = litellm
43
+ except ImportError:
44
+ raise ImportError(
45
+ "litellm is required for model_manager. "
46
+ "Install with: pip install litellm"
47
+ )
48
+ return self._litellm
49
+
50
+ def complete(
51
+ self,
52
+ prompt: str,
53
+ system_prompt: Optional[str] = None,
54
+ **kwargs
55
+ ) -> str:
56
+ """
57
+ Synchronous completion.
58
+
59
+ Args:
60
+ prompt: User prompt
61
+ system_prompt: Optional system prompt
62
+ **kwargs: Additional arguments passed to litellm
63
+
64
+ Returns:
65
+ Model response text
66
+ """
67
+ litellm = self._get_litellm()
68
+
69
+ messages: List[Dict[str, str]] = []
70
+ if system_prompt:
71
+ messages.append({"role": "system", "content": system_prompt})
72
+ messages.append({"role": "user", "content": prompt})
73
+
74
+ try:
75
+ response = litellm.completion(
76
+ model=self.config.model_name,
77
+ messages=messages,
78
+ temperature=kwargs.get("temperature", self.config.temperature),
79
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
80
+ timeout=kwargs.get("timeout", self.config.timeout),
81
+ )
82
+ return response.choices[0].message.content
83
+ except Exception as e:
84
+ logger.error("Model completion failed", model=self.config.model_name, error=str(e))
85
+ raise
86
+
87
+ async def acomplete(
88
+ self,
89
+ prompt: str,
90
+ system_prompt: Optional[str] = None,
91
+ **kwargs
92
+ ) -> str:
93
+ """
94
+ Asynchronous completion.
95
+
96
+ Args:
97
+ prompt: User prompt
98
+ system_prompt: Optional system prompt
99
+ **kwargs: Additional arguments passed to litellm
100
+
101
+ Returns:
102
+ Model response text
103
+ """
104
+ litellm = self._get_litellm()
105
+
106
+ messages: List[Dict[str, str]] = []
107
+ if system_prompt:
108
+ messages.append({"role": "system", "content": system_prompt})
109
+ messages.append({"role": "user", "content": prompt})
110
+
111
+ try:
112
+ response = await litellm.acompletion(
113
+ model=self.config.model_name,
114
+ messages=messages,
115
+ temperature=kwargs.get("temperature", self.config.temperature),
116
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
117
+ timeout=kwargs.get("timeout", self.config.timeout),
118
+ )
119
+ return response.choices[0].message.content
120
+ except Exception as e:
121
+ logger.error("Async model completion failed", model=self.config.model_name, error=str(e))
122
+ raise
123
+
124
+ def embed(self, text: str) -> List[float]:
125
+ """
126
+ Get embedding for text.
127
+
128
+ Args:
129
+ text: Text to embed
130
+
131
+ Returns:
132
+ Embedding vector
133
+ """
134
+ litellm = self._get_litellm()
135
+
136
+ try:
137
+ response = litellm.embedding(
138
+ model="text-embedding-ada-002",
139
+ input=text,
140
+ )
141
+ return response.data[0]["embedding"]
142
+ except Exception as e:
143
+ logger.error("Embedding failed", error=str(e))
144
+ raise
145
+
146
+
147
+ # Cache for model instances
148
+ _model_cache: Dict[str, ModelWrapper] = {}
149
+
150
+
151
+ def get_model(
152
+ model_name: Optional[str] = None,
153
+ **kwargs
154
+ ) -> ModelWrapper:
155
+ """
156
+ Get or create a model instance.
157
+
158
+ Args:
159
+ model_name: Model identifier (default from env or gpt-4)
160
+ **kwargs: Additional config options
161
+
162
+ Returns:
163
+ ModelWrapper instance
164
+ """
165
+ if model_name is None:
166
+ model_name = os.getenv("AIPT_LLM_MODEL", "gpt-4")
167
+
168
+ cache_key = f"{model_name}:{hash(frozenset(kwargs.items()))}"
169
+
170
+ if cache_key not in _model_cache:
171
+ config = ModelConfig(
172
+ model_name=model_name,
173
+ temperature=kwargs.get("temperature", 0.7),
174
+ max_tokens=kwargs.get("max_tokens", 4096),
175
+ timeout=kwargs.get("timeout", 120),
176
+ )
177
+ _model_cache[cache_key] = ModelWrapper(config)
178
+ logger.info("Created model instance", model=model_name)
179
+
180
+ return _model_cache[cache_key]
181
+
182
+
183
+ def clear_model_cache():
184
+ """Clear the model cache."""
185
+ global _model_cache
186
+ _model_cache = {}
187
+ logger.info("Model cache cleared")