caption-flow 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +1 -1
- caption_flow/cli.py +4 -2
- caption_flow/monitor.py +3 -0
- caption_flow/orchestrator.py +53 -32
- caption_flow/processors/huggingface.py +2 -2
- caption_flow/processors/webdataset.py +39 -4
- caption_flow/utils/checkpoint_tracker.py +41 -21
- caption_flow/utils/chunk_tracker.py +85 -47
- caption_flow/workers/base.py +7 -2
- caption_flow/workers/caption.py +7 -1
- {caption_flow-0.3.3.dist-info → caption_flow-0.3.4.dist-info}/METADATA +1 -1
- {caption_flow-0.3.3.dist-info → caption_flow-0.3.4.dist-info}/RECORD +16 -16
- {caption_flow-0.3.3.dist-info → caption_flow-0.3.4.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.3.4.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.3.dist-info → caption_flow-0.3.4.dist-info}/top_level.txt +0 -0
caption_flow/__init__.py
CHANGED
caption_flow/cli.py
CHANGED
@@ -124,7 +124,7 @@ def setup_logging(verbose: bool = False):
|
|
124
124
|
level = logging.DEBUG if verbose else logging.INFO
|
125
125
|
logging.basicConfig(
|
126
126
|
level=level,
|
127
|
-
format="%(message)s",
|
127
|
+
format="%(name)s: %(message)s",
|
128
128
|
datefmt="[%Y-%m-%d %H:%M:%S]",
|
129
129
|
handlers=[
|
130
130
|
RichHandler(
|
@@ -490,7 +490,9 @@ def reload_config(
|
|
490
490
|
|
491
491
|
async def send_reload():
|
492
492
|
try:
|
493
|
-
async with websockets.connect(
|
493
|
+
async with websockets.connect(
|
494
|
+
server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
|
495
|
+
) as websocket:
|
494
496
|
# Authenticate as admin
|
495
497
|
await websocket.send(json.dumps({"token": token, "role": "admin"}))
|
496
498
|
|
caption_flow/monitor.py
CHANGED
@@ -73,6 +73,9 @@ class Monitor:
|
|
73
73
|
async with websockets.connect(
|
74
74
|
self.server_url,
|
75
75
|
ssl=self.ssl_context if self.server_url.startswith("wss://") else None,
|
76
|
+
ping_interval=20,
|
77
|
+
ping_timeout=60,
|
78
|
+
close_timeout=10,
|
76
79
|
) as websocket:
|
77
80
|
# Authenticate
|
78
81
|
await websocket.send(json.dumps({"token": self.token}))
|
caption_flow/orchestrator.py
CHANGED
@@ -377,16 +377,17 @@ class Orchestrator:
|
|
377
377
|
"""Process results submission from worker."""
|
378
378
|
# Extract user from worker_id
|
379
379
|
worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
380
|
+
|
380
381
|
# Create work result
|
381
382
|
_job_id = data.get("job_id")
|
382
383
|
job_id = JobId.from_str(_job_id)
|
383
|
-
shard_name = job_id.shard_id
|
384
|
-
chunk_name = job_id.chunk_id
|
385
|
-
|
384
|
+
shard_name = job_id.shard_id
|
385
|
+
chunk_name = job_id.chunk_id
|
386
|
+
|
386
387
|
result = WorkResult(
|
387
388
|
unit_id=data["unit_id"],
|
388
389
|
source_id=shard_name,
|
389
|
-
chunk_id=job_id.get_chunk_str(),
|
390
|
+
chunk_id=job_id.get_chunk_str(),
|
390
391
|
sample_id=data["sample_id"],
|
391
392
|
dataset=data["dataset"],
|
392
393
|
outputs=data["outputs"],
|
@@ -394,7 +395,9 @@ class Orchestrator:
|
|
394
395
|
processing_time_ms=data.get("processing_time_ms", 0),
|
395
396
|
)
|
396
397
|
|
397
|
-
# Let processor handle any custom processing
|
398
|
+
# Let processor handle any custom processing - this updates chunk tracker
|
399
|
+
# IMPORTANT: Call this BEFORE saving to storage so chunk tracker is updated
|
400
|
+
# regardless of whether the item is a duplicate
|
398
401
|
processed = self.processor.handle_result(result)
|
399
402
|
|
400
403
|
# Create caption record for storage
|
@@ -412,6 +415,7 @@ class Orchestrator:
|
|
412
415
|
for key in to_delete_metadata_keys:
|
413
416
|
if key in result.metadata:
|
414
417
|
del result.metadata[key]
|
418
|
+
|
415
419
|
caption = Caption(
|
416
420
|
job_id=job_id,
|
417
421
|
dataset=result.dataset,
|
@@ -433,14 +437,15 @@ class Orchestrator:
|
|
433
437
|
image_format=image_format,
|
434
438
|
)
|
435
439
|
|
436
|
-
# Save to storage
|
437
|
-
await self.storage.save_caption(caption)
|
440
|
+
# Save to storage (might skip if duplicate)
|
441
|
+
saved = await self.storage.save_caption(caption)
|
438
442
|
|
439
|
-
# Update contributor stats
|
440
|
-
|
441
|
-
|
442
|
-
contributor
|
443
|
-
|
443
|
+
# Update contributor stats only if actually saved
|
444
|
+
if saved:
|
445
|
+
contributor = await self.storage.get_contributor(worker_user)
|
446
|
+
if contributor:
|
447
|
+
contributor.total_captions += total_outputs
|
448
|
+
await self.storage.save_contributor(contributor)
|
444
449
|
|
445
450
|
async def _handle_monitor(self, websocket: WebSocketServerProtocol):
|
446
451
|
"""Handle monitor connection."""
|
@@ -840,39 +845,55 @@ class Orchestrator:
|
|
840
845
|
self.monitors -= disconnected
|
841
846
|
|
842
847
|
async def _heartbeat_loop(self):
|
843
|
-
"""
|
848
|
+
"""Collect and log worker status periodically."""
|
844
849
|
while True:
|
845
850
|
await asyncio.sleep(30)
|
846
851
|
|
847
|
-
|
852
|
+
# Just collect status - no ping/pong
|
853
|
+
active_workers = []
|
848
854
|
for worker_id, ws in list(self.workers.items()):
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
# Clean up disconnected workers
|
856
|
-
for worker_id in disconnected:
|
857
|
-
logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
|
858
|
-
if worker_id in self.workers:
|
855
|
+
# Check if WebSocket is still open (don't ping)
|
856
|
+
if ws.state == websockets.protocol.State.OPEN:
|
857
|
+
active_workers.append(worker_id)
|
858
|
+
else:
|
859
|
+
# Clean up closed connections
|
860
|
+
logger.info(f"Worker {worker_id} connection closed")
|
859
861
|
del self.workers[worker_id]
|
860
|
-
logger.warning(
|
861
|
-
f"Releasing assignments for worker {worker_id} because it did not respond to ping"
|
862
|
-
)
|
863
862
|
self.processor.release_assignments(worker_id)
|
864
|
-
|
863
|
+
|
864
|
+
# Log status
|
865
|
+
if active_workers:
|
866
|
+
logger.debug(
|
867
|
+
f"Active workers: {len(active_workers)} - {', '.join(active_workers[:5])}"
|
868
|
+
)
|
869
|
+
logger.debug(f"Inactive workers: {len(self.workers) - len(active_workers)}")
|
870
|
+
# add to self.stats
|
871
|
+
self.stats["active_workers"] = len(active_workers)
|
872
|
+
self.stats["inactive_workers"] = len(self.workers) - len(active_workers)
|
865
873
|
|
866
874
|
async def _checkpoint_loop(self):
|
867
|
-
"""Periodically checkpoint storage."""
|
875
|
+
"""Periodically checkpoint storage and chunk tracker."""
|
868
876
|
interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
|
869
877
|
|
870
878
|
while True:
|
871
879
|
await asyncio.sleep(interval)
|
872
880
|
|
873
|
-
|
874
|
-
|
875
|
-
|
881
|
+
try:
|
882
|
+
# Checkpoint storage
|
883
|
+
await self.storage.checkpoint()
|
884
|
+
|
885
|
+
# Also checkpoint the chunk tracker if using webdataset processor
|
886
|
+
if hasattr(self.processor, "chunk_tracker") and self.processor.chunk_tracker:
|
887
|
+
# Save checkpoint in thread pool to avoid blocking
|
888
|
+
await asyncio.get_event_loop().run_in_executor(
|
889
|
+
None, self.processor.chunk_tracker.save
|
890
|
+
)
|
891
|
+
logger.debug("Saved chunk tracker checkpoint")
|
892
|
+
|
893
|
+
self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
|
894
|
+
logger.info("Storage and chunk tracker checkpoint complete")
|
895
|
+
except Exception as e:
|
896
|
+
logger.error(f"Error during checkpoint: {e}", exc_info=True)
|
876
897
|
|
877
898
|
async def _stats_update_loop(self):
|
878
899
|
"""Periodically update and broadcast stats."""
|
@@ -551,7 +551,7 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
551
551
|
|
552
552
|
# Force checkpoint save if needed
|
553
553
|
if self.chunk_tracker:
|
554
|
-
self.chunk_tracker.
|
554
|
+
self.chunk_tracker.save()
|
555
555
|
|
556
556
|
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
557
557
|
"""Get available work units for a worker."""
|
@@ -717,7 +717,7 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
717
717
|
|
718
718
|
# Save final state
|
719
719
|
if self.chunk_tracker:
|
720
|
-
self.chunk_tracker.
|
720
|
+
self.chunk_tracker.save()
|
721
721
|
|
722
722
|
|
723
723
|
class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
@@ -306,8 +306,15 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
306
306
|
assigned = []
|
307
307
|
|
308
308
|
with self.lock:
|
309
|
-
|
309
|
+
units_checked = 0
|
310
|
+
max_units_to_check = len(self.pending_units)
|
311
|
+
|
312
|
+
while len(assigned) < count and units_checked < max_units_to_check:
|
313
|
+
if not self.pending_units:
|
314
|
+
break
|
315
|
+
|
310
316
|
unit_id = self.pending_units.popleft()
|
317
|
+
units_checked += 1
|
311
318
|
unit = self.work_units.get(unit_id)
|
312
319
|
|
313
320
|
if unit:
|
@@ -316,6 +323,16 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
316
323
|
chunk_state = self.chunk_tracker.chunks[unit_id]
|
317
324
|
relative_unprocessed = chunk_state.get_unprocessed_ranges()
|
318
325
|
|
326
|
+
# If no unprocessed ranges, mark as completed and skip
|
327
|
+
if not relative_unprocessed:
|
328
|
+
logger.info(
|
329
|
+
f"Chunk {unit_id} has no unprocessed ranges, marking as completed"
|
330
|
+
)
|
331
|
+
self.chunk_tracker.mark_completed(unit_id)
|
332
|
+
# Remove from work units
|
333
|
+
del self.work_units[unit_id]
|
334
|
+
continue
|
335
|
+
|
319
336
|
# Convert relative to absolute indices
|
320
337
|
absolute_ranges = []
|
321
338
|
for start, end in relative_unprocessed:
|
@@ -335,6 +352,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
335
352
|
|
336
353
|
if self.chunk_tracker:
|
337
354
|
self.chunk_tracker.mark_assigned(unit_id, worker_id)
|
355
|
+
else:
|
356
|
+
# Put it back if we couldn't get the unit
|
357
|
+
self.pending_units.append(unit_id)
|
338
358
|
|
339
359
|
logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
|
340
360
|
return assigned
|
@@ -394,8 +414,20 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
394
414
|
logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
|
395
415
|
|
396
416
|
def handle_result(self, result: WorkResult) -> Dict[str, Any]:
|
397
|
-
"""Handle result from worker."""
|
398
|
-
#
|
417
|
+
"""Handle result from worker and update chunk tracker."""
|
418
|
+
# Extract the actual item index from the metadata
|
419
|
+
item_index = result.metadata.get("_item_index", None)
|
420
|
+
|
421
|
+
# If we have an item index, mark it as processed in the chunk tracker
|
422
|
+
if self.chunk_tracker and item_index is not None and result.chunk_id:
|
423
|
+
try:
|
424
|
+
# Mark single item as processed
|
425
|
+
self.chunk_tracker.mark_items_processed(result.chunk_id, item_index, item_index)
|
426
|
+
# logger.debug(f"Marked item {item_index} as processed in chunk {result.chunk_id}")
|
427
|
+
except Exception as e:
|
428
|
+
logger.error(f"Error marking item {item_index} as processed: {e}")
|
429
|
+
|
430
|
+
# Also handle batch results if present (backward compatibility)
|
399
431
|
if self.chunk_tracker and "item_indices" in result.metadata:
|
400
432
|
indices = result.metadata["item_indices"]
|
401
433
|
|
@@ -419,6 +451,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
419
451
|
# Mark ranges as processed
|
420
452
|
for start_idx, end_idx in ranges:
|
421
453
|
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
454
|
+
logger.debug(
|
455
|
+
f"Marked range {start_idx}-{end_idx} as processed in chunk {result.chunk_id}"
|
456
|
+
)
|
422
457
|
|
423
458
|
return {
|
424
459
|
"source_id": result.source_id,
|
@@ -539,7 +574,7 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
539
574
|
|
540
575
|
# Save checkpoint
|
541
576
|
if self.chunk_tracker:
|
542
|
-
self.chunk_tracker.
|
577
|
+
self.chunk_tracker.save()
|
543
578
|
|
544
579
|
|
545
580
|
class WebDatasetWorkerProcessor(WorkerProcessor):
|
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import Dict, Any, Optional
|
8
8
|
from datetime import datetime
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
10
|
|
10
11
|
logger = logging.getLogger(__name__)
|
11
12
|
|
@@ -52,35 +53,54 @@ class CheckpointTracker(ABC):
|
|
52
53
|
|
53
54
|
def save(self) -> None:
|
54
55
|
"""Save checkpoint to disk atomically."""
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
56
|
+
with self.lock:
|
57
|
+
try:
|
58
|
+
# Prepare data with metadata
|
59
|
+
data = self._serialize_state()
|
60
|
+
data["updated_at"] = datetime.utcnow().isoformat()
|
61
|
+
|
62
|
+
# Write atomically using temp file
|
63
|
+
tmp_file = self.checkpoint_path.with_suffix(".tmp")
|
64
|
+
# If a save is already in progress, let it finish.
|
65
|
+
# This prevents race conditions if save() is called rapidly.
|
66
|
+
if (
|
67
|
+
hasattr(self, "_save_future")
|
68
|
+
and self._save_future
|
69
|
+
and not self._save_future.done()
|
70
|
+
):
|
71
|
+
self._save_future.result() # Wait for the previous save to complete
|
72
|
+
|
73
|
+
# Use an executor to run the save operation in a background thread.
|
74
|
+
# This makes the save call non-blocking.
|
75
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
76
|
+
data_to_save = data.copy()
|
77
|
+
self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
|
78
|
+
except Exception as e:
|
79
|
+
logger.error(f"Failed to submit save task: {e}", exc_info=True)
|
59
80
|
|
60
|
-
|
61
|
-
|
81
|
+
def _write_to_disk(self, data: Dict[str, Any]) -> None:
|
82
|
+
"""Write checkpoint data to disk atomically."""
|
83
|
+
# Create a temporary file in the same directory as the checkpoint
|
84
|
+
tmp_file = self.checkpoint_path.with_suffix(".tmp")
|
85
|
+
|
86
|
+
try:
|
87
|
+
# Ensure the parent directory exists
|
88
|
+
self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
62
89
|
|
63
90
|
with open(tmp_file, "w") as f:
|
64
91
|
json.dump(data, f, indent=2)
|
65
92
|
|
66
|
-
#
|
67
|
-
if not tmp_file.exists():
|
68
|
-
raise IOError(f"Failed to create temporary file: {tmp_file}")
|
69
|
-
|
70
|
-
# Move atomically
|
93
|
+
# Atomically replace the checkpoint file
|
71
94
|
tmp_file.replace(self.checkpoint_path)
|
72
|
-
|
73
95
|
logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
|
74
|
-
|
75
96
|
except Exception as e:
|
76
|
-
|
77
|
-
# Try
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
logger.error(f"Fallback save also failed: {fallback_error}")
|
97
|
+
logger.error(f"Failed to save checkpoint atomically: {e}", exc_info=True)
|
98
|
+
# Try to clean up the temp file if it exists
|
99
|
+
if tmp_file.exists():
|
100
|
+
try:
|
101
|
+
tmp_file.unlink()
|
102
|
+
except:
|
103
|
+
pass
|
84
104
|
|
85
105
|
def get_stats(self) -> Dict[str, Any]:
|
86
106
|
"""Get statistics about tracked items. Override for custom stats."""
|
@@ -8,6 +8,7 @@ from datetime import datetime, timedelta
|
|
8
8
|
from dataclasses import dataclass, asdict, field
|
9
9
|
|
10
10
|
from .checkpoint_tracker import CheckpointTracker
|
11
|
+
from threading import Lock
|
11
12
|
|
12
13
|
logger = logging.getLogger(__name__)
|
13
14
|
logger.setLevel(logging.DEBUG)
|
@@ -60,12 +61,12 @@ class ChunkState:
|
|
60
61
|
self.status = "completed"
|
61
62
|
self.completed_at = datetime.utcnow()
|
62
63
|
# Clear processed_ranges since we don't need them after completion
|
63
|
-
self.processed_ranges = []
|
64
|
-
self.assigned_to = None
|
65
|
-
self.assigned_at = None
|
64
|
+
# self.processed_ranges = []
|
65
|
+
# self.assigned_to = None
|
66
|
+
# self.assigned_at = None
|
66
67
|
|
67
68
|
def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
|
68
|
-
"""Get ranges
|
69
|
+
"""Get ranges of unprocessed items within the chunk (relative indices)."""
|
69
70
|
if self.status == "completed":
|
70
71
|
return []
|
71
72
|
|
@@ -73,22 +74,57 @@ class ChunkState:
|
|
73
74
|
logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
|
74
75
|
return [(0, self.chunk_size - 1)]
|
75
76
|
|
77
|
+
# Merge ranges first to ensure no overlaps
|
78
|
+
merged_ranges = self._merge_ranges(self.processed_ranges)
|
79
|
+
|
76
80
|
unprocessed = []
|
77
|
-
|
81
|
+
current_pos = 0
|
78
82
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
if current < start:
|
84
|
-
unprocessed.append((current, start - 1))
|
85
|
-
current = max(current, end + 1)
|
83
|
+
for start, end in merged_ranges:
|
84
|
+
if current_pos < start:
|
85
|
+
unprocessed.append((current_pos, start - 1))
|
86
|
+
current_pos = max(current_pos, end + 1)
|
86
87
|
|
87
|
-
|
88
|
-
|
88
|
+
# Add any remaining range
|
89
|
+
if current_pos < self.chunk_size:
|
90
|
+
unprocessed.append((current_pos, self.chunk_size - 1))
|
91
|
+
|
92
|
+
# Log for debugging
|
93
|
+
if not unprocessed:
|
94
|
+
logger.info(
|
95
|
+
f"Chunk {self.chunk_id} has processed ranges {merged_ranges} covering entire chunk size {self.chunk_size}"
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
total_processed = sum(end - start + 1 for start, end in merged_ranges)
|
99
|
+
total_unprocessed = sum(end - start + 1 for start, end in unprocessed)
|
100
|
+
logger.debug(
|
101
|
+
f"Chunk {self.chunk_id}: {total_processed} processed, {total_unprocessed} unprocessed"
|
102
|
+
)
|
89
103
|
|
90
104
|
return unprocessed
|
91
105
|
|
106
|
+
def _merge_ranges(self, ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
|
107
|
+
"""Merge overlapping or adjacent ranges."""
|
108
|
+
if not ranges:
|
109
|
+
return []
|
110
|
+
|
111
|
+
# Sort ranges by start index, ensuring all are tuples
|
112
|
+
sorted_ranges = sorted([tuple(r) for r in ranges])
|
113
|
+
merged = [sorted_ranges[0]]
|
114
|
+
|
115
|
+
for current_start, current_end in sorted_ranges[1:]:
|
116
|
+
last_start, last_end = merged[-1]
|
117
|
+
|
118
|
+
# Check if ranges overlap or are adjacent
|
119
|
+
if current_start <= last_end + 1:
|
120
|
+
# Merge the ranges
|
121
|
+
merged[-1] = (last_start, max(last_end, current_end))
|
122
|
+
else:
|
123
|
+
# Add as new range
|
124
|
+
merged.append((current_start, current_end))
|
125
|
+
|
126
|
+
return merged
|
127
|
+
|
92
128
|
def to_dict(self):
|
93
129
|
"""Convert to dictionary for JSON serialization."""
|
94
130
|
d = asdict(self)
|
@@ -124,6 +160,7 @@ class ChunkTracker(CheckpointTracker):
|
|
124
160
|
self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
|
125
161
|
self.archive_after_hours = archive_after_hours
|
126
162
|
self._completed_count = 0 # Track count without storing all IDs
|
163
|
+
self.lock = Lock()
|
127
164
|
super().__init__(checkpoint_file)
|
128
165
|
|
129
166
|
def _get_default_state(self) -> Dict[str, Any]:
|
@@ -132,16 +169,17 @@ class ChunkTracker(CheckpointTracker):
|
|
132
169
|
|
133
170
|
def _deserialize_state(self, data: Dict[str, Any]) -> None:
|
134
171
|
"""Deserialize loaded data into instance state."""
|
135
|
-
self.
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
172
|
+
with self.lock:
|
173
|
+
self.chunks = {}
|
174
|
+
self._completed_count = data.get("completed_count", 0)
|
175
|
+
|
176
|
+
# Load chunk states
|
177
|
+
completed_chunks = 0
|
178
|
+
for chunk_id, chunk_data in data.get("chunks", {}).items():
|
179
|
+
chunk_state = ChunkState.from_dict(chunk_data)
|
180
|
+
self.chunks[chunk_id] = chunk_state
|
181
|
+
if chunk_state.status == "completed":
|
182
|
+
completed_chunks += 1
|
145
183
|
|
146
184
|
logger.info(
|
147
185
|
f"Loaded {len(self.chunks)} chunks from checkpoint, "
|
@@ -494,40 +532,40 @@ class ChunkTracker(CheckpointTracker):
|
|
494
532
|
for start_idx, end_idx in ranges:
|
495
533
|
chunk.add_processed_range(start_idx, end_idx)
|
496
534
|
|
497
|
-
def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int):
|
498
|
-
"""Mark a range of items as processed within a chunk
|
535
|
+
def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int) -> None:
|
536
|
+
"""Mark a range of items as processed within a chunk."""
|
499
537
|
if chunk_id not in self.chunks:
|
500
|
-
logger.
|
538
|
+
logger.warning(f"Chunk {chunk_id} not found in tracker")
|
501
539
|
return
|
502
540
|
|
503
|
-
|
541
|
+
chunk_state = self.chunks[chunk_id]
|
504
542
|
|
505
|
-
# Convert absolute indices to chunk-relative
|
506
|
-
relative_start = start_idx -
|
507
|
-
relative_end = end_idx -
|
543
|
+
# Convert absolute indices to chunk-relative indices
|
544
|
+
relative_start = start_idx - chunk_state.start_index
|
545
|
+
relative_end = end_idx - chunk_state.start_index
|
508
546
|
|
509
|
-
#
|
510
|
-
|
511
|
-
|
512
|
-
f"Invalid indices for chunk {chunk_id}: "
|
513
|
-
f"absolute {start_idx}-{end_idx} (relative {relative_start}-{relative_end}) "
|
514
|
-
f"outside chunk bounds [{chunk.start_index}, {chunk.start_index + chunk.chunk_size - 1}]"
|
515
|
-
)
|
516
|
-
return
|
547
|
+
# Ensure indices are within chunk bounds
|
548
|
+
relative_start = max(0, relative_start)
|
549
|
+
relative_end = min(chunk_state.chunk_size - 1, relative_end)
|
517
550
|
|
518
|
-
# Add
|
519
|
-
|
551
|
+
# Add to processed ranges
|
552
|
+
chunk_state.processed_ranges.append((relative_start, relative_end))
|
520
553
|
|
521
|
-
#
|
522
|
-
|
523
|
-
self._completed_count += 1
|
554
|
+
# Merge overlapping ranges
|
555
|
+
chunk_state.processed_ranges = chunk_state._merge_ranges(chunk_state.processed_ranges)
|
524
556
|
|
525
|
-
self.save()
|
526
557
|
logger.debug(
|
527
|
-
f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} "
|
528
|
-
f"(relative indices: {relative_start}-{relative_end})"
|
558
|
+
f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} (relative indices: {relative_start}-{relative_end})"
|
529
559
|
)
|
530
560
|
|
561
|
+
# Check if chunk is now complete
|
562
|
+
if chunk_state.get_unprocessed_ranges() == []:
|
563
|
+
logger.info(f"Chunk {chunk_id} is now complete")
|
564
|
+
chunk_state.status = "completed"
|
565
|
+
|
566
|
+
# Save checkpoint after updating
|
567
|
+
self.save()
|
568
|
+
|
531
569
|
def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
|
532
570
|
"""Get chunk info with unprocessed item ranges."""
|
533
571
|
chunk_state = self.chunks.get(chunk_id)
|
caption_flow/workers/base.py
CHANGED
@@ -89,8 +89,13 @@ class BaseWorker(ABC):
|
|
89
89
|
async def _connect_and_run(self):
|
90
90
|
"""Connect to orchestrator and run main loop."""
|
91
91
|
logger.info(f"Connecting to {self.server_url}")
|
92
|
-
|
93
|
-
|
92
|
+
async with websockets.connect(
|
93
|
+
self.server_url,
|
94
|
+
ssl=self.ssl_context,
|
95
|
+
ping_interval=20,
|
96
|
+
ping_timeout=60,
|
97
|
+
close_timeout=10,
|
98
|
+
) as websocket:
|
94
99
|
self.websocket = websocket
|
95
100
|
self.connected.set()
|
96
101
|
|
caption_flow/workers/caption.py
CHANGED
@@ -248,7 +248,13 @@ class CaptionWorker(BaseWorker):
|
|
248
248
|
async def _initial_connect_for_config(self):
|
249
249
|
"""Connect initially just to get configuration."""
|
250
250
|
logger.info(f"Connecting to {self.server_url}")
|
251
|
-
async with websockets.connect(
|
251
|
+
async with websockets.connect(
|
252
|
+
self.server_url,
|
253
|
+
ssl=self.ssl_context,
|
254
|
+
ping_interval=20,
|
255
|
+
ping_timeout=60,
|
256
|
+
close_timeout=10,
|
257
|
+
) as websocket:
|
252
258
|
await websocket.send(json.dumps(self._get_auth_data()))
|
253
259
|
|
254
260
|
welcome = await websocket.recv()
|
@@ -1,14 +1,14 @@
|
|
1
|
-
caption_flow/__init__.py,sha256=
|
2
|
-
caption_flow/cli.py,sha256=
|
1
|
+
caption_flow/__init__.py,sha256=2M1VLvkVjUmTHXuJFMLnZKqVYni5A0HJfxcnjz53K7c,303
|
2
|
+
caption_flow/cli.py,sha256=K3lML3WIYjD7OluGltHGP4N98S5w-KyhDUlQZudDQXE,41464
|
3
3
|
caption_flow/models.py,sha256=2n6iphTEL62xK2FFcJM6axMsaE8KwsUv5Ak_cCF-TdQ,5652
|
4
|
-
caption_flow/monitor.py,sha256=
|
5
|
-
caption_flow/orchestrator.py,sha256=
|
4
|
+
caption_flow/monitor.py,sha256=z2HakZSG799HvTJgjgG7u_MHvhq9-JL1LXzxBwP3WQc,7998
|
5
|
+
caption_flow/orchestrator.py,sha256=3XKZXFE1Aw1kCqb_Vw9loYpkmJ5LTLyZZf9pj4k6ldA,37175
|
6
6
|
caption_flow/viewer.py,sha256=HxO98eHR1xtivG0dEdYC2U9T_RgeRfJqqTK-37u9bNM,20471
|
7
7
|
caption_flow/processors/__init__.py,sha256=hvq-OuAJWQe6hFglKe7QmkS8473k20FmxZDSxfXpCrg,423
|
8
8
|
caption_flow/processors/base.py,sha256=IAEr0pqHRuSkXunvDWk1vf2IKeYQ-2YERqej9iSQm94,6931
|
9
|
-
caption_flow/processors/huggingface.py,sha256=
|
9
|
+
caption_flow/processors/huggingface.py,sha256=t_dklhmNRAyk2jISu4FqmNecjg9hfY47omOiRVkbhvA,41215
|
10
10
|
caption_flow/processors/local_filesystem.py,sha256=OuNNDemy0sdtpBBC_5GbI-c1vMqp8OIz983Cq85gdb8,27964
|
11
|
-
caption_flow/processors/webdataset.py,sha256=
|
11
|
+
caption_flow/processors/webdataset.py,sha256=tUBCUKunqooHibTWtQ1wljuRI55Wc6M1WrI2hOZgt7g,33858
|
12
12
|
caption_flow/storage/__init__.py,sha256=IVnzcSCPpPuyp-QLlgJirRZ9Sb3tR0F4sfuF5u2cNMk,36
|
13
13
|
caption_flow/storage/exporter.py,sha256=mFJqMDQ61cP-qcXe118_-oL1TUqULdQZ8LdjSTym44I,19697
|
14
14
|
caption_flow/storage/manager.py,sha256=KPExcKPuFVQSsBnfCBdne5PO4PwN4NTfd-EJQk13OY0,47459
|
@@ -16,18 +16,18 @@ caption_flow/utils/__init__.py,sha256=bDcO5uR455TKCQ2hX-_XcdTnRXDBaT8Yn4jWqWzfFs
|
|
16
16
|
caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
|
17
17
|
caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
|
18
18
|
caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
|
19
|
-
caption_flow/utils/checkpoint_tracker.py,sha256
|
20
|
-
caption_flow/utils/chunk_tracker.py,sha256=
|
19
|
+
caption_flow/utils/checkpoint_tracker.py,sha256=nOZIIGsXTRUj09tFSnWtRgj_zoa8Og_-rutkr2GFz8Y,4417
|
20
|
+
caption_flow/utils/chunk_tracker.py,sha256=JZIFvaHS5AYaVOzsSJKrnNlS4E3BdzV64cRkQa_65g0,21508
|
21
21
|
caption_flow/utils/image_processor.py,sha256=wmOExkVfM7OeuLfX3AwMefsH-TxL8TNcn22gp0NmJKY,1541
|
22
22
|
caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
|
23
23
|
caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
|
24
24
|
caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
|
25
|
-
caption_flow/workers/base.py,sha256=
|
26
|
-
caption_flow/workers/caption.py,sha256=
|
25
|
+
caption_flow/workers/base.py,sha256=nEWohozFZ0Bw3_8U8xirnKLeZsGR5k69rSu4j-oDitc,7698
|
26
|
+
caption_flow/workers/caption.py,sha256=swE4pYg4ZYAAtMxvyvlETa3wv4yKWUPXXulCAwPhPiQ,39477
|
27
27
|
caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
|
28
|
-
caption_flow-0.3.
|
29
|
-
caption_flow-0.3.
|
30
|
-
caption_flow-0.3.
|
31
|
-
caption_flow-0.3.
|
32
|
-
caption_flow-0.3.
|
33
|
-
caption_flow-0.3.
|
28
|
+
caption_flow-0.3.4.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
29
|
+
caption_flow-0.3.4.dist-info/METADATA,sha256=dfB40EF_Zgz2Ux8qvdBbfLdhzY85_MUFRX-904I-qb4,9708
|
30
|
+
caption_flow-0.3.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
31
|
+
caption_flow-0.3.4.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
|
32
|
+
caption_flow-0.3.4.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
|
33
|
+
caption_flow-0.3.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|