janito 2.27.1__py3-none-any.whl → 2.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janito/README.md +9 -9
- janito/agent/setup_agent.py +29 -16
- janito/cli/chat_mode/script_runner.py +1 -1
- janito/cli/chat_mode/session.py +160 -56
- janito/cli/chat_mode/session_profile_select.py +8 -2
- janito/cli/chat_mode/shell/commands/execute.py +4 -2
- janito/cli/chat_mode/shell/commands/help.py +2 -0
- janito/cli/chat_mode/shell/commands/privileges.py +6 -2
- janito/cli/chat_mode/shell/commands/provider.py +7 -4
- janito/cli/chat_mode/shell/commands/read.py +4 -2
- janito/cli/chat_mode/shell/commands/security/__init__.py +1 -1
- janito/cli/chat_mode/shell/commands/security/allowed_sites.py +16 -13
- janito/cli/chat_mode/shell/commands/security_command.py +14 -10
- janito/cli/chat_mode/shell/commands/tools.py +4 -2
- janito/cli/chat_mode/shell/commands/unrestricted.py +17 -12
- janito/cli/chat_mode/shell/commands/write.py +4 -2
- janito/cli/chat_mode/toolbar.py +4 -4
- janito/cli/cli_commands/enable_disable_plugin.py +48 -25
- janito/cli/cli_commands/list_models.py +2 -2
- janito/cli/cli_commands/list_plugins.py +18 -18
- janito/cli/cli_commands/list_profiles.py +6 -6
- janito/cli/cli_commands/list_providers.py +1 -1
- janito/cli/cli_commands/model_utils.py +45 -20
- janito/cli/cli_commands/ping_providers.py +10 -10
- janito/cli/cli_commands/set_api_key.py +5 -3
- janito/cli/cli_commands/show_config.py +13 -7
- janito/cli/cli_commands/show_system_prompt.py +13 -6
- janito/cli/core/getters.py +1 -0
- janito/cli/core/model_guesser.py +18 -15
- janito/cli/core/runner.py +15 -7
- janito/cli/core/setters.py +9 -6
- janito/cli/main_cli.py +15 -12
- janito/cli/prompt_setup.py +4 -4
- janito/cli/rich_terminal_reporter.py +2 -1
- janito/config_manager.py +2 -0
- janito/docs/GETTING_STARTED.md +9 -9
- janito/drivers/cerebras/__init__.py +1 -1
- janito/exceptions.py +6 -4
- janito/plugins/__init__.py +2 -2
- janito/plugins/base.py +48 -40
- janito/plugins/builtin.py +13 -9
- janito/plugins/config.py +16 -19
- janito/plugins/discovery.py +73 -66
- janito/plugins/manager.py +62 -60
- janito/provider_registry.py +10 -10
- janito/providers/__init__.py +1 -1
- janito/providers/alibaba/model_info.py +3 -5
- janito/providers/alibaba/provider.py +3 -1
- janito/providers/cerebras/__init__.py +1 -1
- janito/providers/cerebras/model_info.py +12 -27
- janito/providers/cerebras/provider.py +11 -9
- janito/providers/mistral/__init__.py +1 -1
- janito/providers/mistral/model_info.py +1 -1
- janito/providers/mistral/provider.py +1 -1
- janito/providers/moonshot/__init__.py +1 -0
- janito/providers/{moonshotai → moonshot}/model_info.py +3 -3
- janito/providers/{moonshotai → moonshot}/provider.py +8 -8
- janito/providers/openai/provider.py +3 -1
- janito/report_events.py +0 -1
- janito/tools/adapters/local/create_file.py +1 -1
- janito/tools/adapters/local/fetch_url.py +45 -29
- janito/tools/adapters/local/python_command_run.py +2 -1
- janito/tools/adapters/local/python_file_run.py +1 -0
- janito/tools/adapters/local/run_powershell_command.py +1 -1
- janito/tools/adapters/local/validate_file_syntax/jinja2_validator.py +14 -11
- janito/tools/base.py +4 -3
- janito/tools/loop_protection.py +24 -22
- janito/tools/path_utils.py +7 -7
- janito/tools/tool_base.py +0 -2
- janito/tools/tools_adapter.py +15 -5
- janito/tools/url_whitelist.py +27 -26
- {janito-2.27.1.dist-info → janito-2.29.0.dist-info}/METADATA +1 -1
- {janito-2.27.1.dist-info → janito-2.29.0.dist-info}/RECORD +77 -77
- janito/providers/moonshotai/__init__.py +0 -1
- {janito-2.27.1.dist-info → janito-2.29.0.dist-info}/WHEEL +0 -0
- {janito-2.27.1.dist-info → janito-2.29.0.dist-info}/entry_points.txt +0 -0
- {janito-2.27.1.dist-info → janito-2.29.0.dist-info}/licenses/LICENSE +0 -0
- {janito-2.27.1.dist-info → janito-2.29.0.dist-info}/top_level.txt +0 -0
@@ -68,19 +68,21 @@ class FetchUrlTool(ToolBase):
|
|
68
68
|
{}
|
69
69
|
) # In-memory session cache - lifetime matches tool instance
|
70
70
|
self._load_cache()
|
71
|
-
|
71
|
+
|
72
72
|
# Browser-like session with cookies and headers
|
73
73
|
self.session = requests.Session()
|
74
|
-
self.session.headers.update(
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
74
|
+
self.session.headers.update(
|
75
|
+
{
|
76
|
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
77
|
+
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
|
78
|
+
"Accept-Language": "en-US,en;q=0.5",
|
79
|
+
"Accept-Encoding": "gzip, deflate, br",
|
80
|
+
"DNT": "1",
|
81
|
+
"Connection": "keep-alive",
|
82
|
+
"Upgrade-Insecure-Requests": "1",
|
83
|
+
}
|
84
|
+
)
|
85
|
+
|
84
86
|
# Load cookies from disk if they exist
|
85
87
|
self.cookies_file = self.cache_dir / "cookies.json"
|
86
88
|
self._load_cookies()
|
@@ -120,12 +122,14 @@ class FetchUrlTool(ToolBase):
|
|
120
122
|
try:
|
121
123
|
cookies_data = []
|
122
124
|
for cookie in self.session.cookies:
|
123
|
-
cookies_data.append(
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
125
|
+
cookies_data.append(
|
126
|
+
{
|
127
|
+
"name": cookie.name,
|
128
|
+
"value": cookie.value,
|
129
|
+
"domain": cookie.domain,
|
130
|
+
"path": cookie.path,
|
131
|
+
}
|
132
|
+
)
|
129
133
|
with open(self.cookies_file, "w", encoding="utf-8") as f:
|
130
134
|
json.dump(cookies_data, f, indent=2)
|
131
135
|
except IOError:
|
@@ -170,8 +174,14 @@ class FetchUrlTool(ToolBase):
|
|
170
174
|
}
|
171
175
|
self._save_cache()
|
172
176
|
|
173
|
-
def _fetch_url_content(
|
174
|
-
|
177
|
+
def _fetch_url_content(
|
178
|
+
self,
|
179
|
+
url: str,
|
180
|
+
timeout: int = 10,
|
181
|
+
headers: Optional[Dict[str, str]] = None,
|
182
|
+
cookies: Optional[Dict[str, str]] = None,
|
183
|
+
follow_redirects: bool = True,
|
184
|
+
) -> str:
|
175
185
|
"""Fetch URL content and handle HTTP errors.
|
176
186
|
|
177
187
|
Implements two-tier caching:
|
@@ -224,23 +234,23 @@ class FetchUrlTool(ToolBase):
|
|
224
234
|
request_headers = self.session.headers.copy()
|
225
235
|
if headers:
|
226
236
|
request_headers.update(headers)
|
227
|
-
|
237
|
+
|
228
238
|
# Merge custom cookies
|
229
239
|
if cookies:
|
230
240
|
self.session.cookies.update(cookies)
|
231
241
|
|
232
242
|
response = self.session.get(
|
233
|
-
url,
|
234
|
-
timeout=timeout,
|
243
|
+
url,
|
244
|
+
timeout=timeout,
|
235
245
|
headers=request_headers,
|
236
|
-
allow_redirects=follow_redirects
|
246
|
+
allow_redirects=follow_redirects,
|
237
247
|
)
|
238
248
|
response.raise_for_status()
|
239
249
|
content = response.text
|
240
|
-
|
250
|
+
|
241
251
|
# Save cookies after successful request
|
242
252
|
self._save_cookies()
|
243
|
-
|
253
|
+
|
244
254
|
# Cache successful responses in session cache
|
245
255
|
self.session_cache[url] = content
|
246
256
|
return content
|
@@ -354,8 +364,11 @@ class FetchUrlTool(ToolBase):
|
|
354
364
|
# Check if we should save to file
|
355
365
|
if save_to_file:
|
356
366
|
html_content = self._fetch_url_content(
|
357
|
-
url,
|
358
|
-
|
367
|
+
url,
|
368
|
+
timeout=timeout,
|
369
|
+
headers=headers,
|
370
|
+
cookies=cookies,
|
371
|
+
follow_redirects=follow_redirects,
|
359
372
|
)
|
360
373
|
if html_content.startswith("Warning:"):
|
361
374
|
return html_content
|
@@ -380,8 +393,11 @@ class FetchUrlTool(ToolBase):
|
|
380
393
|
|
381
394
|
# Normal processing path
|
382
395
|
html_content = self._fetch_url_content(
|
383
|
-
url,
|
384
|
-
|
396
|
+
url,
|
397
|
+
timeout=timeout,
|
398
|
+
headers=headers,
|
399
|
+
cookies=cookies,
|
400
|
+
follow_redirects=follow_redirects,
|
385
401
|
)
|
386
402
|
if html_content.startswith("Warning:"):
|
387
403
|
return html_content
|
@@ -32,7 +32,8 @@ class PythonCommandRunTool(ToolBase):
|
|
32
32
|
return tr("Warning: Empty code provided. Operation skipped.")
|
33
33
|
if not silent:
|
34
34
|
self.report_action(
|
35
|
-
tr("🐍 Running: python -c ...\n{code}\n", code=code),
|
35
|
+
tr("🐍 Running: python -c ...\n{code}\n", code=code),
|
36
|
+
ReportAction.EXECUTE,
|
36
37
|
)
|
37
38
|
self.report_stdout("\n")
|
38
39
|
else:
|
@@ -43,7 +43,7 @@ class RunPowershellCommandTool(ToolBase):
|
|
43
43
|
if require_confirmation:
|
44
44
|
self.report_warning(
|
45
45
|
tr("⚠️ Confirmation requested, but no handler (auto-confirmed)."),
|
46
|
-
ReportAction.EXECUTE
|
46
|
+
ReportAction.EXECUTE,
|
47
47
|
)
|
48
48
|
return True # Auto-confirm for now
|
49
49
|
return True
|
@@ -8,40 +8,43 @@ def validate_jinja2(path: str) -> str:
|
|
8
8
|
"""Validate Jinja2 template syntax."""
|
9
9
|
try:
|
10
10
|
from jinja2 import Environment, TemplateSyntaxError
|
11
|
-
|
11
|
+
|
12
12
|
with open(path, "r", encoding="utf-8") as f:
|
13
13
|
content = f.read()
|
14
|
-
|
14
|
+
|
15
15
|
# Create a Jinja2 environment and try to parse the template
|
16
16
|
env = Environment()
|
17
17
|
try:
|
18
18
|
env.parse(content)
|
19
19
|
return tr("✅ Syntax OK")
|
20
20
|
except TemplateSyntaxError as e:
|
21
|
-
line_num = getattr(e,
|
22
|
-
return tr(
|
23
|
-
|
21
|
+
line_num = getattr(e, "lineno", 0)
|
22
|
+
return tr(
|
23
|
+
"⚠️ Warning: Syntax error: {error} at line {line}",
|
24
|
+
error=str(e),
|
25
|
+
line=line_num,
|
26
|
+
)
|
24
27
|
except Exception as e:
|
25
28
|
return tr("⚠️ Warning: Syntax error: {error}", error=str(e))
|
26
|
-
|
29
|
+
|
27
30
|
except ImportError:
|
28
31
|
# If jinja2 is not available, just check basic structure
|
29
32
|
try:
|
30
33
|
with open(path, "r", encoding="utf-8") as f:
|
31
34
|
content = f.read()
|
32
|
-
|
35
|
+
|
33
36
|
# Basic checks for common Jinja2 syntax issues
|
34
37
|
open_tags = content.count("{%")
|
35
38
|
close_tags = content.count("%}")
|
36
39
|
open_vars = content.count("{{")
|
37
40
|
close_vars = content.count("}}")
|
38
|
-
|
41
|
+
|
39
42
|
if open_tags != close_tags:
|
40
43
|
return tr("⚠️ Warning: Syntax error: Mismatched Jinja2 tags")
|
41
44
|
if open_vars != close_vars:
|
42
45
|
return tr("⚠️ Warning: Syntax error: Mismatched Jinja2 variables")
|
43
|
-
|
46
|
+
|
44
47
|
return tr("✅ Syntax OK (basic validation)")
|
45
|
-
|
48
|
+
|
46
49
|
except Exception as e:
|
47
|
-
return tr("⚠️ Warning: Syntax error: {error}", error=str(e))
|
50
|
+
return tr("⚠️ Warning: Syntax error: {error}", error=str(e))
|
janito/tools/base.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
class BaseTool:
|
2
2
|
"""Base class for all tools."""
|
3
|
+
|
3
4
|
tool_name: str = ""
|
4
|
-
|
5
|
+
|
5
6
|
def __init__(self):
|
6
7
|
if not self.tool_name:
|
7
8
|
self.tool_name = self.__class__.__name__.lower()
|
8
|
-
|
9
|
+
|
9
10
|
def run(self, *args, **kwargs) -> str:
|
10
11
|
"""Execute the tool."""
|
11
|
-
raise NotImplementedError
|
12
|
+
raise NotImplementedError
|
janito/tools/loop_protection.py
CHANGED
@@ -8,19 +8,19 @@ class LoopProtection:
|
|
8
8
|
"""
|
9
9
|
Provides loop protection for tool calls by tracking repeated operations
|
10
10
|
on the same resources within a short time period.
|
11
|
-
|
11
|
+
|
12
12
|
This class monitors file operations and prevents excessive reads on the same
|
13
13
|
file within a configurable time window. It helps prevent infinite loops or
|
14
14
|
excessive resource consumption when tools repeatedly access the same files.
|
15
|
-
|
15
|
+
|
16
16
|
The default configuration allows up to 5 operations on the same file within
|
17
17
|
a 10-second window. Operations outside this window are automatically cleaned
|
18
18
|
up to prevent memory accumulation.
|
19
19
|
"""
|
20
|
-
|
20
|
+
|
21
21
|
_instance = None
|
22
22
|
_lock = threading.Lock()
|
23
|
-
|
23
|
+
|
24
24
|
def __new__(cls):
|
25
25
|
if not cls._instance:
|
26
26
|
with cls._lock:
|
@@ -28,7 +28,7 @@ class LoopProtection:
|
|
28
28
|
cls._instance = super().__new__(cls)
|
29
29
|
cls._instance._init_protection()
|
30
30
|
return cls._instance
|
31
|
-
|
31
|
+
|
32
32
|
def _init_protection(self):
|
33
33
|
# Track file operations: {normalized_path: [(timestamp, operation_type), ...]}
|
34
34
|
self._file_operations: Dict[str, List[Tuple[float, str]]] = {}
|
@@ -36,7 +36,7 @@ class LoopProtection:
|
|
36
36
|
self._time_window = 10.0
|
37
37
|
# Maximum allowed operations on the same file within time window
|
38
38
|
self._max_operations = 5
|
39
|
-
|
39
|
+
|
40
40
|
"""
|
41
41
|
Configuration parameters:
|
42
42
|
|
@@ -46,23 +46,23 @@ class LoopProtection:
|
|
46
46
|
_max_operations: Maximum number of operations allowed on the same file
|
47
47
|
within the time window. Default is 5 operations.
|
48
48
|
"""
|
49
|
-
|
49
|
+
|
50
50
|
def check_file_operation_limit(self, path: str, operation_type: str) -> bool:
|
51
51
|
"""
|
52
52
|
Check if performing an operation on a file would exceed the limit.
|
53
|
-
|
53
|
+
|
54
54
|
This method tracks file operations and prevents excessive reads on the same
|
55
55
|
file within a configurable time window (default 10 seconds). It helps prevent
|
56
56
|
infinite loops or excessive resource consumption when tools repeatedly access
|
57
57
|
the same files.
|
58
|
-
|
58
|
+
|
59
59
|
Args:
|
60
60
|
path: The file path being operated on
|
61
61
|
operation_type: Type of operation (e.g., "view_file", "read_files")
|
62
|
-
|
62
|
+
|
63
63
|
Returns:
|
64
64
|
bool: True if operation is allowed, False if it would exceed the limit
|
65
|
-
|
65
|
+
|
66
66
|
Example:
|
67
67
|
>>> loop_protection = LoopProtection.instance()
|
68
68
|
>>> if loop_protection.check_file_operation_limit("/path/to/file.txt", "view_file"):
|
@@ -74,42 +74,44 @@ class LoopProtection:
|
|
74
74
|
"""
|
75
75
|
norm_path = normalize_path(path)
|
76
76
|
current_time = time.time()
|
77
|
-
|
77
|
+
|
78
78
|
# Clean up old operations outside the time window
|
79
79
|
if norm_path in self._file_operations:
|
80
80
|
self._file_operations[norm_path] = [
|
81
|
-
(timestamp, op_type)
|
81
|
+
(timestamp, op_type)
|
82
82
|
for timestamp, op_type in self._file_operations[norm_path]
|
83
83
|
if current_time - timestamp <= self._time_window
|
84
84
|
]
|
85
|
-
|
85
|
+
|
86
86
|
# Check if we're exceeding the limit
|
87
87
|
if norm_path in self._file_operations:
|
88
88
|
operations = self._file_operations[norm_path]
|
89
89
|
if len(operations) >= self._max_operations:
|
90
90
|
# Check if all recent operations are within the time window
|
91
|
-
if all(
|
92
|
-
|
91
|
+
if all(
|
92
|
+
current_time - timestamp <= self._time_window
|
93
|
+
for timestamp, _ in operations
|
94
|
+
):
|
93
95
|
return False # Would exceed limit - potential loop
|
94
|
-
|
96
|
+
|
95
97
|
# Record this operation
|
96
98
|
if norm_path not in self._file_operations:
|
97
99
|
self._file_operations[norm_path] = []
|
98
100
|
self._file_operations[norm_path].append((current_time, operation_type))
|
99
|
-
|
101
|
+
|
100
102
|
return True # Operation allowed
|
101
|
-
|
103
|
+
|
102
104
|
def reset_tracking(self):
|
103
105
|
"""
|
104
106
|
Reset all tracking data.
|
105
|
-
|
107
|
+
|
106
108
|
This method clears all recorded file operations, effectively resetting
|
107
109
|
the loop protection state. This can be useful in testing scenarios or
|
108
110
|
when you want to explicitly clear the tracking history.
|
109
111
|
"""
|
110
112
|
with self._lock:
|
111
113
|
self._file_operations.clear()
|
112
|
-
|
114
|
+
|
113
115
|
@classmethod
|
114
116
|
def instance(cls):
|
115
|
-
return cls()
|
117
|
+
return cls()
|
janito/tools/path_utils.py
CHANGED
@@ -9,19 +9,19 @@ from pathlib import Path
|
|
9
9
|
def expand_path(path: str) -> str:
|
10
10
|
"""
|
11
11
|
Expand a path, handling tilde (~) expansion for user home directory.
|
12
|
-
|
12
|
+
|
13
13
|
Args:
|
14
14
|
path (str): The path to expand.
|
15
|
-
|
15
|
+
|
16
16
|
Returns:
|
17
17
|
str: The expanded absolute path.
|
18
18
|
"""
|
19
19
|
if not path:
|
20
20
|
return path
|
21
|
-
|
21
|
+
|
22
22
|
# Handle tilde expansion
|
23
23
|
expanded = os.path.expanduser(path)
|
24
|
-
|
24
|
+
|
25
25
|
# Convert to absolute path
|
26
26
|
return os.path.abspath(expanded)
|
27
27
|
|
@@ -29,11 +29,11 @@ def expand_path(path: str) -> str:
|
|
29
29
|
def normalize_path(path: str) -> str:
|
30
30
|
"""
|
31
31
|
Normalize a path by expanding tilde and resolving any relative paths.
|
32
|
-
|
32
|
+
|
33
33
|
Args:
|
34
34
|
path (str): The path to normalize.
|
35
|
-
|
35
|
+
|
36
36
|
Returns:
|
37
37
|
str: The normalized absolute path.
|
38
38
|
"""
|
39
|
-
return expand_path(path)
|
39
|
+
return expand_path(path)
|
janito/tools/tool_base.py
CHANGED
janito/tools/tools_adapter.py
CHANGED
@@ -150,11 +150,17 @@ class ToolsAdapterBase:
|
|
150
150
|
unexpected = [k for k in arguments.keys() if k not in params]
|
151
151
|
if unexpected:
|
152
152
|
# Build detailed error message with received arguments
|
153
|
-
error_parts = [
|
154
|
-
|
153
|
+
error_parts = [
|
154
|
+
"Unexpected argument(s): " + ", ".join(sorted(unexpected))
|
155
|
+
]
|
156
|
+
error_parts.append(
|
157
|
+
"Valid parameters: " + ", ".join(sorted(params.keys()))
|
158
|
+
)
|
155
159
|
error_parts.append("Arguments received:")
|
156
160
|
for key, value in arguments.items():
|
157
|
-
error_parts.append(
|
161
|
+
error_parts.append(
|
162
|
+
f" {key}: {repr(value)} ({type(value).__name__})"
|
163
|
+
)
|
158
164
|
return "\n".join(error_parts)
|
159
165
|
|
160
166
|
# Check for missing required arguments (ignoring *args / **kwargs / self)
|
@@ -172,11 +178,15 @@ class ToolsAdapterBase:
|
|
172
178
|
missing = [name for name in required_params if name not in arguments]
|
173
179
|
if missing:
|
174
180
|
# Build detailed error message with received arguments
|
175
|
-
error_parts = [
|
181
|
+
error_parts = [
|
182
|
+
"Missing required argument(s): " + ", ".join(sorted(missing))
|
183
|
+
]
|
176
184
|
error_parts.append("Arguments received:")
|
177
185
|
if isinstance(arguments, dict):
|
178
186
|
for key, value in arguments.items():
|
179
|
-
error_parts.append(
|
187
|
+
error_parts.append(
|
188
|
+
f" {key}: {repr(value)} ({type(value).__name__})"
|
189
|
+
)
|
180
190
|
elif arguments is not None:
|
181
191
|
error_parts.append(f" {repr(arguments)} ({type(arguments).__name__})")
|
182
192
|
else:
|
janito/tools/url_whitelist.py
CHANGED
@@ -8,102 +8,102 @@ from urllib.parse import urlparse
|
|
8
8
|
|
9
9
|
class UrlWhitelistManager:
|
10
10
|
"""Manages allowed sites for the fetch_url tool."""
|
11
|
-
|
11
|
+
|
12
12
|
def __init__(self):
|
13
13
|
self.config_path = Path.home() / ".janito" / "url_whitelist.json"
|
14
14
|
self._allowed_sites = self._load_whitelist()
|
15
15
|
self._unrestricted_mode = False
|
16
|
-
|
16
|
+
|
17
17
|
def set_unrestricted_mode(self, enabled: bool = True):
|
18
18
|
"""Enable or disable unrestricted mode (bypasses whitelist)."""
|
19
19
|
self._unrestricted_mode = enabled
|
20
|
-
|
20
|
+
|
21
21
|
def _load_whitelist(self) -> Set[str]:
|
22
22
|
"""Load the whitelist from config file."""
|
23
23
|
if not self.config_path.exists():
|
24
24
|
return set()
|
25
|
-
|
25
|
+
|
26
26
|
try:
|
27
|
-
with open(self.config_path,
|
27
|
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
28
28
|
data = json.load(f)
|
29
|
-
return set(data.get(
|
29
|
+
return set(data.get("allowed_sites", []))
|
30
30
|
except (json.JSONDecodeError, IOError):
|
31
31
|
return set()
|
32
|
-
|
32
|
+
|
33
33
|
def _save_whitelist(self):
|
34
34
|
"""Save the whitelist to config file."""
|
35
35
|
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
36
36
|
try:
|
37
|
-
with open(self.config_path,
|
38
|
-
json.dump({
|
37
|
+
with open(self.config_path, "w", encoding="utf-8") as f:
|
38
|
+
json.dump({"allowed_sites": list(self._allowed_sites)}, f, indent=2)
|
39
39
|
except IOError:
|
40
40
|
pass # Silently fail if we can't write
|
41
|
-
|
41
|
+
|
42
42
|
def is_url_allowed(self, url: str) -> bool:
|
43
43
|
"""Check if a URL is allowed based on the whitelist."""
|
44
44
|
if self._unrestricted_mode:
|
45
45
|
return True # Unrestricted mode bypasses all whitelist checks
|
46
|
-
|
46
|
+
|
47
47
|
if not self._allowed_sites:
|
48
48
|
return True # No whitelist means all sites allowed
|
49
|
-
|
49
|
+
|
50
50
|
try:
|
51
51
|
parsed = urlparse(url)
|
52
52
|
domain = parsed.netloc.lower()
|
53
|
-
|
53
|
+
|
54
54
|
# Check exact matches and subdomain matches
|
55
55
|
for allowed in self._allowed_sites:
|
56
56
|
allowed = allowed.lower()
|
57
|
-
if domain == allowed or domain.endswith(
|
57
|
+
if domain == allowed or domain.endswith("." + allowed):
|
58
58
|
return True
|
59
|
-
|
59
|
+
|
60
60
|
return False
|
61
61
|
except Exception:
|
62
62
|
return False # Invalid URLs are blocked
|
63
|
-
|
63
|
+
|
64
64
|
def add_allowed_site(self, site: str) -> bool:
|
65
65
|
"""Add a site to the whitelist."""
|
66
66
|
# Clean up the site format
|
67
67
|
site = site.strip().lower()
|
68
|
-
if site.startswith(
|
68
|
+
if site.startswith("http://") or site.startswith("https://"):
|
69
69
|
parsed = urlparse(site)
|
70
70
|
site = parsed.netloc
|
71
|
-
|
71
|
+
|
72
72
|
if site and site not in self._allowed_sites:
|
73
73
|
self._allowed_sites.add(site)
|
74
74
|
self._save_whitelist()
|
75
75
|
return True
|
76
76
|
return False
|
77
|
-
|
77
|
+
|
78
78
|
def remove_allowed_site(self, site: str) -> bool:
|
79
79
|
"""Remove a site from the whitelist."""
|
80
80
|
site = site.strip().lower()
|
81
|
-
if site.startswith(
|
81
|
+
if site.startswith("http://") or site.startswith("https://"):
|
82
82
|
parsed = urlparse(site)
|
83
83
|
site = parsed.netloc
|
84
|
-
|
84
|
+
|
85
85
|
if site in self._allowed_sites:
|
86
86
|
self._allowed_sites.remove(site)
|
87
87
|
self._save_whitelist()
|
88
88
|
return True
|
89
89
|
return False
|
90
|
-
|
90
|
+
|
91
91
|
def get_allowed_sites(self) -> List[str]:
|
92
92
|
"""Get the list of allowed sites."""
|
93
93
|
return sorted(self._allowed_sites)
|
94
|
-
|
94
|
+
|
95
95
|
def set_allowed_sites(self, sites: List[str]):
|
96
96
|
"""Set the complete list of allowed sites."""
|
97
97
|
self._allowed_sites = set()
|
98
98
|
for site in sites:
|
99
99
|
site = site.strip().lower()
|
100
|
-
if site.startswith(
|
100
|
+
if site.startswith("http://") or site.startswith("https://"):
|
101
101
|
parsed = urlparse(site)
|
102
102
|
site = parsed.netloc
|
103
103
|
if site:
|
104
104
|
self._allowed_sites.add(site)
|
105
105
|
self._save_whitelist()
|
106
|
-
|
106
|
+
|
107
107
|
def clear_whitelist(self):
|
108
108
|
"""Clear all allowed sites."""
|
109
109
|
self._allowed_sites.clear()
|
@@ -113,9 +113,10 @@ class UrlWhitelistManager:
|
|
113
113
|
# Global singleton
|
114
114
|
_url_whitelist_manager = None
|
115
115
|
|
116
|
+
|
116
117
|
def get_url_whitelist_manager() -> UrlWhitelistManager:
|
117
118
|
"""Get the global URL whitelist manager instance."""
|
118
119
|
global _url_whitelist_manager
|
119
120
|
if _url_whitelist_manager is None:
|
120
121
|
_url_whitelist_manager = UrlWhitelistManager()
|
121
|
-
return _url_whitelist_manager
|
122
|
+
return _url_whitelist_manager
|