caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,370 @@
1
+ Metadata-Version: 2.4
2
+ Name: caption-flow
3
+ Version: 0.2.1
4
+ Summary: Self-contained distributed community captioning system
5
+ Author-email: bghira <bghira@users.github.com>
6
+ License: MIT
7
+ Keywords: captioning,distributed,vllm,dataset,community
8
+ Classifier: Development Status :: 4 - Beta
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Python: <3.13,>=3.10
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: websockets>=12.0
19
+ Requires-Dist: pyarrow>=14.0.0
20
+ Requires-Dist: click>=8.1.0
21
+ Requires-Dist: pydantic>=2.0.0
22
+ Requires-Dist: aiofiles>=23.0.0
23
+ Requires-Dist: rich>=13.0.0
24
+ Requires-Dist: cryptography>=41.0.0
25
+ Requires-Dist: pyyaml>=6.0
26
+ Requires-Dist: certbot>=2.0.0
27
+ Requires-Dist: numpy>=1.24.0
28
+ Requires-Dist: pillow>=10.0.0
29
+ Requires-Dist: vllm<0.11.0,>=0.10.0
30
+ Requires-Dist: webdataset<2.0.0,>=1.0.2
31
+ Requires-Dist: pandas<3.0.0,>=2.3.1
32
+ Requires-Dist: arrow<2.0.0,>=1.3.0
33
+ Requires-Dist: datasets<5.0.0,>=4.0.0
34
+ Requires-Dist: boto3<2.0.0,>=1.40.11
35
+ Requires-Dist: torchdata<0.12.0,>=0.11.0
36
+ Provides-Extra: dev
37
+ Requires-Dist: pytest>=7.4.0; extra == "dev"
38
+ Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
39
+ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
40
+ Requires-Dist: black>=23.0.0; extra == "dev"
41
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
42
+ Requires-Dist: mypy>=1.5.0; extra == "dev"
43
+ Dynamic: license-file
44
+
45
+ # CaptionFlow
46
+
47
+ scalable, fault-tolerant **vLLM-powered image captioning**. this "first round" focuses on a fast websocket orchestrator plus lightweight gpu workers that batch requests through vLLM.
48
+
49
+ * **orchestrator**: hands out work in chunked shards, collects captions, checkpoints progress, and keeps simple stats.
50
+ * **workers (vLLM)**: connect to the orchestrator, stream in image samples, batch them, and generate 1..N captions per image using prompts supplied by the orchestrator.
51
+ * **config-driven**: all components read YAML config; flags can override.
52
+ * **tui monitor (optional)**: a monitor client is wired into the CLI; ship a `monitor` module to enable it.
53
+
54
+ > no conda. just `venv` + `pip`.
55
+
56
+ ---
57
+
58
+ ## install
59
+
60
+ ```bash
61
+ python -m venv .venv
62
+ source .venv/bin/activate # windows: .venv\Scripts\activate
63
+ pip install --upgrade pip
64
+ pip install -e . # installs the `caption-flow` command
65
+ ```
66
+
67
+ ## quickstart (single box)
68
+
69
+ 1. copy + edit the sample configs
70
+
71
+ ```bash
72
+ cp orchestrator.yaml my-orchestrator.yaml
73
+ cp worker.yaml my-worker.yaml
74
+ cp monitor.yaml my-monitor.yaml # optional; requires a monitor module
75
+ ```
76
+
77
+ set a unique shared token in both `my-orchestrator.yaml` and `my-worker.yaml` (see `auth.worker_tokens` in the orchestrator config and `worker.token` in the worker config). if you use private hugging face datasets/models, export `HUGGINGFACE_HUB_TOKEN` before starting workers.
78
+
79
+ 2. start the orchestrator
80
+
81
+ ```bash
82
+ caption-flow orchestrator --config my-orchestrator.yaml
83
+ ```
84
+
85
+ 3. start one or more vLLM workers
86
+
87
+ ```bash
88
+ # gpu 0 on the same host
89
+ caption-flow worker --config my-worker.yaml --gpu-id 0
90
+
91
+ # your second GPU
92
+ caption-flow worker --config my-worker.yaml --gpu-id 1
93
+ ```
94
+
95
+ 4. (optional) start the monitor
96
+
97
+ ```bash
98
+ caption-flow monitor --config my-monitor.yaml
99
+ ```
100
+
101
+ 5. (optional) scan/fix chunks on disk if you had crashes
102
+
103
+ ```bash
104
+ caption-flow scan_chunks --data-dir ./caption_data --checkpoint-dir ./checkpoints --fix
105
+ ```
106
+
107
+ ---
108
+
109
+ ## how it’s wired
110
+
111
+ ### orchestrator
112
+
113
+ * **websocket server** (default `0.0.0.0:8765`) with three client roles: workers, data-feeders, and admin.
114
+ * **dataset control**: the orchestrator centrally defines the dataset (`huggingface` or `local`) and version/name. it chunk-slices shards and assigns work.
115
+ * **vLLM config broadcast**: model, tp size, dtype, max seq len, memory targets, batching, sampling params, and **inference prompts** are all pushed to workers; workers can apply many changes without a model reload.
116
+ * **storage + checkpoints**: captions buffer to disk with periodic checkpoints. chunk state is tracked so restarts don’t double-work.
117
+ * **auth**: token lists for `worker`, `monitor`, and `admin` roles.
118
+
119
+ start flags you’ll likely use:
120
+
121
+ ```text
122
+ --config PATH # yaml config for the orchestrator
123
+ --port INT, --host STR # bind controls
124
+ --data-dir PATH # overrides storage.data_dir
125
+ --cert PATH, --key PATH # enable TLS (or use --no-ssl for ws:// in dev)
126
+ --vllm # use the vLLM-style orchestrator (webdataset/hf)
127
+ ```
128
+
129
+ ### vLLM worker
130
+
131
+ * **one process per gpu**. select the device with `--gpu-id` (or `worker.gpu_id` in YAML).
132
+ * **gets its marching orders** from the orchestrator: dataset info, model, prompts, batch size, and sampling.
133
+ * **resilient**: detects disconnects, abandons the current chunk cleanly, clears queues, reconnects, and resumes.
134
+ * **batched generate()**: images are resized down for consistent batching; each image can get multiple captions (one per prompt).
135
+
136
+ start flags you’ll likely use:
137
+
138
+ ```text
139
+ --config PATH # yaml for the worker
140
+ --server URL # ws(s)://host:port
141
+ --token STR # must match an allowed worker token on the orchestrator
142
+ --name STR # display name
143
+ --batch-size INT # override vLLM batch size
144
+ --vllm # use the vLLM worker implementation
145
+ --gpu-id INT # which gpu to use
146
+ --precision STR, --model STR # optional overrides for dtype/model
147
+ --no-verify-ssl # accept self-signed certs in dev
148
+ ```
149
+
150
+ ### (optional) monitor
151
+
152
+ * a CLI entry exists for a TUI monitor; wire in a `monitor` module to enable it. config lives in `monitor.yaml` or inside `orchestrator.yaml` under `monitor:`.
153
+
154
+ ---
155
+
156
+ ## configuration
157
+
158
+ ### config discovery order
159
+
160
+ for any component, the CLI looks for config in this order (first match wins):
161
+
162
+ 1. `--config /path/to/file.yaml`
163
+ 2. `./<component>.yaml` (current directory)
164
+ 3. `~/.caption-flow/<component>.yaml`
165
+ 4. `$XDG_CONFIG_HOME/caption-flow/<component>.yaml`
166
+ 5. `/etc/caption-flow/<component>.yaml`
167
+ 6. any `$XDG_CONFIG_DIRS` entries under `caption-flow/`
168
+ 7. `./examples/<component>.yaml` (fallback)
169
+
170
+ ### orchestrator.yaml (highlights)
171
+
172
+ ```yaml
173
+ orchestrator:
174
+ host: 0.0.0.0
175
+ port: 8765
176
+ # ssl:
177
+ # cert: /path/fullchain.pem
178
+ # key: /path/privkey.pem
179
+
180
+ dataset:
181
+ type: huggingface # or "local"
182
+ path: <hf-dataset-or-local-path>
183
+ name: <logical-name>
184
+ version: "1.0"
185
+
186
+ vllm:
187
+ model: Qwen/Qwen2.5-VL-3B-Instruct
188
+ tensor_parallel_size: 1
189
+ max_model_len: 16384
190
+ dtype: float16
191
+ gpu_memory_utilization: 0.92
192
+ enforce_eager: true
193
+ disable_mm_preprocessor_cache: true
194
+ limit_mm_per_prompt: { image: 1 }
195
+
196
+ batch_size: 8
197
+
198
+ sampling:
199
+ temperature: 0.7
200
+ top_p: 0.95
201
+ max_tokens: 256
202
+ repetition_penalty: 1.05
203
+ skip_special_tokens: true
204
+ stop: ["<|end|>", "<|endoftext|>", "<|im_end|>"]
205
+
206
+ inference_prompts:
207
+ - "describe this image in detail"
208
+ - "provide a comprehensive description of the visual content"
209
+ - "what are the key elements in this image?"
210
+
211
+ storage:
212
+ data_dir: ./caption_data
213
+ checkpoint_dir: ./checkpoints
214
+ caption_buffer_size: 100
215
+ checkpoint_interval: 1000
216
+
217
+ # chunking/queueing
218
+ chunk_size: 1000
219
+ chunks_per_request: 2
220
+ chunk_buffer_multiplier: 3
221
+ min_chunk_buffer: 10
222
+
223
+ auth:
224
+ worker_tokens:
225
+ - { token: "example-worker-token", name: "Example Worker" }
226
+ monitor_tokens:
227
+ - { token: "letmein", name: "Default monitor" }
228
+ admin_tokens:
229
+ - { token: "admin-secret-2024", name: "Admin" }
230
+ ```
231
+
232
+ ### worker.yaml (highlights)
233
+
234
+ ```yaml
235
+ worker:
236
+ server: ws://localhost:8765 # use wss:// in prod
237
+ token: example-worker-token
238
+ name: local-gpu
239
+ gpu_id: 0
240
+ vllm: true
241
+
242
+ # local queues
243
+ readahead_size: 256
244
+ inference_queue_size: 128
245
+ ```
246
+
247
+ ### monitor.yaml (optional)
248
+
249
+ ```yaml
250
+ monitor:
251
+ server: ws://localhost:8765
252
+ token: letmein
253
+ refresh_rate: 1.0
254
+ show_contributors: true
255
+ show_quality_metrics: true
256
+ max_activity_items: 20
257
+ show_chunk_progress: true
258
+ show_worker_queues: true
259
+ show_throughput_graph: true
260
+ ```
261
+
262
+ ---
263
+
264
+ ## tls / certificates
265
+
266
+ use the built-in helpers during development:
267
+
268
+ ```bash
269
+ # self-signed certs for quick local testing
270
+ caption-flow generate_cert --self-signed --domain localhost --output-dir ./certs
271
+
272
+ # inspect any certificate file
273
+ caption-flow inspect_cert ./certs/fullchain.pem
274
+ ```
275
+
276
+ then point the orchestrator at the resulting cert/key (or run `--no-ssl` for dev-only ws\://).
277
+
278
+ ---
279
+
280
+ ## tips & notes
281
+
282
+ * **multi-gpu**: start one worker process per gpu (set `--gpu-id` or `worker.gpu_id`).
283
+ * **throughput**: tune `vllm.batch_size` in the orchestrator config (or override with `--batch-size` at worker start). higher isn’t always better; watch VRAM.
284
+ * **prompts**: add more strings under `vllm.inference_prompts` to get multiple captions per image; the worker returns only non-empty generations.
285
+ * **private HF**: if your dataset/model needs auth, export `HUGGINGFACE_HUB_TOKEN` before `caption-flow worker ...`.
286
+ * **self-signed ssl**: pass `--no-verify-ssl` to workers/monitors in dev.
287
+ * **recovery**: if you hard-crash mid-run, `caption-flow scan_chunks --fix` can reset abandoned chunks so the orchestrator can reissue them cleanly.
288
+
289
+ ---
290
+
291
+ ## roadmap
292
+
293
+ * hot config reload via the admin websocket path.
294
+ * dedicated data-feeder clients (separate from gpu workers) that push samples into the orchestrator.
295
+ * richer monitor TUI.
296
+
297
+ PRs welcome. keep it simple and fast.
298
+
299
+ ## architecture
300
+
301
+ ```
302
+ ┌─────────────┐ WebSocket ┌─────────────┐
303
+ │ Worker │◄──────────────────►│ │
304
+ └─────────────┘ │ │ ┌──────────────┐
305
+ │ Orchestrator│────►│Arrow/Parquet │
306
+ ┌─────────────┐ │ │ │ Storage │
307
+ │ Worker │◄──────────────────►│ │ └──────────────┘
308
+ └─────────────┘ └─────────────┘
309
+
310
+ ┌─────────────┐ │
311
+ │ Monitor │◄──────────────────────────┘
312
+ └─────────────┘
313
+ ```
314
+
315
+ ## Storage Schema
316
+
317
+ ### captions.parquet
318
+ - `job_id`: Unique job identifier
319
+ - `dataset`: Dataset name
320
+ - `shard`: Shard identifier
321
+ - `item_key`: Item within shard
322
+ - `caption`: Generated caption text
323
+ - `contributor_id`: Worker who generated it
324
+ - `timestamp`: Generation time
325
+ - `quality_score`: Optional quality metric
326
+
327
+ ### jobs.parquet
328
+ - `job_id`: Unique identifier
329
+ - `dataset`: Dataset name
330
+ - `shard`: Shard identifier
331
+ - `status`: pending/processing/completed/failed
332
+ - `assigned_to`: Worker ID
333
+ - `timestamp`: Status change time
334
+
335
+ ### contributors.parquet
336
+ - `contributor_id`: Unique identifier
337
+ - `name`: Display name
338
+ - `total_captions`: Lifetime count
339
+ - `trust_level`: Quality tier (0-5)
340
+
341
+ ## Development
342
+
343
+ ```bash
344
+ # Install with dev dependencies
345
+ pip install -e ".[dev]"
346
+
347
+ # Run tests
348
+ pytest
349
+
350
+ # Format code
351
+ black src/
352
+ ruff --fix src/
353
+
354
+ # Type checking
355
+ mypy src/
356
+ ```
357
+
358
+ ## Community Contribution
359
+
360
+ To contribute compute:
361
+
362
+ 1. Install caption-flow: `pip install caption-flow`
363
+ 2. Get a worker token from the project maintainer
364
+ 3. Run: `caption-flow worker --server wss://project.domain.com:8765 --token YOUR_TOKEN`
365
+
366
+ Your contributions will be tracked and attributed in the final dataset!
367
+
368
+ ## License
369
+
370
+ MIT
@@ -0,0 +1,29 @@
1
+ caption_flow/__init__.py,sha256=NLPJ25lRN7xHqncXweINDNwbt0q8lgjZ30G21zlPdRs,303
2
+ caption_flow/cli.py,sha256=bHxx66CPsCmSieaH3pw8NZBojIIbniRTdU9mEBHMmWA,28832
3
+ caption_flow/models.py,sha256=qo6lQiO10UISbaBVr6Cs-fSW_pmjwE6kmiTmmU_l3Wk,2140
4
+ caption_flow/monitor.py,sha256=ZZCSasYLKJ-UzA3-RoAtytv-tbNA-m3h5YjlZg_vukg,7870
5
+ caption_flow/orchestrator.py,sha256=bZ8NnGdqoXSmu7Nq-_7cOSH1DLHkBT88cne0uDyPeNY,89112
6
+ caption_flow/storage.py,sha256=hC6ZHT_PHFoUVjqD5JUwy3_79oAD1e1H30neA_xsz7s,40748
7
+ caption_flow/utils/__init__.py,sha256=F1BChVoCsj9zn1GJRBOLHET1kLW6xrAmsbzcR7hHy6Y,202
8
+ caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
9
+ caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
10
+ caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
11
+ caption_flow/utils/checkpoint_tracker.py,sha256=8tsTFF-HcygitK92YcS-QWzeg-qRm9AuCpQoQRfC8M0,3335
12
+ caption_flow/utils/chunk_tracker.py,sha256=hKn8CN6ubErc9kuCWZMj12ZCZKxVlqXqAEocbzjfa-k,17296
13
+ caption_flow/utils/dataset_loader.py,sha256=ZplJv655ZMyUbaZC4BBiL5II18sBy4JSJhxGZtK_VmA,29107
14
+ caption_flow/utils/image_processor.py,sha256=Zl8TAv9gYPdAYat3UiTuuNdIb2fXNfZ35AxsxuovJTs,5650
15
+ caption_flow/utils/job_queue.py,sha256=itdfXcrkvGjmXn4qtpgMF63k1ufRBaejDe4V6WcxzgU,1104
16
+ caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
17
+ caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
18
+ caption_flow/utils/shard_processor.py,sha256=CRda6M4xh4U0vwvYlzq9nJEzz4d_4yzUBosYAeBcPEA,10854
19
+ caption_flow/utils/shard_tracker.py,sha256=Wt2oE-O85F2FxSnqIocJiaYeFn00OVVjIiklZIZRGL8,3233
20
+ caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
21
+ caption_flow/workers/base.py,sha256=jPm_Xw4Lxd0cnrPs-biBqKRQKkTOJLvHLolmp0Gb1CI,7530
22
+ caption_flow/workers/caption.py,sha256=NZ9kTjk2uOoNwyyNSkB_arYk213vLr5mowHN-OjiFkk,54631
23
+ caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
24
+ caption_flow-0.2.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
25
+ caption_flow-0.2.1.dist-info/METADATA,sha256=fxNfSOqkCklb96aq3ZFU7SvRuXEBUQ11xbjkQn7Yzuo,11941
26
+ caption_flow-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ caption_flow-0.2.1.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
28
+ caption_flow-0.2.1.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
29
+ caption_flow-0.2.1.dist-info/RECORD,,
caption_flow/worker.py DELETED
@@ -1,300 +0,0 @@
1
- """Worker node for distributed captioning."""
2
-
3
- import asyncio
4
- import json
5
- import logging
6
- import ssl
7
- from typing import Dict, Any, Optional
8
- from pathlib import Path
9
-
10
- import websockets
11
- import websockets.exceptions
12
- from websockets.client import WebSocketClientProtocol
13
-
14
- from .models import Job, JobStatus
15
- from .utils.image_processor import ImageProcessor
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- class Worker:
21
- """Worker node that processes captioning jobs."""
22
-
23
- def __init__(self, config: Dict[str, Any]):
24
- self.config = config
25
- self.server_url = config["server"]
26
- self.token = config["token"]
27
- self.name = config.get("name", "worker")
28
- self.batch_size = config.get("batch_size", 32)
29
-
30
- # Dataset configuration will be received from orchestrator
31
- self.dataset_config = None
32
- self.dataset_type = None
33
- self.dataset_path = None
34
-
35
- # SSL configuration
36
- self.ssl_context = self._setup_ssl()
37
-
38
- # Components
39
- self.image_processor = ImageProcessor()
40
-
41
- # State
42
- self.worker_id: Optional[str] = None
43
- self.websocket: Optional[WebSocketClientProtocol] = None
44
- self.running = False
45
- self.current_job: Optional[Job] = None
46
-
47
- # Metrics
48
- self.processed_count = 0
49
- self.error_count = 0
50
-
51
- def _setup_ssl(self) -> Optional[ssl.SSLContext]:
52
- """Configure SSL context."""
53
- # Check if URL is WSS (requires SSL)
54
- if self.server_url.startswith("ws://"):
55
- logger.warning(
56
- "Using insecure WebSocket connection (ws://). Consider using wss:// for production."
57
- )
58
- return None # No SSL for ws://
59
-
60
- if not self.config.get("verify_ssl", True):
61
- # Disable SSL verification for development
62
- context = ssl.create_default_context()
63
- context.check_hostname = False
64
- context.verify_mode = ssl.CERT_NONE
65
- return context
66
-
67
- return ssl.create_default_context()
68
-
69
- async def start(self):
70
- """Start the worker and connect to orchestrator."""
71
- self.running = True
72
-
73
- while self.running:
74
- try:
75
- await self._connect_and_run()
76
- except websockets.exceptions.ConnectionClosed as e:
77
- logger.warning(f"Connection closed: {e}")
78
- if self.running:
79
- logger.info("Reconnecting in 5 seconds...")
80
- await asyncio.sleep(5)
81
- except Exception as e:
82
- logger.error(f"Connection error: {e}")
83
- if self.running:
84
- logger.info("Reconnecting in 5 seconds...")
85
- await asyncio.sleep(5)
86
-
87
- async def _connect_and_run(self):
88
- """Connect to orchestrator and process jobs."""
89
- logger.info(f"Connecting to {self.server_url}")
90
-
91
- try:
92
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
93
- self.websocket = websocket
94
-
95
- # Authenticate
96
- await websocket.send(json.dumps({"token": self.token, "name": self.name}))
97
-
98
- # Wait for welcome message with dataset configuration
99
- welcome = await websocket.recv()
100
- welcome_data = json.loads(welcome)
101
-
102
- if "error" in welcome_data:
103
- logger.error(f"Authentication failed: {welcome_data['error']}")
104
- self.running = False
105
- return
106
-
107
- self.worker_id = welcome_data.get("worker_id")
108
-
109
- # Extract and store dataset configuration from orchestrator
110
- if "dataset_config" in welcome_data:
111
- self.dataset_config = welcome_data["dataset_config"]
112
- self.dataset_type = self.dataset_config.get("dataset_type")
113
- self.dataset_path = self.dataset_config.get("dataset_path")
114
- logger.info(
115
- f"Received dataset configuration from orchestrator: "
116
- f"type={self.dataset_type}, path={self.dataset_path}"
117
- )
118
- else:
119
- logger.warning("No dataset configuration received from orchestrator")
120
-
121
- logger.info(f"Connected as {self.worker_id}")
122
-
123
- # Create tasks for concurrent operations
124
- tasks = [
125
- asyncio.create_task(self._heartbeat_loop()),
126
- asyncio.create_task(self._job_processing_loop()),
127
- asyncio.create_task(self._message_handler()),
128
- ]
129
-
130
- try:
131
- # Wait for any task to complete (usually due to connection close)
132
- done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
133
-
134
- # Cancel remaining tasks
135
- for task in pending:
136
- task.cancel()
137
- try:
138
- await task
139
- except asyncio.CancelledError:
140
- pass
141
-
142
- # Check if we had an error in completed tasks
143
- for task in done:
144
- try:
145
- task.result()
146
- except websockets.exceptions.ConnectionClosed:
147
- logger.info("WebSocket connection closed")
148
- except Exception as e:
149
- logger.error(f"Task error: {e}")
150
-
151
- except websockets.exceptions.ConnectionClosed:
152
- logger.info("Connection closed by orchestrator")
153
-
154
- except websockets.exceptions.ConnectionClosed as e:
155
- logger.info(f"Failed to connect: {e}")
156
- raise
157
- except Exception as e:
158
- logger.error(f"Unexpected error in connection: {e}")
159
- raise
160
- finally:
161
- self.websocket = None
162
- self.current_job = None
163
-
164
- async def _job_processing_loop(self):
165
- """Main loop for requesting and processing jobs."""
166
- while self.running and self.websocket:
167
- try:
168
- # Request a job
169
- await self.websocket.send(json.dumps({"type": "request_job"}))
170
-
171
- # Wait a bit for response
172
- await asyncio.sleep(1)
173
-
174
- if self.current_job:
175
- await self._process_job(self.current_job)
176
- self.current_job = None
177
- else:
178
- # No job available, wait before requesting again
179
- await asyncio.sleep(5)
180
-
181
- except websockets.exceptions.ConnectionClosed:
182
- logger.info("Connection closed during job processing")
183
- break
184
- except Exception as e:
185
- logger.error(f"Job processing error: {e}")
186
- self.error_count += 1
187
- await asyncio.sleep(1)
188
-
189
- async def _message_handler(self):
190
- """Handle incoming messages from orchestrator."""
191
- try:
192
- async for message in self.websocket:
193
- try:
194
- data = json.loads(message)
195
- msg_type = data.get("type")
196
-
197
- if msg_type == "job":
198
- job_data = data["job"]
199
- self.current_job = Job(**job_data)
200
- logger.info(f"Received job {self.current_job.job_id}")
201
-
202
- elif msg_type == "no_jobs":
203
- logger.debug("No jobs available")
204
-
205
- elif msg_type == "ack":
206
- logger.debug(f"Job {data['job_id']} acknowledged")
207
- self.processed_count += 1
208
-
209
- except json.JSONDecodeError as e:
210
- logger.error(f"Invalid message format: {e}")
211
- except Exception as e:
212
- logger.error(f"Error handling message: {e}")
213
-
214
- except websockets.exceptions.ConnectionClosed:
215
- logger.info("Connection closed while waiting for messages")
216
- except Exception as e:
217
- logger.error(f"Message handler error: {e}")
218
-
219
- async def _process_job(self, job: Job):
220
- """Process a single captioning job."""
221
- if not self.websocket:
222
- logger.warning(f"No websocket connection, skipping job {job.job_id}")
223
- return
224
-
225
- logger.info(f"Processing job {job.job_id}")
226
-
227
- try:
228
- # Load and preprocess images
229
- images = await self._load_images(job)
230
-
231
- # TODO: Here you would integrate your captioning model
232
- # For now, using placeholder
233
- caption = f"[Generated caption for {job.item_key}]"
234
-
235
- # Submit result
236
- await self.websocket.send(
237
- json.dumps(
238
- {
239
- "type": "submit_caption",
240
- "job_id": job.job_id,
241
- "dataset": job.dataset,
242
- "shard": job.shard,
243
- "item_key": job.item_key,
244
- "caption": caption,
245
- }
246
- )
247
- )
248
-
249
- logger.info(f"Completed job {job.job_id}")
250
-
251
- except websockets.exceptions.ConnectionClosed:
252
- logger.warning(f"Connection lost while processing job {job.job_id}")
253
- raise # Re-raise to trigger reconnection
254
- except Exception as e:
255
- logger.error(f"Failed to process job {job.job_id}: {e}")
256
-
257
- # Report failure if still connected
258
- if self.websocket:
259
- try:
260
- await self.websocket.send(
261
- json.dumps({"type": "job_failed", "job_id": job.job_id, "error": str(e)})
262
- )
263
- except:
264
- pass # Connection might be closed
265
-
266
- async def _load_images(self, job: Job):
267
- """Load and preprocess images for a job."""
268
- # This would load actual images from the dataset
269
- # Now can use self.dataset_type and self.dataset_path received from orchestrator
270
- # For now, returning placeholder
271
- return []
272
-
273
- async def _heartbeat_loop(self):
274
- """Send periodic heartbeats to orchestrator."""
275
- while self.running and self.websocket:
276
- try:
277
- await self.websocket.send(
278
- json.dumps(
279
- {
280
- "type": "heartbeat",
281
- "processed": self.processed_count,
282
- "errors": self.error_count,
283
- }
284
- )
285
- )
286
- await asyncio.sleep(30)
287
- except websockets.exceptions.ConnectionClosed:
288
- logger.info("Connection closed during heartbeat")
289
- break
290
- except Exception as e:
291
- logger.error(f"Heartbeat error: {e}")
292
- break
293
-
294
- async def shutdown(self):
295
- """Graceful shutdown."""
296
- logger.info("Shutting down worker...")
297
- self.running = False
298
-
299
- if self.websocket:
300
- await self.websocket.close()