caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,369 @@
1
+ Metadata-Version: 2.4
2
+ Name: caption-flow
3
+ Version: 0.2.0
4
+ Summary: Self-contained distributed community captioning system
5
+ Author-email: bghira <bghira@users.github.com>
6
+ License: MIT
7
+ Keywords: captioning,distributed,vllm,dataset,community
8
+ Classifier: Development Status :: 4 - Beta
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Python: <3.13,>=3.10
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: websockets>=12.0
19
+ Requires-Dist: pyarrow>=14.0.0
20
+ Requires-Dist: click>=8.1.0
21
+ Requires-Dist: pydantic>=2.0.0
22
+ Requires-Dist: aiofiles>=23.0.0
23
+ Requires-Dist: rich>=13.0.0
24
+ Requires-Dist: cryptography>=41.0.0
25
+ Requires-Dist: pyyaml>=6.0
26
+ Requires-Dist: certbot>=2.0.0
27
+ Requires-Dist: numpy>=1.24.0
28
+ Requires-Dist: pillow>=10.0.0
29
+ Requires-Dist: vllm<0.11.0,>=0.10.0
30
+ Requires-Dist: webdataset<2.0.0,>=1.0.2
31
+ Requires-Dist: pandas<3.0.0,>=2.3.1
32
+ Requires-Dist: arrow<2.0.0,>=1.3.0
33
+ Requires-Dist: datasets<5.0.0,>=4.0.0
34
+ Requires-Dist: boto3<2.0.0,>=1.40.11
35
+ Provides-Extra: dev
36
+ Requires-Dist: pytest>=7.4.0; extra == "dev"
37
+ Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
38
+ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
39
+ Requires-Dist: black>=23.0.0; extra == "dev"
40
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
41
+ Requires-Dist: mypy>=1.5.0; extra == "dev"
42
+ Dynamic: license-file
43
+
44
+ # CaptionFlow
45
+
46
+ scalable, fault-tolerant **vLLM-powered image captioning**. this "first round" focuses on a fast websocket orchestrator plus lightweight gpu workers that batch requests through vLLM.
47
+
48
+ * **orchestrator**: hands out work in chunked shards, collects captions, checkpoints progress, and keeps simple stats.
49
+ * **workers (vLLM)**: connect to the orchestrator, stream in image samples, batch them, and generate 1..N captions per image using prompts supplied by the orchestrator.
50
+ * **config-driven**: all components read YAML config; flags can override.
51
+ * **tui monitor (optional)**: a monitor client is wired into the CLI; ship a `monitor` module to enable it.
52
+
53
+ > no conda. just `venv` + `pip`.
54
+
55
+ ---
56
+
57
+ ## install
58
+
59
+ ```bash
60
+ python -m venv .venv
61
+ source .venv/bin/activate # windows: .venv\Scripts\activate
62
+ pip install --upgrade pip
63
+ pip install -e . # installs the `caption-flow` command
64
+ ```
65
+
66
+ ## quickstart (single box)
67
+
68
+ 1. copy + edit the sample configs
69
+
70
+ ```bash
71
+ cp orchestrator.yaml my-orchestrator.yaml
72
+ cp worker.yaml my-worker.yaml
73
+ cp monitor.yaml my-monitor.yaml # optional; requires a monitor module
74
+ ```
75
+
76
+ set a unique shared token in both `my-orchestrator.yaml` and `my-worker.yaml` (see `auth.worker_tokens` in the orchestrator config and `worker.token` in the worker config). if you use private hugging face datasets/models, export `HUGGINGFACE_HUB_TOKEN` before starting workers.
77
+
78
+ 2. start the orchestrator
79
+
80
+ ```bash
81
+ caption-flow orchestrator --config my-orchestrator.yaml
82
+ ```
83
+
84
+ 3. start one or more vLLM workers
85
+
86
+ ```bash
87
+ # gpu 0 on the same host
88
+ caption-flow worker --config my-worker.yaml --gpu-id 0
89
+
90
+ # your second GPU
91
+ caption-flow worker --config my-worker.yaml --gpu-id 1
92
+ ```
93
+
94
+ 4. (optional) start the monitor
95
+
96
+ ```bash
97
+ caption-flow monitor --config my-monitor.yaml
98
+ ```
99
+
100
+ 5. (optional) scan/fix chunks on disk if you had crashes
101
+
102
+ ```bash
103
+ caption-flow scan_chunks --data-dir ./caption_data --checkpoint-dir ./checkpoints --fix
104
+ ```
105
+
106
+ ---
107
+
108
+ ## how it’s wired
109
+
110
+ ### orchestrator
111
+
112
+ * **websocket server** (default `0.0.0.0:8765`) with three client roles: workers, data-feeders, and admin.
113
+ * **dataset control**: the orchestrator centrally defines the dataset (`huggingface` or `local`) and version/name. it chunk-slices shards and assigns work.
114
+ * **vLLM config broadcast**: model, tp size, dtype, max seq len, memory targets, batching, sampling params, and **inference prompts** are all pushed to workers; workers can apply many changes without a model reload.
115
+ * **storage + checkpoints**: captions buffer to disk with periodic checkpoints. chunk state is tracked so restarts don’t double-work.
116
+ * **auth**: token lists for `worker`, `monitor`, and `admin` roles.
117
+
118
+ start flags you’ll likely use:
119
+
120
+ ```text
121
+ --config PATH # yaml config for the orchestrator
122
+ --port INT, --host STR # bind controls
123
+ --data-dir PATH # overrides storage.data_dir
124
+ --cert PATH, --key PATH # enable TLS (or use --no-ssl for ws:// in dev)
125
+ --vllm # use the vLLM-style orchestrator (webdataset/hf)
126
+ ```
127
+
128
+ ### vLLM worker
129
+
130
+ * **one process per gpu**. select the device with `--gpu-id` (or `worker.gpu_id` in YAML).
131
+ * **gets its marching orders** from the orchestrator: dataset info, model, prompts, batch size, and sampling.
132
+ * **resilient**: detects disconnects, abandons the current chunk cleanly, clears queues, reconnects, and resumes.
133
+ * **batched generate()**: images are resized down for consistent batching; each image can get multiple captions (one per prompt).
134
+
135
+ start flags you’ll likely use:
136
+
137
+ ```text
138
+ --config PATH # yaml for the worker
139
+ --server URL # ws(s)://host:port
140
+ --token STR # must match an allowed worker token on the orchestrator
141
+ --name STR # display name
142
+ --batch-size INT # override vLLM batch size
143
+ --vllm # use the vLLM worker implementation
144
+ --gpu-id INT # which gpu to use
145
+ --precision STR, --model STR # optional overrides for dtype/model
146
+ --no-verify-ssl # accept self-signed certs in dev
147
+ ```
148
+
149
+ ### (optional) monitor
150
+
151
+ * a CLI entry exists for a TUI monitor; wire in a `monitor` module to enable it. config lives in `monitor.yaml` or inside `orchestrator.yaml` under `monitor:`.
152
+
153
+ ---
154
+
155
+ ## configuration
156
+
157
+ ### config discovery order
158
+
159
+ for any component, the CLI looks for config in this order (first match wins):
160
+
161
+ 1. `--config /path/to/file.yaml`
162
+ 2. `./<component>.yaml` (current directory)
163
+ 3. `~/.caption-flow/<component>.yaml`
164
+ 4. `$XDG_CONFIG_HOME/caption-flow/<component>.yaml`
165
+ 5. `/etc/caption-flow/<component>.yaml`
166
+ 6. any `$XDG_CONFIG_DIRS` entries under `caption-flow/`
167
+ 7. `./examples/<component>.yaml` (fallback)
168
+
169
+ ### orchestrator.yaml (highlights)
170
+
171
+ ```yaml
172
+ orchestrator:
173
+ host: 0.0.0.0
174
+ port: 8765
175
+ # ssl:
176
+ # cert: /path/fullchain.pem
177
+ # key: /path/privkey.pem
178
+
179
+ dataset:
180
+ type: huggingface # or "local"
181
+ path: <hf-dataset-or-local-path>
182
+ name: <logical-name>
183
+ version: "1.0"
184
+
185
+ vllm:
186
+ model: Qwen/Qwen2.5-VL-3B-Instruct
187
+ tensor_parallel_size: 1
188
+ max_model_len: 16384
189
+ dtype: float16
190
+ gpu_memory_utilization: 0.92
191
+ enforce_eager: true
192
+ disable_mm_preprocessor_cache: true
193
+ limit_mm_per_prompt: { image: 1 }
194
+
195
+ batch_size: 8
196
+
197
+ sampling:
198
+ temperature: 0.7
199
+ top_p: 0.95
200
+ max_tokens: 256
201
+ repetition_penalty: 1.05
202
+ skip_special_tokens: true
203
+ stop: ["<|end|>", "<|endoftext|>", "<|im_end|>"]
204
+
205
+ inference_prompts:
206
+ - "describe this image in detail"
207
+ - "provide a comprehensive description of the visual content"
208
+ - "what are the key elements in this image?"
209
+
210
+ storage:
211
+ data_dir: ./caption_data
212
+ checkpoint_dir: ./checkpoints
213
+ caption_buffer_size: 100
214
+ checkpoint_interval: 1000
215
+
216
+ # chunking/queueing
217
+ chunk_size: 1000
218
+ chunks_per_request: 2
219
+ chunk_buffer_multiplier: 3
220
+ min_chunk_buffer: 10
221
+
222
+ auth:
223
+ worker_tokens:
224
+ - { token: "example-worker-token", name: "Example Worker" }
225
+ monitor_tokens:
226
+ - { token: "letmein", name: "Default monitor" }
227
+ admin_tokens:
228
+ - { token: "admin-secret-2024", name: "Admin" }
229
+ ```
230
+
231
+ ### worker.yaml (highlights)
232
+
233
+ ```yaml
234
+ worker:
235
+ server: ws://localhost:8765 # use wss:// in prod
236
+ token: example-worker-token
237
+ name: local-gpu
238
+ gpu_id: 0
239
+ vllm: true
240
+
241
+ # local queues
242
+ readahead_size: 256
243
+ inference_queue_size: 128
244
+ ```
245
+
246
+ ### monitor.yaml (optional)
247
+
248
+ ```yaml
249
+ monitor:
250
+ server: ws://localhost:8765
251
+ token: letmein
252
+ refresh_rate: 1.0
253
+ show_contributors: true
254
+ show_quality_metrics: true
255
+ max_activity_items: 20
256
+ show_chunk_progress: true
257
+ show_worker_queues: true
258
+ show_throughput_graph: true
259
+ ```
260
+
261
+ ---
262
+
263
+ ## tls / certificates
264
+
265
+ use the built-in helpers during development:
266
+
267
+ ```bash
268
+ # self-signed certs for quick local testing
269
+ caption-flow generate_cert --self-signed --domain localhost --output-dir ./certs
270
+
271
+ # inspect any certificate file
272
+ caption-flow inspect_cert ./certs/fullchain.pem
273
+ ```
274
+
275
+ then point the orchestrator at the resulting cert/key (or run `--no-ssl` for dev-only ws\://).
276
+
277
+ ---
278
+
279
+ ## tips & notes
280
+
281
+ * **multi-gpu**: start one worker process per gpu (set `--gpu-id` or `worker.gpu_id`).
282
+ * **throughput**: tune `vllm.batch_size` in the orchestrator config (or override with `--batch-size` at worker start). higher isn’t always better; watch VRAM.
283
+ * **prompts**: add more strings under `vllm.inference_prompts` to get multiple captions per image; the worker returns only non-empty generations.
284
+ * **private HF**: if your dataset/model needs auth, export `HUGGINGFACE_HUB_TOKEN` before `caption-flow worker ...`.
285
+ * **self-signed ssl**: pass `--no-verify-ssl` to workers/monitors in dev.
286
+ * **recovery**: if you hard-crash mid-run, `caption-flow scan_chunks --fix` can reset abandoned chunks so the orchestrator can reissue them cleanly.
287
+
288
+ ---
289
+
290
+ ## roadmap
291
+
292
+ * hot config reload via the admin websocket path.
293
+ * dedicated data-feeder clients (separate from gpu workers) that push samples into the orchestrator.
294
+ * richer monitor TUI.
295
+
296
+ PRs welcome. keep it simple and fast.
297
+
298
+ ## architecture
299
+
300
+ ```
301
+ ┌─────────────┐ WebSocket ┌─────────────┐
302
+ │ Worker │◄──────────────────►│ │
303
+ └─────────────┘ │ │ ┌──────────────┐
304
+ │ Orchestrator│────►│Arrow/Parquet │
305
+ ┌─────────────┐ │ │ │ Storage │
306
+ │ Worker │◄──────────────────►│ │ └──────────────┘
307
+ └─────────────┘ └─────────────┘
308
+
309
+ ┌─────────────┐ │
310
+ │ Monitor │◄──────────────────────────┘
311
+ └─────────────┘
312
+ ```
313
+
314
+ ## Storage Schema
315
+
316
+ ### captions.parquet
317
+ - `job_id`: Unique job identifier
318
+ - `dataset`: Dataset name
319
+ - `shard`: Shard identifier
320
+ - `item_key`: Item within shard
321
+ - `caption`: Generated caption text
322
+ - `contributor_id`: Worker who generated it
323
+ - `timestamp`: Generation time
324
+ - `quality_score`: Optional quality metric
325
+
326
+ ### jobs.parquet
327
+ - `job_id`: Unique identifier
328
+ - `dataset`: Dataset name
329
+ - `shard`: Shard identifier
330
+ - `status`: pending/processing/completed/failed
331
+ - `assigned_to`: Worker ID
332
+ - `timestamp`: Status change time
333
+
334
+ ### contributors.parquet
335
+ - `contributor_id`: Unique identifier
336
+ - `name`: Display name
337
+ - `total_captions`: Lifetime count
338
+ - `trust_level`: Quality tier (0-5)
339
+
340
+ ## Development
341
+
342
+ ```bash
343
+ # Install with dev dependencies
344
+ pip install -e ".[dev]"
345
+
346
+ # Run tests
347
+ pytest
348
+
349
+ # Format code
350
+ black src/
351
+ ruff --fix src/
352
+
353
+ # Type checking
354
+ mypy src/
355
+ ```
356
+
357
+ ## Community Contribution
358
+
359
+ To contribute compute:
360
+
361
+ 1. Install caption-flow: `pip install caption-flow`
362
+ 2. Get a worker token from the project maintainer
363
+ 3. Run: `caption-flow worker --server wss://project.domain.com:8765 --token YOUR_TOKEN`
364
+
365
+ Your contributions will be tracked and attributed in the final dataset!
366
+
367
+ ## License
368
+
369
+ MIT
@@ -0,0 +1,29 @@
1
+ caption_flow/__init__.py,sha256=NLPJ25lRN7xHqncXweINDNwbt0q8lgjZ30G21zlPdRs,303
2
+ caption_flow/cli.py,sha256=DVVN4e4uL0jL0gRTaIC5BL0DBU2IU_2yUOi4lg6-lEw,28639
3
+ caption_flow/models.py,sha256=qo6lQiO10UISbaBVr6Cs-fSW_pmjwE6kmiTmmU_l3Wk,2140
4
+ caption_flow/monitor.py,sha256=MltOwBqcFwni1XEPWu5dIO-os5NKDbH_LInOBXUWHAY,7870
5
+ caption_flow/orchestrator.py,sha256=vLW_w5KuRn9Asy_343DxZDRxiUs0xYgbfuuNGgqIf7k,76403
6
+ caption_flow/storage.py,sha256=hC6ZHT_PHFoUVjqD5JUwy3_79oAD1e1H30neA_xsz7s,40748
7
+ caption_flow/utils/__init__.py,sha256=F1BChVoCsj9zn1GJRBOLHET1kLW6xrAmsbzcR7hHy6Y,202
8
+ caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
9
+ caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
10
+ caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
11
+ caption_flow/utils/checkpoint_tracker.py,sha256=8tsTFF-HcygitK92YcS-QWzeg-qRm9AuCpQoQRfC8M0,3335
12
+ caption_flow/utils/chunk_tracker.py,sha256=hKn8CN6ubErc9kuCWZMj12ZCZKxVlqXqAEocbzjfa-k,17296
13
+ caption_flow/utils/dataset_loader.py,sha256=qjoRuPnCv_2nGPfrdqf45AgBXlthw1HwqZ1IqwIXzH4,20792
14
+ caption_flow/utils/image_processor.py,sha256=Zl8TAv9gYPdAYat3UiTuuNdIb2fXNfZ35AxsxuovJTs,5650
15
+ caption_flow/utils/job_queue.py,sha256=itdfXcrkvGjmXn4qtpgMF63k1ufRBaejDe4V6WcxzgU,1104
16
+ caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
17
+ caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
18
+ caption_flow/utils/shard_processor.py,sha256=CRda6M4xh4U0vwvYlzq9nJEzz4d_4yzUBosYAeBcPEA,10854
19
+ caption_flow/utils/shard_tracker.py,sha256=Wt2oE-O85F2FxSnqIocJiaYeFn00OVVjIiklZIZRGL8,3233
20
+ caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
21
+ caption_flow/workers/base.py,sha256=jPm_Xw4Lxd0cnrPs-biBqKRQKkTOJLvHLolmp0Gb1CI,7530
22
+ caption_flow/workers/caption.py,sha256=NZ9kTjk2uOoNwyyNSkB_arYk213vLr5mowHN-OjiFkk,54631
23
+ caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
24
+ caption_flow-0.2.0.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
25
+ caption_flow-0.2.0.dist-info/METADATA,sha256=6qwt05U0S23Omjz1yR6VzLq_wRHbRx_xl3YzhwHyDLc,11900
26
+ caption_flow-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ caption_flow-0.2.0.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
28
+ caption_flow-0.2.0.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
29
+ caption_flow-0.2.0.dist-info/RECORD,,
caption_flow/worker.py DELETED
@@ -1,300 +0,0 @@
1
- """Worker node for distributed captioning."""
2
-
3
- import asyncio
4
- import json
5
- import logging
6
- import ssl
7
- from typing import Dict, Any, Optional
8
- from pathlib import Path
9
-
10
- import websockets
11
- import websockets.exceptions
12
- from websockets.client import WebSocketClientProtocol
13
-
14
- from .models import Job, JobStatus
15
- from .utils.image_processor import ImageProcessor
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- class Worker:
21
- """Worker node that processes captioning jobs."""
22
-
23
- def __init__(self, config: Dict[str, Any]):
24
- self.config = config
25
- self.server_url = config["server"]
26
- self.token = config["token"]
27
- self.name = config.get("name", "worker")
28
- self.batch_size = config.get("batch_size", 32)
29
-
30
- # Dataset configuration will be received from orchestrator
31
- self.dataset_config = None
32
- self.dataset_type = None
33
- self.dataset_path = None
34
-
35
- # SSL configuration
36
- self.ssl_context = self._setup_ssl()
37
-
38
- # Components
39
- self.image_processor = ImageProcessor()
40
-
41
- # State
42
- self.worker_id: Optional[str] = None
43
- self.websocket: Optional[WebSocketClientProtocol] = None
44
- self.running = False
45
- self.current_job: Optional[Job] = None
46
-
47
- # Metrics
48
- self.processed_count = 0
49
- self.error_count = 0
50
-
51
- def _setup_ssl(self) -> Optional[ssl.SSLContext]:
52
- """Configure SSL context."""
53
- # Check if URL is WSS (requires SSL)
54
- if self.server_url.startswith("ws://"):
55
- logger.warning(
56
- "Using insecure WebSocket connection (ws://). Consider using wss:// for production."
57
- )
58
- return None # No SSL for ws://
59
-
60
- if not self.config.get("verify_ssl", True):
61
- # Disable SSL verification for development
62
- context = ssl.create_default_context()
63
- context.check_hostname = False
64
- context.verify_mode = ssl.CERT_NONE
65
- return context
66
-
67
- return ssl.create_default_context()
68
-
69
- async def start(self):
70
- """Start the worker and connect to orchestrator."""
71
- self.running = True
72
-
73
- while self.running:
74
- try:
75
- await self._connect_and_run()
76
- except websockets.exceptions.ConnectionClosed as e:
77
- logger.warning(f"Connection closed: {e}")
78
- if self.running:
79
- logger.info("Reconnecting in 5 seconds...")
80
- await asyncio.sleep(5)
81
- except Exception as e:
82
- logger.error(f"Connection error: {e}")
83
- if self.running:
84
- logger.info("Reconnecting in 5 seconds...")
85
- await asyncio.sleep(5)
86
-
87
- async def _connect_and_run(self):
88
- """Connect to orchestrator and process jobs."""
89
- logger.info(f"Connecting to {self.server_url}")
90
-
91
- try:
92
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
93
- self.websocket = websocket
94
-
95
- # Authenticate
96
- await websocket.send(json.dumps({"token": self.token, "name": self.name}))
97
-
98
- # Wait for welcome message with dataset configuration
99
- welcome = await websocket.recv()
100
- welcome_data = json.loads(welcome)
101
-
102
- if "error" in welcome_data:
103
- logger.error(f"Authentication failed: {welcome_data['error']}")
104
- self.running = False
105
- return
106
-
107
- self.worker_id = welcome_data.get("worker_id")
108
-
109
- # Extract and store dataset configuration from orchestrator
110
- if "dataset_config" in welcome_data:
111
- self.dataset_config = welcome_data["dataset_config"]
112
- self.dataset_type = self.dataset_config.get("dataset_type")
113
- self.dataset_path = self.dataset_config.get("dataset_path")
114
- logger.info(
115
- f"Received dataset configuration from orchestrator: "
116
- f"type={self.dataset_type}, path={self.dataset_path}"
117
- )
118
- else:
119
- logger.warning("No dataset configuration received from orchestrator")
120
-
121
- logger.info(f"Connected as {self.worker_id}")
122
-
123
- # Create tasks for concurrent operations
124
- tasks = [
125
- asyncio.create_task(self._heartbeat_loop()),
126
- asyncio.create_task(self._job_processing_loop()),
127
- asyncio.create_task(self._message_handler()),
128
- ]
129
-
130
- try:
131
- # Wait for any task to complete (usually due to connection close)
132
- done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
133
-
134
- # Cancel remaining tasks
135
- for task in pending:
136
- task.cancel()
137
- try:
138
- await task
139
- except asyncio.CancelledError:
140
- pass
141
-
142
- # Check if we had an error in completed tasks
143
- for task in done:
144
- try:
145
- task.result()
146
- except websockets.exceptions.ConnectionClosed:
147
- logger.info("WebSocket connection closed")
148
- except Exception as e:
149
- logger.error(f"Task error: {e}")
150
-
151
- except websockets.exceptions.ConnectionClosed:
152
- logger.info("Connection closed by orchestrator")
153
-
154
- except websockets.exceptions.ConnectionClosed as e:
155
- logger.info(f"Failed to connect: {e}")
156
- raise
157
- except Exception as e:
158
- logger.error(f"Unexpected error in connection: {e}")
159
- raise
160
- finally:
161
- self.websocket = None
162
- self.current_job = None
163
-
164
- async def _job_processing_loop(self):
165
- """Main loop for requesting and processing jobs."""
166
- while self.running and self.websocket:
167
- try:
168
- # Request a job
169
- await self.websocket.send(json.dumps({"type": "request_job"}))
170
-
171
- # Wait a bit for response
172
- await asyncio.sleep(1)
173
-
174
- if self.current_job:
175
- await self._process_job(self.current_job)
176
- self.current_job = None
177
- else:
178
- # No job available, wait before requesting again
179
- await asyncio.sleep(5)
180
-
181
- except websockets.exceptions.ConnectionClosed:
182
- logger.info("Connection closed during job processing")
183
- break
184
- except Exception as e:
185
- logger.error(f"Job processing error: {e}")
186
- self.error_count += 1
187
- await asyncio.sleep(1)
188
-
189
- async def _message_handler(self):
190
- """Handle incoming messages from orchestrator."""
191
- try:
192
- async for message in self.websocket:
193
- try:
194
- data = json.loads(message)
195
- msg_type = data.get("type")
196
-
197
- if msg_type == "job":
198
- job_data = data["job"]
199
- self.current_job = Job(**job_data)
200
- logger.info(f"Received job {self.current_job.job_id}")
201
-
202
- elif msg_type == "no_jobs":
203
- logger.debug("No jobs available")
204
-
205
- elif msg_type == "ack":
206
- logger.debug(f"Job {data['job_id']} acknowledged")
207
- self.processed_count += 1
208
-
209
- except json.JSONDecodeError as e:
210
- logger.error(f"Invalid message format: {e}")
211
- except Exception as e:
212
- logger.error(f"Error handling message: {e}")
213
-
214
- except websockets.exceptions.ConnectionClosed:
215
- logger.info("Connection closed while waiting for messages")
216
- except Exception as e:
217
- logger.error(f"Message handler error: {e}")
218
-
219
- async def _process_job(self, job: Job):
220
- """Process a single captioning job."""
221
- if not self.websocket:
222
- logger.warning(f"No websocket connection, skipping job {job.job_id}")
223
- return
224
-
225
- logger.info(f"Processing job {job.job_id}")
226
-
227
- try:
228
- # Load and preprocess images
229
- images = await self._load_images(job)
230
-
231
- # TODO: Here you would integrate your captioning model
232
- # For now, using placeholder
233
- caption = f"[Generated caption for {job.item_key}]"
234
-
235
- # Submit result
236
- await self.websocket.send(
237
- json.dumps(
238
- {
239
- "type": "submit_caption",
240
- "job_id": job.job_id,
241
- "dataset": job.dataset,
242
- "shard": job.shard,
243
- "item_key": job.item_key,
244
- "caption": caption,
245
- }
246
- )
247
- )
248
-
249
- logger.info(f"Completed job {job.job_id}")
250
-
251
- except websockets.exceptions.ConnectionClosed:
252
- logger.warning(f"Connection lost while processing job {job.job_id}")
253
- raise # Re-raise to trigger reconnection
254
- except Exception as e:
255
- logger.error(f"Failed to process job {job.job_id}: {e}")
256
-
257
- # Report failure if still connected
258
- if self.websocket:
259
- try:
260
- await self.websocket.send(
261
- json.dumps({"type": "job_failed", "job_id": job.job_id, "error": str(e)})
262
- )
263
- except:
264
- pass # Connection might be closed
265
-
266
- async def _load_images(self, job: Job):
267
- """Load and preprocess images for a job."""
268
- # This would load actual images from the dataset
269
- # Now can use self.dataset_type and self.dataset_path received from orchestrator
270
- # For now, returning placeholder
271
- return []
272
-
273
- async def _heartbeat_loop(self):
274
- """Send periodic heartbeats to orchestrator."""
275
- while self.running and self.websocket:
276
- try:
277
- await self.websocket.send(
278
- json.dumps(
279
- {
280
- "type": "heartbeat",
281
- "processed": self.processed_count,
282
- "errors": self.error_count,
283
- }
284
- )
285
- )
286
- await asyncio.sleep(30)
287
- except websockets.exceptions.ConnectionClosed:
288
- logger.info("Connection closed during heartbeat")
289
- break
290
- except Exception as e:
291
- logger.error(f"Heartbeat error: {e}")
292
- break
293
-
294
- async def shutdown(self):
295
- """Graceful shutdown."""
296
- logger.info("Shutting down worker...")
297
- self.running = False
298
-
299
- if self.websocket:
300
- await self.websocket.close()