caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,228 @@
1
+ """Base worker class for WebSocket-based distributed workers."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import ssl
7
+ import time
8
+ from abc import ABC, abstractmethod
9
+ from typing import Dict, Any, Optional
10
+ from threading import Event
11
+
12
+ import websockets
13
+ from websockets.client import WebSocketClientProtocol
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class BaseWorker(ABC):
19
+ """Base class for all WebSocket-based workers with common connection logic."""
20
+
21
+ def __init__(self, config: Dict[str, Any]):
22
+ self.config = config
23
+ self.server_url = config["server"]
24
+ self.token = config["token"]
25
+ self.name = config.get("name", "worker")
26
+
27
+ # SSL configuration
28
+ self.ssl_context = self._setup_ssl()
29
+
30
+ # State
31
+ self.worker_id: Optional[str] = None
32
+ self.websocket: Optional[WebSocketClientProtocol] = None
33
+ self.running = False
34
+ self.connected = Event()
35
+ self.main_loop: Optional[asyncio.AbstractEventLoop] = None
36
+
37
+ # Metrics (subclasses can extend)
38
+ self._init_metrics()
39
+
40
+ def _init_metrics(self):
41
+ """Initialize basic metrics. Subclasses can override to add more."""
42
+ pass
43
+
44
+ def _setup_ssl(self) -> Optional[ssl.SSLContext]:
45
+ """Configure SSL context."""
46
+ if self.server_url.startswith("ws://"):
47
+ logger.warning("Using insecure WebSocket connection")
48
+ return None
49
+
50
+ if not self.config.get("verify_ssl", True):
51
+ context = ssl.create_default_context()
52
+ context.check_hostname = False
53
+ context.verify_mode = ssl.CERT_NONE
54
+ return context
55
+
56
+ return ssl.create_default_context()
57
+
58
+ async def start(self):
59
+ """Start the worker with automatic reconnection."""
60
+ self.running = True
61
+
62
+ # Allow subclasses to initialize before connection
63
+ await self._pre_start()
64
+
65
+ # Capture the main event loop
66
+ self.main_loop = asyncio.get_running_loop()
67
+
68
+ # Reconnection with exponential backoff
69
+ reconnect_delay = 5
70
+ max_delay = 60
71
+
72
+ while self.running:
73
+ try:
74
+ await self._connect_and_run()
75
+ reconnect_delay = 5 # Reset delay on successful connection
76
+ except Exception as e:
77
+ logger.error(f"Connection error: {e}")
78
+ self.connected.clear()
79
+ self.websocket = None
80
+
81
+ # Let subclass handle disconnection
82
+ await self._on_disconnect()
83
+
84
+ if self.running:
85
+ logger.info(f"Reconnecting in {reconnect_delay} seconds...")
86
+ await asyncio.sleep(reconnect_delay)
87
+ reconnect_delay = min(reconnect_delay * 2, max_delay)
88
+
89
+ async def _connect_and_run(self):
90
+ """Connect to orchestrator and run main loop."""
91
+ logger.info(f"Connecting to {self.server_url}")
92
+
93
+ async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
94
+ self.websocket = websocket
95
+ self.connected.set()
96
+
97
+ # Authenticate
98
+ auth_data = self._get_auth_data()
99
+ await websocket.send(json.dumps(auth_data))
100
+
101
+ # Wait for welcome message
102
+ welcome = await websocket.recv()
103
+ welcome_data = json.loads(welcome)
104
+
105
+ if "error" in welcome_data:
106
+ logger.error(f"Authentication failed: {welcome_data['error']}")
107
+ self.running = False
108
+ return
109
+
110
+ self.worker_id = welcome_data.get("worker_id")
111
+ logger.info(f"Connected as {self.worker_id}")
112
+
113
+ # Let subclass handle welcome data
114
+ await self._handle_welcome(welcome_data)
115
+
116
+ # Start processing
117
+ try:
118
+ tasks = await self._create_tasks()
119
+ done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
120
+
121
+ # Cancel remaining tasks
122
+ for task in pending:
123
+ task.cancel()
124
+ try:
125
+ await task
126
+ except asyncio.CancelledError:
127
+ pass
128
+
129
+ finally:
130
+ self.connected.clear()
131
+ self.websocket = None
132
+
133
+ async def _heartbeat_loop(self):
134
+ """Send periodic heartbeats."""
135
+ try:
136
+ while self.running and self.connected.is_set():
137
+ try:
138
+ if self.websocket:
139
+ heartbeat_data = self._get_heartbeat_data()
140
+ await self.websocket.send(json.dumps(heartbeat_data))
141
+ await asyncio.sleep(30)
142
+ except websockets.exceptions.ConnectionClosed as e:
143
+ logger.info(f"Connection lost during heartbeat: {e}")
144
+ raise
145
+ except Exception as e:
146
+ logger.error(f"Heartbeat error: {e}")
147
+ raise
148
+ except asyncio.CancelledError:
149
+ logger.debug("Heartbeat loop cancelled")
150
+ raise
151
+
152
+ async def _base_message_handler(self):
153
+ """Base message handler that delegates to subclass."""
154
+ try:
155
+ async for message in self.websocket:
156
+ try:
157
+ data = json.loads(message)
158
+ await self._handle_message(data)
159
+ except json.JSONDecodeError as e:
160
+ logger.error(f"Invalid message format: {e}")
161
+ except Exception as e:
162
+ logger.error(f"Error handling message: {e}")
163
+
164
+ except websockets.exceptions.ConnectionClosed as e:
165
+ logger.info(f"Connection closed by orchestrator: {e}")
166
+ raise
167
+ except Exception as e:
168
+ logger.error(f"Message handler error: {e}")
169
+ raise
170
+
171
+ async def shutdown(self):
172
+ """Graceful shutdown."""
173
+ logger.info(f"Shutting down {self.__class__.__name__}...")
174
+ self.running = False
175
+ self.connected.clear()
176
+
177
+ # Let subclass do cleanup
178
+ await self._pre_shutdown()
179
+
180
+ if self.websocket:
181
+ try:
182
+ await self.websocket.close()
183
+ except:
184
+ pass
185
+ self.websocket = None
186
+
187
+ logger.info(f"{self.__class__.__name__} shutdown complete")
188
+
189
+ # Abstract methods that subclasses must implement
190
+
191
+ @abstractmethod
192
+ def _get_auth_data(self) -> Dict[str, Any]:
193
+ """Get authentication data to send to orchestrator."""
194
+ pass
195
+
196
+ @abstractmethod
197
+ async def _handle_welcome(self, welcome_data: Dict[str, Any]):
198
+ """Handle welcome message from orchestrator."""
199
+ pass
200
+
201
+ @abstractmethod
202
+ async def _handle_message(self, data: Dict[str, Any]):
203
+ """Handle a message from orchestrator."""
204
+ pass
205
+
206
+ @abstractmethod
207
+ def _get_heartbeat_data(self) -> Dict[str, Any]:
208
+ """Get data to include in heartbeat."""
209
+ pass
210
+
211
+ @abstractmethod
212
+ async def _create_tasks(self) -> list:
213
+ """Create async tasks to run. Must include _heartbeat_loop and _base_message_handler."""
214
+ pass
215
+
216
+ # Optional hooks for subclasses
217
+
218
+ async def _pre_start(self):
219
+ """Called before starting connection loop. Override to initialize components."""
220
+ pass
221
+
222
+ async def _on_disconnect(self):
223
+ """Called when disconnected. Override to clean up state."""
224
+ pass
225
+
226
+ async def _pre_shutdown(self):
227
+ """Called before shutdown. Override for cleanup."""
228
+ pass