caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -2
- caption_flow/cli.py +65 -42
- caption_flow/models.py +6 -4
- caption_flow/monitor.py +13 -3
- caption_flow/orchestrator.py +1049 -264
- caption_flow/storage.py +579 -222
- caption_flow/utils/__init__.py +3 -1
- caption_flow/utils/auth.py +24 -25
- caption_flow/utils/checkpoint_tracker.py +92 -0
- caption_flow/utils/chunk_tracker.py +278 -194
- caption_flow/utils/dataset_loader.py +567 -73
- caption_flow/utils/image_processor.py +121 -1
- caption_flow/utils/prompt_template.py +137 -0
- caption_flow/utils/shard_processor.py +315 -0
- caption_flow/utils/shard_tracker.py +87 -0
- caption_flow/workers/base.py +228 -0
- caption_flow/workers/caption.py +1321 -0
- caption_flow/{worker_data.py → workers/data.py} +162 -234
- caption_flow-0.2.1.dist-info/METADATA +370 -0
- caption_flow-0.2.1.dist-info/RECORD +29 -0
- caption_flow/worker.py +0 -300
- caption_flow/worker_vllm.py +0 -1028
- caption_flow-0.1.0.dist-info/METADATA +0 -427
- caption_flow-0.1.0.dist-info/RECORD +0 -25
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.1.0.dist-info → caption_flow-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,370 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: caption-flow
|
3
|
+
Version: 0.2.1
|
4
|
+
Summary: Self-contained distributed community captioning system
|
5
|
+
Author-email: bghira <bghira@users.github.com>
|
6
|
+
License: MIT
|
7
|
+
Keywords: captioning,distributed,vllm,dataset,community
|
8
|
+
Classifier: Development Status :: 4 - Beta
|
9
|
+
Classifier: Intended Audience :: Developers
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Requires-Python: <3.13,>=3.10
|
16
|
+
Description-Content-Type: text/markdown
|
17
|
+
License-File: LICENSE
|
18
|
+
Requires-Dist: websockets>=12.0
|
19
|
+
Requires-Dist: pyarrow>=14.0.0
|
20
|
+
Requires-Dist: click>=8.1.0
|
21
|
+
Requires-Dist: pydantic>=2.0.0
|
22
|
+
Requires-Dist: aiofiles>=23.0.0
|
23
|
+
Requires-Dist: rich>=13.0.0
|
24
|
+
Requires-Dist: cryptography>=41.0.0
|
25
|
+
Requires-Dist: pyyaml>=6.0
|
26
|
+
Requires-Dist: certbot>=2.0.0
|
27
|
+
Requires-Dist: numpy>=1.24.0
|
28
|
+
Requires-Dist: pillow>=10.0.0
|
29
|
+
Requires-Dist: vllm<0.11.0,>=0.10.0
|
30
|
+
Requires-Dist: webdataset<2.0.0,>=1.0.2
|
31
|
+
Requires-Dist: pandas<3.0.0,>=2.3.1
|
32
|
+
Requires-Dist: arrow<2.0.0,>=1.3.0
|
33
|
+
Requires-Dist: datasets<5.0.0,>=4.0.0
|
34
|
+
Requires-Dist: boto3<2.0.0,>=1.40.11
|
35
|
+
Requires-Dist: torchdata<0.12.0,>=0.11.0
|
36
|
+
Provides-Extra: dev
|
37
|
+
Requires-Dist: pytest>=7.4.0; extra == "dev"
|
38
|
+
Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
|
39
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
|
40
|
+
Requires-Dist: black>=23.0.0; extra == "dev"
|
41
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
42
|
+
Requires-Dist: mypy>=1.5.0; extra == "dev"
|
43
|
+
Dynamic: license-file
|
44
|
+
|
45
|
+
# CaptionFlow
|
46
|
+
|
47
|
+
scalable, fault-tolerant **vLLM-powered image captioning**. this "first round" focuses on a fast websocket orchestrator plus lightweight gpu workers that batch requests through vLLM.
|
48
|
+
|
49
|
+
* **orchestrator**: hands out work in chunked shards, collects captions, checkpoints progress, and keeps simple stats.
|
50
|
+
* **workers (vLLM)**: connect to the orchestrator, stream in image samples, batch them, and generate 1..N captions per image using prompts supplied by the orchestrator.
|
51
|
+
* **config-driven**: all components read YAML config; flags can override.
|
52
|
+
* **tui monitor (optional)**: a monitor client is wired into the CLI; ship a `monitor` module to enable it.
|
53
|
+
|
54
|
+
> no conda. just `venv` + `pip`.
|
55
|
+
|
56
|
+
---
|
57
|
+
|
58
|
+
## install
|
59
|
+
|
60
|
+
```bash
|
61
|
+
python -m venv .venv
|
62
|
+
source .venv/bin/activate # windows: .venv\Scripts\activate
|
63
|
+
pip install --upgrade pip
|
64
|
+
pip install -e . # installs the `caption-flow` command
|
65
|
+
```
|
66
|
+
|
67
|
+
## quickstart (single box)
|
68
|
+
|
69
|
+
1. copy + edit the sample configs
|
70
|
+
|
71
|
+
```bash
|
72
|
+
cp orchestrator.yaml my-orchestrator.yaml
|
73
|
+
cp worker.yaml my-worker.yaml
|
74
|
+
cp monitor.yaml my-monitor.yaml # optional; requires a monitor module
|
75
|
+
```
|
76
|
+
|
77
|
+
set a unique shared token in both `my-orchestrator.yaml` and `my-worker.yaml` (see `auth.worker_tokens` in the orchestrator config and `worker.token` in the worker config). if you use private hugging face datasets/models, export `HUGGINGFACE_HUB_TOKEN` before starting workers.
|
78
|
+
|
79
|
+
2. start the orchestrator
|
80
|
+
|
81
|
+
```bash
|
82
|
+
caption-flow orchestrator --config my-orchestrator.yaml
|
83
|
+
```
|
84
|
+
|
85
|
+
3. start one or more vLLM workers
|
86
|
+
|
87
|
+
```bash
|
88
|
+
# gpu 0 on the same host
|
89
|
+
caption-flow worker --config my-worker.yaml --gpu-id 0
|
90
|
+
|
91
|
+
# your second GPU
|
92
|
+
caption-flow worker --config my-worker.yaml --gpu-id 1
|
93
|
+
```
|
94
|
+
|
95
|
+
4. (optional) start the monitor
|
96
|
+
|
97
|
+
```bash
|
98
|
+
caption-flow monitor --config my-monitor.yaml
|
99
|
+
```
|
100
|
+
|
101
|
+
5. (optional) scan/fix chunks on disk if you had crashes
|
102
|
+
|
103
|
+
```bash
|
104
|
+
caption-flow scan_chunks --data-dir ./caption_data --checkpoint-dir ./checkpoints --fix
|
105
|
+
```
|
106
|
+
|
107
|
+
---
|
108
|
+
|
109
|
+
## how it’s wired
|
110
|
+
|
111
|
+
### orchestrator
|
112
|
+
|
113
|
+
* **websocket server** (default `0.0.0.0:8765`) with three client roles: workers, data-feeders, and admin.
|
114
|
+
* **dataset control**: the orchestrator centrally defines the dataset (`huggingface` or `local`) and version/name. it chunk-slices shards and assigns work.
|
115
|
+
* **vLLM config broadcast**: model, tp size, dtype, max seq len, memory targets, batching, sampling params, and **inference prompts** are all pushed to workers; workers can apply many changes without a model reload.
|
116
|
+
* **storage + checkpoints**: captions buffer to disk with periodic checkpoints. chunk state is tracked so restarts don’t double-work.
|
117
|
+
* **auth**: token lists for `worker`, `monitor`, and `admin` roles.
|
118
|
+
|
119
|
+
start flags you’ll likely use:
|
120
|
+
|
121
|
+
```text
|
122
|
+
--config PATH # yaml config for the orchestrator
|
123
|
+
--port INT, --host STR # bind controls
|
124
|
+
--data-dir PATH # overrides storage.data_dir
|
125
|
+
--cert PATH, --key PATH # enable TLS (or use --no-ssl for ws:// in dev)
|
126
|
+
--vllm # use the vLLM-style orchestrator (webdataset/hf)
|
127
|
+
```
|
128
|
+
|
129
|
+
### vLLM worker
|
130
|
+
|
131
|
+
* **one process per gpu**. select the device with `--gpu-id` (or `worker.gpu_id` in YAML).
|
132
|
+
* **gets its marching orders** from the orchestrator: dataset info, model, prompts, batch size, and sampling.
|
133
|
+
* **resilient**: detects disconnects, abandons the current chunk cleanly, clears queues, reconnects, and resumes.
|
134
|
+
* **batched generate()**: images are resized down for consistent batching; each image can get multiple captions (one per prompt).
|
135
|
+
|
136
|
+
start flags you’ll likely use:
|
137
|
+
|
138
|
+
```text
|
139
|
+
--config PATH # yaml for the worker
|
140
|
+
--server URL # ws(s)://host:port
|
141
|
+
--token STR # must match an allowed worker token on the orchestrator
|
142
|
+
--name STR # display name
|
143
|
+
--batch-size INT # override vLLM batch size
|
144
|
+
--vllm # use the vLLM worker implementation
|
145
|
+
--gpu-id INT # which gpu to use
|
146
|
+
--precision STR, --model STR # optional overrides for dtype/model
|
147
|
+
--no-verify-ssl # accept self-signed certs in dev
|
148
|
+
```
|
149
|
+
|
150
|
+
### (optional) monitor
|
151
|
+
|
152
|
+
* a CLI entry exists for a TUI monitor; wire in a `monitor` module to enable it. config lives in `monitor.yaml` or inside `orchestrator.yaml` under `monitor:`.
|
153
|
+
|
154
|
+
---
|
155
|
+
|
156
|
+
## configuration
|
157
|
+
|
158
|
+
### config discovery order
|
159
|
+
|
160
|
+
for any component, the CLI looks for config in this order (first match wins):
|
161
|
+
|
162
|
+
1. `--config /path/to/file.yaml`
|
163
|
+
2. `./<component>.yaml` (current directory)
|
164
|
+
3. `~/.caption-flow/<component>.yaml`
|
165
|
+
4. `$XDG_CONFIG_HOME/caption-flow/<component>.yaml`
|
166
|
+
5. `/etc/caption-flow/<component>.yaml`
|
167
|
+
6. any `$XDG_CONFIG_DIRS` entries under `caption-flow/`
|
168
|
+
7. `./examples/<component>.yaml` (fallback)
|
169
|
+
|
170
|
+
### orchestrator.yaml (highlights)
|
171
|
+
|
172
|
+
```yaml
|
173
|
+
orchestrator:
|
174
|
+
host: 0.0.0.0
|
175
|
+
port: 8765
|
176
|
+
# ssl:
|
177
|
+
# cert: /path/fullchain.pem
|
178
|
+
# key: /path/privkey.pem
|
179
|
+
|
180
|
+
dataset:
|
181
|
+
type: huggingface # or "local"
|
182
|
+
path: <hf-dataset-or-local-path>
|
183
|
+
name: <logical-name>
|
184
|
+
version: "1.0"
|
185
|
+
|
186
|
+
vllm:
|
187
|
+
model: Qwen/Qwen2.5-VL-3B-Instruct
|
188
|
+
tensor_parallel_size: 1
|
189
|
+
max_model_len: 16384
|
190
|
+
dtype: float16
|
191
|
+
gpu_memory_utilization: 0.92
|
192
|
+
enforce_eager: true
|
193
|
+
disable_mm_preprocessor_cache: true
|
194
|
+
limit_mm_per_prompt: { image: 1 }
|
195
|
+
|
196
|
+
batch_size: 8
|
197
|
+
|
198
|
+
sampling:
|
199
|
+
temperature: 0.7
|
200
|
+
top_p: 0.95
|
201
|
+
max_tokens: 256
|
202
|
+
repetition_penalty: 1.05
|
203
|
+
skip_special_tokens: true
|
204
|
+
stop: ["<|end|>", "<|endoftext|>", "<|im_end|>"]
|
205
|
+
|
206
|
+
inference_prompts:
|
207
|
+
- "describe this image in detail"
|
208
|
+
- "provide a comprehensive description of the visual content"
|
209
|
+
- "what are the key elements in this image?"
|
210
|
+
|
211
|
+
storage:
|
212
|
+
data_dir: ./caption_data
|
213
|
+
checkpoint_dir: ./checkpoints
|
214
|
+
caption_buffer_size: 100
|
215
|
+
checkpoint_interval: 1000
|
216
|
+
|
217
|
+
# chunking/queueing
|
218
|
+
chunk_size: 1000
|
219
|
+
chunks_per_request: 2
|
220
|
+
chunk_buffer_multiplier: 3
|
221
|
+
min_chunk_buffer: 10
|
222
|
+
|
223
|
+
auth:
|
224
|
+
worker_tokens:
|
225
|
+
- { token: "example-worker-token", name: "Example Worker" }
|
226
|
+
monitor_tokens:
|
227
|
+
- { token: "letmein", name: "Default monitor" }
|
228
|
+
admin_tokens:
|
229
|
+
- { token: "admin-secret-2024", name: "Admin" }
|
230
|
+
```
|
231
|
+
|
232
|
+
### worker.yaml (highlights)
|
233
|
+
|
234
|
+
```yaml
|
235
|
+
worker:
|
236
|
+
server: ws://localhost:8765 # use wss:// in prod
|
237
|
+
token: example-worker-token
|
238
|
+
name: local-gpu
|
239
|
+
gpu_id: 0
|
240
|
+
vllm: true
|
241
|
+
|
242
|
+
# local queues
|
243
|
+
readahead_size: 256
|
244
|
+
inference_queue_size: 128
|
245
|
+
```
|
246
|
+
|
247
|
+
### monitor.yaml (optional)
|
248
|
+
|
249
|
+
```yaml
|
250
|
+
monitor:
|
251
|
+
server: ws://localhost:8765
|
252
|
+
token: letmein
|
253
|
+
refresh_rate: 1.0
|
254
|
+
show_contributors: true
|
255
|
+
show_quality_metrics: true
|
256
|
+
max_activity_items: 20
|
257
|
+
show_chunk_progress: true
|
258
|
+
show_worker_queues: true
|
259
|
+
show_throughput_graph: true
|
260
|
+
```
|
261
|
+
|
262
|
+
---
|
263
|
+
|
264
|
+
## tls / certificates
|
265
|
+
|
266
|
+
use the built-in helpers during development:
|
267
|
+
|
268
|
+
```bash
|
269
|
+
# self-signed certs for quick local testing
|
270
|
+
caption-flow generate_cert --self-signed --domain localhost --output-dir ./certs
|
271
|
+
|
272
|
+
# inspect any certificate file
|
273
|
+
caption-flow inspect_cert ./certs/fullchain.pem
|
274
|
+
```
|
275
|
+
|
276
|
+
then point the orchestrator at the resulting cert/key (or run `--no-ssl` for dev-only ws\://).
|
277
|
+
|
278
|
+
---
|
279
|
+
|
280
|
+
## tips & notes
|
281
|
+
|
282
|
+
* **multi-gpu**: start one worker process per gpu (set `--gpu-id` or `worker.gpu_id`).
|
283
|
+
* **throughput**: tune `vllm.batch_size` in the orchestrator config (or override with `--batch-size` at worker start). higher isn’t always better; watch VRAM.
|
284
|
+
* **prompts**: add more strings under `vllm.inference_prompts` to get multiple captions per image; the worker returns only non-empty generations.
|
285
|
+
* **private HF**: if your dataset/model needs auth, export `HUGGINGFACE_HUB_TOKEN` before `caption-flow worker ...`.
|
286
|
+
* **self-signed ssl**: pass `--no-verify-ssl` to workers/monitors in dev.
|
287
|
+
* **recovery**: if you hard-crash mid-run, `caption-flow scan_chunks --fix` can reset abandoned chunks so the orchestrator can reissue them cleanly.
|
288
|
+
|
289
|
+
---
|
290
|
+
|
291
|
+
## roadmap
|
292
|
+
|
293
|
+
* hot config reload via the admin websocket path.
|
294
|
+
* dedicated data-feeder clients (separate from gpu workers) that push samples into the orchestrator.
|
295
|
+
* richer monitor TUI.
|
296
|
+
|
297
|
+
PRs welcome. keep it simple and fast.
|
298
|
+
|
299
|
+
## architecture
|
300
|
+
|
301
|
+
```
|
302
|
+
┌─────────────┐ WebSocket ┌─────────────┐
|
303
|
+
│ Worker │◄──────────────────►│ │
|
304
|
+
└─────────────┘ │ │ ┌──────────────┐
|
305
|
+
│ Orchestrator│────►│Arrow/Parquet │
|
306
|
+
┌─────────────┐ │ │ │ Storage │
|
307
|
+
│ Worker │◄──────────────────►│ │ └──────────────┘
|
308
|
+
└─────────────┘ └─────────────┘
|
309
|
+
▲
|
310
|
+
┌─────────────┐ │
|
311
|
+
│ Monitor │◄──────────────────────────┘
|
312
|
+
└─────────────┘
|
313
|
+
```
|
314
|
+
|
315
|
+
## Storage Schema
|
316
|
+
|
317
|
+
### captions.parquet
|
318
|
+
- `job_id`: Unique job identifier
|
319
|
+
- `dataset`: Dataset name
|
320
|
+
- `shard`: Shard identifier
|
321
|
+
- `item_key`: Item within shard
|
322
|
+
- `caption`: Generated caption text
|
323
|
+
- `contributor_id`: Worker who generated it
|
324
|
+
- `timestamp`: Generation time
|
325
|
+
- `quality_score`: Optional quality metric
|
326
|
+
|
327
|
+
### jobs.parquet
|
328
|
+
- `job_id`: Unique identifier
|
329
|
+
- `dataset`: Dataset name
|
330
|
+
- `shard`: Shard identifier
|
331
|
+
- `status`: pending/processing/completed/failed
|
332
|
+
- `assigned_to`: Worker ID
|
333
|
+
- `timestamp`: Status change time
|
334
|
+
|
335
|
+
### contributors.parquet
|
336
|
+
- `contributor_id`: Unique identifier
|
337
|
+
- `name`: Display name
|
338
|
+
- `total_captions`: Lifetime count
|
339
|
+
- `trust_level`: Quality tier (0-5)
|
340
|
+
|
341
|
+
## Development
|
342
|
+
|
343
|
+
```bash
|
344
|
+
# Install with dev dependencies
|
345
|
+
pip install -e ".[dev]"
|
346
|
+
|
347
|
+
# Run tests
|
348
|
+
pytest
|
349
|
+
|
350
|
+
# Format code
|
351
|
+
black src/
|
352
|
+
ruff --fix src/
|
353
|
+
|
354
|
+
# Type checking
|
355
|
+
mypy src/
|
356
|
+
```
|
357
|
+
|
358
|
+
## Community Contribution
|
359
|
+
|
360
|
+
To contribute compute:
|
361
|
+
|
362
|
+
1. Install caption-flow: `pip install caption-flow`
|
363
|
+
2. Get a worker token from the project maintainer
|
364
|
+
3. Run: `caption-flow worker --server wss://project.domain.com:8765 --token YOUR_TOKEN`
|
365
|
+
|
366
|
+
Your contributions will be tracked and attributed in the final dataset!
|
367
|
+
|
368
|
+
## License
|
369
|
+
|
370
|
+
MIT
|
@@ -0,0 +1,29 @@
|
|
1
|
+
caption_flow/__init__.py,sha256=NLPJ25lRN7xHqncXweINDNwbt0q8lgjZ30G21zlPdRs,303
|
2
|
+
caption_flow/cli.py,sha256=bHxx66CPsCmSieaH3pw8NZBojIIbniRTdU9mEBHMmWA,28832
|
3
|
+
caption_flow/models.py,sha256=qo6lQiO10UISbaBVr6Cs-fSW_pmjwE6kmiTmmU_l3Wk,2140
|
4
|
+
caption_flow/monitor.py,sha256=ZZCSasYLKJ-UzA3-RoAtytv-tbNA-m3h5YjlZg_vukg,7870
|
5
|
+
caption_flow/orchestrator.py,sha256=bZ8NnGdqoXSmu7Nq-_7cOSH1DLHkBT88cne0uDyPeNY,89112
|
6
|
+
caption_flow/storage.py,sha256=hC6ZHT_PHFoUVjqD5JUwy3_79oAD1e1H30neA_xsz7s,40748
|
7
|
+
caption_flow/utils/__init__.py,sha256=F1BChVoCsj9zn1GJRBOLHET1kLW6xrAmsbzcR7hHy6Y,202
|
8
|
+
caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
|
9
|
+
caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
|
10
|
+
caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
|
11
|
+
caption_flow/utils/checkpoint_tracker.py,sha256=8tsTFF-HcygitK92YcS-QWzeg-qRm9AuCpQoQRfC8M0,3335
|
12
|
+
caption_flow/utils/chunk_tracker.py,sha256=hKn8CN6ubErc9kuCWZMj12ZCZKxVlqXqAEocbzjfa-k,17296
|
13
|
+
caption_flow/utils/dataset_loader.py,sha256=ZplJv655ZMyUbaZC4BBiL5II18sBy4JSJhxGZtK_VmA,29107
|
14
|
+
caption_flow/utils/image_processor.py,sha256=Zl8TAv9gYPdAYat3UiTuuNdIb2fXNfZ35AxsxuovJTs,5650
|
15
|
+
caption_flow/utils/job_queue.py,sha256=itdfXcrkvGjmXn4qtpgMF63k1ufRBaejDe4V6WcxzgU,1104
|
16
|
+
caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
|
17
|
+
caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
|
18
|
+
caption_flow/utils/shard_processor.py,sha256=CRda6M4xh4U0vwvYlzq9nJEzz4d_4yzUBosYAeBcPEA,10854
|
19
|
+
caption_flow/utils/shard_tracker.py,sha256=Wt2oE-O85F2FxSnqIocJiaYeFn00OVVjIiklZIZRGL8,3233
|
20
|
+
caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
|
21
|
+
caption_flow/workers/base.py,sha256=jPm_Xw4Lxd0cnrPs-biBqKRQKkTOJLvHLolmp0Gb1CI,7530
|
22
|
+
caption_flow/workers/caption.py,sha256=NZ9kTjk2uOoNwyyNSkB_arYk213vLr5mowHN-OjiFkk,54631
|
23
|
+
caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
|
24
|
+
caption_flow-0.2.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
25
|
+
caption_flow-0.2.1.dist-info/METADATA,sha256=fxNfSOqkCklb96aq3ZFU7SvRuXEBUQ11xbjkQn7Yzuo,11941
|
26
|
+
caption_flow-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
27
|
+
caption_flow-0.2.1.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
|
28
|
+
caption_flow-0.2.1.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
|
29
|
+
caption_flow-0.2.1.dist-info/RECORD,,
|
caption_flow/worker.py
DELETED
@@ -1,300 +0,0 @@
|
|
1
|
-
"""Worker node for distributed captioning."""
|
2
|
-
|
3
|
-
import asyncio
|
4
|
-
import json
|
5
|
-
import logging
|
6
|
-
import ssl
|
7
|
-
from typing import Dict, Any, Optional
|
8
|
-
from pathlib import Path
|
9
|
-
|
10
|
-
import websockets
|
11
|
-
import websockets.exceptions
|
12
|
-
from websockets.client import WebSocketClientProtocol
|
13
|
-
|
14
|
-
from .models import Job, JobStatus
|
15
|
-
from .utils.image_processor import ImageProcessor
|
16
|
-
|
17
|
-
logger = logging.getLogger(__name__)
|
18
|
-
|
19
|
-
|
20
|
-
class Worker:
|
21
|
-
"""Worker node that processes captioning jobs."""
|
22
|
-
|
23
|
-
def __init__(self, config: Dict[str, Any]):
|
24
|
-
self.config = config
|
25
|
-
self.server_url = config["server"]
|
26
|
-
self.token = config["token"]
|
27
|
-
self.name = config.get("name", "worker")
|
28
|
-
self.batch_size = config.get("batch_size", 32)
|
29
|
-
|
30
|
-
# Dataset configuration will be received from orchestrator
|
31
|
-
self.dataset_config = None
|
32
|
-
self.dataset_type = None
|
33
|
-
self.dataset_path = None
|
34
|
-
|
35
|
-
# SSL configuration
|
36
|
-
self.ssl_context = self._setup_ssl()
|
37
|
-
|
38
|
-
# Components
|
39
|
-
self.image_processor = ImageProcessor()
|
40
|
-
|
41
|
-
# State
|
42
|
-
self.worker_id: Optional[str] = None
|
43
|
-
self.websocket: Optional[WebSocketClientProtocol] = None
|
44
|
-
self.running = False
|
45
|
-
self.current_job: Optional[Job] = None
|
46
|
-
|
47
|
-
# Metrics
|
48
|
-
self.processed_count = 0
|
49
|
-
self.error_count = 0
|
50
|
-
|
51
|
-
def _setup_ssl(self) -> Optional[ssl.SSLContext]:
|
52
|
-
"""Configure SSL context."""
|
53
|
-
# Check if URL is WSS (requires SSL)
|
54
|
-
if self.server_url.startswith("ws://"):
|
55
|
-
logger.warning(
|
56
|
-
"Using insecure WebSocket connection (ws://). Consider using wss:// for production."
|
57
|
-
)
|
58
|
-
return None # No SSL for ws://
|
59
|
-
|
60
|
-
if not self.config.get("verify_ssl", True):
|
61
|
-
# Disable SSL verification for development
|
62
|
-
context = ssl.create_default_context()
|
63
|
-
context.check_hostname = False
|
64
|
-
context.verify_mode = ssl.CERT_NONE
|
65
|
-
return context
|
66
|
-
|
67
|
-
return ssl.create_default_context()
|
68
|
-
|
69
|
-
async def start(self):
|
70
|
-
"""Start the worker and connect to orchestrator."""
|
71
|
-
self.running = True
|
72
|
-
|
73
|
-
while self.running:
|
74
|
-
try:
|
75
|
-
await self._connect_and_run()
|
76
|
-
except websockets.exceptions.ConnectionClosed as e:
|
77
|
-
logger.warning(f"Connection closed: {e}")
|
78
|
-
if self.running:
|
79
|
-
logger.info("Reconnecting in 5 seconds...")
|
80
|
-
await asyncio.sleep(5)
|
81
|
-
except Exception as e:
|
82
|
-
logger.error(f"Connection error: {e}")
|
83
|
-
if self.running:
|
84
|
-
logger.info("Reconnecting in 5 seconds...")
|
85
|
-
await asyncio.sleep(5)
|
86
|
-
|
87
|
-
async def _connect_and_run(self):
|
88
|
-
"""Connect to orchestrator and process jobs."""
|
89
|
-
logger.info(f"Connecting to {self.server_url}")
|
90
|
-
|
91
|
-
try:
|
92
|
-
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
93
|
-
self.websocket = websocket
|
94
|
-
|
95
|
-
# Authenticate
|
96
|
-
await websocket.send(json.dumps({"token": self.token, "name": self.name}))
|
97
|
-
|
98
|
-
# Wait for welcome message with dataset configuration
|
99
|
-
welcome = await websocket.recv()
|
100
|
-
welcome_data = json.loads(welcome)
|
101
|
-
|
102
|
-
if "error" in welcome_data:
|
103
|
-
logger.error(f"Authentication failed: {welcome_data['error']}")
|
104
|
-
self.running = False
|
105
|
-
return
|
106
|
-
|
107
|
-
self.worker_id = welcome_data.get("worker_id")
|
108
|
-
|
109
|
-
# Extract and store dataset configuration from orchestrator
|
110
|
-
if "dataset_config" in welcome_data:
|
111
|
-
self.dataset_config = welcome_data["dataset_config"]
|
112
|
-
self.dataset_type = self.dataset_config.get("dataset_type")
|
113
|
-
self.dataset_path = self.dataset_config.get("dataset_path")
|
114
|
-
logger.info(
|
115
|
-
f"Received dataset configuration from orchestrator: "
|
116
|
-
f"type={self.dataset_type}, path={self.dataset_path}"
|
117
|
-
)
|
118
|
-
else:
|
119
|
-
logger.warning("No dataset configuration received from orchestrator")
|
120
|
-
|
121
|
-
logger.info(f"Connected as {self.worker_id}")
|
122
|
-
|
123
|
-
# Create tasks for concurrent operations
|
124
|
-
tasks = [
|
125
|
-
asyncio.create_task(self._heartbeat_loop()),
|
126
|
-
asyncio.create_task(self._job_processing_loop()),
|
127
|
-
asyncio.create_task(self._message_handler()),
|
128
|
-
]
|
129
|
-
|
130
|
-
try:
|
131
|
-
# Wait for any task to complete (usually due to connection close)
|
132
|
-
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
133
|
-
|
134
|
-
# Cancel remaining tasks
|
135
|
-
for task in pending:
|
136
|
-
task.cancel()
|
137
|
-
try:
|
138
|
-
await task
|
139
|
-
except asyncio.CancelledError:
|
140
|
-
pass
|
141
|
-
|
142
|
-
# Check if we had an error in completed tasks
|
143
|
-
for task in done:
|
144
|
-
try:
|
145
|
-
task.result()
|
146
|
-
except websockets.exceptions.ConnectionClosed:
|
147
|
-
logger.info("WebSocket connection closed")
|
148
|
-
except Exception as e:
|
149
|
-
logger.error(f"Task error: {e}")
|
150
|
-
|
151
|
-
except websockets.exceptions.ConnectionClosed:
|
152
|
-
logger.info("Connection closed by orchestrator")
|
153
|
-
|
154
|
-
except websockets.exceptions.ConnectionClosed as e:
|
155
|
-
logger.info(f"Failed to connect: {e}")
|
156
|
-
raise
|
157
|
-
except Exception as e:
|
158
|
-
logger.error(f"Unexpected error in connection: {e}")
|
159
|
-
raise
|
160
|
-
finally:
|
161
|
-
self.websocket = None
|
162
|
-
self.current_job = None
|
163
|
-
|
164
|
-
async def _job_processing_loop(self):
|
165
|
-
"""Main loop for requesting and processing jobs."""
|
166
|
-
while self.running and self.websocket:
|
167
|
-
try:
|
168
|
-
# Request a job
|
169
|
-
await self.websocket.send(json.dumps({"type": "request_job"}))
|
170
|
-
|
171
|
-
# Wait a bit for response
|
172
|
-
await asyncio.sleep(1)
|
173
|
-
|
174
|
-
if self.current_job:
|
175
|
-
await self._process_job(self.current_job)
|
176
|
-
self.current_job = None
|
177
|
-
else:
|
178
|
-
# No job available, wait before requesting again
|
179
|
-
await asyncio.sleep(5)
|
180
|
-
|
181
|
-
except websockets.exceptions.ConnectionClosed:
|
182
|
-
logger.info("Connection closed during job processing")
|
183
|
-
break
|
184
|
-
except Exception as e:
|
185
|
-
logger.error(f"Job processing error: {e}")
|
186
|
-
self.error_count += 1
|
187
|
-
await asyncio.sleep(1)
|
188
|
-
|
189
|
-
async def _message_handler(self):
|
190
|
-
"""Handle incoming messages from orchestrator."""
|
191
|
-
try:
|
192
|
-
async for message in self.websocket:
|
193
|
-
try:
|
194
|
-
data = json.loads(message)
|
195
|
-
msg_type = data.get("type")
|
196
|
-
|
197
|
-
if msg_type == "job":
|
198
|
-
job_data = data["job"]
|
199
|
-
self.current_job = Job(**job_data)
|
200
|
-
logger.info(f"Received job {self.current_job.job_id}")
|
201
|
-
|
202
|
-
elif msg_type == "no_jobs":
|
203
|
-
logger.debug("No jobs available")
|
204
|
-
|
205
|
-
elif msg_type == "ack":
|
206
|
-
logger.debug(f"Job {data['job_id']} acknowledged")
|
207
|
-
self.processed_count += 1
|
208
|
-
|
209
|
-
except json.JSONDecodeError as e:
|
210
|
-
logger.error(f"Invalid message format: {e}")
|
211
|
-
except Exception as e:
|
212
|
-
logger.error(f"Error handling message: {e}")
|
213
|
-
|
214
|
-
except websockets.exceptions.ConnectionClosed:
|
215
|
-
logger.info("Connection closed while waiting for messages")
|
216
|
-
except Exception as e:
|
217
|
-
logger.error(f"Message handler error: {e}")
|
218
|
-
|
219
|
-
async def _process_job(self, job: Job):
|
220
|
-
"""Process a single captioning job."""
|
221
|
-
if not self.websocket:
|
222
|
-
logger.warning(f"No websocket connection, skipping job {job.job_id}")
|
223
|
-
return
|
224
|
-
|
225
|
-
logger.info(f"Processing job {job.job_id}")
|
226
|
-
|
227
|
-
try:
|
228
|
-
# Load and preprocess images
|
229
|
-
images = await self._load_images(job)
|
230
|
-
|
231
|
-
# TODO: Here you would integrate your captioning model
|
232
|
-
# For now, using placeholder
|
233
|
-
caption = f"[Generated caption for {job.item_key}]"
|
234
|
-
|
235
|
-
# Submit result
|
236
|
-
await self.websocket.send(
|
237
|
-
json.dumps(
|
238
|
-
{
|
239
|
-
"type": "submit_caption",
|
240
|
-
"job_id": job.job_id,
|
241
|
-
"dataset": job.dataset,
|
242
|
-
"shard": job.shard,
|
243
|
-
"item_key": job.item_key,
|
244
|
-
"caption": caption,
|
245
|
-
}
|
246
|
-
)
|
247
|
-
)
|
248
|
-
|
249
|
-
logger.info(f"Completed job {job.job_id}")
|
250
|
-
|
251
|
-
except websockets.exceptions.ConnectionClosed:
|
252
|
-
logger.warning(f"Connection lost while processing job {job.job_id}")
|
253
|
-
raise # Re-raise to trigger reconnection
|
254
|
-
except Exception as e:
|
255
|
-
logger.error(f"Failed to process job {job.job_id}: {e}")
|
256
|
-
|
257
|
-
# Report failure if still connected
|
258
|
-
if self.websocket:
|
259
|
-
try:
|
260
|
-
await self.websocket.send(
|
261
|
-
json.dumps({"type": "job_failed", "job_id": job.job_id, "error": str(e)})
|
262
|
-
)
|
263
|
-
except:
|
264
|
-
pass # Connection might be closed
|
265
|
-
|
266
|
-
async def _load_images(self, job: Job):
|
267
|
-
"""Load and preprocess images for a job."""
|
268
|
-
# This would load actual images from the dataset
|
269
|
-
# Now can use self.dataset_type and self.dataset_path received from orchestrator
|
270
|
-
# For now, returning placeholder
|
271
|
-
return []
|
272
|
-
|
273
|
-
async def _heartbeat_loop(self):
|
274
|
-
"""Send periodic heartbeats to orchestrator."""
|
275
|
-
while self.running and self.websocket:
|
276
|
-
try:
|
277
|
-
await self.websocket.send(
|
278
|
-
json.dumps(
|
279
|
-
{
|
280
|
-
"type": "heartbeat",
|
281
|
-
"processed": self.processed_count,
|
282
|
-
"errors": self.error_count,
|
283
|
-
}
|
284
|
-
)
|
285
|
-
)
|
286
|
-
await asyncio.sleep(30)
|
287
|
-
except websockets.exceptions.ConnectionClosed:
|
288
|
-
logger.info("Connection closed during heartbeat")
|
289
|
-
break
|
290
|
-
except Exception as e:
|
291
|
-
logger.error(f"Heartbeat error: {e}")
|
292
|
-
break
|
293
|
-
|
294
|
-
async def shutdown(self):
|
295
|
-
"""Graceful shutdown."""
|
296
|
-
logger.info("Shutting down worker...")
|
297
|
-
self.running = False
|
298
|
-
|
299
|
-
if self.websocket:
|
300
|
-
await self.websocket.close()
|