caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +56 -39
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +12 -2
- caption_flow/orchestrator.py +729 -217
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +392 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.0.dist-info/METADATA +369 -0
- caption_flow-0.2.0.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,228 @@
|
|
1
|
+
"""Base worker class for WebSocket-based distributed workers."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import ssl
|
7
|
+
import time
|
8
|
+
from abc import ABC, abstractmethod
|
9
|
+
from typing import Dict, Any, Optional
|
10
|
+
from threading import Event
|
11
|
+
|
12
|
+
import websockets
|
13
|
+
from websockets.client import WebSocketClientProtocol
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class BaseWorker(ABC):
|
19
|
+
"""Base class for all WebSocket-based workers with common connection logic."""
|
20
|
+
|
21
|
+
def __init__(self, config: Dict[str, Any]):
|
22
|
+
self.config = config
|
23
|
+
self.server_url = config["server"]
|
24
|
+
self.token = config["token"]
|
25
|
+
self.name = config.get("name", "worker")
|
26
|
+
|
27
|
+
# SSL configuration
|
28
|
+
self.ssl_context = self._setup_ssl()
|
29
|
+
|
30
|
+
# State
|
31
|
+
self.worker_id: Optional[str] = None
|
32
|
+
self.websocket: Optional[WebSocketClientProtocol] = None
|
33
|
+
self.running = False
|
34
|
+
self.connected = Event()
|
35
|
+
self.main_loop: Optional[asyncio.AbstractEventLoop] = None
|
36
|
+
|
37
|
+
# Metrics (subclasses can extend)
|
38
|
+
self._init_metrics()
|
39
|
+
|
40
|
+
def _init_metrics(self):
|
41
|
+
"""Initialize basic metrics. Subclasses can override to add more."""
|
42
|
+
pass
|
43
|
+
|
44
|
+
def _setup_ssl(self) -> Optional[ssl.SSLContext]:
|
45
|
+
"""Configure SSL context."""
|
46
|
+
if self.server_url.startswith("ws://"):
|
47
|
+
logger.warning("Using insecure WebSocket connection")
|
48
|
+
return None
|
49
|
+
|
50
|
+
if not self.config.get("verify_ssl", True):
|
51
|
+
context = ssl.create_default_context()
|
52
|
+
context.check_hostname = False
|
53
|
+
context.verify_mode = ssl.CERT_NONE
|
54
|
+
return context
|
55
|
+
|
56
|
+
return ssl.create_default_context()
|
57
|
+
|
58
|
+
async def start(self):
|
59
|
+
"""Start the worker with automatic reconnection."""
|
60
|
+
self.running = True
|
61
|
+
|
62
|
+
# Allow subclasses to initialize before connection
|
63
|
+
await self._pre_start()
|
64
|
+
|
65
|
+
# Capture the main event loop
|
66
|
+
self.main_loop = asyncio.get_running_loop()
|
67
|
+
|
68
|
+
# Reconnection with exponential backoff
|
69
|
+
reconnect_delay = 5
|
70
|
+
max_delay = 60
|
71
|
+
|
72
|
+
while self.running:
|
73
|
+
try:
|
74
|
+
await self._connect_and_run()
|
75
|
+
reconnect_delay = 5 # Reset delay on successful connection
|
76
|
+
except Exception as e:
|
77
|
+
logger.error(f"Connection error: {e}")
|
78
|
+
self.connected.clear()
|
79
|
+
self.websocket = None
|
80
|
+
|
81
|
+
# Let subclass handle disconnection
|
82
|
+
await self._on_disconnect()
|
83
|
+
|
84
|
+
if self.running:
|
85
|
+
logger.info(f"Reconnecting in {reconnect_delay} seconds...")
|
86
|
+
await asyncio.sleep(reconnect_delay)
|
87
|
+
reconnect_delay = min(reconnect_delay * 2, max_delay)
|
88
|
+
|
89
|
+
async def _connect_and_run(self):
|
90
|
+
"""Connect to orchestrator and run main loop."""
|
91
|
+
logger.info(f"Connecting to {self.server_url}")
|
92
|
+
|
93
|
+
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
94
|
+
self.websocket = websocket
|
95
|
+
self.connected.set()
|
96
|
+
|
97
|
+
# Authenticate
|
98
|
+
auth_data = self._get_auth_data()
|
99
|
+
await websocket.send(json.dumps(auth_data))
|
100
|
+
|
101
|
+
# Wait for welcome message
|
102
|
+
welcome = await websocket.recv()
|
103
|
+
welcome_data = json.loads(welcome)
|
104
|
+
|
105
|
+
if "error" in welcome_data:
|
106
|
+
logger.error(f"Authentication failed: {welcome_data['error']}")
|
107
|
+
self.running = False
|
108
|
+
return
|
109
|
+
|
110
|
+
self.worker_id = welcome_data.get("worker_id")
|
111
|
+
logger.info(f"Connected as {self.worker_id}")
|
112
|
+
|
113
|
+
# Let subclass handle welcome data
|
114
|
+
await self._handle_welcome(welcome_data)
|
115
|
+
|
116
|
+
# Start processing
|
117
|
+
try:
|
118
|
+
tasks = await self._create_tasks()
|
119
|
+
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
120
|
+
|
121
|
+
# Cancel remaining tasks
|
122
|
+
for task in pending:
|
123
|
+
task.cancel()
|
124
|
+
try:
|
125
|
+
await task
|
126
|
+
except asyncio.CancelledError:
|
127
|
+
pass
|
128
|
+
|
129
|
+
finally:
|
130
|
+
self.connected.clear()
|
131
|
+
self.websocket = None
|
132
|
+
|
133
|
+
async def _heartbeat_loop(self):
|
134
|
+
"""Send periodic heartbeats."""
|
135
|
+
try:
|
136
|
+
while self.running and self.connected.is_set():
|
137
|
+
try:
|
138
|
+
if self.websocket:
|
139
|
+
heartbeat_data = self._get_heartbeat_data()
|
140
|
+
await self.websocket.send(json.dumps(heartbeat_data))
|
141
|
+
await asyncio.sleep(30)
|
142
|
+
except websockets.exceptions.ConnectionClosed as e:
|
143
|
+
logger.info(f"Connection lost during heartbeat: {e}")
|
144
|
+
raise
|
145
|
+
except Exception as e:
|
146
|
+
logger.error(f"Heartbeat error: {e}")
|
147
|
+
raise
|
148
|
+
except asyncio.CancelledError:
|
149
|
+
logger.debug("Heartbeat loop cancelled")
|
150
|
+
raise
|
151
|
+
|
152
|
+
async def _base_message_handler(self):
|
153
|
+
"""Base message handler that delegates to subclass."""
|
154
|
+
try:
|
155
|
+
async for message in self.websocket:
|
156
|
+
try:
|
157
|
+
data = json.loads(message)
|
158
|
+
await self._handle_message(data)
|
159
|
+
except json.JSONDecodeError as e:
|
160
|
+
logger.error(f"Invalid message format: {e}")
|
161
|
+
except Exception as e:
|
162
|
+
logger.error(f"Error handling message: {e}")
|
163
|
+
|
164
|
+
except websockets.exceptions.ConnectionClosed as e:
|
165
|
+
logger.info(f"Connection closed by orchestrator: {e}")
|
166
|
+
raise
|
167
|
+
except Exception as e:
|
168
|
+
logger.error(f"Message handler error: {e}")
|
169
|
+
raise
|
170
|
+
|
171
|
+
async def shutdown(self):
|
172
|
+
"""Graceful shutdown."""
|
173
|
+
logger.info(f"Shutting down {self.__class__.__name__}...")
|
174
|
+
self.running = False
|
175
|
+
self.connected.clear()
|
176
|
+
|
177
|
+
# Let subclass do cleanup
|
178
|
+
await self._pre_shutdown()
|
179
|
+
|
180
|
+
if self.websocket:
|
181
|
+
try:
|
182
|
+
await self.websocket.close()
|
183
|
+
except:
|
184
|
+
pass
|
185
|
+
self.websocket = None
|
186
|
+
|
187
|
+
logger.info(f"{self.__class__.__name__} shutdown complete")
|
188
|
+
|
189
|
+
# Abstract methods that subclasses must implement
|
190
|
+
|
191
|
+
@abstractmethod
|
192
|
+
def _get_auth_data(self) -> Dict[str, Any]:
|
193
|
+
"""Get authentication data to send to orchestrator."""
|
194
|
+
pass
|
195
|
+
|
196
|
+
@abstractmethod
|
197
|
+
async def _handle_welcome(self, welcome_data: Dict[str, Any]):
|
198
|
+
"""Handle welcome message from orchestrator."""
|
199
|
+
pass
|
200
|
+
|
201
|
+
@abstractmethod
|
202
|
+
async def _handle_message(self, data: Dict[str, Any]):
|
203
|
+
"""Handle a message from orchestrator."""
|
204
|
+
pass
|
205
|
+
|
206
|
+
@abstractmethod
|
207
|
+
def _get_heartbeat_data(self) -> Dict[str, Any]:
|
208
|
+
"""Get data to include in heartbeat."""
|
209
|
+
pass
|
210
|
+
|
211
|
+
@abstractmethod
|
212
|
+
async def _create_tasks(self) -> list:
|
213
|
+
"""Create async tasks to run. Must include _heartbeat_loop and _base_message_handler."""
|
214
|
+
pass
|
215
|
+
|
216
|
+
# Optional hooks for subclasses
|
217
|
+
|
218
|
+
async def _pre_start(self):
|
219
|
+
"""Called before starting connection loop. Override to initialize components."""
|
220
|
+
pass
|
221
|
+
|
222
|
+
async def _on_disconnect(self):
|
223
|
+
"""Called when disconnected. Override to clean up state."""
|
224
|
+
pass
|
225
|
+
|
226
|
+
async def _pre_shutdown(self):
|
227
|
+
"""Called before shutdown. Override for cleanup."""
|
228
|
+
pass
|