caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,6 @@
3
3
  import asyncio
4
4
  import json
5
5
  import logging
6
- import ssl
7
6
  import io
8
7
  import time
9
8
  from dataclasses import dataclass
@@ -12,14 +11,14 @@ from typing import Dict, Any, Optional, List, AsyncIterator
12
11
  from queue import Queue, Empty
13
12
  from threading import Thread, Event
14
13
 
15
- import websockets
16
- from websockets.client import WebSocketClientProtocol
17
14
  import pandas as pd
18
15
  import pyarrow.parquet as pq
19
16
  from PIL import Image
20
17
  import boto3
21
18
  from botocore.config import Config
22
19
 
20
+ from .base import BaseWorker
21
+
23
22
  logger = logging.getLogger(__name__)
24
23
 
25
24
 
@@ -33,59 +32,94 @@ class DataSample:
33
32
  metadata: Optional[Dict[str, Any]] = None
34
33
 
35
34
 
36
- class DataWorker:
35
+ class DataWorker(BaseWorker):
37
36
  """Worker that retrieves data from various sources and forwards to orchestrator/storage."""
38
37
 
39
38
  def __init__(self, config: Dict[str, Any]):
40
- self.config = config
41
- self.server_url = config["server"]
42
- self.token = config["token"]
43
- self.name = config.get("name", "data_worker")
39
+ super().__init__(config)
44
40
 
45
41
  # Data source configuration
46
- self.data_source = config.get(
47
- "data_source"
48
- ) # Path to .jsonl, .csv, .parquet, or HF dataset
49
- self.source_type = config.get(
50
- "source_type", "auto"
51
- ) # auto, jsonl, csv, parquet, huggingface
42
+ self.data_source = config.get("data_source")
43
+ self.source_type = config.get("source_type", "auto")
52
44
  self.batch_size = config.get("batch_size", 10)
53
45
 
54
46
  # Storage configuration (will be updated from orchestrator)
55
47
  self.storage_config = None
56
48
  self.s3_client = None
57
49
 
58
- # SSL configuration
59
- self.ssl_context = self._setup_ssl()
60
-
61
- # State
62
- self.worker_id: Optional[str] = None
63
- self.websocket: Optional[WebSocketClientProtocol] = None
64
- self.running = False
65
- self.connected = Event()
50
+ # State specific to data worker
66
51
  self.can_send = Event() # For backpressure
67
52
 
68
53
  # Queues
69
54
  self.send_queue = Queue(maxsize=100)
70
55
 
71
- # Metrics
56
+ def _init_metrics(self):
57
+ """Initialize data worker metrics."""
72
58
  self.samples_sent = 0
73
59
  self.samples_stored = 0
74
60
  self.samples_failed = 0
75
61
 
76
- def _setup_ssl(self) -> Optional[ssl.SSLContext]:
77
- """Configure SSL context."""
78
- if self.server_url.startswith("ws://"):
79
- logger.warning("Using insecure WebSocket connection")
80
- return None
81
-
82
- if not self.config.get("verify_ssl", True):
83
- context = ssl.create_default_context()
84
- context.check_hostname = False
85
- context.verify_mode = ssl.CERT_NONE
86
- return context
87
-
88
- return ssl.create_default_context()
62
+ def _get_auth_data(self) -> Dict[str, Any]:
63
+ """Get authentication data."""
64
+ return {"token": self.token, "name": self.name, "role": "data_worker"}
65
+
66
+ async def _handle_welcome(self, welcome_data: Dict[str, Any]):
67
+ """Handle welcome message from orchestrator."""
68
+ self.storage_config = welcome_data.get("storage_config", {})
69
+
70
+ # Setup S3 if configured
71
+ if self.storage_config.get("s3", {}).get("enabled"):
72
+ self._setup_s3_client(self.storage_config["s3"])
73
+
74
+ logger.info(f"Storage config: {self.storage_config}")
75
+
76
+ # Start with ability to send
77
+ self.can_send.set()
78
+
79
+ async def _handle_message(self, data: Dict[str, Any]):
80
+ """Handle message from orchestrator."""
81
+ msg_type = data.get("type")
82
+
83
+ if msg_type == "backpressure":
84
+ # Orchestrator is overwhelmed
85
+ self.can_send.clear()
86
+ logger.info("Received backpressure signal")
87
+
88
+ elif msg_type == "resume":
89
+ # Orchestrator ready for more
90
+ self.can_send.set()
91
+ logger.info("Received resume signal")
92
+
93
+ def _get_heartbeat_data(self) -> Dict[str, Any]:
94
+ """Get heartbeat data."""
95
+ return {
96
+ "type": "heartbeat",
97
+ "sent": self.samples_sent,
98
+ "stored": self.samples_stored,
99
+ "failed": self.samples_failed,
100
+ "queue_size": self.send_queue.qsize(),
101
+ }
102
+
103
+ async def _create_tasks(self) -> list:
104
+ """Create async tasks to run."""
105
+ return [
106
+ asyncio.create_task(self._heartbeat_loop()),
107
+ asyncio.create_task(self._base_message_handler()),
108
+ asyncio.create_task(self._data_processor()),
109
+ asyncio.create_task(self._send_loop()),
110
+ ]
111
+
112
+ async def _on_disconnect(self):
113
+ """Handle disconnection."""
114
+ # Clear send capability
115
+ self.can_send.clear()
116
+
117
+ # Clear send queue
118
+ try:
119
+ while True:
120
+ self.send_queue.get_nowait()
121
+ except Empty:
122
+ pass
89
123
 
90
124
  def _setup_s3_client(self, s3_config: Dict[str, Any]):
91
125
  """Setup S3 client from config."""
@@ -108,6 +142,98 @@ class DataWorker:
108
142
  logger.error(f"Failed to setup S3 client: {e}")
109
143
  return None
110
144
 
145
+ async def _data_processor(self):
146
+ """Process data from source."""
147
+ try:
148
+ batch = []
149
+
150
+ async for sample in self._load_data_source():
151
+ # Get image data
152
+ if sample.image_data:
153
+ image_data = sample.image_data
154
+ elif sample.image_url:
155
+ image_data = await self._download_image(sample.image_url)
156
+ if not image_data:
157
+ self.samples_failed += 1
158
+ continue
159
+ else:
160
+ logger.warning(f"No image data for sample {sample.sample_id}")
161
+ continue
162
+
163
+ # Store if configured
164
+ if self.storage_config.get("forward_to_orchestrator", True):
165
+ # Add to send queue
166
+ batch.append(
167
+ {
168
+ "sample_id": sample.sample_id,
169
+ "image_data": image_data,
170
+ "metadata": sample.metadata,
171
+ }
172
+ )
173
+
174
+ if len(batch) >= self.batch_size:
175
+ # Wait for backpressure clearance
176
+ await asyncio.wait_for(self.can_send.wait(), timeout=300)
177
+
178
+ # Add batch to send queue
179
+ try:
180
+ self.send_queue.put_nowait(batch)
181
+ batch = []
182
+ except:
183
+ # Queue full, wait
184
+ await asyncio.sleep(1)
185
+
186
+ # Store locally/S3 if configured
187
+ if self.storage_config.get("local", {}).get("enabled") or self.storage_config.get(
188
+ "s3", {}
189
+ ).get("enabled"):
190
+ if await self._store_sample(sample, image_data):
191
+ self.samples_stored += 1
192
+
193
+ # Send remaining batch
194
+ if batch and self.storage_config.get("forward_to_orchestrator", True):
195
+ await asyncio.wait_for(self.can_send.wait(), timeout=300)
196
+ self.send_queue.put_nowait(batch)
197
+
198
+ except Exception as e:
199
+ logger.error(f"Data processing error: {e}")
200
+
201
+ async def _send_loop(self):
202
+ """Send data samples to orchestrator."""
203
+ while self.running and self.connected.is_set():
204
+ try:
205
+ # Get batch from queue
206
+ batch = await asyncio.get_event_loop().run_in_executor(
207
+ None, self.send_queue.get, True, 1
208
+ )
209
+
210
+ if batch and self.websocket:
211
+ # Send samples
212
+ await self.websocket.send(
213
+ json.dumps(
214
+ {
215
+ "type": "submit_samples",
216
+ "samples": [
217
+ {"sample_id": s["sample_id"], "metadata": s["metadata"]}
218
+ for s in batch
219
+ ],
220
+ "batch_size": len(batch),
221
+ }
222
+ )
223
+ )
224
+
225
+ # Send actual image data separately
226
+ for sample in batch:
227
+ await self.websocket.send(sample["image_data"])
228
+
229
+ self.samples_sent += len(batch)
230
+ logger.info(f"Sent batch of {len(batch)} samples")
231
+
232
+ except Empty:
233
+ continue
234
+ except Exception as e:
235
+ logger.error(f"Send error: {e}")
236
+
111
237
  async def _load_data_source(self) -> AsyncIterator[DataSample]:
112
238
  """Load data from configured source."""
113
239
  source_type = self.source_type
@@ -282,201 +408,3 @@ class DataWorker:
282
408
  logger.error(f"Failed to store to S3: {e}")
283
409
 
284
410
  return stored
285
-
286
- async def start(self):
287
- """Start the data worker."""
288
- self.running = True
289
-
290
- # Connect and get configuration
291
- while self.running:
292
- try:
293
- await self._connect_and_run()
294
- except Exception as e:
295
- logger.error(f"Connection error: {e}")
296
- await asyncio.sleep(5)
297
-
298
- async def _connect_and_run(self):
299
- """Connect to orchestrator and process data."""
300
- logger.info(f"Connecting to {self.server_url}")
301
-
302
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
303
- self.websocket = websocket
304
- self.connected.set()
305
- self.can_send.set() # Start with ability to send
306
-
307
- # Authenticate
308
- await websocket.send(
309
- json.dumps({"token": self.token, "name": self.name, "role": "data_worker"})
310
- )
311
-
312
- # Wait for welcome message with storage config
313
- welcome = await websocket.recv()
314
- welcome_data = json.loads(welcome)
315
-
316
- if "error" in welcome_data:
317
- logger.error(f"Authentication failed: {welcome_data['error']}")
318
- self.running = False
319
- return
320
-
321
- self.worker_id = welcome_data.get("worker_id")
322
- self.storage_config = welcome_data.get("storage_config", {})
323
-
324
- # Setup S3 if configured
325
- if self.storage_config.get("s3", {}).get("enabled"):
326
- self._setup_s3_client(self.storage_config["s3"])
327
-
328
- logger.info(f"Connected as {self.worker_id}")
329
- logger.info(f"Storage config: {self.storage_config}")
330
-
331
- # Start processing
332
- tasks = [
333
- asyncio.create_task(self._message_handler()),
334
- asyncio.create_task(self._data_processor()),
335
- asyncio.create_task(self._send_loop()),
336
- asyncio.create_task(self._heartbeat_loop()),
337
- ]
338
-
339
- done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
340
-
341
- for task in pending:
342
- task.cancel()
343
-
344
- async def _message_handler(self):
345
- """Handle messages from orchestrator."""
346
- async for message in self.websocket:
347
- try:
348
- data = json.loads(message)
349
- msg_type = data.get("type")
350
-
351
- if msg_type == "backpressure":
352
- # Orchestrator is overwhelmed
353
- self.can_send.clear()
354
- logger.info("Received backpressure signal")
355
-
356
- elif msg_type == "resume":
357
- # Orchestrator ready for more
358
- self.can_send.set()
359
- logger.info("Received resume signal")
360
-
361
- except Exception as e:
362
- logger.error(f"Error handling message: {e}")
363
-
364
- async def _data_processor(self):
365
- """Process data from source."""
366
- try:
367
- batch = []
368
-
369
- async for sample in self._load_data_source():
370
- # Get image data
371
- if sample.image_data:
372
- image_data = sample.image_data
373
- elif sample.image_url:
374
- image_data = await self._download_image(sample.image_url)
375
- if not image_data:
376
- self.samples_failed += 1
377
- continue
378
- else:
379
- logger.warning(f"No image data for sample {sample.sample_id}")
380
- continue
381
-
382
- # Store if configured
383
- if self.storage_config.get("forward_to_orchestrator", True):
384
- # Add to send queue
385
- batch.append(
386
- {
387
- "sample_id": sample.sample_id,
388
- "image_data": image_data,
389
- "metadata": sample.metadata,
390
- }
391
- )
392
-
393
- if len(batch) >= self.batch_size:
394
- # Wait for backpressure clearance
395
- await asyncio.wait_for(self.can_send.wait(), timeout=300)
396
-
397
- # Add batch to send queue
398
- try:
399
- self.send_queue.put_nowait(batch)
400
- batch = []
401
- except:
402
- # Queue full, wait
403
- await asyncio.sleep(1)
404
-
405
- # Store locally/S3 if configured
406
- if self.storage_config.get("local", {}).get("enabled") or self.storage_config.get(
407
- "s3", {}
408
- ).get("enabled"):
409
- if await self._store_sample(sample, image_data):
410
- self.samples_stored += 1
411
-
412
- # Send remaining batch
413
- if batch and self.storage_config.get("forward_to_orchestrator", True):
414
- await asyncio.wait_for(self.can_send.wait(), timeout=300)
415
- self.send_queue.put_nowait(batch)
416
-
417
- except Exception as e:
418
- logger.error(f"Data processing error: {e}")
419
-
420
- async def _send_loop(self):
421
- """Send data samples to orchestrator."""
422
- while self.running and self.connected.is_set():
423
- try:
424
- # Get batch from queue
425
- batch = await asyncio.get_event_loop().run_in_executor(
426
- None, self.send_queue.get, True, 1
427
- )
428
-
429
- if batch and self.websocket:
430
- # Send samples
431
- await self.websocket.send(
432
- json.dumps(
433
- {
434
- "type": "submit_samples",
435
- "samples": [
436
- {"sample_id": s["sample_id"], "metadata": s["metadata"]}
437
- for s in batch
438
- ],
439
- "batch_size": len(batch),
440
- }
441
- )
442
- )
443
-
444
- # Send actual image data separately
445
- for sample in batch:
446
- await self.websocket.send(sample["image_data"])
447
-
448
- self.samples_sent += len(batch)
449
- logger.info(f"Sent batch of {len(batch)} samples")
450
-
451
- except Empty:
452
- continue
453
- except Exception as e:
454
- logger.error(f"Send error: {e}")
455
-
456
- async def _heartbeat_loop(self):
457
- """Send periodic heartbeats."""
458
- while self.running and self.connected.is_set():
459
- try:
460
- await self.websocket.send(
461
- json.dumps(
462
- {
463
- "type": "heartbeat",
464
- "sent": self.samples_sent,
465
- "stored": self.samples_stored,
466
- "failed": self.samples_failed,
467
- "queue_size": self.send_queue.qsize(),
468
- }
469
- )
470
- )
471
- await asyncio.sleep(30)
472
- except:
473
- break
474
-
475
- async def shutdown(self):
476
- """Graceful shutdown."""
477
- logger.info("Shutting down data worker...")
478
- self.running = False
479
- self.connected.clear()
480
-
481
- if self.websocket:
482
- await self.websocket.close()