caption-flow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,482 @@
1
+ """DataWorker for retrieving data from various sources and forwarding to orchestrator or storage."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import ssl
7
+ import io
8
+ import time
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Dict, Any, Optional, List, AsyncIterator
12
+ from queue import Queue, Empty
13
+ from threading import Thread, Event
14
+
15
+ import websockets
16
+ from websockets.client import WebSocketClientProtocol
17
+ import pandas as pd
18
+ import pyarrow.parquet as pq
19
+ from PIL import Image
20
+ import boto3
21
+ from botocore.config import Config
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class DataSample:
28
+ """A single data sample to process."""
29
+
30
+ sample_id: str
31
+ image_url: Optional[str] = None
32
+ image_data: Optional[bytes] = None
33
+ metadata: Optional[Dict[str, Any]] = None
34
+
35
+
36
+ class DataWorker:
37
+ """Worker that retrieves data from various sources and forwards to orchestrator/storage."""
38
+
39
+ def __init__(self, config: Dict[str, Any]):
40
+ self.config = config
41
+ self.server_url = config["server"]
42
+ self.token = config["token"]
43
+ self.name = config.get("name", "data_worker")
44
+
45
+ # Data source configuration
46
+ self.data_source = config.get(
47
+ "data_source"
48
+ ) # Path to .jsonl, .csv, .parquet, or HF dataset
49
+ self.source_type = config.get(
50
+ "source_type", "auto"
51
+ ) # auto, jsonl, csv, parquet, huggingface
52
+ self.batch_size = config.get("batch_size", 10)
53
+
54
+ # Storage configuration (will be updated from orchestrator)
55
+ self.storage_config = None
56
+ self.s3_client = None
57
+
58
+ # SSL configuration
59
+ self.ssl_context = self._setup_ssl()
60
+
61
+ # State
62
+ self.worker_id: Optional[str] = None
63
+ self.websocket: Optional[WebSocketClientProtocol] = None
64
+ self.running = False
65
+ self.connected = Event()
66
+ self.can_send = Event() # For backpressure
67
+
68
+ # Queues
69
+ self.send_queue = Queue(maxsize=100)
70
+
71
+ # Metrics
72
+ self.samples_sent = 0
73
+ self.samples_stored = 0
74
+ self.samples_failed = 0
75
+
76
+ def _setup_ssl(self) -> Optional[ssl.SSLContext]:
77
+ """Configure SSL context."""
78
+ if self.server_url.startswith("ws://"):
79
+ logger.warning("Using insecure WebSocket connection")
80
+ return None
81
+
82
+ if not self.config.get("verify_ssl", True):
83
+ context = ssl.create_default_context()
84
+ context.check_hostname = False
85
+ context.verify_mode = ssl.CERT_NONE
86
+ return context
87
+
88
+ return ssl.create_default_context()
89
+
90
+ def _setup_s3_client(self, s3_config: Dict[str, Any]):
91
+ """Setup S3 client from config."""
92
+ if not s3_config:
93
+ return None
94
+
95
+ try:
96
+ self.s3_client = boto3.client(
97
+ "s3",
98
+ endpoint_url=s3_config.get("endpoint_url"),
99
+ aws_access_key_id=s3_config.get("access_key"),
100
+ aws_secret_access_key=s3_config.get("secret_key"),
101
+ region_name=s3_config.get("region", "us-east-1"),
102
+ config=Config(signature_version="s3v4"),
103
+ )
104
+ self.s3_bucket = s3_config.get("bucket")
105
+ logger.info(f"S3 client configured for bucket: {self.s3_bucket}")
106
+ return self.s3_client
107
+ except Exception as e:
108
+ logger.error(f"Failed to setup S3 client: {e}")
109
+ return None
110
+
111
+ async def _load_data_source(self) -> AsyncIterator[DataSample]:
112
+ """Load data from configured source."""
113
+ source_type = self.source_type
114
+
115
+ if source_type == "auto":
116
+ # Auto-detect based on file extension
117
+ if self.data_source.endswith(".jsonl"):
118
+ source_type = "jsonl"
119
+ elif self.data_source.endswith(".csv"):
120
+ source_type = "csv"
121
+ elif self.data_source.endswith(".parquet"):
122
+ source_type = "parquet"
123
+ elif self.data_source.startswith("hf://") or "/" in self.data_source:
124
+ source_type = "huggingface"
125
+
126
+ logger.info(f"Loading data from {source_type} source: {self.data_source}")
127
+
128
+ if source_type == "jsonl":
129
+ async for sample in self._load_jsonl():
130
+ yield sample
131
+ elif source_type == "csv":
132
+ async for sample in self._load_csv():
133
+ yield sample
134
+ elif source_type == "parquet":
135
+ async for sample in self._load_parquet():
136
+ yield sample
137
+ elif source_type == "huggingface":
138
+ async for sample in self._load_huggingface():
139
+ yield sample
140
+ else:
141
+ raise ValueError(f"Unknown source type: {source_type}")
142
+
143
+ async def _load_jsonl(self) -> AsyncIterator[DataSample]:
144
+ """Load data from JSONL file with URL list."""
145
+ with open(self.data_source, "r") as f:
146
+ for line_num, line in enumerate(f):
147
+ try:
148
+ data = json.loads(line.strip())
149
+ sample = DataSample(
150
+ sample_id=data.get("id", f"sample_{line_num}"),
151
+ image_url=data.get("url") or data.get("image_url"),
152
+ metadata=data,
153
+ )
154
+ yield sample
155
+ except Exception as e:
156
+ logger.error(f"Error loading line {line_num}: {e}")
157
+
158
+ async def _load_csv(self) -> AsyncIterator[DataSample]:
159
+ """Load data from CSV file."""
160
+ df = pd.read_csv(self.data_source)
161
+
162
+ # Try to find URL column
163
+ url_cols = [col for col in df.columns if "url" in col.lower() or "link" in col.lower()]
164
+ url_col = url_cols[0] if url_cols else None
165
+
166
+ for idx, row in df.iterrows():
167
+ sample = DataSample(
168
+ sample_id=str(row.get("id", idx)),
169
+ image_url=row.get(url_col) if url_col else None,
170
+ metadata=row.to_dict(),
171
+ )
172
+ yield sample
173
+
174
+ async def _load_parquet(self) -> AsyncIterator[DataSample]:
175
+ """Load data from Parquet file."""
176
+ table = pq.read_table(self.data_source)
177
+ df = table.to_pandas()
178
+
179
+ # Try to find URL column
180
+ url_cols = [col for col in df.columns if "url" in col.lower() or "link" in col.lower()]
181
+ url_col = url_cols[0] if url_cols else None
182
+
183
+ for idx, row in df.iterrows():
184
+ sample = DataSample(
185
+ sample_id=str(row.get("id", idx)),
186
+ image_url=row.get(url_col) if url_col else None,
187
+ metadata=row.to_dict(),
188
+ )
189
+ yield sample
190
+
191
+ async def _load_huggingface(self) -> AsyncIterator[DataSample]:
192
+ """Load data from HuggingFace dataset."""
193
+ from datasets import load_dataset
194
+
195
+ # Parse dataset path
196
+ if self.data_source.startswith("hf://"):
197
+ dataset_path = self.data_source[5:]
198
+ else:
199
+ dataset_path = self.data_source
200
+
201
+ # Load dataset
202
+ ds = load_dataset(dataset_path, split="train", streaming=True)
203
+
204
+ for idx, item in enumerate(ds):
205
+ # Try to find image data
206
+ image_url = None
207
+ image_data = None
208
+
209
+ if "image" in item and hasattr(item["image"], "save"):
210
+ # PIL Image
211
+ buffer = io.BytesIO()
212
+ item["image"].save(buffer, format="PNG")
213
+ image_data = buffer.getvalue()
214
+ elif "url" in item:
215
+ image_url = item["url"]
216
+ elif "image_url" in item:
217
+ image_url = item["image_url"]
218
+
219
+ sample = DataSample(
220
+ sample_id=item.get("id", f"hf_{idx}"),
221
+ image_url=image_url,
222
+ image_data=image_data,
223
+ metadata=item,
224
+ )
225
+ yield sample
226
+
227
+ async def _download_image(self, url: str) -> Optional[bytes]:
228
+ """Download image from URL."""
229
+ try:
230
+ import aiohttp
231
+
232
+ async with aiohttp.ClientSession() as session:
233
+ async with session.get(url, timeout=30) as response:
234
+ if response.status == 200:
235
+ return await response.read()
236
+ except Exception as e:
237
+ logger.error(f"Failed to download image from {url}: {e}")
238
+ return None
239
+
240
+ async def _store_sample(self, sample: DataSample, image_data: bytes) -> bool:
241
+ """Store sample according to storage config."""
242
+ stored = False
243
+
244
+ # Store locally if configured
245
+ if self.storage_config.get("local", {}).get("enabled"):
246
+ local_dir = Path(self.storage_config["local"].get("path", "./data"))
247
+ local_dir.mkdir(parents=True, exist_ok=True)
248
+
249
+ try:
250
+ # Save image
251
+ image_path = local_dir / f"{sample.sample_id}.jpg"
252
+ with open(image_path, "wb") as f:
253
+ f.write(image_data)
254
+
255
+ # Save metadata
256
+ meta_path = local_dir / f"{sample.sample_id}.json"
257
+ with open(meta_path, "w") as f:
258
+ json.dump(sample.metadata or {}, f)
259
+
260
+ stored = True
261
+ except Exception as e:
262
+ logger.error(f"Failed to store locally: {e}")
263
+
264
+ # Store to S3 if configured
265
+ if self.storage_config.get("s3", {}).get("enabled") and self.s3_client:
266
+ try:
267
+ # Upload image
268
+ self.s3_client.put_object(
269
+ Bucket=self.s3_bucket, Key=f"images/{sample.sample_id}.jpg", Body=image_data
270
+ )
271
+
272
+ # Upload metadata
273
+ if sample.metadata:
274
+ self.s3_client.put_object(
275
+ Bucket=self.s3_bucket,
276
+ Key=f"metadata/{sample.sample_id}.json",
277
+ Body=json.dumps(sample.metadata),
278
+ )
279
+
280
+ stored = True
281
+ except Exception as e:
282
+ logger.error(f"Failed to store to S3: {e}")
283
+
284
+ return stored
285
+
286
+ async def start(self):
287
+ """Start the data worker."""
288
+ self.running = True
289
+
290
+ # Connect and get configuration
291
+ while self.running:
292
+ try:
293
+ await self._connect_and_run()
294
+ except Exception as e:
295
+ logger.error(f"Connection error: {e}")
296
+ await asyncio.sleep(5)
297
+
298
+ async def _connect_and_run(self):
299
+ """Connect to orchestrator and process data."""
300
+ logger.info(f"Connecting to {self.server_url}")
301
+
302
+ async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
303
+ self.websocket = websocket
304
+ self.connected.set()
305
+ self.can_send.set() # Start with ability to send
306
+
307
+ # Authenticate
308
+ await websocket.send(
309
+ json.dumps({"token": self.token, "name": self.name, "role": "data_worker"})
310
+ )
311
+
312
+ # Wait for welcome message with storage config
313
+ welcome = await websocket.recv()
314
+ welcome_data = json.loads(welcome)
315
+
316
+ if "error" in welcome_data:
317
+ logger.error(f"Authentication failed: {welcome_data['error']}")
318
+ self.running = False
319
+ return
320
+
321
+ self.worker_id = welcome_data.get("worker_id")
322
+ self.storage_config = welcome_data.get("storage_config", {})
323
+
324
+ # Setup S3 if configured
325
+ if self.storage_config.get("s3", {}).get("enabled"):
326
+ self._setup_s3_client(self.storage_config["s3"])
327
+
328
+ logger.info(f"Connected as {self.worker_id}")
329
+ logger.info(f"Storage config: {self.storage_config}")
330
+
331
+ # Start processing
332
+ tasks = [
333
+ asyncio.create_task(self._message_handler()),
334
+ asyncio.create_task(self._data_processor()),
335
+ asyncio.create_task(self._send_loop()),
336
+ asyncio.create_task(self._heartbeat_loop()),
337
+ ]
338
+
339
+ done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
340
+
341
+ for task in pending:
342
+ task.cancel()
343
+
344
+ async def _message_handler(self):
345
+ """Handle messages from orchestrator."""
346
+ async for message in self.websocket:
347
+ try:
348
+ data = json.loads(message)
349
+ msg_type = data.get("type")
350
+
351
+ if msg_type == "backpressure":
352
+ # Orchestrator is overwhelmed
353
+ self.can_send.clear()
354
+ logger.info("Received backpressure signal")
355
+
356
+ elif msg_type == "resume":
357
+ # Orchestrator ready for more
358
+ self.can_send.set()
359
+ logger.info("Received resume signal")
360
+
361
+ except Exception as e:
362
+ logger.error(f"Error handling message: {e}")
363
+
364
+ async def _data_processor(self):
365
+ """Process data from source."""
366
+ try:
367
+ batch = []
368
+
369
+ async for sample in self._load_data_source():
370
+ # Get image data
371
+ if sample.image_data:
372
+ image_data = sample.image_data
373
+ elif sample.image_url:
374
+ image_data = await self._download_image(sample.image_url)
375
+ if not image_data:
376
+ self.samples_failed += 1
377
+ continue
378
+ else:
379
+ logger.warning(f"No image data for sample {sample.sample_id}")
380
+ continue
381
+
382
+ # Store if configured
383
+ if self.storage_config.get("forward_to_orchestrator", True):
384
+ # Add to send queue
385
+ batch.append(
386
+ {
387
+ "sample_id": sample.sample_id,
388
+ "image_data": image_data,
389
+ "metadata": sample.metadata,
390
+ }
391
+ )
392
+
393
+ if len(batch) >= self.batch_size:
394
+ # Wait for backpressure clearance
395
+ await asyncio.wait_for(self.can_send.wait(), timeout=300)
396
+
397
+ # Add batch to send queue
398
+ try:
399
+ self.send_queue.put_nowait(batch)
400
+ batch = []
401
+ except:
402
+ # Queue full, wait
403
+ await asyncio.sleep(1)
404
+
405
+ # Store locally/S3 if configured
406
+ if self.storage_config.get("local", {}).get("enabled") or self.storage_config.get(
407
+ "s3", {}
408
+ ).get("enabled"):
409
+ if await self._store_sample(sample, image_data):
410
+ self.samples_stored += 1
411
+
412
+ # Send remaining batch
413
+ if batch and self.storage_config.get("forward_to_orchestrator", True):
414
+ await asyncio.wait_for(self.can_send.wait(), timeout=300)
415
+ self.send_queue.put_nowait(batch)
416
+
417
+ except Exception as e:
418
+ logger.error(f"Data processing error: {e}")
419
+
420
+ async def _send_loop(self):
421
+ """Send data samples to orchestrator."""
422
+ while self.running and self.connected.is_set():
423
+ try:
424
+ # Get batch from queue
425
+ batch = await asyncio.get_event_loop().run_in_executor(
426
+ None, self.send_queue.get, True, 1
427
+ )
428
+
429
+ if batch and self.websocket:
430
+ # Send samples
431
+ await self.websocket.send(
432
+ json.dumps(
433
+ {
434
+ "type": "submit_samples",
435
+ "samples": [
436
+ {"sample_id": s["sample_id"], "metadata": s["metadata"]}
437
+ for s in batch
438
+ ],
439
+ "batch_size": len(batch),
440
+ }
441
+ )
442
+ )
443
+
444
+ # Send actual image data separately
445
+ for sample in batch:
446
+ await self.websocket.send(sample["image_data"])
447
+
448
+ self.samples_sent += len(batch)
449
+ logger.info(f"Sent batch of {len(batch)} samples")
450
+
451
+ except Empty:
452
+ continue
453
+ except Exception as e:
454
+ logger.error(f"Send error: {e}")
455
+
456
+ async def _heartbeat_loop(self):
457
+ """Send periodic heartbeats."""
458
+ while self.running and self.connected.is_set():
459
+ try:
460
+ await self.websocket.send(
461
+ json.dumps(
462
+ {
463
+ "type": "heartbeat",
464
+ "sent": self.samples_sent,
465
+ "stored": self.samples_stored,
466
+ "failed": self.samples_failed,
467
+ "queue_size": self.send_queue.qsize(),
468
+ }
469
+ )
470
+ )
471
+ await asyncio.sleep(30)
472
+ except:
473
+ break
474
+
475
+ async def shutdown(self):
476
+ """Graceful shutdown."""
477
+ logger.info("Shutting down data worker...")
478
+ self.running = False
479
+ self.connected.clear()
480
+
481
+ if self.websocket:
482
+ await self.websocket.close()