caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +65 -42
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +13 -3
- caption_flow/orchestrator.py +1049 -264
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +567 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.1.dist-info/METADATA +370 -0
- caption_flow-0.2.1.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@
|
|
3
3
|
import asyncio
|
4
4
|
import json
|
5
5
|
import logging
|
6
|
-
import ssl
|
7
6
|
import io
|
8
7
|
import time
|
9
8
|
from dataclasses import dataclass
|
@@ -12,14 +11,14 @@ from typing import Dict, Any, Optional, List, AsyncIterator
|
|
12
11
|
from queue import Queue, Empty
|
13
12
|
from threading import Thread, Event
|
14
13
|
|
15
|
-
import websockets
|
16
|
-
from websockets.client import WebSocketClientProtocol
|
17
14
|
import pandas as pd
|
18
15
|
import pyarrow.parquet as pq
|
19
16
|
from PIL import Image
|
20
17
|
import boto3
|
21
18
|
from botocore.config import Config
|
22
19
|
|
20
|
+
from .base import BaseWorker
|
21
|
+
|
23
22
|
logger = logging.getLogger(__name__)
|
24
23
|
|
25
24
|
|
@@ -33,59 +32,94 @@ class DataSample:
|
|
33
32
|
metadata: Optional[Dict[str, Any]] = None
|
34
33
|
|
35
34
|
|
36
|
-
class DataWorker:
|
35
|
+
class DataWorker(BaseWorker):
|
37
36
|
"""Worker that retrieves data from various sources and forwards to orchestrator/storage."""
|
38
37
|
|
39
38
|
def __init__(self, config: Dict[str, Any]):
|
40
|
-
|
41
|
-
self.server_url = config["server"]
|
42
|
-
self.token = config["token"]
|
43
|
-
self.name = config.get("name", "data_worker")
|
39
|
+
super().__init__(config)
|
44
40
|
|
45
41
|
# Data source configuration
|
46
|
-
self.data_source = config.get(
|
47
|
-
|
48
|
-
) # Path to .jsonl, .csv, .parquet, or HF dataset
|
49
|
-
self.source_type = config.get(
|
50
|
-
"source_type", "auto"
|
51
|
-
) # auto, jsonl, csv, parquet, huggingface
|
42
|
+
self.data_source = config.get("data_source")
|
43
|
+
self.source_type = config.get("source_type", "auto")
|
52
44
|
self.batch_size = config.get("batch_size", 10)
|
53
45
|
|
54
46
|
# Storage configuration (will be updated from orchestrator)
|
55
47
|
self.storage_config = None
|
56
48
|
self.s3_client = None
|
57
49
|
|
58
|
-
#
|
59
|
-
self.ssl_context = self._setup_ssl()
|
60
|
-
|
61
|
-
# State
|
62
|
-
self.worker_id: Optional[str] = None
|
63
|
-
self.websocket: Optional[WebSocketClientProtocol] = None
|
64
|
-
self.running = False
|
65
|
-
self.connected = Event()
|
50
|
+
# State specific to data worker
|
66
51
|
self.can_send = Event() # For backpressure
|
67
52
|
|
68
53
|
# Queues
|
69
54
|
self.send_queue = Queue(maxsize=100)
|
70
55
|
|
71
|
-
|
56
|
+
def _init_metrics(self):
|
57
|
+
"""Initialize data worker metrics."""
|
72
58
|
self.samples_sent = 0
|
73
59
|
self.samples_stored = 0
|
74
60
|
self.samples_failed = 0
|
75
61
|
|
76
|
-
def
|
77
|
-
"""
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
62
|
+
def _get_auth_data(self) -> Dict[str, Any]:
|
63
|
+
"""Get authentication data."""
|
64
|
+
return {"token": self.token, "name": self.name, "role": "data_worker"}
|
65
|
+
|
66
|
+
async def _handle_welcome(self, welcome_data: Dict[str, Any]):
|
67
|
+
"""Handle welcome message from orchestrator."""
|
68
|
+
self.storage_config = welcome_data.get("storage_config", {})
|
69
|
+
|
70
|
+
# Setup S3 if configured
|
71
|
+
if self.storage_config.get("s3", {}).get("enabled"):
|
72
|
+
self._setup_s3_client(self.storage_config["s3"])
|
73
|
+
|
74
|
+
logger.info(f"Storage config: {self.storage_config}")
|
75
|
+
|
76
|
+
# Start with ability to send
|
77
|
+
self.can_send.set()
|
78
|
+
|
79
|
+
async def _handle_message(self, data: Dict[str, Any]):
|
80
|
+
"""Handle message from orchestrator."""
|
81
|
+
msg_type = data.get("type")
|
82
|
+
|
83
|
+
if msg_type == "backpressure":
|
84
|
+
# Orchestrator is overwhelmed
|
85
|
+
self.can_send.clear()
|
86
|
+
logger.info("Received backpressure signal")
|
87
|
+
|
88
|
+
elif msg_type == "resume":
|
89
|
+
# Orchestrator ready for more
|
90
|
+
self.can_send.set()
|
91
|
+
logger.info("Received resume signal")
|
92
|
+
|
93
|
+
def _get_heartbeat_data(self) -> Dict[str, Any]:
|
94
|
+
"""Get heartbeat data."""
|
95
|
+
return {
|
96
|
+
"type": "heartbeat",
|
97
|
+
"sent": self.samples_sent,
|
98
|
+
"stored": self.samples_stored,
|
99
|
+
"failed": self.samples_failed,
|
100
|
+
"queue_size": self.send_queue.qsize(),
|
101
|
+
}
|
102
|
+
|
103
|
+
async def _create_tasks(self) -> list:
|
104
|
+
"""Create async tasks to run."""
|
105
|
+
return [
|
106
|
+
asyncio.create_task(self._heartbeat_loop()),
|
107
|
+
asyncio.create_task(self._base_message_handler()),
|
108
|
+
asyncio.create_task(self._data_processor()),
|
109
|
+
asyncio.create_task(self._send_loop()),
|
110
|
+
]
|
111
|
+
|
112
|
+
async def _on_disconnect(self):
|
113
|
+
"""Handle disconnection."""
|
114
|
+
# Clear send capability
|
115
|
+
self.can_send.clear()
|
116
|
+
|
117
|
+
# Clear send queue
|
118
|
+
try:
|
119
|
+
while True:
|
120
|
+
self.send_queue.get_nowait()
|
121
|
+
except Empty:
|
122
|
+
pass
|
89
123
|
|
90
124
|
def _setup_s3_client(self, s3_config: Dict[str, Any]):
|
91
125
|
"""Setup S3 client from config."""
|
@@ -108,6 +142,98 @@ class DataWorker:
|
|
108
142
|
logger.error(f"Failed to setup S3 client: {e}")
|
109
143
|
return None
|
110
144
|
|
145
|
+
async def _data_processor(self):
|
146
|
+
"""Process data from source."""
|
147
|
+
try:
|
148
|
+
batch = []
|
149
|
+
|
150
|
+
async for sample in self._load_data_source():
|
151
|
+
# Get image data
|
152
|
+
if sample.image_data:
|
153
|
+
image_data = sample.image_data
|
154
|
+
elif sample.image_url:
|
155
|
+
image_data = await self._download_image(sample.image_url)
|
156
|
+
if not image_data:
|
157
|
+
self.samples_failed += 1
|
158
|
+
continue
|
159
|
+
else:
|
160
|
+
logger.warning(f"No image data for sample {sample.sample_id}")
|
161
|
+
continue
|
162
|
+
|
163
|
+
# Store if configured
|
164
|
+
if self.storage_config.get("forward_to_orchestrator", True):
|
165
|
+
# Add to send queue
|
166
|
+
batch.append(
|
167
|
+
{
|
168
|
+
"sample_id": sample.sample_id,
|
169
|
+
"image_data": image_data,
|
170
|
+
"metadata": sample.metadata,
|
171
|
+
}
|
172
|
+
)
|
173
|
+
|
174
|
+
if len(batch) >= self.batch_size:
|
175
|
+
# Wait for backpressure clearance
|
176
|
+
await asyncio.wait_for(self.can_send.wait(), timeout=300)
|
177
|
+
|
178
|
+
# Add batch to send queue
|
179
|
+
try:
|
180
|
+
self.send_queue.put_nowait(batch)
|
181
|
+
batch = []
|
182
|
+
except:
|
183
|
+
# Queue full, wait
|
184
|
+
await asyncio.sleep(1)
|
185
|
+
|
186
|
+
# Store locally/S3 if configured
|
187
|
+
if self.storage_config.get("local", {}).get("enabled") or self.storage_config.get(
|
188
|
+
"s3", {}
|
189
|
+
).get("enabled"):
|
190
|
+
if await self._store_sample(sample, image_data):
|
191
|
+
self.samples_stored += 1
|
192
|
+
|
193
|
+
# Send remaining batch
|
194
|
+
if batch and self.storage_config.get("forward_to_orchestrator", True):
|
195
|
+
await asyncio.wait_for(self.can_send.wait(), timeout=300)
|
196
|
+
self.send_queue.put_nowait(batch)
|
197
|
+
|
198
|
+
except Exception as e:
|
199
|
+
logger.error(f"Data processing error: {e}")
|
200
|
+
|
201
|
+
async def _send_loop(self):
|
202
|
+
"""Send data samples to orchestrator."""
|
203
|
+
while self.running and self.connected.is_set():
|
204
|
+
try:
|
205
|
+
# Get batch from queue
|
206
|
+
batch = await asyncio.get_event_loop().run_in_executor(
|
207
|
+
None, self.send_queue.get, True, 1
|
208
|
+
)
|
209
|
+
|
210
|
+
if batch and self.websocket:
|
211
|
+
# Send samples
|
212
|
+
await self.websocket.send(
|
213
|
+
json.dumps(
|
214
|
+
{
|
215
|
+
"type": "submit_samples",
|
216
|
+
"samples": [
|
217
|
+
{"sample_id": s["sample_id"], "metadata": s["metadata"]}
|
218
|
+
for s in batch
|
219
|
+
],
|
220
|
+
"batch_size": len(batch),
|
221
|
+
}
|
222
|
+
)
|
223
|
+
)
|
224
|
+
|
225
|
+
# Send actual image data separately
|
226
|
+
for sample in batch:
|
227
|
+
await self.websocket.send(sample["image_data"])
|
228
|
+
|
229
|
+
self.samples_sent += len(batch)
|
230
|
+
logger.info(f"Sent batch of {len(batch)} samples")
|
231
|
+
|
232
|
+
except Empty:
|
233
|
+
continue
|
234
|
+
except Exception as e:
|
235
|
+
logger.error(f"Send error: {e}")
|
236
|
+
|
111
237
|
async def _load_data_source(self) -> AsyncIterator[DataSample]:
|
112
238
|
"""Load data from configured source."""
|
113
239
|
source_type = self.source_type
|
@@ -282,201 +408,3 @@ class DataWorker:
|
|
282
408
|
logger.error(f"Failed to store to S3: {e}")
|
283
409
|
|
284
410
|
return stored
|
285
|
-
|
286
|
-
async def start(self):
|
287
|
-
"""Start the data worker."""
|
288
|
-
self.running = True
|
289
|
-
|
290
|
-
# Connect and get configuration
|
291
|
-
while self.running:
|
292
|
-
try:
|
293
|
-
await self._connect_and_run()
|
294
|
-
except Exception as e:
|
295
|
-
logger.error(f"Connection error: {e}")
|
296
|
-
await asyncio.sleep(5)
|
297
|
-
|
298
|
-
async def _connect_and_run(self):
|
299
|
-
"""Connect to orchestrator and process data."""
|
300
|
-
logger.info(f"Connecting to {self.server_url}")
|
301
|
-
|
302
|
-
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
303
|
-
self.websocket = websocket
|
304
|
-
self.connected.set()
|
305
|
-
self.can_send.set() # Start with ability to send
|
306
|
-
|
307
|
-
# Authenticate
|
308
|
-
await websocket.send(
|
309
|
-
json.dumps({"token": self.token, "name": self.name, "role": "data_worker"})
|
310
|
-
)
|
311
|
-
|
312
|
-
# Wait for welcome message with storage config
|
313
|
-
welcome = await websocket.recv()
|
314
|
-
welcome_data = json.loads(welcome)
|
315
|
-
|
316
|
-
if "error" in welcome_data:
|
317
|
-
logger.error(f"Authentication failed: {welcome_data['error']}")
|
318
|
-
self.running = False
|
319
|
-
return
|
320
|
-
|
321
|
-
self.worker_id = welcome_data.get("worker_id")
|
322
|
-
self.storage_config = welcome_data.get("storage_config", {})
|
323
|
-
|
324
|
-
# Setup S3 if configured
|
325
|
-
if self.storage_config.get("s3", {}).get("enabled"):
|
326
|
-
self._setup_s3_client(self.storage_config["s3"])
|
327
|
-
|
328
|
-
logger.info(f"Connected as {self.worker_id}")
|
329
|
-
logger.info(f"Storage config: {self.storage_config}")
|
330
|
-
|
331
|
-
# Start processing
|
332
|
-
tasks = [
|
333
|
-
asyncio.create_task(self._message_handler()),
|
334
|
-
asyncio.create_task(self._data_processor()),
|
335
|
-
asyncio.create_task(self._send_loop()),
|
336
|
-
asyncio.create_task(self._heartbeat_loop()),
|
337
|
-
]
|
338
|
-
|
339
|
-
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
340
|
-
|
341
|
-
for task in pending:
|
342
|
-
task.cancel()
|
343
|
-
|
344
|
-
async def _message_handler(self):
|
345
|
-
"""Handle messages from orchestrator."""
|
346
|
-
async for message in self.websocket:
|
347
|
-
try:
|
348
|
-
data = json.loads(message)
|
349
|
-
msg_type = data.get("type")
|
350
|
-
|
351
|
-
if msg_type == "backpressure":
|
352
|
-
# Orchestrator is overwhelmed
|
353
|
-
self.can_send.clear()
|
354
|
-
logger.info("Received backpressure signal")
|
355
|
-
|
356
|
-
elif msg_type == "resume":
|
357
|
-
# Orchestrator ready for more
|
358
|
-
self.can_send.set()
|
359
|
-
logger.info("Received resume signal")
|
360
|
-
|
361
|
-
except Exception as e:
|
362
|
-
logger.error(f"Error handling message: {e}")
|
363
|
-
|
364
|
-
async def _data_processor(self):
|
365
|
-
"""Process data from source."""
|
366
|
-
try:
|
367
|
-
batch = []
|
368
|
-
|
369
|
-
async for sample in self._load_data_source():
|
370
|
-
# Get image data
|
371
|
-
if sample.image_data:
|
372
|
-
image_data = sample.image_data
|
373
|
-
elif sample.image_url:
|
374
|
-
image_data = await self._download_image(sample.image_url)
|
375
|
-
if not image_data:
|
376
|
-
self.samples_failed += 1
|
377
|
-
continue
|
378
|
-
else:
|
379
|
-
logger.warning(f"No image data for sample {sample.sample_id}")
|
380
|
-
continue
|
381
|
-
|
382
|
-
# Store if configured
|
383
|
-
if self.storage_config.get("forward_to_orchestrator", True):
|
384
|
-
# Add to send queue
|
385
|
-
batch.append(
|
386
|
-
{
|
387
|
-
"sample_id": sample.sample_id,
|
388
|
-
"image_data": image_data,
|
389
|
-
"metadata": sample.metadata,
|
390
|
-
}
|
391
|
-
)
|
392
|
-
|
393
|
-
if len(batch) >= self.batch_size:
|
394
|
-
# Wait for backpressure clearance
|
395
|
-
await asyncio.wait_for(self.can_send.wait(), timeout=300)
|
396
|
-
|
397
|
-
# Add batch to send queue
|
398
|
-
try:
|
399
|
-
self.send_queue.put_nowait(batch)
|
400
|
-
batch = []
|
401
|
-
except:
|
402
|
-
# Queue full, wait
|
403
|
-
await asyncio.sleep(1)
|
404
|
-
|
405
|
-
# Store locally/S3 if configured
|
406
|
-
if self.storage_config.get("local", {}).get("enabled") or self.storage_config.get(
|
407
|
-
"s3", {}
|
408
|
-
).get("enabled"):
|
409
|
-
if await self._store_sample(sample, image_data):
|
410
|
-
self.samples_stored += 1
|
411
|
-
|
412
|
-
# Send remaining batch
|
413
|
-
if batch and self.storage_config.get("forward_to_orchestrator", True):
|
414
|
-
await asyncio.wait_for(self.can_send.wait(), timeout=300)
|
415
|
-
self.send_queue.put_nowait(batch)
|
416
|
-
|
417
|
-
except Exception as e:
|
418
|
-
logger.error(f"Data processing error: {e}")
|
419
|
-
|
420
|
-
async def _send_loop(self):
|
421
|
-
"""Send data samples to orchestrator."""
|
422
|
-
while self.running and self.connected.is_set():
|
423
|
-
try:
|
424
|
-
# Get batch from queue
|
425
|
-
batch = await asyncio.get_event_loop().run_in_executor(
|
426
|
-
None, self.send_queue.get, True, 1
|
427
|
-
)
|
428
|
-
|
429
|
-
if batch and self.websocket:
|
430
|
-
# Send samples
|
431
|
-
await self.websocket.send(
|
432
|
-
json.dumps(
|
433
|
-
{
|
434
|
-
"type": "submit_samples",
|
435
|
-
"samples": [
|
436
|
-
{"sample_id": s["sample_id"], "metadata": s["metadata"]}
|
437
|
-
for s in batch
|
438
|
-
],
|
439
|
-
"batch_size": len(batch),
|
440
|
-
}
|
441
|
-
)
|
442
|
-
)
|
443
|
-
|
444
|
-
# Send actual image data separately
|
445
|
-
for sample in batch:
|
446
|
-
await self.websocket.send(sample["image_data"])
|
447
|
-
|
448
|
-
self.samples_sent += len(batch)
|
449
|
-
logger.info(f"Sent batch of {len(batch)} samples")
|
450
|
-
|
451
|
-
except Empty:
|
452
|
-
continue
|
453
|
-
except Exception as e:
|
454
|
-
logger.error(f"Send error: {e}")
|
455
|
-
|
456
|
-
async def _heartbeat_loop(self):
|
457
|
-
"""Send periodic heartbeats."""
|
458
|
-
while self.running and self.connected.is_set():
|
459
|
-
try:
|
460
|
-
await self.websocket.send(
|
461
|
-
json.dumps(
|
462
|
-
{
|
463
|
-
"type": "heartbeat",
|
464
|
-
"sent": self.samples_sent,
|
465
|
-
"stored": self.samples_stored,
|
466
|
-
"failed": self.samples_failed,
|
467
|
-
"queue_size": self.send_queue.qsize(),
|
468
|
-
}
|
469
|
-
)
|
470
|
-
)
|
471
|
-
await asyncio.sleep(30)
|
472
|
-
except:
|
473
|
-
break
|
474
|
-
|
475
|
-
async def shutdown(self):
|
476
|
-
"""Graceful shutdown."""
|
477
|
-
logger.info("Shutting down data worker...")
|
478
|
-
self.running = False
|
479
|
-
self.connected.clear()
|
480
|
-
|
481
|
-
if self.websocket:
|
482
|
-
await self.websocket.close()
|