kubrick-cli 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kubrick_cli/safety.py ADDED
@@ -0,0 +1,204 @@
1
+ """Safety manager for dangerous command detection and validation."""
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Optional, Tuple
6
+
7
+ from rich.console import Console
8
+ from rich.panel import Panel
9
+ from rich.prompt import Confirm
10
+
11
+ console = Console()
12
+
13
+
14
+ @dataclass
15
+ class SafetyConfig:
16
+ """Safety configuration settings."""
17
+
18
+ max_iterations: int = 15
19
+ max_tools_per_turn: int = 5
20
+ total_timeout_seconds: int = 600
21
+ tool_timeout_seconds: int = 30
22
+ max_file_size_mb: int = 10
23
+ require_dangerous_command_confirmation: bool = True
24
+
25
+ @classmethod
26
+ def from_config(cls, config: Dict) -> "SafetyConfig":
27
+ """
28
+ Create SafetyConfig from configuration dict.
29
+
30
+ Args:
31
+ config: Configuration dictionary
32
+
33
+ Returns:
34
+ SafetyConfig instance
35
+ """
36
+ return cls(
37
+ max_iterations=config.get("max_iterations", 15),
38
+ max_tools_per_turn=config.get("max_tools_per_turn", 5),
39
+ total_timeout_seconds=config.get("total_timeout_seconds", 600),
40
+ tool_timeout_seconds=config.get("tool_timeout_seconds", 30),
41
+ max_file_size_mb=config.get("max_file_size_mb", 10),
42
+ require_dangerous_command_confirmation=config.get(
43
+ "require_dangerous_command_confirmation", True
44
+ ),
45
+ )
46
+
47
+
48
+ class SafetyManager:
49
+ """
50
+ Manages safety checks for tool execution.
51
+
52
+ Features:
53
+ - Dangerous command detection
54
+ - User confirmation prompts
55
+ - File size limits
56
+ - Timeout enforcement
57
+ """
58
+
59
+ # Dangerous bash command patterns
60
+ DANGEROUS_PATTERNS = [
61
+ (r"\brm\s+-rf\s+/", "Recursive delete from root"),
62
+ (r"\brm\s+-rf\s+~", "Recursive delete from home"),
63
+ (r"\brm\s+-rf\s+\*", "Recursive delete all files"),
64
+ (r"\brm\s+-rf", "Recursive force delete"),
65
+ (r"\bsudo\b", "Elevated privileges"),
66
+ (r"\bchmod\s+777", "Overly permissive permissions"),
67
+ (r"\bchmod\s+-R\s+777", "Recursive overly permissive permissions"),
68
+ (r">\s*/dev/", "Writing to device files"),
69
+ (r"\bgit\s+push\s+--force", "Force push to git"),
70
+ (r"\bgit\s+push\s+-f", "Force push to git (short form)"),
71
+ (r"\bmkfs\b", "Format filesystem"),
72
+ (r"\bdd\s+.*of=/dev", "Writing to block device"),
73
+ (r"\b:(\()\|:\1&", "Fork bomb"),
74
+ (r"\bcurl\b.*\|\s*bash", "Pipe curl to bash"),
75
+ (r"\bwget\b.*\|\s*bash", "Pipe wget to bash"),
76
+ (r"\beval\b.*\$\(", "Eval with command substitution"),
77
+ ]
78
+
79
+ def __init__(self, config: SafetyConfig):
80
+ """
81
+ Initialize safety manager.
82
+
83
+ Args:
84
+ config: SafetyConfig instance
85
+ """
86
+ self.config = config
87
+
88
+ def validate_bash_command(self, command: str) -> Tuple[bool, Optional[str]]:
89
+ """
90
+ Validate a bash command for dangerous patterns.
91
+
92
+ Args:
93
+ command: The bash command to validate
94
+
95
+ Returns:
96
+ Tuple of (is_safe, warning_message)
97
+ - is_safe: True if safe, False if dangerous
98
+ - warning_message: Description of the danger (if any)
99
+ """
100
+ command_lower = command.lower()
101
+
102
+ for pattern, description in self.DANGEROUS_PATTERNS:
103
+ if re.search(pattern, command_lower, re.IGNORECASE):
104
+ return False, f"Dangerous command detected: {description}"
105
+
106
+ return True, None
107
+
108
+ def get_user_confirmation(self, warning: str, command: str) -> bool:
109
+ """
110
+ Get user confirmation for a dangerous command.
111
+
112
+ Args:
113
+ warning: Warning message
114
+ command: The dangerous command
115
+
116
+ Returns:
117
+ True if user confirms, False otherwise
118
+ """
119
+ if not self.config.require_dangerous_command_confirmation:
120
+ return True
121
+
122
+ console.print(
123
+ Panel(
124
+ f"[bold red]⚠️ WARNING[/bold red]\n\n"
125
+ f"{warning}\n\n"
126
+ f"[bold]Command:[/bold]\n"
127
+ f"[yellow]{command}[/yellow]\n\n"
128
+ f"This command could be dangerous.",
129
+ border_style="red",
130
+ title="Safety Check",
131
+ )
132
+ )
133
+
134
+ confirmed = Confirm.ask(
135
+ "[bold red]Do you want to execute this command?[/bold red]",
136
+ default=False,
137
+ )
138
+
139
+ if confirmed:
140
+ console.print("[yellow]⚠️ Proceeding with caution...[/yellow]")
141
+ else:
142
+ console.print("[green]✓ Command cancelled[/green]")
143
+
144
+ return confirmed
145
+
146
+ def validate_file_size(self, file_path: str, size_bytes: int) -> bool:
147
+ """
148
+ Validate file size against limits.
149
+
150
+ Args:
151
+ file_path: Path to the file
152
+ size_bytes: File size in bytes
153
+
154
+ Returns:
155
+ True if within limits, False otherwise
156
+ """
157
+ max_bytes = self.config.max_file_size_mb * 1024 * 1024
158
+
159
+ if size_bytes > max_bytes:
160
+ console.print(
161
+ f"[yellow]⚠️ File {file_path} exceeds size limit "
162
+ f"({size_bytes / 1024 / 1024:.1f}MB > {self.config.max_file_size_mb}MB)[/yellow]"
163
+ )
164
+ return False
165
+
166
+ return True
167
+
168
+ def check_iteration_limit(self, current: int, max_iterations: int) -> bool:
169
+ """
170
+ Check if iteration limit has been reached.
171
+
172
+ Args:
173
+ current: Current iteration number
174
+ max_iterations: Maximum allowed iterations
175
+
176
+ Returns:
177
+ True if within limit, False otherwise
178
+ """
179
+ if current >= max_iterations:
180
+ console.print(
181
+ f"[yellow]⚠️ Max iteration limit reached ({max_iterations})[/yellow]"
182
+ )
183
+ return False
184
+
185
+ return True
186
+
187
+ def check_tool_limit(self, current: int, max_tools: int) -> bool:
188
+ """
189
+ Check if tool call limit has been reached.
190
+
191
+ Args:
192
+ current: Current number of tool calls
193
+ max_tools: Maximum allowed tools
194
+
195
+ Returns:
196
+ True if within limit, False otherwise
197
+ """
198
+ if current > max_tools:
199
+ console.print(
200
+ f"[yellow]⚠️ Too many tool calls ({current} > {max_tools})[/yellow]"
201
+ )
202
+ return False
203
+
204
+ return True
@@ -0,0 +1,183 @@
1
+ """Parallel tool execution scheduler for improved performance."""
2
+
3
+ import concurrent.futures
4
+ from typing import Dict, List, Tuple
5
+
6
+ from rich.console import Console
7
+
8
+ console = Console()
9
+
10
+
11
+ # Read-only tools that can be executed in parallel
12
+ READ_ONLY_TOOLS = {
13
+ "read_file",
14
+ "list_files",
15
+ "search_files",
16
+ }
17
+
18
+ # Write tools that must be executed sequentially
19
+ WRITE_TOOLS = {
20
+ "write_file",
21
+ "edit_file",
22
+ "create_directory",
23
+ "run_bash",
24
+ }
25
+
26
+
27
+ class ToolScheduler:
28
+ """
29
+ Schedules and executes tools with intelligent parallelization.
30
+
31
+ Strategy: Conservative parallelization
32
+ - Read-only tools (read_file, list_files, search_files) run in parallel
33
+ - Write tools (write_file, edit_file, run_bash, etc.) run sequentially
34
+ - Graceful error handling per tool
35
+ """
36
+
37
+ def __init__(
38
+ self, tool_executor, max_workers: int = 3, enable_parallel: bool = True
39
+ ):
40
+ """
41
+ Initialize tool scheduler.
42
+
43
+ Args:
44
+ tool_executor: ToolExecutor instance
45
+ max_workers: Maximum parallel workers (default: 3)
46
+ enable_parallel: Whether to enable parallel execution
47
+ """
48
+ self.tool_executor = tool_executor
49
+ self.max_workers = max_workers
50
+ self.enable_parallel = enable_parallel
51
+
52
+ def execute_tools(self, tool_calls: List[Tuple[str, Dict]]) -> List[Dict]:
53
+ """
54
+ Execute a list of tool calls with intelligent scheduling.
55
+
56
+ Args:
57
+ tool_calls: List of (tool_name, parameters) tuples
58
+
59
+ Returns:
60
+ List of result dictionaries in the same order as input
61
+ """
62
+ if not self.enable_parallel or len(tool_calls) <= 1:
63
+ return self._execute_sequential(tool_calls)
64
+
65
+ read_only_calls = []
66
+ write_calls = []
67
+ call_order = []
68
+
69
+ for i, (tool_name, params) in enumerate(tool_calls):
70
+ if tool_name in READ_ONLY_TOOLS:
71
+ read_only_calls.append((i, tool_name, params))
72
+ call_order.append(("read", len(read_only_calls) - 1))
73
+ else:
74
+ write_calls.append((i, tool_name, params))
75
+ call_order.append(("write", len(write_calls) - 1))
76
+
77
+ read_results = {}
78
+ if read_only_calls:
79
+ console.print(
80
+ f"[dim]→ Executing {len(read_only_calls)} read-only tool(s) in parallel[/dim]"
81
+ )
82
+ read_results = self._execute_parallel(read_only_calls)
83
+
84
+ write_results = {}
85
+ if write_calls:
86
+ write_results = self._execute_sequential_indexed(write_calls)
87
+
88
+ results = []
89
+ for call_type, index in call_order:
90
+ if call_type == "read":
91
+ results.append(read_results[index])
92
+ else:
93
+ results.append(write_results[index])
94
+
95
+ return results
96
+
97
+ def _execute_parallel(
98
+ self, indexed_calls: List[Tuple[int, str, Dict]]
99
+ ) -> Dict[int, Dict]:
100
+ """
101
+ Execute tools in parallel using ThreadPoolExecutor.
102
+
103
+ Args:
104
+ indexed_calls: List of (index, tool_name, parameters) tuples
105
+
106
+ Returns:
107
+ Dict mapping index to result
108
+ """
109
+ results = {}
110
+
111
+ with concurrent.futures.ThreadPoolExecutor(
112
+ max_workers=self.max_workers
113
+ ) as executor:
114
+ future_to_index = {}
115
+ for index, tool_name, params in indexed_calls:
116
+ future = executor.submit(self._execute_single, tool_name, params)
117
+ future_to_index[future] = index
118
+
119
+ for future in concurrent.futures.as_completed(future_to_index):
120
+ index = future_to_index[future]
121
+ try:
122
+ result = future.result()
123
+ results[index] = result
124
+ except Exception as e:
125
+ results[index] = {
126
+ "success": False,
127
+ "error": f"Parallel execution error: {str(e)}",
128
+ }
129
+
130
+ return results
131
+
132
+ def _execute_sequential(self, tool_calls: List[Tuple[str, Dict]]) -> List[Dict]:
133
+ """
134
+ Execute tools sequentially.
135
+
136
+ Args:
137
+ tool_calls: List of (tool_name, parameters) tuples
138
+
139
+ Returns:
140
+ List of result dictionaries
141
+ """
142
+ results = []
143
+ for tool_name, params in tool_calls:
144
+ result = self._execute_single(tool_name, params)
145
+ results.append(result)
146
+ return results
147
+
148
+ def _execute_sequential_indexed(
149
+ self, indexed_calls: List[Tuple[int, str, Dict]]
150
+ ) -> Dict[int, Dict]:
151
+ """
152
+ Execute tools sequentially with index mapping.
153
+
154
+ Args:
155
+ indexed_calls: List of (index, tool_name, parameters) tuples
156
+
157
+ Returns:
158
+ Dict mapping index to result
159
+ """
160
+ results = {}
161
+ for index, tool_name, params in indexed_calls:
162
+ result = self._execute_single(tool_name, params)
163
+ results[index] = result
164
+ return results
165
+
166
+ def _execute_single(self, tool_name: str, params: Dict) -> Dict:
167
+ """
168
+ Execute a single tool.
169
+
170
+ Args:
171
+ tool_name: Name of tool
172
+ params: Tool parameters
173
+
174
+ Returns:
175
+ Result dictionary
176
+ """
177
+ try:
178
+ return self.tool_executor.execute(tool_name, params)
179
+ except Exception as e:
180
+ return {
181
+ "success": False,
182
+ "error": f"Tool execution error: {str(e)}",
183
+ }
@@ -0,0 +1,161 @@
1
+ """Setup wizard for first-time Kubrick configuration."""
2
+
3
+ from rich.console import Console
4
+ from rich.panel import Panel
5
+ from rich.prompt import Prompt
6
+ from rich.table import Table
7
+
8
+ from .providers.factory import ProviderFactory
9
+
10
+ console = Console()
11
+
12
+
13
+ class SetupWizard:
14
+ """
15
+ Interactive setup wizard for Kubrick configuration.
16
+
17
+ This wizard automatically discovers all available providers and
18
+ generates the configuration UI based on provider metadata.
19
+ """
20
+
21
+ @staticmethod
22
+ def run() -> dict:
23
+ """
24
+ Run the setup wizard.
25
+
26
+ Returns:
27
+ Dictionary with provider configuration
28
+ """
29
+ console.print(
30
+ Panel.fit(
31
+ "[bold cyan]Welcome to Kubrick![/bold cyan]\n\n"
32
+ "Let's set up your AI provider.\n"
33
+ "You can change these settings later in ~/.kubrick/config.json",
34
+ border_style="cyan",
35
+ )
36
+ )
37
+
38
+ provider_metadata = SetupWizard._select_provider()
39
+
40
+ config = SetupWizard._configure_provider(provider_metadata)
41
+ config["provider"] = provider_metadata.name
42
+
43
+ SetupWizard._show_summary(provider_metadata, config)
44
+
45
+ console.print("\n[green]✓ Setup complete![/green]")
46
+ console.print(
47
+ "[dim]Your configuration has been saved to ~/.kubrick/config.json[/dim]\n"
48
+ )
49
+
50
+ return config
51
+
52
+ @staticmethod
53
+ def _select_provider():
54
+ """
55
+ Prompt user to select a provider from all discovered providers.
56
+
57
+ Returns:
58
+ ProviderMetadata object for selected provider
59
+ """
60
+ console.print("\n[bold]Step 1: Select Your AI Provider[/bold]\n")
61
+
62
+ providers = ProviderFactory.list_available_providers()
63
+
64
+ if not providers:
65
+ raise RuntimeError(
66
+ "No providers found! Please ensure provider files are in kubrick_cli/providers/"
67
+ )
68
+
69
+ table = Table(show_header=True, header_style="bold cyan")
70
+ table.add_column("Option", style="cyan")
71
+ table.add_column("Provider", style="green")
72
+ table.add_column("Description")
73
+
74
+ choice_map = {}
75
+ for idx, provider in enumerate(providers, start=1):
76
+ choice_map[str(idx)] = provider
77
+ table.add_row(
78
+ str(idx),
79
+ provider.display_name,
80
+ provider.description,
81
+ )
82
+
83
+ console.print(table)
84
+ console.print()
85
+
86
+ choices = list(choice_map.keys())
87
+ choice = Prompt.ask(
88
+ "[bold yellow]Choose your provider[/bold yellow]",
89
+ choices=choices,
90
+ default="1",
91
+ )
92
+
93
+ return choice_map[choice]
94
+
95
+ @staticmethod
96
+ def _configure_provider(metadata) -> dict:
97
+ """
98
+ Get provider-specific configuration based on metadata.
99
+
100
+ Args:
101
+ metadata: ProviderMetadata object
102
+
103
+ Returns:
104
+ Configuration dictionary
105
+ """
106
+ console.print(f"\n[bold]Step 2: Configure {metadata.display_name}[/bold]\n")
107
+
108
+ config = {}
109
+
110
+ for field in metadata.config_fields:
111
+ key = field["key"]
112
+ label = field["label"]
113
+ field_type = field.get("type", "text")
114
+ default = field.get("default")
115
+ help_text = field.get("help_text")
116
+
117
+ if help_text:
118
+ console.print(f"[dim]{help_text}[/dim]\n")
119
+
120
+ if field_type == "password":
121
+ value = Prompt.ask(
122
+ f"[cyan]{label}[/cyan]",
123
+ password=True,
124
+ )
125
+ elif default is not None:
126
+ value = Prompt.ask(
127
+ f"[cyan]{label}[/cyan]",
128
+ default=str(default),
129
+ )
130
+ else:
131
+ value = Prompt.ask(f"[cyan]{label}[/cyan]")
132
+
133
+ config[key] = value
134
+
135
+ return config
136
+
137
+ @staticmethod
138
+ def _show_summary(metadata, config: dict):
139
+ """
140
+ Show configuration summary.
141
+
142
+ Args:
143
+ metadata: ProviderMetadata object
144
+ config: Configuration dict
145
+ """
146
+ console.print("\n[bold]Configuration Summary[/bold]\n")
147
+
148
+ console.print(f"[cyan]Provider:[/cyan] {metadata.display_name}")
149
+
150
+ for field in metadata.config_fields:
151
+ key = field["key"]
152
+ label = field["label"]
153
+ field_type = field.get("type", "text")
154
+ value = config.get(key)
155
+
156
+ if field_type == "password" and value:
157
+ display_value = f"{'*' * 20}{value[-4:]}"
158
+ else:
159
+ display_value = value
160
+
161
+ console.print(f"[cyan]{label}:[/cyan] {display_value}")