caption-flow 0.3.2__tar.gz → 0.3.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {caption_flow-0.3.2/src/caption_flow.egg-info → caption_flow-0.3.4}/PKG-INFO +1 -1
- {caption_flow-0.3.2 → caption_flow-0.3.4}/pyproject.toml +1 -1
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/__init__.py +1 -1
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/cli.py +4 -2
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/monitor.py +3 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/orchestrator.py +55 -33
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/processors/huggingface.py +2 -2
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/processors/webdataset.py +156 -59
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/checkpoint_tracker.py +41 -21
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/chunk_tracker.py +85 -47
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/workers/base.py +7 -2
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/workers/caption.py +7 -1
- {caption_flow-0.3.2 → caption_flow-0.3.4/src/caption_flow.egg-info}/PKG-INFO +1 -1
- {caption_flow-0.3.2 → caption_flow-0.3.4}/LICENSE +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/README.md +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/setup.cfg +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/models.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/processors/__init__.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/processors/base.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/processors/local_filesystem.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/storage/__init__.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/storage/exporter.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/storage/manager.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/__init__.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/auth.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/caption_utils.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/certificates.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/image_processor.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/json_utils.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/prompt_template.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/utils/vllm_config.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/viewer.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow/workers/data.py +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow.egg-info/SOURCES.txt +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow.egg-info/dependency_links.txt +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow.egg-info/entry_points.txt +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow.egg-info/requires.txt +0 -0
- {caption_flow-0.3.2 → caption_flow-0.3.4}/src/caption_flow.egg-info/top_level.txt +0 -0
@@ -124,7 +124,7 @@ def setup_logging(verbose: bool = False):
|
|
124
124
|
level = logging.DEBUG if verbose else logging.INFO
|
125
125
|
logging.basicConfig(
|
126
126
|
level=level,
|
127
|
-
format="%(message)s",
|
127
|
+
format="%(name)s: %(message)s",
|
128
128
|
datefmt="[%Y-%m-%d %H:%M:%S]",
|
129
129
|
handlers=[
|
130
130
|
RichHandler(
|
@@ -490,7 +490,9 @@ def reload_config(
|
|
490
490
|
|
491
491
|
async def send_reload():
|
492
492
|
try:
|
493
|
-
async with websockets.connect(
|
493
|
+
async with websockets.connect(
|
494
|
+
server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
|
495
|
+
) as websocket:
|
494
496
|
# Authenticate as admin
|
495
497
|
await websocket.send(json.dumps({"token": token, "role": "admin"}))
|
496
498
|
|
@@ -73,6 +73,9 @@ class Monitor:
|
|
73
73
|
async with websockets.connect(
|
74
74
|
self.server_url,
|
75
75
|
ssl=self.ssl_context if self.server_url.startswith("wss://") else None,
|
76
|
+
ping_interval=20,
|
77
|
+
ping_timeout=60,
|
78
|
+
close_timeout=10,
|
76
79
|
) as websocket:
|
77
80
|
# Authenticate
|
78
81
|
await websocket.send(json.dumps({"token": self.token}))
|
@@ -124,13 +124,14 @@ class Orchestrator:
|
|
124
124
|
|
125
125
|
# Initialize storage
|
126
126
|
await self.storage.initialize()
|
127
|
-
await self.update_unprocessed_ranges()
|
128
127
|
|
129
128
|
# Start background tasks
|
130
129
|
asyncio.create_task(self._heartbeat_loop())
|
131
130
|
asyncio.create_task(self._checkpoint_loop())
|
132
131
|
asyncio.create_task(self._stats_update_loop())
|
133
132
|
|
133
|
+
await self.update_unprocessed_ranges()
|
134
|
+
|
134
135
|
# Start WebSocket server
|
135
136
|
websocket_logger = logging.getLogger("websockets")
|
136
137
|
websocket_logger.setLevel(logging.WARNING)
|
@@ -376,16 +377,17 @@ class Orchestrator:
|
|
376
377
|
"""Process results submission from worker."""
|
377
378
|
# Extract user from worker_id
|
378
379
|
worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
380
|
+
|
379
381
|
# Create work result
|
380
382
|
_job_id = data.get("job_id")
|
381
383
|
job_id = JobId.from_str(_job_id)
|
382
|
-
shard_name = job_id.shard_id
|
383
|
-
chunk_name = job_id.chunk_id
|
384
|
-
|
384
|
+
shard_name = job_id.shard_id
|
385
|
+
chunk_name = job_id.chunk_id
|
386
|
+
|
385
387
|
result = WorkResult(
|
386
388
|
unit_id=data["unit_id"],
|
387
389
|
source_id=shard_name,
|
388
|
-
chunk_id=job_id.get_chunk_str(),
|
390
|
+
chunk_id=job_id.get_chunk_str(),
|
389
391
|
sample_id=data["sample_id"],
|
390
392
|
dataset=data["dataset"],
|
391
393
|
outputs=data["outputs"],
|
@@ -393,7 +395,9 @@ class Orchestrator:
|
|
393
395
|
processing_time_ms=data.get("processing_time_ms", 0),
|
394
396
|
)
|
395
397
|
|
396
|
-
# Let processor handle any custom processing
|
398
|
+
# Let processor handle any custom processing - this updates chunk tracker
|
399
|
+
# IMPORTANT: Call this BEFORE saving to storage so chunk tracker is updated
|
400
|
+
# regardless of whether the item is a duplicate
|
397
401
|
processed = self.processor.handle_result(result)
|
398
402
|
|
399
403
|
# Create caption record for storage
|
@@ -411,6 +415,7 @@ class Orchestrator:
|
|
411
415
|
for key in to_delete_metadata_keys:
|
412
416
|
if key in result.metadata:
|
413
417
|
del result.metadata[key]
|
418
|
+
|
414
419
|
caption = Caption(
|
415
420
|
job_id=job_id,
|
416
421
|
dataset=result.dataset,
|
@@ -432,14 +437,15 @@ class Orchestrator:
|
|
432
437
|
image_format=image_format,
|
433
438
|
)
|
434
439
|
|
435
|
-
# Save to storage
|
436
|
-
await self.storage.save_caption(caption)
|
440
|
+
# Save to storage (might skip if duplicate)
|
441
|
+
saved = await self.storage.save_caption(caption)
|
437
442
|
|
438
|
-
# Update contributor stats
|
439
|
-
|
440
|
-
|
441
|
-
contributor
|
442
|
-
|
443
|
+
# Update contributor stats only if actually saved
|
444
|
+
if saved:
|
445
|
+
contributor = await self.storage.get_contributor(worker_user)
|
446
|
+
if contributor:
|
447
|
+
contributor.total_captions += total_outputs
|
448
|
+
await self.storage.save_contributor(contributor)
|
443
449
|
|
444
450
|
async def _handle_monitor(self, websocket: WebSocketServerProtocol):
|
445
451
|
"""Handle monitor connection."""
|
@@ -839,39 +845,55 @@ class Orchestrator:
|
|
839
845
|
self.monitors -= disconnected
|
840
846
|
|
841
847
|
async def _heartbeat_loop(self):
|
842
|
-
"""
|
848
|
+
"""Collect and log worker status periodically."""
|
843
849
|
while True:
|
844
850
|
await asyncio.sleep(30)
|
845
851
|
|
846
|
-
|
852
|
+
# Just collect status - no ping/pong
|
853
|
+
active_workers = []
|
847
854
|
for worker_id, ws in list(self.workers.items()):
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
# Clean up disconnected workers
|
855
|
-
for worker_id in disconnected:
|
856
|
-
logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
|
857
|
-
if worker_id in self.workers:
|
855
|
+
# Check if WebSocket is still open (don't ping)
|
856
|
+
if ws.state == websockets.protocol.State.OPEN:
|
857
|
+
active_workers.append(worker_id)
|
858
|
+
else:
|
859
|
+
# Clean up closed connections
|
860
|
+
logger.info(f"Worker {worker_id} connection closed")
|
858
861
|
del self.workers[worker_id]
|
859
|
-
logger.warning(
|
860
|
-
f"Releasing assignments for worker {worker_id} because it did not respond to ping"
|
861
|
-
)
|
862
862
|
self.processor.release_assignments(worker_id)
|
863
|
-
|
863
|
+
|
864
|
+
# Log status
|
865
|
+
if active_workers:
|
866
|
+
logger.debug(
|
867
|
+
f"Active workers: {len(active_workers)} - {', '.join(active_workers[:5])}"
|
868
|
+
)
|
869
|
+
logger.debug(f"Inactive workers: {len(self.workers) - len(active_workers)}")
|
870
|
+
# add to self.stats
|
871
|
+
self.stats["active_workers"] = len(active_workers)
|
872
|
+
self.stats["inactive_workers"] = len(self.workers) - len(active_workers)
|
864
873
|
|
865
874
|
async def _checkpoint_loop(self):
|
866
|
-
"""Periodically checkpoint storage."""
|
875
|
+
"""Periodically checkpoint storage and chunk tracker."""
|
867
876
|
interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
|
868
877
|
|
869
878
|
while True:
|
870
879
|
await asyncio.sleep(interval)
|
871
880
|
|
872
|
-
|
873
|
-
|
874
|
-
|
881
|
+
try:
|
882
|
+
# Checkpoint storage
|
883
|
+
await self.storage.checkpoint()
|
884
|
+
|
885
|
+
# Also checkpoint the chunk tracker if using webdataset processor
|
886
|
+
if hasattr(self.processor, "chunk_tracker") and self.processor.chunk_tracker:
|
887
|
+
# Save checkpoint in thread pool to avoid blocking
|
888
|
+
await asyncio.get_event_loop().run_in_executor(
|
889
|
+
None, self.processor.chunk_tracker.save
|
890
|
+
)
|
891
|
+
logger.debug("Saved chunk tracker checkpoint")
|
892
|
+
|
893
|
+
self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
|
894
|
+
logger.info("Storage and chunk tracker checkpoint complete")
|
895
|
+
except Exception as e:
|
896
|
+
logger.error(f"Error during checkpoint: {e}", exc_info=True)
|
875
897
|
|
876
898
|
async def _stats_update_loop(self):
|
877
899
|
"""Periodically update and broadcast stats."""
|
@@ -551,7 +551,7 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
551
551
|
|
552
552
|
# Force checkpoint save if needed
|
553
553
|
if self.chunk_tracker:
|
554
|
-
self.chunk_tracker.
|
554
|
+
self.chunk_tracker.save()
|
555
555
|
|
556
556
|
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
557
557
|
"""Get available work units for a worker."""
|
@@ -717,7 +717,7 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
717
717
|
|
718
718
|
# Save final state
|
719
719
|
if self.chunk_tracker:
|
720
|
-
self.chunk_tracker.
|
720
|
+
self.chunk_tracker.save()
|
721
721
|
|
722
722
|
|
723
723
|
class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
@@ -110,58 +110,86 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
110
110
|
return self.shard_info_cache[shard_idx]
|
111
111
|
|
112
112
|
def _restore_state(self, storage: StorageManager) -> None:
|
113
|
-
"""Restore state from chunk tracker."""
|
114
|
-
logger.
|
113
|
+
"""Restore state from chunk tracker and synchronize with storage."""
|
114
|
+
logger.info("Restoring state from chunk tracker and synchronizing with storage")
|
115
115
|
if not self.chunk_tracker:
|
116
116
|
return
|
117
117
|
|
118
|
+
# First, update chunk tracker from storage
|
119
|
+
processed_job_ids = storage.get_all_processed_job_ids()
|
120
|
+
if processed_job_ids:
|
121
|
+
logger.info(
|
122
|
+
f"Synchronizing chunk tracker with {len(processed_job_ids)} processed items from storage"
|
123
|
+
)
|
124
|
+
self.update_from_storage(processed_job_ids)
|
125
|
+
|
126
|
+
# Then restore work units from chunk tracker
|
118
127
|
shards_summary = self.chunk_tracker.get_shards_summary()
|
119
|
-
logger.
|
128
|
+
logger.info(f"Restoring work units from chunk tracker: {len(shards_summary)} shards")
|
120
129
|
|
121
130
|
with self.lock:
|
131
|
+
restored_count = 0
|
122
132
|
for shard_name, shard_info in shards_summary.items():
|
123
133
|
chunks = shard_info.get("chunks", [])
|
124
|
-
logger.debug(f"Existing job ids: {storage.get_all_processed_job_ids()}")
|
125
134
|
for chunk_state in chunks:
|
126
135
|
# Only add incomplete chunks
|
127
|
-
if chunk_state.status
|
128
|
-
logger.debug(f"
|
136
|
+
if chunk_state.status == "completed":
|
137
|
+
logger.debug(f"Skipping completed chunk {chunk_state.chunk_id}")
|
138
|
+
continue
|
129
139
|
|
130
|
-
|
131
|
-
|
140
|
+
# Get unprocessed ranges
|
141
|
+
unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
142
|
+
if not unprocessed_ranges:
|
132
143
|
logger.debug(
|
133
|
-
f"Chunk {chunk_state.chunk_id} unprocessed ranges
|
144
|
+
f"Chunk {chunk_state.chunk_id} has no unprocessed ranges, marking as completed"
|
134
145
|
)
|
135
|
-
|
136
|
-
|
146
|
+
self.chunk_tracker.mark_completed(chunk_state.chunk_id)
|
147
|
+
continue
|
137
148
|
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
abs_start = chunk_state.start_index + start
|
142
|
-
abs_end = chunk_state.start_index + end
|
143
|
-
absolute_ranges.append((abs_start, abs_end))
|
149
|
+
logger.info(
|
150
|
+
f"Restoring chunk {chunk_state.chunk_id} with unprocessed ranges: {unprocessed_ranges}"
|
151
|
+
)
|
144
152
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
},
|
161
|
-
)
|
153
|
+
# Convert relative ranges to absolute file indices
|
154
|
+
absolute_ranges = []
|
155
|
+
for start, end in unprocessed_ranges:
|
156
|
+
abs_start = chunk_state.start_index + start
|
157
|
+
abs_end = chunk_state.start_index + end
|
158
|
+
absolute_ranges.append((abs_start, abs_end))
|
159
|
+
|
160
|
+
# Get shard index if available
|
161
|
+
shard_idx = None
|
162
|
+
if self.dataset:
|
163
|
+
for idx in range(self.dataset.num_shards):
|
164
|
+
shard_info = self._get_shard_info_cached(idx)
|
165
|
+
if shard_info and shard_info["name"] == shard_name:
|
166
|
+
shard_idx = idx
|
167
|
+
break
|
162
168
|
|
163
|
-
|
164
|
-
|
169
|
+
unit = WorkUnit(
|
170
|
+
unit_id=chunk_state.chunk_id,
|
171
|
+
chunk_id=chunk_state.chunk_id,
|
172
|
+
source_id=shard_name,
|
173
|
+
unit_size=chunk_state.chunk_size,
|
174
|
+
data={
|
175
|
+
"shard_url": chunk_state.shard_url,
|
176
|
+
"shard_name": shard_name,
|
177
|
+
"shard_idx": shard_idx,
|
178
|
+
"start_index": chunk_state.start_index,
|
179
|
+
"chunk_size": chunk_state.chunk_size,
|
180
|
+
"unprocessed_ranges": absolute_ranges,
|
181
|
+
},
|
182
|
+
metadata={
|
183
|
+
"shard_name": shard_name,
|
184
|
+
"chunk_index": chunk_state.start_index // self.chunk_size,
|
185
|
+
},
|
186
|
+
)
|
187
|
+
|
188
|
+
self.work_units[unit.unit_id] = unit
|
189
|
+
self.pending_units.append(unit.unit_id)
|
190
|
+
restored_count += 1
|
191
|
+
|
192
|
+
logger.info(f"Restored {restored_count} incomplete work units")
|
165
193
|
|
166
194
|
def _create_units_background(self) -> None:
|
167
195
|
"""Background thread to create work units on demand."""
|
@@ -278,8 +306,15 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
278
306
|
assigned = []
|
279
307
|
|
280
308
|
with self.lock:
|
281
|
-
|
309
|
+
units_checked = 0
|
310
|
+
max_units_to_check = len(self.pending_units)
|
311
|
+
|
312
|
+
while len(assigned) < count and units_checked < max_units_to_check:
|
313
|
+
if not self.pending_units:
|
314
|
+
break
|
315
|
+
|
282
316
|
unit_id = self.pending_units.popleft()
|
317
|
+
units_checked += 1
|
283
318
|
unit = self.work_units.get(unit_id)
|
284
319
|
|
285
320
|
if unit:
|
@@ -288,6 +323,16 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
288
323
|
chunk_state = self.chunk_tracker.chunks[unit_id]
|
289
324
|
relative_unprocessed = chunk_state.get_unprocessed_ranges()
|
290
325
|
|
326
|
+
# If no unprocessed ranges, mark as completed and skip
|
327
|
+
if not relative_unprocessed:
|
328
|
+
logger.info(
|
329
|
+
f"Chunk {unit_id} has no unprocessed ranges, marking as completed"
|
330
|
+
)
|
331
|
+
self.chunk_tracker.mark_completed(unit_id)
|
332
|
+
# Remove from work units
|
333
|
+
del self.work_units[unit_id]
|
334
|
+
continue
|
335
|
+
|
291
336
|
# Convert relative to absolute indices
|
292
337
|
absolute_ranges = []
|
293
338
|
for start, end in relative_unprocessed:
|
@@ -307,6 +352,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
307
352
|
|
308
353
|
if self.chunk_tracker:
|
309
354
|
self.chunk_tracker.mark_assigned(unit_id, worker_id)
|
355
|
+
else:
|
356
|
+
# Put it back if we couldn't get the unit
|
357
|
+
self.pending_units.append(unit_id)
|
310
358
|
|
311
359
|
logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
|
312
360
|
return assigned
|
@@ -366,8 +414,20 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
366
414
|
logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
|
367
415
|
|
368
416
|
def handle_result(self, result: WorkResult) -> Dict[str, Any]:
|
369
|
-
"""Handle result from worker."""
|
370
|
-
#
|
417
|
+
"""Handle result from worker and update chunk tracker."""
|
418
|
+
# Extract the actual item index from the metadata
|
419
|
+
item_index = result.metadata.get("_item_index", None)
|
420
|
+
|
421
|
+
# If we have an item index, mark it as processed in the chunk tracker
|
422
|
+
if self.chunk_tracker and item_index is not None and result.chunk_id:
|
423
|
+
try:
|
424
|
+
# Mark single item as processed
|
425
|
+
self.chunk_tracker.mark_items_processed(result.chunk_id, item_index, item_index)
|
426
|
+
# logger.debug(f"Marked item {item_index} as processed in chunk {result.chunk_id}")
|
427
|
+
except Exception as e:
|
428
|
+
logger.error(f"Error marking item {item_index} as processed: {e}")
|
429
|
+
|
430
|
+
# Also handle batch results if present (backward compatibility)
|
371
431
|
if self.chunk_tracker and "item_indices" in result.metadata:
|
372
432
|
indices = result.metadata["item_indices"]
|
373
433
|
|
@@ -391,6 +451,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
391
451
|
# Mark ranges as processed
|
392
452
|
for start_idx, end_idx in ranges:
|
393
453
|
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
454
|
+
logger.debug(
|
455
|
+
f"Marked range {start_idx}-{end_idx} as processed in chunk {result.chunk_id}"
|
456
|
+
)
|
394
457
|
|
395
458
|
return {
|
396
459
|
"source_id": result.source_id,
|
@@ -407,22 +470,46 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
407
470
|
# Group by chunk
|
408
471
|
processed_by_chunk = defaultdict(set)
|
409
472
|
|
410
|
-
for
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
continue
|
473
|
+
for job_id_str in processed_job_ids:
|
474
|
+
try:
|
475
|
+
# Use JobId to parse the job ID string
|
476
|
+
job_id = JobId.from_str(job_id_str)
|
477
|
+
chunk_id = job_id.get_chunk_str()
|
478
|
+
sample_idx = int(job_id.sample_id)
|
479
|
+
processed_by_chunk[chunk_id].add(sample_idx)
|
480
|
+
except ValueError as e:
|
481
|
+
logger.warning(f"Invalid job ID format: {job_id_str} - {e}")
|
482
|
+
continue
|
421
483
|
|
422
484
|
# Update chunk tracker with processed items
|
423
485
|
if self.chunk_tracker:
|
424
486
|
for chunk_id, indices in processed_by_chunk.items():
|
425
487
|
if indices:
|
488
|
+
# Get or create chunk state
|
489
|
+
chunk_state = self.chunk_tracker.chunks.get(chunk_id)
|
490
|
+
if not chunk_state:
|
491
|
+
# Parse chunk_id using JobId to get shard info
|
492
|
+
try:
|
493
|
+
# chunk_id format: "shard_id:chunk:chunk_idx"
|
494
|
+
parts = chunk_id.split(":")
|
495
|
+
if len(parts) >= 3:
|
496
|
+
shard_name = parts[0]
|
497
|
+
chunk_idx = int(parts[2])
|
498
|
+
# Infer start index from chunk index and size
|
499
|
+
start_index = chunk_idx * self.chunk_size
|
500
|
+
# Create chunk state
|
501
|
+
self.chunk_tracker.add_chunk(
|
502
|
+
chunk_id,
|
503
|
+
shard_name,
|
504
|
+
f"{shard_name}.tar",
|
505
|
+
start_index,
|
506
|
+
self.chunk_size,
|
507
|
+
)
|
508
|
+
logger.info(f"Created missing chunk state for {chunk_id}")
|
509
|
+
except (ValueError, IndexError) as e:
|
510
|
+
logger.error(f"Failed to create chunk state for {chunk_id}: {e}")
|
511
|
+
continue
|
512
|
+
|
426
513
|
# Sort indices and convert to ranges
|
427
514
|
sorted_indices = sorted(indices)
|
428
515
|
if not sorted_indices:
|
@@ -443,10 +530,13 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
443
530
|
ranges.append((start_range, end_range))
|
444
531
|
|
445
532
|
# Mark each contiguous range as processed
|
446
|
-
logger.
|
533
|
+
logger.info(f"Marking ranges {ranges} as processed in chunk {chunk_id}")
|
447
534
|
for start_idx, end_idx in ranges:
|
448
535
|
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
449
536
|
|
537
|
+
# Save checkpoint after updating
|
538
|
+
self.chunk_tracker.save()
|
539
|
+
|
450
540
|
def get_stats(self) -> Dict[str, Any]:
|
451
541
|
"""Get processor statistics."""
|
452
542
|
with self.lock:
|
@@ -484,7 +574,7 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
484
574
|
|
485
575
|
# Save checkpoint
|
486
576
|
if self.chunk_tracker:
|
487
|
-
self.chunk_tracker.
|
577
|
+
self.chunk_tracker.save()
|
488
578
|
|
489
579
|
|
490
580
|
class WebDatasetWorkerProcessor(WorkerProcessor):
|
@@ -555,7 +645,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
555
645
|
# Generate mock results for unprocessed ranges
|
556
646
|
for start_idx, end_idx in unprocessed_ranges:
|
557
647
|
for idx in range(start_idx, end_idx + 1):
|
558
|
-
|
648
|
+
# Use JobId to create consistent job ID
|
649
|
+
job_id = JobId.from_values(
|
650
|
+
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(idx)
|
651
|
+
)
|
652
|
+
job_id_str = job_id.get_sample_str()
|
559
653
|
|
560
654
|
yield {
|
561
655
|
"image": self._create_mock_image(idx),
|
@@ -565,11 +659,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
565
659
|
"metadata": {
|
566
660
|
"_item_index": idx,
|
567
661
|
"_chunk_relative_index": idx - unit.data["start_index"],
|
568
|
-
"_job_id":
|
662
|
+
"_job_id": job_id_str,
|
569
663
|
"_mock": True,
|
570
664
|
"_processed_indices": processed_indices,
|
571
665
|
},
|
572
|
-
"job_id":
|
666
|
+
"job_id": job_id_str,
|
573
667
|
}
|
574
668
|
|
575
669
|
processed_indices.append(idx)
|
@@ -614,8 +708,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
614
708
|
f"Error decoding image {entry.path} with cv2: {img_e}"
|
615
709
|
)
|
616
710
|
|
617
|
-
# Generate job ID
|
618
|
-
job_id =
|
711
|
+
# Generate job ID using JobId class
|
712
|
+
job_id = JobId.from_values(
|
713
|
+
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(idx)
|
714
|
+
)
|
715
|
+
job_id_str = job_id.get_sample_str()
|
619
716
|
|
620
717
|
yield {
|
621
718
|
"image": image,
|
@@ -625,12 +722,12 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
625
722
|
"metadata": {
|
626
723
|
"_item_index": idx,
|
627
724
|
"_chunk_relative_index": idx - unit.data["start_index"],
|
628
|
-
"_job_id":
|
725
|
+
"_job_id": job_id_str,
|
629
726
|
"_filename": entry.path,
|
630
727
|
"_file_size": entry.size,
|
631
728
|
"_processed_indices": processed_indices,
|
632
729
|
},
|
633
|
-
"job_id":
|
730
|
+
"job_id": job_id_str,
|
634
731
|
}
|
635
732
|
|
636
733
|
processed_indices.append(idx)
|
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import Dict, Any, Optional
|
8
8
|
from datetime import datetime
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
10
|
|
10
11
|
logger = logging.getLogger(__name__)
|
11
12
|
|
@@ -52,35 +53,54 @@ class CheckpointTracker(ABC):
|
|
52
53
|
|
53
54
|
def save(self) -> None:
|
54
55
|
"""Save checkpoint to disk atomically."""
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
56
|
+
with self.lock:
|
57
|
+
try:
|
58
|
+
# Prepare data with metadata
|
59
|
+
data = self._serialize_state()
|
60
|
+
data["updated_at"] = datetime.utcnow().isoformat()
|
61
|
+
|
62
|
+
# Write atomically using temp file
|
63
|
+
tmp_file = self.checkpoint_path.with_suffix(".tmp")
|
64
|
+
# If a save is already in progress, let it finish.
|
65
|
+
# This prevents race conditions if save() is called rapidly.
|
66
|
+
if (
|
67
|
+
hasattr(self, "_save_future")
|
68
|
+
and self._save_future
|
69
|
+
and not self._save_future.done()
|
70
|
+
):
|
71
|
+
self._save_future.result() # Wait for the previous save to complete
|
72
|
+
|
73
|
+
# Use an executor to run the save operation in a background thread.
|
74
|
+
# This makes the save call non-blocking.
|
75
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
76
|
+
data_to_save = data.copy()
|
77
|
+
self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
|
78
|
+
except Exception as e:
|
79
|
+
logger.error(f"Failed to submit save task: {e}", exc_info=True)
|
59
80
|
|
60
|
-
|
61
|
-
|
81
|
+
def _write_to_disk(self, data: Dict[str, Any]) -> None:
|
82
|
+
"""Write checkpoint data to disk atomically."""
|
83
|
+
# Create a temporary file in the same directory as the checkpoint
|
84
|
+
tmp_file = self.checkpoint_path.with_suffix(".tmp")
|
85
|
+
|
86
|
+
try:
|
87
|
+
# Ensure the parent directory exists
|
88
|
+
self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
62
89
|
|
63
90
|
with open(tmp_file, "w") as f:
|
64
91
|
json.dump(data, f, indent=2)
|
65
92
|
|
66
|
-
#
|
67
|
-
if not tmp_file.exists():
|
68
|
-
raise IOError(f"Failed to create temporary file: {tmp_file}")
|
69
|
-
|
70
|
-
# Move atomically
|
93
|
+
# Atomically replace the checkpoint file
|
71
94
|
tmp_file.replace(self.checkpoint_path)
|
72
|
-
|
73
95
|
logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
|
74
|
-
|
75
96
|
except Exception as e:
|
76
|
-
|
77
|
-
# Try
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
logger.error(f"Fallback save also failed: {fallback_error}")
|
97
|
+
logger.error(f"Failed to save checkpoint atomically: {e}", exc_info=True)
|
98
|
+
# Try to clean up the temp file if it exists
|
99
|
+
if tmp_file.exists():
|
100
|
+
try:
|
101
|
+
tmp_file.unlink()
|
102
|
+
except:
|
103
|
+
pass
|
84
104
|
|
85
105
|
def get_stats(self) -> Dict[str, Any]:
|
86
106
|
"""Get statistics about tracked items. Override for custom stats."""
|
@@ -8,6 +8,7 @@ from datetime import datetime, timedelta
|
|
8
8
|
from dataclasses import dataclass, asdict, field
|
9
9
|
|
10
10
|
from .checkpoint_tracker import CheckpointTracker
|
11
|
+
from threading import Lock
|
11
12
|
|
12
13
|
logger = logging.getLogger(__name__)
|
13
14
|
logger.setLevel(logging.DEBUG)
|
@@ -60,12 +61,12 @@ class ChunkState:
|
|
60
61
|
self.status = "completed"
|
61
62
|
self.completed_at = datetime.utcnow()
|
62
63
|
# Clear processed_ranges since we don't need them after completion
|
63
|
-
self.processed_ranges = []
|
64
|
-
self.assigned_to = None
|
65
|
-
self.assigned_at = None
|
64
|
+
# self.processed_ranges = []
|
65
|
+
# self.assigned_to = None
|
66
|
+
# self.assigned_at = None
|
66
67
|
|
67
68
|
def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
|
68
|
-
"""Get ranges
|
69
|
+
"""Get ranges of unprocessed items within the chunk (relative indices)."""
|
69
70
|
if self.status == "completed":
|
70
71
|
return []
|
71
72
|
|
@@ -73,22 +74,57 @@ class ChunkState:
|
|
73
74
|
logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
|
74
75
|
return [(0, self.chunk_size - 1)]
|
75
76
|
|
77
|
+
# Merge ranges first to ensure no overlaps
|
78
|
+
merged_ranges = self._merge_ranges(self.processed_ranges)
|
79
|
+
|
76
80
|
unprocessed = []
|
77
|
-
|
81
|
+
current_pos = 0
|
78
82
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
if current < start:
|
84
|
-
unprocessed.append((current, start - 1))
|
85
|
-
current = max(current, end + 1)
|
83
|
+
for start, end in merged_ranges:
|
84
|
+
if current_pos < start:
|
85
|
+
unprocessed.append((current_pos, start - 1))
|
86
|
+
current_pos = max(current_pos, end + 1)
|
86
87
|
|
87
|
-
|
88
|
-
|
88
|
+
# Add any remaining range
|
89
|
+
if current_pos < self.chunk_size:
|
90
|
+
unprocessed.append((current_pos, self.chunk_size - 1))
|
91
|
+
|
92
|
+
# Log for debugging
|
93
|
+
if not unprocessed:
|
94
|
+
logger.info(
|
95
|
+
f"Chunk {self.chunk_id} has processed ranges {merged_ranges} covering entire chunk size {self.chunk_size}"
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
total_processed = sum(end - start + 1 for start, end in merged_ranges)
|
99
|
+
total_unprocessed = sum(end - start + 1 for start, end in unprocessed)
|
100
|
+
logger.debug(
|
101
|
+
f"Chunk {self.chunk_id}: {total_processed} processed, {total_unprocessed} unprocessed"
|
102
|
+
)
|
89
103
|
|
90
104
|
return unprocessed
|
91
105
|
|
106
|
+
def _merge_ranges(self, ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
|
107
|
+
"""Merge overlapping or adjacent ranges."""
|
108
|
+
if not ranges:
|
109
|
+
return []
|
110
|
+
|
111
|
+
# Sort ranges by start index, ensuring all are tuples
|
112
|
+
sorted_ranges = sorted([tuple(r) for r in ranges])
|
113
|
+
merged = [sorted_ranges[0]]
|
114
|
+
|
115
|
+
for current_start, current_end in sorted_ranges[1:]:
|
116
|
+
last_start, last_end = merged[-1]
|
117
|
+
|
118
|
+
# Check if ranges overlap or are adjacent
|
119
|
+
if current_start <= last_end + 1:
|
120
|
+
# Merge the ranges
|
121
|
+
merged[-1] = (last_start, max(last_end, current_end))
|
122
|
+
else:
|
123
|
+
# Add as new range
|
124
|
+
merged.append((current_start, current_end))
|
125
|
+
|
126
|
+
return merged
|
127
|
+
|
92
128
|
def to_dict(self):
|
93
129
|
"""Convert to dictionary for JSON serialization."""
|
94
130
|
d = asdict(self)
|
@@ -124,6 +160,7 @@ class ChunkTracker(CheckpointTracker):
|
|
124
160
|
self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
|
125
161
|
self.archive_after_hours = archive_after_hours
|
126
162
|
self._completed_count = 0 # Track count without storing all IDs
|
163
|
+
self.lock = Lock()
|
127
164
|
super().__init__(checkpoint_file)
|
128
165
|
|
129
166
|
def _get_default_state(self) -> Dict[str, Any]:
|
@@ -132,16 +169,17 @@ class ChunkTracker(CheckpointTracker):
|
|
132
169
|
|
133
170
|
def _deserialize_state(self, data: Dict[str, Any]) -> None:
|
134
171
|
"""Deserialize loaded data into instance state."""
|
135
|
-
self.
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
172
|
+
with self.lock:
|
173
|
+
self.chunks = {}
|
174
|
+
self._completed_count = data.get("completed_count", 0)
|
175
|
+
|
176
|
+
# Load chunk states
|
177
|
+
completed_chunks = 0
|
178
|
+
for chunk_id, chunk_data in data.get("chunks", {}).items():
|
179
|
+
chunk_state = ChunkState.from_dict(chunk_data)
|
180
|
+
self.chunks[chunk_id] = chunk_state
|
181
|
+
if chunk_state.status == "completed":
|
182
|
+
completed_chunks += 1
|
145
183
|
|
146
184
|
logger.info(
|
147
185
|
f"Loaded {len(self.chunks)} chunks from checkpoint, "
|
@@ -494,40 +532,40 @@ class ChunkTracker(CheckpointTracker):
|
|
494
532
|
for start_idx, end_idx in ranges:
|
495
533
|
chunk.add_processed_range(start_idx, end_idx)
|
496
534
|
|
497
|
-
def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int):
|
498
|
-
"""Mark a range of items as processed within a chunk
|
535
|
+
def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int) -> None:
|
536
|
+
"""Mark a range of items as processed within a chunk."""
|
499
537
|
if chunk_id not in self.chunks:
|
500
|
-
logger.
|
538
|
+
logger.warning(f"Chunk {chunk_id} not found in tracker")
|
501
539
|
return
|
502
540
|
|
503
|
-
|
541
|
+
chunk_state = self.chunks[chunk_id]
|
504
542
|
|
505
|
-
# Convert absolute indices to chunk-relative
|
506
|
-
relative_start = start_idx -
|
507
|
-
relative_end = end_idx -
|
543
|
+
# Convert absolute indices to chunk-relative indices
|
544
|
+
relative_start = start_idx - chunk_state.start_index
|
545
|
+
relative_end = end_idx - chunk_state.start_index
|
508
546
|
|
509
|
-
#
|
510
|
-
|
511
|
-
|
512
|
-
f"Invalid indices for chunk {chunk_id}: "
|
513
|
-
f"absolute {start_idx}-{end_idx} (relative {relative_start}-{relative_end}) "
|
514
|
-
f"outside chunk bounds [{chunk.start_index}, {chunk.start_index + chunk.chunk_size - 1}]"
|
515
|
-
)
|
516
|
-
return
|
547
|
+
# Ensure indices are within chunk bounds
|
548
|
+
relative_start = max(0, relative_start)
|
549
|
+
relative_end = min(chunk_state.chunk_size - 1, relative_end)
|
517
550
|
|
518
|
-
# Add
|
519
|
-
|
551
|
+
# Add to processed ranges
|
552
|
+
chunk_state.processed_ranges.append((relative_start, relative_end))
|
520
553
|
|
521
|
-
#
|
522
|
-
|
523
|
-
self._completed_count += 1
|
554
|
+
# Merge overlapping ranges
|
555
|
+
chunk_state.processed_ranges = chunk_state._merge_ranges(chunk_state.processed_ranges)
|
524
556
|
|
525
|
-
self.save()
|
526
557
|
logger.debug(
|
527
|
-
f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} "
|
528
|
-
f"(relative indices: {relative_start}-{relative_end})"
|
558
|
+
f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} (relative indices: {relative_start}-{relative_end})"
|
529
559
|
)
|
530
560
|
|
561
|
+
# Check if chunk is now complete
|
562
|
+
if chunk_state.get_unprocessed_ranges() == []:
|
563
|
+
logger.info(f"Chunk {chunk_id} is now complete")
|
564
|
+
chunk_state.status = "completed"
|
565
|
+
|
566
|
+
# Save checkpoint after updating
|
567
|
+
self.save()
|
568
|
+
|
531
569
|
def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
|
532
570
|
"""Get chunk info with unprocessed item ranges."""
|
533
571
|
chunk_state = self.chunks.get(chunk_id)
|
@@ -89,8 +89,13 @@ class BaseWorker(ABC):
|
|
89
89
|
async def _connect_and_run(self):
|
90
90
|
"""Connect to orchestrator and run main loop."""
|
91
91
|
logger.info(f"Connecting to {self.server_url}")
|
92
|
-
|
93
|
-
|
92
|
+
async with websockets.connect(
|
93
|
+
self.server_url,
|
94
|
+
ssl=self.ssl_context,
|
95
|
+
ping_interval=20,
|
96
|
+
ping_timeout=60,
|
97
|
+
close_timeout=10,
|
98
|
+
) as websocket:
|
94
99
|
self.websocket = websocket
|
95
100
|
self.connected.set()
|
96
101
|
|
@@ -248,7 +248,13 @@ class CaptionWorker(BaseWorker):
|
|
248
248
|
async def _initial_connect_for_config(self):
|
249
249
|
"""Connect initially just to get configuration."""
|
250
250
|
logger.info(f"Connecting to {self.server_url}")
|
251
|
-
async with websockets.connect(
|
251
|
+
async with websockets.connect(
|
252
|
+
self.server_url,
|
253
|
+
ssl=self.ssl_context,
|
254
|
+
ping_interval=20,
|
255
|
+
ping_timeout=60,
|
256
|
+
close_timeout=10,
|
257
|
+
) as websocket:
|
252
258
|
await websocket.send(json.dumps(self._get_auth_data()))
|
253
259
|
|
254
260
|
welcome = await websocket.recv()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|