caption-flow 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +937 -416
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +5 -3
- caption_flow/orchestrator.py +186 -116
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +440 -68
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +66 -25
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +41 -19
- caption_flow/utils/chunk_tracker.py +200 -65
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +12 -6
- caption_flow/workers/caption.py +272 -91
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
- caption_flow-0.4.0.dist-info/RECORD +33 -0
- caption_flow-0.3.3.dist-info/RECORD +0 -33
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,27 @@
|
|
1
1
|
"""Local filesystem datasets processor implementation."""
|
2
2
|
|
3
|
+
import asyncio
|
4
|
+
import io
|
3
5
|
import logging
|
4
|
-
import
|
6
|
+
import mimetypes
|
5
7
|
import os
|
6
|
-
|
7
|
-
from collections import
|
8
|
+
import threading
|
9
|
+
from collections import defaultdict, deque
|
8
10
|
from pathlib import Path
|
9
|
-
import
|
10
|
-
|
11
|
-
import mimetypes
|
12
|
-
from datetime import datetime
|
13
|
-
from PIL import Image
|
11
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Tuple
|
12
|
+
|
14
13
|
import aiofiles
|
15
|
-
from fastapi import FastAPI, HTTPException, Response
|
16
|
-
from fastapi.responses import StreamingResponse
|
17
|
-
import uvicorn
|
18
|
-
import asyncio
|
19
14
|
import requests
|
15
|
+
import uvicorn
|
16
|
+
from fastapi import FastAPI, HTTPException
|
17
|
+
from fastapi.responses import StreamingResponse
|
18
|
+
from PIL import Image
|
20
19
|
|
21
20
|
from caption_flow.storage import StorageManager
|
22
|
-
|
23
|
-
from ..utils import ChunkTracker
|
21
|
+
|
24
22
|
from ..models import JobId
|
23
|
+
from ..utils import ChunkTracker
|
24
|
+
from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
|
25
25
|
|
26
26
|
logger = logging.getLogger(__name__)
|
27
27
|
logger.setLevel(logging.DEBUG)
|
@@ -217,23 +217,19 @@ class LocalFilesystemOrchestratorProcessor(OrchestratorProcessor):
|
|
217
217
|
if not self.chunk_tracker:
|
218
218
|
return
|
219
219
|
|
220
|
-
|
220
|
+
storage.get_all_processed_job_ids()
|
221
221
|
|
222
222
|
with self.lock:
|
223
223
|
for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
|
224
|
-
#
|
225
|
-
|
226
|
-
chunk_state.start_index,
|
227
|
-
chunk_state.start_index + chunk_state.chunk_size - 1,
|
228
|
-
)
|
229
|
-
|
230
|
-
# Get processed indices for this chunk
|
231
|
-
processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
|
232
|
-
chunk_id, all_processed_jobs
|
233
|
-
)
|
224
|
+
# Get unprocessed ranges (relative coordinates from ChunkTracker)
|
225
|
+
relative_unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
234
226
|
|
235
|
-
#
|
236
|
-
unprocessed_ranges =
|
227
|
+
# Convert relative ranges to absolute ranges
|
228
|
+
unprocessed_ranges = []
|
229
|
+
for start, end in relative_unprocessed_ranges:
|
230
|
+
abs_start = chunk_state.start_index + start
|
231
|
+
abs_end = chunk_state.start_index + end
|
232
|
+
unprocessed_ranges.append((abs_start, abs_end))
|
237
233
|
|
238
234
|
if unprocessed_ranges:
|
239
235
|
# Create work unit for unprocessed items
|
@@ -588,7 +584,7 @@ class LocalFilesystemWorkerProcessor(WorkerProcessor):
|
|
588
584
|
processed_indices = []
|
589
585
|
|
590
586
|
# Get orchestrator info if we need HTTP
|
591
|
-
|
587
|
+
context.get("orchestrator")
|
592
588
|
|
593
589
|
for idx in sorted(indices_to_process):
|
594
590
|
try:
|
@@ -1,28 +1,27 @@
|
|
1
1
|
"""WebDataset processor implementation using webshart TarDataLoader."""
|
2
2
|
|
3
|
-
import logging
|
4
|
-
import threading
|
5
3
|
import gc
|
4
|
+
import io
|
5
|
+
import logging
|
6
6
|
import os
|
7
|
-
|
8
|
-
from collections import
|
7
|
+
import threading
|
8
|
+
from collections import defaultdict, deque
|
9
9
|
from pathlib import Path
|
10
|
-
import
|
11
|
-
|
10
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional, Set
|
11
|
+
|
12
|
+
import cv2
|
13
|
+
import numpy as np
|
14
|
+
import webshart
|
12
15
|
from PIL import Image
|
13
|
-
import io
|
14
16
|
|
15
17
|
from caption_flow.models import JobId
|
16
18
|
from caption_flow.storage import StorageManager
|
17
|
-
from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
|
18
|
-
from ..utils import ChunkTracker
|
19
19
|
|
20
|
-
import
|
21
|
-
import
|
22
|
-
import numpy as np
|
20
|
+
from ..utils import ChunkTracker
|
21
|
+
from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
|
23
22
|
|
24
23
|
logger = logging.getLogger(__name__)
|
25
|
-
logger.setLevel(
|
24
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
26
25
|
|
27
26
|
|
28
27
|
class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
@@ -217,7 +216,7 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
217
216
|
while units_created < units_needed and not self.stop_creation.is_set():
|
218
217
|
# Get current shard info
|
219
218
|
if current_shard_idx >= self.dataset.num_shards:
|
220
|
-
|
219
|
+
threading.Event().wait(5)
|
221
220
|
break
|
222
221
|
|
223
222
|
shard_info = self._get_shard_info_cached(current_shard_idx)
|
@@ -240,8 +239,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
240
239
|
self.current_chunk_index = current_file_idx // self.chunk_size
|
241
240
|
job_id_obj = JobId(
|
242
241
|
shard_id=shard_name,
|
243
|
-
chunk_id=self.current_chunk_index,
|
244
|
-
sample_id=current_file_idx,
|
242
|
+
chunk_id=str(self.current_chunk_index),
|
243
|
+
sample_id=str(current_file_idx),
|
245
244
|
)
|
246
245
|
chunk_id = job_id_obj.get_chunk_str()
|
247
246
|
|
@@ -306,8 +305,15 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
306
305
|
assigned = []
|
307
306
|
|
308
307
|
with self.lock:
|
309
|
-
|
308
|
+
units_checked = 0
|
309
|
+
max_units_to_check = len(self.pending_units)
|
310
|
+
|
311
|
+
while len(assigned) < count and units_checked < max_units_to_check:
|
312
|
+
if not self.pending_units:
|
313
|
+
break
|
314
|
+
|
310
315
|
unit_id = self.pending_units.popleft()
|
316
|
+
units_checked += 1
|
311
317
|
unit = self.work_units.get(unit_id)
|
312
318
|
|
313
319
|
if unit:
|
@@ -316,6 +322,16 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
316
322
|
chunk_state = self.chunk_tracker.chunks[unit_id]
|
317
323
|
relative_unprocessed = chunk_state.get_unprocessed_ranges()
|
318
324
|
|
325
|
+
# If no unprocessed ranges, mark as completed and skip
|
326
|
+
if not relative_unprocessed:
|
327
|
+
logger.info(
|
328
|
+
f"Chunk {unit_id} has no unprocessed ranges, marking as completed"
|
329
|
+
)
|
330
|
+
self.chunk_tracker.mark_completed(unit_id)
|
331
|
+
# Remove from work units
|
332
|
+
del self.work_units[unit_id]
|
333
|
+
continue
|
334
|
+
|
319
335
|
# Convert relative to absolute indices
|
320
336
|
absolute_ranges = []
|
321
337
|
for start, end in relative_unprocessed:
|
@@ -335,6 +351,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
335
351
|
|
336
352
|
if self.chunk_tracker:
|
337
353
|
self.chunk_tracker.mark_assigned(unit_id, worker_id)
|
354
|
+
else:
|
355
|
+
# Put it back if we couldn't get the unit
|
356
|
+
self.pending_units.append(unit_id)
|
338
357
|
|
339
358
|
logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
|
340
359
|
return assigned
|
@@ -394,8 +413,20 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
394
413
|
logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
|
395
414
|
|
396
415
|
def handle_result(self, result: WorkResult) -> Dict[str, Any]:
|
397
|
-
"""Handle result from worker."""
|
398
|
-
#
|
416
|
+
"""Handle result from worker and update chunk tracker."""
|
417
|
+
# Extract the actual item index from the metadata
|
418
|
+
item_index = result.metadata.get("_item_index", None)
|
419
|
+
|
420
|
+
# If we have an item index, mark it as processed in the chunk tracker
|
421
|
+
if self.chunk_tracker and item_index is not None and result.chunk_id:
|
422
|
+
try:
|
423
|
+
# Mark single item as processed
|
424
|
+
self.chunk_tracker.mark_items_processed(result.chunk_id, item_index, item_index)
|
425
|
+
# logger.debug(f"Marked item {item_index} as processed in chunk {result.chunk_id}")
|
426
|
+
except Exception as e:
|
427
|
+
logger.error(f"Error marking item {item_index} as processed: {e}")
|
428
|
+
|
429
|
+
# Also handle batch results if present (backward compatibility)
|
399
430
|
if self.chunk_tracker and "item_indices" in result.metadata:
|
400
431
|
indices = result.metadata["item_indices"]
|
401
432
|
|
@@ -419,6 +450,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
419
450
|
# Mark ranges as processed
|
420
451
|
for start_idx, end_idx in ranges:
|
421
452
|
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
453
|
+
logger.debug(
|
454
|
+
f"Marked range {start_idx}-{end_idx} as processed in chunk {result.chunk_id}"
|
455
|
+
)
|
422
456
|
|
423
457
|
return {
|
424
458
|
"source_id": result.source_id,
|
@@ -499,8 +533,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
499
533
|
for start_idx, end_idx in ranges:
|
500
534
|
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
501
535
|
|
502
|
-
#
|
503
|
-
self.chunk_tracker.
|
536
|
+
# Flush checkpoint after major update
|
537
|
+
self.chunk_tracker.flush()
|
504
538
|
|
505
539
|
def get_stats(self) -> Dict[str, Any]:
|
506
540
|
"""Get processor statistics."""
|
@@ -537,9 +571,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
537
571
|
if self.unit_creation_thread:
|
538
572
|
self.unit_creation_thread.join(timeout=5)
|
539
573
|
|
540
|
-
#
|
574
|
+
# Flush final checkpoint on cleanup
|
541
575
|
if self.chunk_tracker:
|
542
|
-
self.chunk_tracker.
|
576
|
+
self.chunk_tracker.flush()
|
543
577
|
|
544
578
|
|
545
579
|
class WebDatasetWorkerProcessor(WorkerProcessor):
|
@@ -559,6 +593,9 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
559
593
|
self.dataset_path = dataset_cfg.get("dataset_path")
|
560
594
|
metadata_path = dataset_cfg.get("metadata_path", None)
|
561
595
|
self.mock_results = dataset_cfg.get("mock_results", False)
|
596
|
+
split_worker_cache = dataset_cfg.get(
|
597
|
+
"split_worker_cache", True
|
598
|
+
) # multiple workers get their own cache by default
|
562
599
|
|
563
600
|
# Cache configuration
|
564
601
|
cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
|
@@ -574,7 +611,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
574
611
|
# Enable caching
|
575
612
|
self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
|
576
613
|
self.dataset.enable_shard_cache(
|
577
|
-
location=
|
614
|
+
location=(
|
615
|
+
str(cache_dir / "shard_cache" / str(self.gpu_id))
|
616
|
+
if split_worker_cache
|
617
|
+
else str(cache_dir / "shard_cache")
|
618
|
+
),
|
578
619
|
cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
|
579
620
|
)
|
580
621
|
|
@@ -646,7 +687,7 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
646
687
|
# Iterate through the range
|
647
688
|
for idx in range(start_idx, end_idx + 1):
|
648
689
|
try:
|
649
|
-
entry =
|
690
|
+
entry = webshart.next_with_cache_wait(self.loader)
|
650
691
|
|
651
692
|
# Decode image
|
652
693
|
image = None
|