caption-flow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +9 -0
- caption_flow/cli.py +709 -0
- caption_flow/models.py +82 -0
- caption_flow/monitor.py +211 -0
- caption_flow/orchestrator.py +1301 -0
- caption_flow/storage.py +694 -0
- caption_flow/utils/__init__.py +4 -0
- caption_flow/utils/auth.py +67 -0
- caption_flow/utils/caption_utils.py +172 -0
- caption_flow/utils/certificates.py +140 -0
- caption_flow/utils/chunk_tracker.py +365 -0
- caption_flow/utils/dataset_loader.py +186 -0
- caption_flow/utils/image_processor.py +51 -0
- caption_flow/utils/job_queue.py +41 -0
- caption_flow/utils/json_utils.py +201 -0
- caption_flow/utils/vllm_config.py +164 -0
- caption_flow/worker.py +300 -0
- caption_flow/worker_data.py +482 -0
- caption_flow/worker_vllm.py +1028 -0
- caption_flow-0.1.0.dist-info/METADATA +427 -0
- caption_flow-0.1.0.dist-info/RECORD +25 -0
- caption_flow-0.1.0.dist-info/WHEEL +5 -0
- caption_flow-0.1.0.dist-info/entry_points.txt +2 -0
- caption_flow-0.1.0.dist-info/licenses/LICENSE +661 -0
- caption_flow-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,482 @@
|
|
1
|
+
"""DataWorker for retrieving data from various sources and forwarding to orchestrator or storage."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import ssl
|
7
|
+
import io
|
8
|
+
import time
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Dict, Any, Optional, List, AsyncIterator
|
12
|
+
from queue import Queue, Empty
|
13
|
+
from threading import Thread, Event
|
14
|
+
|
15
|
+
import websockets
|
16
|
+
from websockets.client import WebSocketClientProtocol
|
17
|
+
import pandas as pd
|
18
|
+
import pyarrow.parquet as pq
|
19
|
+
from PIL import Image
|
20
|
+
import boto3
|
21
|
+
from botocore.config import Config
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class DataSample:
|
28
|
+
"""A single data sample to process."""
|
29
|
+
|
30
|
+
sample_id: str
|
31
|
+
image_url: Optional[str] = None
|
32
|
+
image_data: Optional[bytes] = None
|
33
|
+
metadata: Optional[Dict[str, Any]] = None
|
34
|
+
|
35
|
+
|
36
|
+
class DataWorker:
|
37
|
+
"""Worker that retrieves data from various sources and forwards to orchestrator/storage."""
|
38
|
+
|
39
|
+
def __init__(self, config: Dict[str, Any]):
|
40
|
+
self.config = config
|
41
|
+
self.server_url = config["server"]
|
42
|
+
self.token = config["token"]
|
43
|
+
self.name = config.get("name", "data_worker")
|
44
|
+
|
45
|
+
# Data source configuration
|
46
|
+
self.data_source = config.get(
|
47
|
+
"data_source"
|
48
|
+
) # Path to .jsonl, .csv, .parquet, or HF dataset
|
49
|
+
self.source_type = config.get(
|
50
|
+
"source_type", "auto"
|
51
|
+
) # auto, jsonl, csv, parquet, huggingface
|
52
|
+
self.batch_size = config.get("batch_size", 10)
|
53
|
+
|
54
|
+
# Storage configuration (will be updated from orchestrator)
|
55
|
+
self.storage_config = None
|
56
|
+
self.s3_client = None
|
57
|
+
|
58
|
+
# SSL configuration
|
59
|
+
self.ssl_context = self._setup_ssl()
|
60
|
+
|
61
|
+
# State
|
62
|
+
self.worker_id: Optional[str] = None
|
63
|
+
self.websocket: Optional[WebSocketClientProtocol] = None
|
64
|
+
self.running = False
|
65
|
+
self.connected = Event()
|
66
|
+
self.can_send = Event() # For backpressure
|
67
|
+
|
68
|
+
# Queues
|
69
|
+
self.send_queue = Queue(maxsize=100)
|
70
|
+
|
71
|
+
# Metrics
|
72
|
+
self.samples_sent = 0
|
73
|
+
self.samples_stored = 0
|
74
|
+
self.samples_failed = 0
|
75
|
+
|
76
|
+
def _setup_ssl(self) -> Optional[ssl.SSLContext]:
|
77
|
+
"""Configure SSL context."""
|
78
|
+
if self.server_url.startswith("ws://"):
|
79
|
+
logger.warning("Using insecure WebSocket connection")
|
80
|
+
return None
|
81
|
+
|
82
|
+
if not self.config.get("verify_ssl", True):
|
83
|
+
context = ssl.create_default_context()
|
84
|
+
context.check_hostname = False
|
85
|
+
context.verify_mode = ssl.CERT_NONE
|
86
|
+
return context
|
87
|
+
|
88
|
+
return ssl.create_default_context()
|
89
|
+
|
90
|
+
def _setup_s3_client(self, s3_config: Dict[str, Any]):
|
91
|
+
"""Setup S3 client from config."""
|
92
|
+
if not s3_config:
|
93
|
+
return None
|
94
|
+
|
95
|
+
try:
|
96
|
+
self.s3_client = boto3.client(
|
97
|
+
"s3",
|
98
|
+
endpoint_url=s3_config.get("endpoint_url"),
|
99
|
+
aws_access_key_id=s3_config.get("access_key"),
|
100
|
+
aws_secret_access_key=s3_config.get("secret_key"),
|
101
|
+
region_name=s3_config.get("region", "us-east-1"),
|
102
|
+
config=Config(signature_version="s3v4"),
|
103
|
+
)
|
104
|
+
self.s3_bucket = s3_config.get("bucket")
|
105
|
+
logger.info(f"S3 client configured for bucket: {self.s3_bucket}")
|
106
|
+
return self.s3_client
|
107
|
+
except Exception as e:
|
108
|
+
logger.error(f"Failed to setup S3 client: {e}")
|
109
|
+
return None
|
110
|
+
|
111
|
+
async def _load_data_source(self) -> AsyncIterator[DataSample]:
|
112
|
+
"""Load data from configured source."""
|
113
|
+
source_type = self.source_type
|
114
|
+
|
115
|
+
if source_type == "auto":
|
116
|
+
# Auto-detect based on file extension
|
117
|
+
if self.data_source.endswith(".jsonl"):
|
118
|
+
source_type = "jsonl"
|
119
|
+
elif self.data_source.endswith(".csv"):
|
120
|
+
source_type = "csv"
|
121
|
+
elif self.data_source.endswith(".parquet"):
|
122
|
+
source_type = "parquet"
|
123
|
+
elif self.data_source.startswith("hf://") or "/" in self.data_source:
|
124
|
+
source_type = "huggingface"
|
125
|
+
|
126
|
+
logger.info(f"Loading data from {source_type} source: {self.data_source}")
|
127
|
+
|
128
|
+
if source_type == "jsonl":
|
129
|
+
async for sample in self._load_jsonl():
|
130
|
+
yield sample
|
131
|
+
elif source_type == "csv":
|
132
|
+
async for sample in self._load_csv():
|
133
|
+
yield sample
|
134
|
+
elif source_type == "parquet":
|
135
|
+
async for sample in self._load_parquet():
|
136
|
+
yield sample
|
137
|
+
elif source_type == "huggingface":
|
138
|
+
async for sample in self._load_huggingface():
|
139
|
+
yield sample
|
140
|
+
else:
|
141
|
+
raise ValueError(f"Unknown source type: {source_type}")
|
142
|
+
|
143
|
+
async def _load_jsonl(self) -> AsyncIterator[DataSample]:
|
144
|
+
"""Load data from JSONL file with URL list."""
|
145
|
+
with open(self.data_source, "r") as f:
|
146
|
+
for line_num, line in enumerate(f):
|
147
|
+
try:
|
148
|
+
data = json.loads(line.strip())
|
149
|
+
sample = DataSample(
|
150
|
+
sample_id=data.get("id", f"sample_{line_num}"),
|
151
|
+
image_url=data.get("url") or data.get("image_url"),
|
152
|
+
metadata=data,
|
153
|
+
)
|
154
|
+
yield sample
|
155
|
+
except Exception as e:
|
156
|
+
logger.error(f"Error loading line {line_num}: {e}")
|
157
|
+
|
158
|
+
async def _load_csv(self) -> AsyncIterator[DataSample]:
|
159
|
+
"""Load data from CSV file."""
|
160
|
+
df = pd.read_csv(self.data_source)
|
161
|
+
|
162
|
+
# Try to find URL column
|
163
|
+
url_cols = [col for col in df.columns if "url" in col.lower() or "link" in col.lower()]
|
164
|
+
url_col = url_cols[0] if url_cols else None
|
165
|
+
|
166
|
+
for idx, row in df.iterrows():
|
167
|
+
sample = DataSample(
|
168
|
+
sample_id=str(row.get("id", idx)),
|
169
|
+
image_url=row.get(url_col) if url_col else None,
|
170
|
+
metadata=row.to_dict(),
|
171
|
+
)
|
172
|
+
yield sample
|
173
|
+
|
174
|
+
async def _load_parquet(self) -> AsyncIterator[DataSample]:
|
175
|
+
"""Load data from Parquet file."""
|
176
|
+
table = pq.read_table(self.data_source)
|
177
|
+
df = table.to_pandas()
|
178
|
+
|
179
|
+
# Try to find URL column
|
180
|
+
url_cols = [col for col in df.columns if "url" in col.lower() or "link" in col.lower()]
|
181
|
+
url_col = url_cols[0] if url_cols else None
|
182
|
+
|
183
|
+
for idx, row in df.iterrows():
|
184
|
+
sample = DataSample(
|
185
|
+
sample_id=str(row.get("id", idx)),
|
186
|
+
image_url=row.get(url_col) if url_col else None,
|
187
|
+
metadata=row.to_dict(),
|
188
|
+
)
|
189
|
+
yield sample
|
190
|
+
|
191
|
+
async def _load_huggingface(self) -> AsyncIterator[DataSample]:
|
192
|
+
"""Load data from HuggingFace dataset."""
|
193
|
+
from datasets import load_dataset
|
194
|
+
|
195
|
+
# Parse dataset path
|
196
|
+
if self.data_source.startswith("hf://"):
|
197
|
+
dataset_path = self.data_source[5:]
|
198
|
+
else:
|
199
|
+
dataset_path = self.data_source
|
200
|
+
|
201
|
+
# Load dataset
|
202
|
+
ds = load_dataset(dataset_path, split="train", streaming=True)
|
203
|
+
|
204
|
+
for idx, item in enumerate(ds):
|
205
|
+
# Try to find image data
|
206
|
+
image_url = None
|
207
|
+
image_data = None
|
208
|
+
|
209
|
+
if "image" in item and hasattr(item["image"], "save"):
|
210
|
+
# PIL Image
|
211
|
+
buffer = io.BytesIO()
|
212
|
+
item["image"].save(buffer, format="PNG")
|
213
|
+
image_data = buffer.getvalue()
|
214
|
+
elif "url" in item:
|
215
|
+
image_url = item["url"]
|
216
|
+
elif "image_url" in item:
|
217
|
+
image_url = item["image_url"]
|
218
|
+
|
219
|
+
sample = DataSample(
|
220
|
+
sample_id=item.get("id", f"hf_{idx}"),
|
221
|
+
image_url=image_url,
|
222
|
+
image_data=image_data,
|
223
|
+
metadata=item,
|
224
|
+
)
|
225
|
+
yield sample
|
226
|
+
|
227
|
+
async def _download_image(self, url: str) -> Optional[bytes]:
|
228
|
+
"""Download image from URL."""
|
229
|
+
try:
|
230
|
+
import aiohttp
|
231
|
+
|
232
|
+
async with aiohttp.ClientSession() as session:
|
233
|
+
async with session.get(url, timeout=30) as response:
|
234
|
+
if response.status == 200:
|
235
|
+
return await response.read()
|
236
|
+
except Exception as e:
|
237
|
+
logger.error(f"Failed to download image from {url}: {e}")
|
238
|
+
return None
|
239
|
+
|
240
|
+
async def _store_sample(self, sample: DataSample, image_data: bytes) -> bool:
|
241
|
+
"""Store sample according to storage config."""
|
242
|
+
stored = False
|
243
|
+
|
244
|
+
# Store locally if configured
|
245
|
+
if self.storage_config.get("local", {}).get("enabled"):
|
246
|
+
local_dir = Path(self.storage_config["local"].get("path", "./data"))
|
247
|
+
local_dir.mkdir(parents=True, exist_ok=True)
|
248
|
+
|
249
|
+
try:
|
250
|
+
# Save image
|
251
|
+
image_path = local_dir / f"{sample.sample_id}.jpg"
|
252
|
+
with open(image_path, "wb") as f:
|
253
|
+
f.write(image_data)
|
254
|
+
|
255
|
+
# Save metadata
|
256
|
+
meta_path = local_dir / f"{sample.sample_id}.json"
|
257
|
+
with open(meta_path, "w") as f:
|
258
|
+
json.dump(sample.metadata or {}, f)
|
259
|
+
|
260
|
+
stored = True
|
261
|
+
except Exception as e:
|
262
|
+
logger.error(f"Failed to store locally: {e}")
|
263
|
+
|
264
|
+
# Store to S3 if configured
|
265
|
+
if self.storage_config.get("s3", {}).get("enabled") and self.s3_client:
|
266
|
+
try:
|
267
|
+
# Upload image
|
268
|
+
self.s3_client.put_object(
|
269
|
+
Bucket=self.s3_bucket, Key=f"images/{sample.sample_id}.jpg", Body=image_data
|
270
|
+
)
|
271
|
+
|
272
|
+
# Upload metadata
|
273
|
+
if sample.metadata:
|
274
|
+
self.s3_client.put_object(
|
275
|
+
Bucket=self.s3_bucket,
|
276
|
+
Key=f"metadata/{sample.sample_id}.json",
|
277
|
+
Body=json.dumps(sample.metadata),
|
278
|
+
)
|
279
|
+
|
280
|
+
stored = True
|
281
|
+
except Exception as e:
|
282
|
+
logger.error(f"Failed to store to S3: {e}")
|
283
|
+
|
284
|
+
return stored
|
285
|
+
|
286
|
+
async def start(self):
|
287
|
+
"""Start the data worker."""
|
288
|
+
self.running = True
|
289
|
+
|
290
|
+
# Connect and get configuration
|
291
|
+
while self.running:
|
292
|
+
try:
|
293
|
+
await self._connect_and_run()
|
294
|
+
except Exception as e:
|
295
|
+
logger.error(f"Connection error: {e}")
|
296
|
+
await asyncio.sleep(5)
|
297
|
+
|
298
|
+
async def _connect_and_run(self):
|
299
|
+
"""Connect to orchestrator and process data."""
|
300
|
+
logger.info(f"Connecting to {self.server_url}")
|
301
|
+
|
302
|
+
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
303
|
+
self.websocket = websocket
|
304
|
+
self.connected.set()
|
305
|
+
self.can_send.set() # Start with ability to send
|
306
|
+
|
307
|
+
# Authenticate
|
308
|
+
await websocket.send(
|
309
|
+
json.dumps({"token": self.token, "name": self.name, "role": "data_worker"})
|
310
|
+
)
|
311
|
+
|
312
|
+
# Wait for welcome message with storage config
|
313
|
+
welcome = await websocket.recv()
|
314
|
+
welcome_data = json.loads(welcome)
|
315
|
+
|
316
|
+
if "error" in welcome_data:
|
317
|
+
logger.error(f"Authentication failed: {welcome_data['error']}")
|
318
|
+
self.running = False
|
319
|
+
return
|
320
|
+
|
321
|
+
self.worker_id = welcome_data.get("worker_id")
|
322
|
+
self.storage_config = welcome_data.get("storage_config", {})
|
323
|
+
|
324
|
+
# Setup S3 if configured
|
325
|
+
if self.storage_config.get("s3", {}).get("enabled"):
|
326
|
+
self._setup_s3_client(self.storage_config["s3"])
|
327
|
+
|
328
|
+
logger.info(f"Connected as {self.worker_id}")
|
329
|
+
logger.info(f"Storage config: {self.storage_config}")
|
330
|
+
|
331
|
+
# Start processing
|
332
|
+
tasks = [
|
333
|
+
asyncio.create_task(self._message_handler()),
|
334
|
+
asyncio.create_task(self._data_processor()),
|
335
|
+
asyncio.create_task(self._send_loop()),
|
336
|
+
asyncio.create_task(self._heartbeat_loop()),
|
337
|
+
]
|
338
|
+
|
339
|
+
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
340
|
+
|
341
|
+
for task in pending:
|
342
|
+
task.cancel()
|
343
|
+
|
344
|
+
async def _message_handler(self):
|
345
|
+
"""Handle messages from orchestrator."""
|
346
|
+
async for message in self.websocket:
|
347
|
+
try:
|
348
|
+
data = json.loads(message)
|
349
|
+
msg_type = data.get("type")
|
350
|
+
|
351
|
+
if msg_type == "backpressure":
|
352
|
+
# Orchestrator is overwhelmed
|
353
|
+
self.can_send.clear()
|
354
|
+
logger.info("Received backpressure signal")
|
355
|
+
|
356
|
+
elif msg_type == "resume":
|
357
|
+
# Orchestrator ready for more
|
358
|
+
self.can_send.set()
|
359
|
+
logger.info("Received resume signal")
|
360
|
+
|
361
|
+
except Exception as e:
|
362
|
+
logger.error(f"Error handling message: {e}")
|
363
|
+
|
364
|
+
async def _data_processor(self):
|
365
|
+
"""Process data from source."""
|
366
|
+
try:
|
367
|
+
batch = []
|
368
|
+
|
369
|
+
async for sample in self._load_data_source():
|
370
|
+
# Get image data
|
371
|
+
if sample.image_data:
|
372
|
+
image_data = sample.image_data
|
373
|
+
elif sample.image_url:
|
374
|
+
image_data = await self._download_image(sample.image_url)
|
375
|
+
if not image_data:
|
376
|
+
self.samples_failed += 1
|
377
|
+
continue
|
378
|
+
else:
|
379
|
+
logger.warning(f"No image data for sample {sample.sample_id}")
|
380
|
+
continue
|
381
|
+
|
382
|
+
# Store if configured
|
383
|
+
if self.storage_config.get("forward_to_orchestrator", True):
|
384
|
+
# Add to send queue
|
385
|
+
batch.append(
|
386
|
+
{
|
387
|
+
"sample_id": sample.sample_id,
|
388
|
+
"image_data": image_data,
|
389
|
+
"metadata": sample.metadata,
|
390
|
+
}
|
391
|
+
)
|
392
|
+
|
393
|
+
if len(batch) >= self.batch_size:
|
394
|
+
# Wait for backpressure clearance
|
395
|
+
await asyncio.wait_for(self.can_send.wait(), timeout=300)
|
396
|
+
|
397
|
+
# Add batch to send queue
|
398
|
+
try:
|
399
|
+
self.send_queue.put_nowait(batch)
|
400
|
+
batch = []
|
401
|
+
except:
|
402
|
+
# Queue full, wait
|
403
|
+
await asyncio.sleep(1)
|
404
|
+
|
405
|
+
# Store locally/S3 if configured
|
406
|
+
if self.storage_config.get("local", {}).get("enabled") or self.storage_config.get(
|
407
|
+
"s3", {}
|
408
|
+
).get("enabled"):
|
409
|
+
if await self._store_sample(sample, image_data):
|
410
|
+
self.samples_stored += 1
|
411
|
+
|
412
|
+
# Send remaining batch
|
413
|
+
if batch and self.storage_config.get("forward_to_orchestrator", True):
|
414
|
+
await asyncio.wait_for(self.can_send.wait(), timeout=300)
|
415
|
+
self.send_queue.put_nowait(batch)
|
416
|
+
|
417
|
+
except Exception as e:
|
418
|
+
logger.error(f"Data processing error: {e}")
|
419
|
+
|
420
|
+
async def _send_loop(self):
|
421
|
+
"""Send data samples to orchestrator."""
|
422
|
+
while self.running and self.connected.is_set():
|
423
|
+
try:
|
424
|
+
# Get batch from queue
|
425
|
+
batch = await asyncio.get_event_loop().run_in_executor(
|
426
|
+
None, self.send_queue.get, True, 1
|
427
|
+
)
|
428
|
+
|
429
|
+
if batch and self.websocket:
|
430
|
+
# Send samples
|
431
|
+
await self.websocket.send(
|
432
|
+
json.dumps(
|
433
|
+
{
|
434
|
+
"type": "submit_samples",
|
435
|
+
"samples": [
|
436
|
+
{"sample_id": s["sample_id"], "metadata": s["metadata"]}
|
437
|
+
for s in batch
|
438
|
+
],
|
439
|
+
"batch_size": len(batch),
|
440
|
+
}
|
441
|
+
)
|
442
|
+
)
|
443
|
+
|
444
|
+
# Send actual image data separately
|
445
|
+
for sample in batch:
|
446
|
+
await self.websocket.send(sample["image_data"])
|
447
|
+
|
448
|
+
self.samples_sent += len(batch)
|
449
|
+
logger.info(f"Sent batch of {len(batch)} samples")
|
450
|
+
|
451
|
+
except Empty:
|
452
|
+
continue
|
453
|
+
except Exception as e:
|
454
|
+
logger.error(f"Send error: {e}")
|
455
|
+
|
456
|
+
async def _heartbeat_loop(self):
|
457
|
+
"""Send periodic heartbeats."""
|
458
|
+
while self.running and self.connected.is_set():
|
459
|
+
try:
|
460
|
+
await self.websocket.send(
|
461
|
+
json.dumps(
|
462
|
+
{
|
463
|
+
"type": "heartbeat",
|
464
|
+
"sent": self.samples_sent,
|
465
|
+
"stored": self.samples_stored,
|
466
|
+
"failed": self.samples_failed,
|
467
|
+
"queue_size": self.send_queue.qsize(),
|
468
|
+
}
|
469
|
+
)
|
470
|
+
)
|
471
|
+
await asyncio.sleep(30)
|
472
|
+
except:
|
473
|
+
break
|
474
|
+
|
475
|
+
async def shutdown(self):
|
476
|
+
"""Graceful shutdown."""
|
477
|
+
logger.info("Shutting down data worker...")
|
478
|
+
self.running = False
|
479
|
+
self.connected.clear()
|
480
|
+
|
481
|
+
if self.websocket:
|
482
|
+
await self.websocket.close()
|