ai-agent-browser 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agent_browser/utils.py ADDED
@@ -0,0 +1,365 @@
1
+ """
2
+ Shared utilities for agent-browser.
3
+
4
+ This module contains utility functions used across the package, including:
5
+ - Screenshot resizing
6
+ - State file management
7
+ - Process management
8
+ - Logging utilities
9
+ """
10
+
11
+ import json
12
+ import os
13
+ import sys
14
+ import tempfile
15
+ import time
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+ from typing import Any, Dict, List, Optional, Union
19
+
20
+ # =============================================================================
21
+ # CONFIGURATION DEFAULTS
22
+ # =============================================================================
23
+
24
+ MAX_SCREENSHOT_HEIGHT = 1500
25
+ MAX_SCREENSHOT_WIDTH = 1500
26
+ DEFAULT_TIMEOUT = 5000
27
+ WAIT_FOR_TIMEOUT = 10000
28
+ IPC_TIMEOUT = 10
29
+
30
+ # =============================================================================
31
+ # GENERAL UTILITIES
32
+ # =============================================================================
33
+
34
+
35
+ def sanitize_filename(name: str) -> str:
36
+ """Return a safe filename by replacing path separators."""
37
+ return name.replace("/", "_").replace("\\", "_")
38
+
39
+
40
+ # =============================================================================
41
+ class PathTraversalError(Exception):
42
+ """Raised when a path escapes the allowed sandbox directory."""
43
+ pass
44
+
45
+
46
+ def validate_path_in_sandbox(path: Path, sandbox: Path) -> Path:
47
+ """
48
+ Validate that a path is within the sandbox directory.
49
+
50
+ Args:
51
+ path: The path to validate (can be relative or absolute)
52
+ sandbox: The allowed root directory (typically CWD or output_dir)
53
+
54
+ Returns:
55
+ The resolved absolute path if valid
56
+
57
+ Raises:
58
+ PathTraversalError: If the path escapes the sandbox
59
+ """
60
+ resolved_path = path.resolve()
61
+ resolved_sandbox = sandbox.resolve()
62
+
63
+ try:
64
+ resolved_path.relative_to(resolved_sandbox)
65
+ return resolved_path
66
+ except ValueError:
67
+ raise PathTraversalError(
68
+ f"Path '{path}' escapes sandbox directory '{sandbox}'. "
69
+ f"Resolved to '{resolved_path}' which is outside '{resolved_sandbox}'."
70
+ )
71
+
72
+
73
+ def validate_path(path: Union[str, Path], root: Path = None) -> Path:
74
+ """
75
+ Resolve a path and ensure it stays within the sandbox root.
76
+
77
+ Args:
78
+ path: Path string or Path object to validate
79
+ root: Sandbox root directory (defaults to current working directory)
80
+
81
+ Returns:
82
+ The resolved absolute path within the sandbox
83
+
84
+ Raises:
85
+ PathTraversalError: If the path escapes the sandbox root
86
+ """
87
+ if root is None:
88
+ root = Path.cwd()
89
+ resolved_path = Path(path).resolve()
90
+ return validate_path_in_sandbox(resolved_path, root)
91
+
92
+
93
+ def validate_output_dir(output_dir: Path, cwd: Path = None) -> Path:
94
+ """
95
+ Validate that output_dir is within the current working directory.
96
+
97
+ Args:
98
+ output_dir: The output directory path
99
+ cwd: The allowed root directory (defaults to Path.cwd())
100
+
101
+ Returns:
102
+ The validated output directory path
103
+
104
+ Raises:
105
+ PathTraversalError: If output_dir escapes cwd
106
+ """
107
+ if cwd is None:
108
+ cwd = Path.cwd()
109
+ return validate_path_in_sandbox(output_dir, cwd)
110
+
111
+
112
+ # =============================================================================
113
+ # FILE PATH HELPERS
114
+ # =============================================================================
115
+
116
+
117
+ def get_temp_file_path(session_id: str, suffix: str) -> Path:
118
+ return Path(tempfile.gettempdir()) / f"agent_browser_{session_id}_{suffix}"
119
+
120
+
121
+ def get_state_file(session_id: str) -> Path:
122
+ return get_temp_file_path(session_id, "state.json")
123
+
124
+
125
+ def get_command_file(session_id: str) -> Path:
126
+ return get_temp_file_path(session_id, "cmd.json")
127
+
128
+
129
+ def get_result_file(session_id: str) -> Path:
130
+ return get_temp_file_path(session_id, "result.json")
131
+
132
+
133
+ def get_console_log_file(session_id: str) -> Path:
134
+ return get_temp_file_path(session_id, "console.json")
135
+
136
+
137
+ def get_network_log_file(session_id: str) -> Path:
138
+ return get_temp_file_path(session_id, "network.json")
139
+
140
+
141
+ def get_pid_file(session_id: str) -> Path:
142
+ return get_temp_file_path(session_id, "pid.txt")
143
+
144
+
145
+ # =============================================================================
146
+ # FILE IO HELPERS
147
+ # =============================================================================
148
+
149
+
150
+ def atomic_write_text(path: Path, content: str) -> None:
151
+ """
152
+ Atomically write text content to a file by writing to a temp file first.
153
+
154
+ This ensures readers never observe a partially-written file. The temporary
155
+ file is placed in the same directory to keep the replace operation atomic
156
+ on the same filesystem.
157
+ """
158
+ path.parent.mkdir(parents=True, exist_ok=True)
159
+ tmp_path: Optional[Path] = None
160
+ try:
161
+ with tempfile.NamedTemporaryFile(
162
+ "w", encoding="utf-8", dir=path.parent, delete=False
163
+ ) as tmp:
164
+ tmp_path = Path(tmp.name)
165
+ tmp.write(content)
166
+ tmp.flush()
167
+ os.fsync(tmp.fileno())
168
+
169
+ # On Windows, os.replace can fail if the destination is recently closed
170
+ # or being indexed/antivirus-scanned.
171
+ max_retries = 5
172
+ for i in range(max_retries):
173
+ try:
174
+ os.replace(tmp_path, path)
175
+ break
176
+ except PermissionError:
177
+ if i == max_retries - 1:
178
+ raise
179
+ time.sleep(0.05)
180
+ finally:
181
+ if tmp_path and tmp_path.exists():
182
+ try:
183
+ tmp_path.unlink()
184
+ except OSError:
185
+ pass
186
+
187
+
188
+ # =============================================================================
189
+ # SCREENSHOT UTILITIES
190
+ # =============================================================================
191
+
192
+
193
+ def resize_screenshot_if_needed(filepath: Path) -> str:
194
+ try:
195
+ from PIL import Image
196
+ img = Image.open(filepath)
197
+ width, height = img.size
198
+ if height <= MAX_SCREENSHOT_HEIGHT and width <= MAX_SCREENSHOT_WIDTH:
199
+ return f"{width}x{height} (ok)"
200
+ ratio = min(MAX_SCREENSHOT_WIDTH / width, MAX_SCREENSHOT_HEIGHT / height)
201
+ new_width = int(width * ratio)
202
+ new_height = int(height * ratio)
203
+ img_resized = img.resize((new_width, new_height), Image.LANCZOS)
204
+ img_resized.save(filepath, optimize=True)
205
+ return f"{width}x{height} -> {new_width}x{new_height} (resized)"
206
+ except ImportError:
207
+ return "not resized (PIL not installed)"
208
+ except Exception as e:
209
+ return f"resize error: {e}"
210
+
211
+
212
+ # =============================================================================
213
+ # STATE MANAGEMENT
214
+ # =============================================================================
215
+
216
+
217
+ def get_state(session_id: str) -> Dict[str, Any]:
218
+ state_file = get_state_file(session_id)
219
+ if state_file.exists():
220
+ try:
221
+ return json.loads(state_file.read_text())
222
+ except json.JSONDecodeError:
223
+ return {}
224
+ return {}
225
+
226
+
227
+ def save_state(session_id: str, state: Dict[str, Any]) -> None:
228
+ state_file = get_state_file(session_id)
229
+ atomic_write_text(state_file, json.dumps(state, indent=2))
230
+
231
+
232
+ def clear_state(session_id: str) -> None:
233
+ files = [
234
+ get_state_file(session_id),
235
+ get_command_file(session_id),
236
+ get_result_file(session_id),
237
+ get_console_log_file(session_id),
238
+ get_network_log_file(session_id),
239
+ get_pid_file(session_id),
240
+ ]
241
+ for f in files:
242
+ if f.exists():
243
+ try:
244
+ f.unlink()
245
+ except OSError:
246
+ pass
247
+
248
+
249
+ # =============================================================================
250
+ # PROCESS MANAGEMENT
251
+ # =============================================================================
252
+
253
+
254
+ def is_process_running(pid: int) -> bool:
255
+ if sys.platform == "win32":
256
+ try:
257
+ import ctypes
258
+ kernel32 = ctypes.windll.kernel32
259
+ handle = kernel32.OpenProcess(0x1000, False, pid)
260
+ if handle:
261
+ kernel32.CloseHandle(handle)
262
+ return True
263
+ return False
264
+ except Exception:
265
+ return False
266
+ else:
267
+ try:
268
+ os.kill(pid, 0)
269
+ return True
270
+ except OSError:
271
+ return False
272
+
273
+
274
+ def get_browser_pid(session_id: str) -> Optional[int]:
275
+ pid_file = get_pid_file(session_id)
276
+ if pid_file.exists():
277
+ try:
278
+ return int(pid_file.read_text().strip())
279
+ except (ValueError, OSError):
280
+ return None
281
+ return None
282
+
283
+
284
+ def save_browser_pid(session_id: str) -> None:
285
+ pid_file = get_pid_file(session_id)
286
+ atomic_write_text(pid_file, str(os.getpid()))
287
+
288
+
289
+ # =============================================================================
290
+ # LOGGING UTILITIES
291
+ # =============================================================================
292
+
293
+
294
+ def get_console_logs(session_id: str) -> List[Dict[str, Any]]:
295
+ log_file = get_console_log_file(session_id)
296
+ if log_file.exists():
297
+ try:
298
+ return json.loads(log_file.read_text())
299
+ except json.JSONDecodeError:
300
+ return []
301
+ return []
302
+
303
+
304
+ def save_console_log(session_id: str, entry: Dict[str, Any]) -> None:
305
+ logs = get_console_logs(session_id)
306
+ logs.append(entry)
307
+ if len(logs) > 100:
308
+ logs = logs[-100:]
309
+ log_file = get_console_log_file(session_id)
310
+ atomic_write_text(log_file, json.dumps(logs, indent=2))
311
+
312
+
313
+ def get_network_logs(session_id: str) -> Dict[str, Dict[str, Any]]:
314
+ log_file = get_network_log_file(session_id)
315
+ if log_file.exists():
316
+ try:
317
+ return json.loads(log_file.read_text())
318
+ except json.JSONDecodeError:
319
+ return {}
320
+ return {}
321
+
322
+
323
+ def save_network_logs(session_id: str, logs: Dict[str, Dict[str, Any]]) -> None:
324
+ if len(logs) > 100:
325
+ sorted_keys = sorted(logs.keys(), key=lambda k: logs[k].get("start_time", ""))
326
+ for key in sorted_keys[:-100]:
327
+ del logs[key]
328
+ log_file = get_network_log_file(session_id)
329
+ atomic_write_text(log_file, json.dumps(logs, indent=2))
330
+
331
+
332
+ def add_network_request(session_id: str, request_id: str, entry: Dict[str, Any]) -> None:
333
+ logs = get_network_logs(session_id)
334
+ if request_id in logs:
335
+ logs[request_id].update(entry)
336
+ else:
337
+ logs[request_id] = entry
338
+ save_network_logs(session_id, logs)
339
+
340
+
341
+ def clear_logs(session_id: str) -> None:
342
+ for f in [get_console_log_file(session_id), get_network_log_file(session_id)]:
343
+ if f.exists():
344
+ try:
345
+ f.unlink()
346
+ except OSError:
347
+ pass
348
+
349
+
350
+ # =============================================================================
351
+ # FORMATTING UTILITIES
352
+ # =============================================================================
353
+
354
+
355
+ def format_assertion_result(passed: bool, message: str) -> str:
356
+ status = "PASS" if passed else "FAIL"
357
+ return f"[{status}] {message}"
358
+
359
+
360
+ def configure_windows_console() -> None:
361
+ if sys.platform == "win32":
362
+ try:
363
+ sys.stdout.reconfigure(encoding="utf-8")
364
+ except AttributeError:
365
+ pass