cua-computer 0.1.29__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,27 +1,599 @@
1
- """Linux computer interface implementation."""
1
+ import asyncio
2
+ import json
3
+ import time
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+ from PIL import Image
2
6
 
3
- from typing import Dict
7
+ import websockets
8
+
9
+ from ..logger import Logger, LogLevel
4
10
  from .base import BaseComputerInterface
11
+ from ..utils import decode_base64_image, bytes_to_image, draw_box, resize_image
12
+ from .models import Key, KeyType
13
+
14
+
15
+ class LinuxComputerInterface(BaseComputerInterface):
16
+ """Interface for Linux."""
17
+
18
+ def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
19
+ super().__init__(ip_address, username, password)
20
+ self._ws = None
21
+ self._reconnect_task = None
22
+ self._closed = False
23
+ self._last_ping = 0
24
+ self._ping_interval = 5 # Send ping every 5 seconds
25
+ self._ping_timeout = 10 # Wait 10 seconds for pong response
26
+ self._reconnect_delay = 1 # Start with 1 second delay
27
+ self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
28
+ self._log_connection_attempts = True # Flag to control connection attempt logging
29
+
30
+ # Set logger name for Linux interface
31
+ self.logger = Logger("cua.interface.linux", LogLevel.NORMAL)
32
+
33
+ @property
34
+ def ws_uri(self) -> str:
35
+ """Get the WebSocket URI using the current IP address.
36
+
37
+ Returns:
38
+ WebSocket URI for the Computer API Server
39
+ """
40
+ return f"ws://{self.ip_address}:8000/ws"
41
+
42
+ async def _keep_alive(self):
43
+ """Keep the WebSocket connection alive with automatic reconnection."""
44
+ retry_count = 0
45
+ max_log_attempts = 1 # Only log the first attempt at INFO level
46
+ log_interval = 500 # Then log every 500th attempt (significantly increased from 30)
47
+ last_warning_time = 0
48
+ min_warning_interval = 30 # Minimum seconds between connection lost warnings
49
+ min_retry_delay = 0.5 # Minimum delay between connection attempts (500ms)
50
+
51
+ while not self._closed:
52
+ try:
53
+ if self._ws is None or (
54
+ self._ws and self._ws.state == websockets.protocol.State.CLOSED
55
+ ):
56
+ try:
57
+ retry_count += 1
58
+
59
+ # Add a minimum delay between connection attempts to avoid flooding
60
+ if retry_count > 1:
61
+ await asyncio.sleep(min_retry_delay)
62
+
63
+ # Only log the first attempt at INFO level, then every Nth attempt
64
+ if retry_count == 1:
65
+ self.logger.info(f"Attempting WebSocket connection to {self.ws_uri}")
66
+ elif retry_count % log_interval == 0:
67
+ self.logger.info(
68
+ f"Still attempting WebSocket connection (attempt {retry_count})..."
69
+ )
70
+ else:
71
+ # All other attempts are logged at DEBUG level
72
+ self.logger.debug(
73
+ f"Attempting WebSocket connection to {self.ws_uri} (attempt {retry_count})"
74
+ )
75
+
76
+ self._ws = await asyncio.wait_for(
77
+ websockets.connect(
78
+ self.ws_uri,
79
+ max_size=1024 * 1024 * 10, # 10MB limit
80
+ max_queue=32,
81
+ ping_interval=self._ping_interval,
82
+ ping_timeout=self._ping_timeout,
83
+ close_timeout=5,
84
+ compression=None, # Disable compression to reduce overhead
85
+ ),
86
+ timeout=30,
87
+ )
88
+ self.logger.info("WebSocket connection established")
89
+ self._reconnect_delay = 1 # Reset reconnect delay on successful connection
90
+ self._last_ping = time.time()
91
+ retry_count = 0 # Reset retry count on successful connection
92
+ except (asyncio.TimeoutError, websockets.exceptions.WebSocketException) as e:
93
+ next_retry = self._reconnect_delay
94
+
95
+ # Only log the first error at WARNING level, then every Nth attempt
96
+ if retry_count == 1:
97
+ self.logger.warning(
98
+ f"Computer API Server not ready yet. Will retry automatically."
99
+ )
100
+ elif retry_count % log_interval == 0:
101
+ self.logger.warning(
102
+ f"Still waiting for Computer API Server (attempt {retry_count})..."
103
+ )
104
+ else:
105
+ # All other errors are logged at DEBUG level
106
+ self.logger.debug(f"Connection attempt {retry_count} failed: {e}")
107
+
108
+ if self._ws:
109
+ try:
110
+ await self._ws.close()
111
+ except:
112
+ pass
113
+ self._ws = None
114
+
115
+ # Use exponential backoff for connection retries
116
+ await asyncio.sleep(self._reconnect_delay)
117
+ self._reconnect_delay = min(
118
+ self._reconnect_delay * 2, self._max_reconnect_delay
119
+ )
120
+ continue
121
+
122
+ # Regular ping to check connection
123
+ if self._ws and self._ws.state == websockets.protocol.State.OPEN:
124
+ try:
125
+ if time.time() - self._last_ping >= self._ping_interval:
126
+ pong_waiter = await self._ws.ping()
127
+ await asyncio.wait_for(pong_waiter, timeout=self._ping_timeout)
128
+ self._last_ping = time.time()
129
+ except Exception as e:
130
+ self.logger.debug(f"Ping failed: {e}")
131
+ if self._ws:
132
+ try:
133
+ await self._ws.close()
134
+ except:
135
+ pass
136
+ self._ws = None
137
+ continue
138
+
139
+ await asyncio.sleep(1)
140
+
141
+ except Exception as e:
142
+ current_time = time.time()
143
+ # Only log connection lost warnings at most once every min_warning_interval seconds
144
+ if current_time - last_warning_time >= min_warning_interval:
145
+ self.logger.warning(
146
+ f"Computer API Server connection lost. Will retry automatically."
147
+ )
148
+ last_warning_time = current_time
149
+ else:
150
+ # Log at debug level instead
151
+ self.logger.debug(f"Connection lost: {e}")
152
+
153
+ if self._ws:
154
+ try:
155
+ await self._ws.close()
156
+ except:
157
+ pass
158
+ self._ws = None
159
+
160
+ async def _ensure_connection(self):
161
+ """Ensure WebSocket connection is established."""
162
+ if self._reconnect_task is None or self._reconnect_task.done():
163
+ self._reconnect_task = asyncio.create_task(self._keep_alive())
164
+
165
+ retry_count = 0
166
+ max_retries = 5
167
+
168
+ while retry_count < max_retries:
169
+ try:
170
+ if self._ws and self._ws.state == websockets.protocol.State.OPEN:
171
+ return
172
+ retry_count += 1
173
+ await asyncio.sleep(1)
174
+ except Exception as e:
175
+ # Only log at ERROR level for the last retry attempt
176
+ if retry_count == max_retries - 1:
177
+ self.logger.error(
178
+ f"Persistent connection check error after {retry_count} attempts: {e}"
179
+ )
180
+ else:
181
+ self.logger.debug(f"Connection check error (attempt {retry_count}): {e}")
182
+ retry_count += 1
183
+ await asyncio.sleep(1)
184
+ continue
185
+
186
+ raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
187
+
188
+ async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
189
+ """Send command through WebSocket."""
190
+ max_retries = 3
191
+ retry_count = 0
192
+ last_error = None
193
+
194
+ while retry_count < max_retries:
195
+ try:
196
+ await self._ensure_connection()
197
+ if not self._ws:
198
+ raise ConnectionError("WebSocket connection is not established")
199
+
200
+ message = {"command": command, "params": params or {}}
201
+ await self._ws.send(json.dumps(message))
202
+ response = await asyncio.wait_for(self._ws.recv(), timeout=30)
203
+ return json.loads(response)
204
+ except Exception as e:
205
+ last_error = e
206
+ retry_count += 1
207
+ if retry_count < max_retries:
208
+ # Only log at debug level for intermediate retries
209
+ self.logger.debug(
210
+ f"Command '{command}' failed (attempt {retry_count}/{max_retries}): {e}"
211
+ )
212
+ await asyncio.sleep(1)
213
+ continue
214
+ else:
215
+ # Only log at error level for the final failure
216
+ self.logger.error(
217
+ f"Failed to send command '{command}' after {max_retries} retries"
218
+ )
219
+ self.logger.debug(f"Command failure details: {e}")
220
+ raise
221
+
222
+ raise last_error if last_error else RuntimeError("Failed to send command")
223
+
224
+ async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
225
+ """Wait for WebSocket connection to become available."""
226
+ start_time = time.time()
227
+ last_error = None
228
+ attempt_count = 0
229
+ progress_interval = 10 # Log progress every 10 seconds
230
+ last_progress_time = start_time
231
+
232
+ # Disable detailed logging for connection attempts
233
+ self._log_connection_attempts = False
234
+
235
+ try:
236
+ self.logger.info(
237
+ f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
238
+ )
239
+
240
+ # Start the keep-alive task if it's not already running
241
+ if self._reconnect_task is None or self._reconnect_task.done():
242
+ self._reconnect_task = asyncio.create_task(self._keep_alive())
243
+
244
+ # Wait for the connection to be established
245
+ while time.time() - start_time < timeout:
246
+ try:
247
+ attempt_count += 1
248
+ current_time = time.time()
249
+
250
+ # Log progress periodically without flooding logs
251
+ if current_time - last_progress_time >= progress_interval:
252
+ elapsed = current_time - start_time
253
+ self.logger.info(
254
+ f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
255
+ )
256
+ last_progress_time = current_time
257
+
258
+ # Check if we have a connection
259
+ if self._ws and self._ws.state == websockets.protocol.State.OPEN:
260
+ # Test the connection with a simple command
261
+ try:
262
+ await self._send_command("get_screen_size")
263
+ elapsed = time.time() - start_time
264
+ self.logger.info(
265
+ f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
266
+ )
267
+ return # Connection is fully working
268
+ except Exception as e:
269
+ last_error = e
270
+ self.logger.debug(f"Connection test failed: {e}")
271
+
272
+ # Wait before trying again
273
+ await asyncio.sleep(interval)
274
+
275
+ except Exception as e:
276
+ last_error = e
277
+ self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
278
+ await asyncio.sleep(interval)
279
+
280
+ # If we get here, we've timed out
281
+ error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
282
+ if last_error:
283
+ error_msg += f": {str(last_error)}"
284
+ self.logger.error(error_msg)
285
+ raise TimeoutError(error_msg)
286
+ finally:
287
+ # Reset to default logging behavior
288
+ self._log_connection_attempts = False
289
+
290
+ def close(self):
291
+ """Close WebSocket connection.
292
+
293
+ Note: In host computer server mode, we leave the connection open
294
+ to allow other clients to connect to the same server. The server
295
+ will handle cleaning up idle connections.
296
+ """
297
+ # Only cancel the reconnect task
298
+ if self._reconnect_task:
299
+ self._reconnect_task.cancel()
300
+
301
+ # Don't set closed flag or close websocket by default
302
+ # This allows the server to stay connected for other clients
303
+ # self._closed = True
304
+ # if self._ws:
305
+ # asyncio.create_task(self._ws.close())
306
+ # self._ws = None
307
+
308
+ def force_close(self):
309
+ """Force close the WebSocket connection.
310
+
311
+ This method should be called when you want to completely
312
+ shut down the connection, not just for regular cleanup.
313
+ """
314
+ self._closed = True
315
+ if self._reconnect_task:
316
+ self._reconnect_task.cancel()
317
+ if self._ws:
318
+ asyncio.create_task(self._ws.close())
319
+ self._ws = None
320
+
321
+ # Mouse Actions
322
+ async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
323
+ await self._send_command("left_click", {"x": x, "y": y})
324
+
325
+ async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
326
+ await self._send_command("right_click", {"x": x, "y": y})
327
+
328
+ async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
329
+ await self._send_command("double_click", {"x": x, "y": y})
330
+
331
+ async def move_cursor(self, x: int, y: int) -> None:
332
+ await self._send_command("move_cursor", {"x": x, "y": y})
333
+
334
+ async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> None:
335
+ await self._send_command(
336
+ "drag_to", {"x": x, "y": y, "button": button, "duration": duration}
337
+ )
338
+
339
+ async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> None:
340
+ await self._send_command(
341
+ "drag", {"path": path, "button": button, "duration": duration}
342
+ )
343
+
344
+ # Keyboard Actions
345
+ async def type_text(self, text: str) -> None:
346
+ # Temporary fix for https://github.com/trycua/cua/issues/165
347
+ # Check if text contains Unicode characters
348
+ if any(ord(char) > 127 for char in text):
349
+ # For Unicode text, use clipboard and paste
350
+ await self.set_clipboard(text)
351
+ await self.hotkey(Key.COMMAND, 'v')
352
+ else:
353
+ # For ASCII text, use the regular typing method
354
+ await self._send_command("type_text", {"text": text})
355
+
356
+ async def press(self, key: "KeyType") -> None:
357
+ """Press a single key.
358
+
359
+ Args:
360
+ key: The key to press. Can be any of:
361
+ - A Key enum value (recommended), e.g. Key.PAGE_DOWN
362
+ - A direct key value string, e.g. 'pagedown'
363
+ - A single character string, e.g. 'a'
364
+
365
+ Examples:
366
+ ```python
367
+ # Using enum (recommended)
368
+ await interface.press(Key.PAGE_DOWN)
369
+ await interface.press(Key.ENTER)
370
+
371
+ # Using direct values
372
+ await interface.press('pagedown')
373
+ await interface.press('enter')
374
+
375
+ # Using single characters
376
+ await interface.press('a')
377
+ ```
378
+
379
+ Raises:
380
+ ValueError: If the key type is invalid or the key is not recognized
381
+ """
382
+ if isinstance(key, Key):
383
+ actual_key = key.value
384
+ elif isinstance(key, str):
385
+ # Try to convert to enum if it matches a known key
386
+ key_or_enum = Key.from_string(key)
387
+ actual_key = key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
388
+ else:
389
+ raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
390
+
391
+ await self._send_command("press_key", {"key": actual_key})
392
+
393
+ async def press_key(self, key: "KeyType") -> None:
394
+ """DEPRECATED: Use press() instead.
395
+
396
+ This method is kept for backward compatibility but will be removed in a future version.
397
+ Please use the press() method instead.
398
+ """
399
+ await self.press(key)
400
+
401
+ async def hotkey(self, *keys: "KeyType") -> None:
402
+ """Press multiple keys simultaneously.
403
+
404
+ Args:
405
+ *keys: Multiple keys to press simultaneously. Each key can be any of:
406
+ - A Key enum value (recommended), e.g. Key.COMMAND
407
+ - A direct key value string, e.g. 'command'
408
+ - A single character string, e.g. 'a'
409
+
410
+ Examples:
411
+ ```python
412
+ # Using enums (recommended)
413
+ await interface.hotkey(Key.COMMAND, Key.C) # Copy
414
+ await interface.hotkey(Key.COMMAND, Key.V) # Paste
415
+
416
+ # Using mixed formats
417
+ await interface.hotkey(Key.COMMAND, 'a') # Select all
418
+ ```
419
+
420
+ Raises:
421
+ ValueError: If any key type is invalid or not recognized
422
+ """
423
+ actual_keys = []
424
+ for key in keys:
425
+ if isinstance(key, Key):
426
+ actual_keys.append(key.value)
427
+ elif isinstance(key, str):
428
+ # Try to convert to enum if it matches a known key
429
+ key_or_enum = Key.from_string(key)
430
+ actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
431
+ else:
432
+ raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
433
+
434
+ await self._send_command("hotkey", {"keys": actual_keys})
435
+
436
+ # Scrolling Actions
437
+ async def scroll_down(self, clicks: int = 1) -> None:
438
+ await self._send_command("scroll_down", {"clicks": clicks})
439
+
440
+ async def scroll_up(self, clicks: int = 1) -> None:
441
+ await self._send_command("scroll_up", {"clicks": clicks})
442
+
443
+ # Screen Actions
444
+ async def screenshot(
445
+ self,
446
+ boxes: Optional[List[Tuple[int, int, int, int]]] = None,
447
+ box_color: str = "#FF0000",
448
+ box_thickness: int = 2,
449
+ scale_factor: float = 1.0,
450
+ ) -> bytes:
451
+ """Take a screenshot with optional box drawing and scaling.
452
+
453
+ Args:
454
+ boxes: Optional list of (x, y, width, height) tuples defining boxes to draw in screen coordinates
455
+ box_color: Color of the boxes in hex format (default: "#FF0000" red)
456
+ box_thickness: Thickness of the box borders in pixels (default: 2)
457
+ scale_factor: Factor to scale the final image by (default: 1.0)
458
+ Use > 1.0 to enlarge, < 1.0 to shrink (e.g., 0.5 for half size, 2.0 for double)
459
+
460
+ Returns:
461
+ bytes: The screenshot image data, optionally with boxes drawn on it and scaled
462
+ """
463
+ result = await self._send_command("screenshot")
464
+ if not result.get("image_data"):
465
+ raise RuntimeError("Failed to take screenshot")
466
+
467
+ screenshot = decode_base64_image(result["image_data"])
468
+
469
+ if boxes:
470
+ # Get the natural scaling between screen and screenshot
471
+ screen_size = await self.get_screen_size()
472
+ screenshot_width, screenshot_height = bytes_to_image(screenshot).size
473
+ width_scale = screenshot_width / screen_size["width"]
474
+ height_scale = screenshot_height / screen_size["height"]
475
+
476
+ # Scale box coordinates from screen space to screenshot space
477
+ for box in boxes:
478
+ scaled_box = (
479
+ int(box[0] * width_scale), # x
480
+ int(box[1] * height_scale), # y
481
+ int(box[2] * width_scale), # width
482
+ int(box[3] * height_scale), # height
483
+ )
484
+ screenshot = draw_box(
485
+ screenshot,
486
+ x=scaled_box[0],
487
+ y=scaled_box[1],
488
+ width=scaled_box[2],
489
+ height=scaled_box[3],
490
+ color=box_color,
491
+ thickness=box_thickness,
492
+ )
493
+
494
+ if scale_factor != 1.0:
495
+ screenshot = resize_image(screenshot, scale_factor)
496
+
497
+ return screenshot
5
498
 
6
- class LinuxInterface(BaseComputerInterface):
7
- """Linux-specific computer interface."""
8
-
9
- async def wait_for_ready(self, timeout: int = 60) -> None:
10
- """Wait for interface to be ready."""
11
- # Placeholder implementation
12
- pass
13
-
14
- def close(self) -> None:
15
- """Close the interface connection."""
16
- # Placeholder implementation
17
- pass
18
-
19
499
  async def get_screen_size(self) -> Dict[str, int]:
20
- """Get the screen dimensions."""
21
- # Placeholder implementation
22
- return {"width": 1920, "height": 1080}
23
-
24
- async def screenshot(self) -> bytes:
25
- """Take a screenshot."""
26
- # Placeholder implementation
27
- return b""
500
+ result = await self._send_command("get_screen_size")
501
+ if result["success"] and result["size"]:
502
+ return result["size"]
503
+ raise RuntimeError("Failed to get screen size")
504
+
505
+ async def get_cursor_position(self) -> Dict[str, int]:
506
+ result = await self._send_command("get_cursor_position")
507
+ if result["success"] and result["position"]:
508
+ return result["position"]
509
+ raise RuntimeError("Failed to get cursor position")
510
+
511
+ # Clipboard Actions
512
+ async def copy_to_clipboard(self) -> str:
513
+ result = await self._send_command("copy_to_clipboard")
514
+ if result["success"] and result["content"]:
515
+ return result["content"]
516
+ raise RuntimeError("Failed to get clipboard content")
517
+
518
+ async def set_clipboard(self, text: str) -> None:
519
+ await self._send_command("set_clipboard", {"text": text})
520
+
521
+ # File System Actions
522
+ async def file_exists(self, path: str) -> bool:
523
+ result = await self._send_command("file_exists", {"path": path})
524
+ return result.get("exists", False)
525
+
526
+ async def directory_exists(self, path: str) -> bool:
527
+ result = await self._send_command("directory_exists", {"path": path})
528
+ return result.get("exists", False)
529
+
530
+ async def run_command(self, command: str) -> Tuple[str, str]:
531
+ result = await self._send_command("run_command", {"command": command})
532
+ if not result.get("success", False):
533
+ raise RuntimeError(result.get("error", "Failed to run command"))
534
+ return result.get("stdout", ""), result.get("stderr", "")
535
+
536
+ # Accessibility Actions
537
+ async def get_accessibility_tree(self) -> Dict[str, Any]:
538
+ """Get the accessibility tree of the current screen."""
539
+ result = await self._send_command("get_accessibility_tree")
540
+ if not result.get("success", False):
541
+ raise RuntimeError(result.get("error", "Failed to get accessibility tree"))
542
+ return result
543
+
544
+ async def get_active_window_bounds(self) -> Dict[str, int]:
545
+ """Get the bounds of the currently active window."""
546
+ result = await self._send_command("get_active_window_bounds")
547
+ if result["success"] and result["bounds"]:
548
+ return result["bounds"]
549
+ raise RuntimeError("Failed to get active window bounds")
550
+
551
+ async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
552
+ """Convert screenshot coordinates to screen coordinates.
553
+
554
+ Args:
555
+ x: X coordinate in screenshot space
556
+ y: Y coordinate in screenshot space
557
+
558
+ Returns:
559
+ tuple[float, float]: (x, y) coordinates in screen space
560
+ """
561
+ screen_size = await self.get_screen_size()
562
+ screenshot = await self.screenshot()
563
+ screenshot_img = bytes_to_image(screenshot)
564
+ screenshot_width, screenshot_height = screenshot_img.size
565
+
566
+ # Calculate scaling factors
567
+ width_scale = screen_size["width"] / screenshot_width
568
+ height_scale = screen_size["height"] / screenshot_height
569
+
570
+ # Convert coordinates
571
+ screen_x = x * width_scale
572
+ screen_y = y * height_scale
573
+
574
+ return screen_x, screen_y
575
+
576
+ async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
577
+ """Convert screen coordinates to screenshot coordinates.
578
+
579
+ Args:
580
+ x: X coordinate in screen space
581
+ y: Y coordinate in screen space
582
+
583
+ Returns:
584
+ tuple[float, float]: (x, y) coordinates in screenshot space
585
+ """
586
+ screen_size = await self.get_screen_size()
587
+ screenshot = await self.screenshot()
588
+ screenshot_img = bytes_to_image(screenshot)
589
+ screenshot_width, screenshot_height = screenshot_img.size
590
+
591
+ # Calculate scaling factors
592
+ width_scale = screenshot_width / screen_size["width"]
593
+ height_scale = screenshot_height / screen_size["height"]
594
+
595
+ # Convert coordinates
596
+ screenshot_x = x * width_scale
597
+ screenshot_y = y * height_scale
598
+
599
+ return screenshot_x, screenshot_y
@@ -17,7 +17,6 @@ class MacOSComputerInterface(BaseComputerInterface):
17
17
 
18
18
  def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
19
19
  super().__init__(ip_address, username, password)
20
- self.ws_uri = f"ws://{ip_address}:8000/ws"
21
20
  self._ws = None
22
21
  self._reconnect_task = None
23
22
  self._closed = False
@@ -31,6 +30,15 @@ class MacOSComputerInterface(BaseComputerInterface):
31
30
  # Set logger name for MacOS interface
32
31
  self.logger = Logger("cua.interface.macos", LogLevel.NORMAL)
33
32
 
33
+ @property
34
+ def ws_uri(self) -> str:
35
+ """Get the WebSocket URI using the current IP address.
36
+
37
+ Returns:
38
+ WebSocket URI for the Computer API Server
39
+ """
40
+ return f"ws://{self.ip_address}:8000/ws"
41
+
34
42
  async def _keep_alive(self):
35
43
  """Keep the WebSocket connection alive with automatic reconnection."""
36
44
  retry_count = 0
computer/models.py CHANGED
@@ -1,8 +1,10 @@
1
1
  """Models for computer configuration."""
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Optional
5
- from pylume import PyLume
4
+ from typing import Optional, Any, Dict
5
+
6
+ # Import base provider interface
7
+ from .providers.base import BaseVMProvider
6
8
 
7
9
  @dataclass
8
10
  class Display:
@@ -26,10 +28,20 @@ class Computer:
26
28
  display: Display
27
29
  memory: str
28
30
  cpu: str
29
- pylume: Optional[PyLume] = None
31
+ vm_provider: Optional[BaseVMProvider] = None
30
32
 
31
33
  # @property # Remove the property decorator
32
34
  async def get_ip(self) -> Optional[str]:
33
35
  """Get the IP address of the VM."""
34
- vm = await self.pylume.get_vm(self.name) # type: ignore[attr-defined]
35
- return vm.ip_address if vm else None
36
+ if not self.vm_provider:
37
+ return None
38
+
39
+ vm = await self.vm_provider.get_vm(self.name)
40
+ # Handle both object attribute and dictionary access for ip_address
41
+ if vm:
42
+ if isinstance(vm, dict):
43
+ return vm.get("ip_address")
44
+ else:
45
+ # Access as attribute for object-based return values
46
+ return getattr(vm, "ip_address", None)
47
+ return None
@@ -0,0 +1,4 @@
1
+ """Provider implementations for different VM backends."""
2
+
3
+ # Import specific providers only when needed to avoid circular imports
4
+ __all__ = [] # Let each provider module handle its own exports