caption-flow 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caption_flow/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """CaptionFlow - Distributed community captioning system."""
2
2
 
3
- __version__ = "0.3.2"
3
+ __version__ = "0.3.4"
4
4
 
5
5
  from .orchestrator import Orchestrator
6
6
  from .workers.data import DataWorker
caption_flow/cli.py CHANGED
@@ -124,7 +124,7 @@ def setup_logging(verbose: bool = False):
124
124
  level = logging.DEBUG if verbose else logging.INFO
125
125
  logging.basicConfig(
126
126
  level=level,
127
- format="%(message)s",
127
+ format="%(name)s: %(message)s",
128
128
  datefmt="[%Y-%m-%d %H:%M:%S]",
129
129
  handlers=[
130
130
  RichHandler(
@@ -490,7 +490,9 @@ def reload_config(
490
490
 
491
491
  async def send_reload():
492
492
  try:
493
- async with websockets.connect(server, ssl=ssl_context) as websocket:
493
+ async with websockets.connect(
494
+ server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
495
+ ) as websocket:
494
496
  # Authenticate as admin
495
497
  await websocket.send(json.dumps({"token": token, "role": "admin"}))
496
498
 
caption_flow/monitor.py CHANGED
@@ -73,6 +73,9 @@ class Monitor:
73
73
  async with websockets.connect(
74
74
  self.server_url,
75
75
  ssl=self.ssl_context if self.server_url.startswith("wss://") else None,
76
+ ping_interval=20,
77
+ ping_timeout=60,
78
+ close_timeout=10,
76
79
  ) as websocket:
77
80
  # Authenticate
78
81
  await websocket.send(json.dumps({"token": self.token}))
@@ -124,13 +124,14 @@ class Orchestrator:
124
124
 
125
125
  # Initialize storage
126
126
  await self.storage.initialize()
127
- await self.update_unprocessed_ranges()
128
127
 
129
128
  # Start background tasks
130
129
  asyncio.create_task(self._heartbeat_loop())
131
130
  asyncio.create_task(self._checkpoint_loop())
132
131
  asyncio.create_task(self._stats_update_loop())
133
132
 
133
+ await self.update_unprocessed_ranges()
134
+
134
135
  # Start WebSocket server
135
136
  websocket_logger = logging.getLogger("websockets")
136
137
  websocket_logger.setLevel(logging.WARNING)
@@ -376,16 +377,17 @@ class Orchestrator:
376
377
  """Process results submission from worker."""
377
378
  # Extract user from worker_id
378
379
  worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
380
+
379
381
  # Create work result
380
382
  _job_id = data.get("job_id")
381
383
  job_id = JobId.from_str(_job_id)
382
- shard_name = job_id.shard_id # >data-0000<
383
- chunk_name = job_id.chunk_id # data-0000:chunk:>0<
384
- # logger.debug(f"({job_id}) Worker result: {data}")
384
+ shard_name = job_id.shard_id
385
+ chunk_name = job_id.chunk_id
386
+
385
387
  result = WorkResult(
386
388
  unit_id=data["unit_id"],
387
389
  source_id=shard_name,
388
- chunk_id=job_id.get_chunk_str(), # we want the full string here
390
+ chunk_id=job_id.get_chunk_str(),
389
391
  sample_id=data["sample_id"],
390
392
  dataset=data["dataset"],
391
393
  outputs=data["outputs"],
@@ -393,7 +395,9 @@ class Orchestrator:
393
395
  processing_time_ms=data.get("processing_time_ms", 0),
394
396
  )
395
397
 
396
- # Let processor handle any custom processing
398
+ # Let processor handle any custom processing - this updates chunk tracker
399
+ # IMPORTANT: Call this BEFORE saving to storage so chunk tracker is updated
400
+ # regardless of whether the item is a duplicate
397
401
  processed = self.processor.handle_result(result)
398
402
 
399
403
  # Create caption record for storage
@@ -411,6 +415,7 @@ class Orchestrator:
411
415
  for key in to_delete_metadata_keys:
412
416
  if key in result.metadata:
413
417
  del result.metadata[key]
418
+
414
419
  caption = Caption(
415
420
  job_id=job_id,
416
421
  dataset=result.dataset,
@@ -432,14 +437,15 @@ class Orchestrator:
432
437
  image_format=image_format,
433
438
  )
434
439
 
435
- # Save to storage
436
- await self.storage.save_caption(caption)
440
+ # Save to storage (might skip if duplicate)
441
+ saved = await self.storage.save_caption(caption)
437
442
 
438
- # Update contributor stats
439
- contributor = await self.storage.get_contributor(worker_user)
440
- if contributor:
441
- contributor.total_captions += total_outputs
442
- await self.storage.save_contributor(contributor)
443
+ # Update contributor stats only if actually saved
444
+ if saved:
445
+ contributor = await self.storage.get_contributor(worker_user)
446
+ if contributor:
447
+ contributor.total_captions += total_outputs
448
+ await self.storage.save_contributor(contributor)
443
449
 
444
450
  async def _handle_monitor(self, websocket: WebSocketServerProtocol):
445
451
  """Handle monitor connection."""
@@ -839,39 +845,55 @@ class Orchestrator:
839
845
  self.monitors -= disconnected
840
846
 
841
847
  async def _heartbeat_loop(self):
842
- """Send periodic heartbeats to maintain connections."""
848
+ """Collect and log worker status periodically."""
843
849
  while True:
844
850
  await asyncio.sleep(30)
845
851
 
846
- disconnected = []
852
+ # Just collect status - no ping/pong
853
+ active_workers = []
847
854
  for worker_id, ws in list(self.workers.items()):
848
- try:
849
- pong_waiter = await ws.ping()
850
- await asyncio.wait_for(pong_waiter, timeout=10)
851
- except:
852
- disconnected.append(worker_id)
853
-
854
- # Clean up disconnected workers
855
- for worker_id in disconnected:
856
- logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
857
- if worker_id in self.workers:
855
+ # Check if WebSocket is still open (don't ping)
856
+ if ws.state == websockets.protocol.State.OPEN:
857
+ active_workers.append(worker_id)
858
+ else:
859
+ # Clean up closed connections
860
+ logger.info(f"Worker {worker_id} connection closed")
858
861
  del self.workers[worker_id]
859
- logger.warning(
860
- f"Releasing assignments for worker {worker_id} because it did not respond to ping"
861
- )
862
862
  self.processor.release_assignments(worker_id)
863
- self.stats["connected_workers"] = len(self.workers)
863
+
864
+ # Log status
865
+ if active_workers:
866
+ logger.debug(
867
+ f"Active workers: {len(active_workers)} - {', '.join(active_workers[:5])}"
868
+ )
869
+ logger.debug(f"Inactive workers: {len(self.workers) - len(active_workers)}")
870
+ # add to self.stats
871
+ self.stats["active_workers"] = len(active_workers)
872
+ self.stats["inactive_workers"] = len(self.workers) - len(active_workers)
864
873
 
865
874
  async def _checkpoint_loop(self):
866
- """Periodically checkpoint storage."""
875
+ """Periodically checkpoint storage and chunk tracker."""
867
876
  interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
868
877
 
869
878
  while True:
870
879
  await asyncio.sleep(interval)
871
880
 
872
- await self.storage.checkpoint()
873
- self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
874
- logger.info("Storage checkpoint complete")
881
+ try:
882
+ # Checkpoint storage
883
+ await self.storage.checkpoint()
884
+
885
+ # Also checkpoint the chunk tracker if using webdataset processor
886
+ if hasattr(self.processor, "chunk_tracker") and self.processor.chunk_tracker:
887
+ # Save checkpoint in thread pool to avoid blocking
888
+ await asyncio.get_event_loop().run_in_executor(
889
+ None, self.processor.chunk_tracker.save
890
+ )
891
+ logger.debug("Saved chunk tracker checkpoint")
892
+
893
+ self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
894
+ logger.info("Storage and chunk tracker checkpoint complete")
895
+ except Exception as e:
896
+ logger.error(f"Error during checkpoint: {e}", exc_info=True)
875
897
 
876
898
  async def _stats_update_loop(self):
877
899
  """Periodically update and broadcast stats."""
@@ -551,7 +551,7 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
551
551
 
552
552
  # Force checkpoint save if needed
553
553
  if self.chunk_tracker:
554
- self.chunk_tracker.save_checkpoint()
554
+ self.chunk_tracker.save()
555
555
 
556
556
  def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
557
557
  """Get available work units for a worker."""
@@ -717,7 +717,7 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
717
717
 
718
718
  # Save final state
719
719
  if self.chunk_tracker:
720
- self.chunk_tracker.save_checkpoint()
720
+ self.chunk_tracker.save()
721
721
 
722
722
 
723
723
  class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
@@ -110,58 +110,86 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
110
110
  return self.shard_info_cache[shard_idx]
111
111
 
112
112
  def _restore_state(self, storage: StorageManager) -> None:
113
- """Restore state from chunk tracker."""
114
- logger.debug("Restoring state from chunk tracker")
113
+ """Restore state from chunk tracker and synchronize with storage."""
114
+ logger.info("Restoring state from chunk tracker and synchronizing with storage")
115
115
  if not self.chunk_tracker:
116
116
  return
117
117
 
118
+ # First, update chunk tracker from storage
119
+ processed_job_ids = storage.get_all_processed_job_ids()
120
+ if processed_job_ids:
121
+ logger.info(
122
+ f"Synchronizing chunk tracker with {len(processed_job_ids)} processed items from storage"
123
+ )
124
+ self.update_from_storage(processed_job_ids)
125
+
126
+ # Then restore work units from chunk tracker
118
127
  shards_summary = self.chunk_tracker.get_shards_summary()
119
- logger.debug(f"Restoring state: {shards_summary}")
128
+ logger.info(f"Restoring work units from chunk tracker: {len(shards_summary)} shards")
120
129
 
121
130
  with self.lock:
131
+ restored_count = 0
122
132
  for shard_name, shard_info in shards_summary.items():
123
133
  chunks = shard_info.get("chunks", [])
124
- logger.debug(f"Existing job ids: {storage.get_all_processed_job_ids()}")
125
134
  for chunk_state in chunks:
126
135
  # Only add incomplete chunks
127
- if chunk_state.status != "completed":
128
- logger.debug(f"Restoring incomplete chunk {chunk_state}")
136
+ if chunk_state.status == "completed":
137
+ logger.debug(f"Skipping completed chunk {chunk_state.chunk_id}")
138
+ continue
129
139
 
130
- # Get unprocessed ranges
131
- unprocessed_ranges = chunk_state.get_unprocessed_ranges()
140
+ # Get unprocessed ranges
141
+ unprocessed_ranges = chunk_state.get_unprocessed_ranges()
142
+ if not unprocessed_ranges:
132
143
  logger.debug(
133
- f"Chunk {chunk_state.chunk_id} unprocessed ranges: {unprocessed_ranges}"
144
+ f"Chunk {chunk_state.chunk_id} has no unprocessed ranges, marking as completed"
134
145
  )
135
- if not unprocessed_ranges:
136
- continue
146
+ self.chunk_tracker.mark_completed(chunk_state.chunk_id)
147
+ continue
137
148
 
138
- # Convert relative ranges to absolute file indices
139
- absolute_ranges = []
140
- for start, end in unprocessed_ranges:
141
- abs_start = chunk_state.start_index + start
142
- abs_end = chunk_state.start_index + end
143
- absolute_ranges.append((abs_start, abs_end))
149
+ logger.info(
150
+ f"Restoring chunk {chunk_state.chunk_id} with unprocessed ranges: {unprocessed_ranges}"
151
+ )
144
152
 
145
- unit = WorkUnit(
146
- unit_id=chunk_state.chunk_id,
147
- chunk_id=chunk_state.chunk_id,
148
- source_id=shard_name,
149
- unit_size=chunk_state.chunk_size,
150
- data={
151
- "shard_url": chunk_state.shard_url,
152
- "shard_name": shard_name,
153
- "start_index": chunk_state.start_index,
154
- "chunk_size": chunk_state.chunk_size,
155
- "unprocessed_ranges": absolute_ranges,
156
- },
157
- metadata={
158
- "shard_name": shard_name,
159
- "chunk_index": chunk_state.start_index // self.chunk_size,
160
- },
161
- )
153
+ # Convert relative ranges to absolute file indices
154
+ absolute_ranges = []
155
+ for start, end in unprocessed_ranges:
156
+ abs_start = chunk_state.start_index + start
157
+ abs_end = chunk_state.start_index + end
158
+ absolute_ranges.append((abs_start, abs_end))
159
+
160
+ # Get shard index if available
161
+ shard_idx = None
162
+ if self.dataset:
163
+ for idx in range(self.dataset.num_shards):
164
+ shard_info = self._get_shard_info_cached(idx)
165
+ if shard_info and shard_info["name"] == shard_name:
166
+ shard_idx = idx
167
+ break
162
168
 
163
- self.work_units[unit.unit_id] = unit
164
- self.pending_units.append(unit.unit_id)
169
+ unit = WorkUnit(
170
+ unit_id=chunk_state.chunk_id,
171
+ chunk_id=chunk_state.chunk_id,
172
+ source_id=shard_name,
173
+ unit_size=chunk_state.chunk_size,
174
+ data={
175
+ "shard_url": chunk_state.shard_url,
176
+ "shard_name": shard_name,
177
+ "shard_idx": shard_idx,
178
+ "start_index": chunk_state.start_index,
179
+ "chunk_size": chunk_state.chunk_size,
180
+ "unprocessed_ranges": absolute_ranges,
181
+ },
182
+ metadata={
183
+ "shard_name": shard_name,
184
+ "chunk_index": chunk_state.start_index // self.chunk_size,
185
+ },
186
+ )
187
+
188
+ self.work_units[unit.unit_id] = unit
189
+ self.pending_units.append(unit.unit_id)
190
+ restored_count += 1
191
+
192
+ logger.info(f"Restored {restored_count} incomplete work units")
165
193
 
166
194
  def _create_units_background(self) -> None:
167
195
  """Background thread to create work units on demand."""
@@ -278,8 +306,15 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
278
306
  assigned = []
279
307
 
280
308
  with self.lock:
281
- while len(assigned) < count and self.pending_units:
309
+ units_checked = 0
310
+ max_units_to_check = len(self.pending_units)
311
+
312
+ while len(assigned) < count and units_checked < max_units_to_check:
313
+ if not self.pending_units:
314
+ break
315
+
282
316
  unit_id = self.pending_units.popleft()
317
+ units_checked += 1
283
318
  unit = self.work_units.get(unit_id)
284
319
 
285
320
  if unit:
@@ -288,6 +323,16 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
288
323
  chunk_state = self.chunk_tracker.chunks[unit_id]
289
324
  relative_unprocessed = chunk_state.get_unprocessed_ranges()
290
325
 
326
+ # If no unprocessed ranges, mark as completed and skip
327
+ if not relative_unprocessed:
328
+ logger.info(
329
+ f"Chunk {unit_id} has no unprocessed ranges, marking as completed"
330
+ )
331
+ self.chunk_tracker.mark_completed(unit_id)
332
+ # Remove from work units
333
+ del self.work_units[unit_id]
334
+ continue
335
+
291
336
  # Convert relative to absolute indices
292
337
  absolute_ranges = []
293
338
  for start, end in relative_unprocessed:
@@ -307,6 +352,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
307
352
 
308
353
  if self.chunk_tracker:
309
354
  self.chunk_tracker.mark_assigned(unit_id, worker_id)
355
+ else:
356
+ # Put it back if we couldn't get the unit
357
+ self.pending_units.append(unit_id)
310
358
 
311
359
  logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
312
360
  return assigned
@@ -366,8 +414,20 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
366
414
  logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
367
415
 
368
416
  def handle_result(self, result: WorkResult) -> Dict[str, Any]:
369
- """Handle result from worker."""
370
- # Track processed items if we have chunk tracker
417
+ """Handle result from worker and update chunk tracker."""
418
+ # Extract the actual item index from the metadata
419
+ item_index = result.metadata.get("_item_index", None)
420
+
421
+ # If we have an item index, mark it as processed in the chunk tracker
422
+ if self.chunk_tracker and item_index is not None and result.chunk_id:
423
+ try:
424
+ # Mark single item as processed
425
+ self.chunk_tracker.mark_items_processed(result.chunk_id, item_index, item_index)
426
+ # logger.debug(f"Marked item {item_index} as processed in chunk {result.chunk_id}")
427
+ except Exception as e:
428
+ logger.error(f"Error marking item {item_index} as processed: {e}")
429
+
430
+ # Also handle batch results if present (backward compatibility)
371
431
  if self.chunk_tracker and "item_indices" in result.metadata:
372
432
  indices = result.metadata["item_indices"]
373
433
 
@@ -391,6 +451,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
391
451
  # Mark ranges as processed
392
452
  for start_idx, end_idx in ranges:
393
453
  self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
454
+ logger.debug(
455
+ f"Marked range {start_idx}-{end_idx} as processed in chunk {result.chunk_id}"
456
+ )
394
457
 
395
458
  return {
396
459
  "source_id": result.source_id,
@@ -407,22 +470,46 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
407
470
  # Group by chunk
408
471
  processed_by_chunk = defaultdict(set)
409
472
 
410
- for job_id in processed_job_ids:
411
- # Parse job_id to extract chunk and index
412
- # Expected format: "shard:chunk:X:idx:Y"
413
- parts = job_id.split(":")
414
- if len(parts) >= 5 and parts[3] == "idx":
415
- chunk_id = ":".join(parts[:3]) # "shard:chunk:X"
416
- try:
417
- idx = int(parts[4])
418
- processed_by_chunk[chunk_id].add(idx)
419
- except ValueError:
420
- continue
473
+ for job_id_str in processed_job_ids:
474
+ try:
475
+ # Use JobId to parse the job ID string
476
+ job_id = JobId.from_str(job_id_str)
477
+ chunk_id = job_id.get_chunk_str()
478
+ sample_idx = int(job_id.sample_id)
479
+ processed_by_chunk[chunk_id].add(sample_idx)
480
+ except ValueError as e:
481
+ logger.warning(f"Invalid job ID format: {job_id_str} - {e}")
482
+ continue
421
483
 
422
484
  # Update chunk tracker with processed items
423
485
  if self.chunk_tracker:
424
486
  for chunk_id, indices in processed_by_chunk.items():
425
487
  if indices:
488
+ # Get or create chunk state
489
+ chunk_state = self.chunk_tracker.chunks.get(chunk_id)
490
+ if not chunk_state:
491
+ # Parse chunk_id using JobId to get shard info
492
+ try:
493
+ # chunk_id format: "shard_id:chunk:chunk_idx"
494
+ parts = chunk_id.split(":")
495
+ if len(parts) >= 3:
496
+ shard_name = parts[0]
497
+ chunk_idx = int(parts[2])
498
+ # Infer start index from chunk index and size
499
+ start_index = chunk_idx * self.chunk_size
500
+ # Create chunk state
501
+ self.chunk_tracker.add_chunk(
502
+ chunk_id,
503
+ shard_name,
504
+ f"{shard_name}.tar",
505
+ start_index,
506
+ self.chunk_size,
507
+ )
508
+ logger.info(f"Created missing chunk state for {chunk_id}")
509
+ except (ValueError, IndexError) as e:
510
+ logger.error(f"Failed to create chunk state for {chunk_id}: {e}")
511
+ continue
512
+
426
513
  # Sort indices and convert to ranges
427
514
  sorted_indices = sorted(indices)
428
515
  if not sorted_indices:
@@ -443,10 +530,13 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
443
530
  ranges.append((start_range, end_range))
444
531
 
445
532
  # Mark each contiguous range as processed
446
- logger.debug(f"Marking ranges {ranges} as processed in chunk {chunk_id}")
533
+ logger.info(f"Marking ranges {ranges} as processed in chunk {chunk_id}")
447
534
  for start_idx, end_idx in ranges:
448
535
  self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
449
536
 
537
+ # Save checkpoint after updating
538
+ self.chunk_tracker.save()
539
+
450
540
  def get_stats(self) -> Dict[str, Any]:
451
541
  """Get processor statistics."""
452
542
  with self.lock:
@@ -484,7 +574,7 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
484
574
 
485
575
  # Save checkpoint
486
576
  if self.chunk_tracker:
487
- self.chunk_tracker.save_checkpoint()
577
+ self.chunk_tracker.save()
488
578
 
489
579
 
490
580
  class WebDatasetWorkerProcessor(WorkerProcessor):
@@ -555,7 +645,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
555
645
  # Generate mock results for unprocessed ranges
556
646
  for start_idx, end_idx in unprocessed_ranges:
557
647
  for idx in range(start_idx, end_idx + 1):
558
- job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
648
+ # Use JobId to create consistent job ID
649
+ job_id = JobId.from_values(
650
+ shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(idx)
651
+ )
652
+ job_id_str = job_id.get_sample_str()
559
653
 
560
654
  yield {
561
655
  "image": self._create_mock_image(idx),
@@ -565,11 +659,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
565
659
  "metadata": {
566
660
  "_item_index": idx,
567
661
  "_chunk_relative_index": idx - unit.data["start_index"],
568
- "_job_id": job_id,
662
+ "_job_id": job_id_str,
569
663
  "_mock": True,
570
664
  "_processed_indices": processed_indices,
571
665
  },
572
- "job_id": job_id,
666
+ "job_id": job_id_str,
573
667
  }
574
668
 
575
669
  processed_indices.append(idx)
@@ -614,8 +708,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
614
708
  f"Error decoding image {entry.path} with cv2: {img_e}"
615
709
  )
616
710
 
617
- # Generate job ID compatible with chunk tracker
618
- job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
711
+ # Generate job ID using JobId class
712
+ job_id = JobId.from_values(
713
+ shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(idx)
714
+ )
715
+ job_id_str = job_id.get_sample_str()
619
716
 
620
717
  yield {
621
718
  "image": image,
@@ -625,12 +722,12 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
625
722
  "metadata": {
626
723
  "_item_index": idx,
627
724
  "_chunk_relative_index": idx - unit.data["start_index"],
628
- "_job_id": job_id,
725
+ "_job_id": job_id_str,
629
726
  "_filename": entry.path,
630
727
  "_file_size": entry.size,
631
728
  "_processed_indices": processed_indices,
632
729
  },
633
- "job_id": job_id,
730
+ "job_id": job_id_str,
634
731
  }
635
732
 
636
733
  processed_indices.append(idx)
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
6
6
  from pathlib import Path
7
7
  from typing import Dict, Any, Optional
8
8
  from datetime import datetime
9
+ from concurrent.futures import ThreadPoolExecutor
9
10
 
10
11
  logger = logging.getLogger(__name__)
11
12
 
@@ -52,35 +53,54 @@ class CheckpointTracker(ABC):
52
53
 
53
54
  def save(self) -> None:
54
55
  """Save checkpoint to disk atomically."""
55
- try:
56
- # Prepare data with metadata
57
- data = self._serialize_state()
58
- data["updated_at"] = datetime.utcnow().isoformat()
56
+ with self.lock:
57
+ try:
58
+ # Prepare data with metadata
59
+ data = self._serialize_state()
60
+ data["updated_at"] = datetime.utcnow().isoformat()
61
+
62
+ # Write atomically using temp file
63
+ tmp_file = self.checkpoint_path.with_suffix(".tmp")
64
+ # If a save is already in progress, let it finish.
65
+ # This prevents race conditions if save() is called rapidly.
66
+ if (
67
+ hasattr(self, "_save_future")
68
+ and self._save_future
69
+ and not self._save_future.done()
70
+ ):
71
+ self._save_future.result() # Wait for the previous save to complete
72
+
73
+ # Use an executor to run the save operation in a background thread.
74
+ # This makes the save call non-blocking.
75
+ with ThreadPoolExecutor(max_workers=1) as executor:
76
+ data_to_save = data.copy()
77
+ self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
78
+ except Exception as e:
79
+ logger.error(f"Failed to submit save task: {e}", exc_info=True)
59
80
 
60
- # Write atomically using temp file
61
- tmp_file = self.checkpoint_path.with_suffix(".tmp")
81
+ def _write_to_disk(self, data: Dict[str, Any]) -> None:
82
+ """Write checkpoint data to disk atomically."""
83
+ # Create a temporary file in the same directory as the checkpoint
84
+ tmp_file = self.checkpoint_path.with_suffix(".tmp")
85
+
86
+ try:
87
+ # Ensure the parent directory exists
88
+ self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
62
89
 
63
90
  with open(tmp_file, "w") as f:
64
91
  json.dump(data, f, indent=2)
65
92
 
66
- # Ensure temp file was created
67
- if not tmp_file.exists():
68
- raise IOError(f"Failed to create temporary file: {tmp_file}")
69
-
70
- # Move atomically
93
+ # Atomically replace the checkpoint file
71
94
  tmp_file.replace(self.checkpoint_path)
72
-
73
95
  logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
74
-
75
96
  except Exception as e:
76
- # logger.error(f"Error saving checkpoint: {e}", exc_info=True)
77
- # Try direct write as fallback
78
- try:
79
- with open(self.checkpoint_path, "w") as f:
80
- json.dump(data, f, indent=2)
81
- # logger.info("Saved checkpoint using fallback direct write")
82
- except Exception as fallback_error:
83
- logger.error(f"Fallback save also failed: {fallback_error}")
97
+ logger.error(f"Failed to save checkpoint atomically: {e}", exc_info=True)
98
+ # Try to clean up the temp file if it exists
99
+ if tmp_file.exists():
100
+ try:
101
+ tmp_file.unlink()
102
+ except:
103
+ pass
84
104
 
85
105
  def get_stats(self) -> Dict[str, Any]:
86
106
  """Get statistics about tracked items. Override for custom stats."""
@@ -8,6 +8,7 @@ from datetime import datetime, timedelta
8
8
  from dataclasses import dataclass, asdict, field
9
9
 
10
10
  from .checkpoint_tracker import CheckpointTracker
11
+ from threading import Lock
11
12
 
12
13
  logger = logging.getLogger(__name__)
13
14
  logger.setLevel(logging.DEBUG)
@@ -60,12 +61,12 @@ class ChunkState:
60
61
  self.status = "completed"
61
62
  self.completed_at = datetime.utcnow()
62
63
  # Clear processed_ranges since we don't need them after completion
63
- self.processed_ranges = []
64
- self.assigned_to = None
65
- self.assigned_at = None
64
+ # self.processed_ranges = []
65
+ # self.assigned_to = None
66
+ # self.assigned_at = None
66
67
 
67
68
  def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
68
- """Get ranges that haven't been processed yet."""
69
+ """Get ranges of unprocessed items within the chunk (relative indices)."""
69
70
  if self.status == "completed":
70
71
  return []
71
72
 
@@ -73,22 +74,57 @@ class ChunkState:
73
74
  logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
74
75
  return [(0, self.chunk_size - 1)]
75
76
 
77
+ # Merge ranges first to ensure no overlaps
78
+ merged_ranges = self._merge_ranges(self.processed_ranges)
79
+
76
80
  unprocessed = []
77
- current = 0
81
+ current_pos = 0
78
82
 
79
- logger.info(
80
- f"Processing {len(self.processed_ranges)} processed ranges for chunk {self.chunk_id}"
81
- )
82
- for start, end in self.processed_ranges:
83
- if current < start:
84
- unprocessed.append((current, start - 1))
85
- current = max(current, end + 1)
83
+ for start, end in merged_ranges:
84
+ if current_pos < start:
85
+ unprocessed.append((current_pos, start - 1))
86
+ current_pos = max(current_pos, end + 1)
86
87
 
87
- if current < self.chunk_size:
88
- unprocessed.append((current, self.chunk_size - 1))
88
+ # Add any remaining range
89
+ if current_pos < self.chunk_size:
90
+ unprocessed.append((current_pos, self.chunk_size - 1))
91
+
92
+ # Log for debugging
93
+ if not unprocessed:
94
+ logger.info(
95
+ f"Chunk {self.chunk_id} has processed ranges {merged_ranges} covering entire chunk size {self.chunk_size}"
96
+ )
97
+ else:
98
+ total_processed = sum(end - start + 1 for start, end in merged_ranges)
99
+ total_unprocessed = sum(end - start + 1 for start, end in unprocessed)
100
+ logger.debug(
101
+ f"Chunk {self.chunk_id}: {total_processed} processed, {total_unprocessed} unprocessed"
102
+ )
89
103
 
90
104
  return unprocessed
91
105
 
106
+ def _merge_ranges(self, ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
107
+ """Merge overlapping or adjacent ranges."""
108
+ if not ranges:
109
+ return []
110
+
111
+ # Sort ranges by start index, ensuring all are tuples
112
+ sorted_ranges = sorted([tuple(r) for r in ranges])
113
+ merged = [sorted_ranges[0]]
114
+
115
+ for current_start, current_end in sorted_ranges[1:]:
116
+ last_start, last_end = merged[-1]
117
+
118
+ # Check if ranges overlap or are adjacent
119
+ if current_start <= last_end + 1:
120
+ # Merge the ranges
121
+ merged[-1] = (last_start, max(last_end, current_end))
122
+ else:
123
+ # Add as new range
124
+ merged.append((current_start, current_end))
125
+
126
+ return merged
127
+
92
128
  def to_dict(self):
93
129
  """Convert to dictionary for JSON serialization."""
94
130
  d = asdict(self)
@@ -124,6 +160,7 @@ class ChunkTracker(CheckpointTracker):
124
160
  self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
125
161
  self.archive_after_hours = archive_after_hours
126
162
  self._completed_count = 0 # Track count without storing all IDs
163
+ self.lock = Lock()
127
164
  super().__init__(checkpoint_file)
128
165
 
129
166
  def _get_default_state(self) -> Dict[str, Any]:
@@ -132,16 +169,17 @@ class ChunkTracker(CheckpointTracker):
132
169
 
133
170
  def _deserialize_state(self, data: Dict[str, Any]) -> None:
134
171
  """Deserialize loaded data into instance state."""
135
- self.chunks = {}
136
- self._completed_count = data.get("completed_count", 0)
137
-
138
- # Load chunk states
139
- completed_chunks = 0
140
- for chunk_id, chunk_data in data.get("chunks", {}).items():
141
- chunk_state = ChunkState.from_dict(chunk_data)
142
- self.chunks[chunk_id] = chunk_state
143
- if chunk_state.status == "completed":
144
- completed_chunks += 1
172
+ with self.lock:
173
+ self.chunks = {}
174
+ self._completed_count = data.get("completed_count", 0)
175
+
176
+ # Load chunk states
177
+ completed_chunks = 0
178
+ for chunk_id, chunk_data in data.get("chunks", {}).items():
179
+ chunk_state = ChunkState.from_dict(chunk_data)
180
+ self.chunks[chunk_id] = chunk_state
181
+ if chunk_state.status == "completed":
182
+ completed_chunks += 1
145
183
 
146
184
  logger.info(
147
185
  f"Loaded {len(self.chunks)} chunks from checkpoint, "
@@ -494,40 +532,40 @@ class ChunkTracker(CheckpointTracker):
494
532
  for start_idx, end_idx in ranges:
495
533
  chunk.add_processed_range(start_idx, end_idx)
496
534
 
497
- def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int):
498
- """Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
535
+ def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int) -> None:
536
+ """Mark a range of items as processed within a chunk."""
499
537
  if chunk_id not in self.chunks:
500
- logger.error(f"Unknown chunk: {chunk_id}")
538
+ logger.warning(f"Chunk {chunk_id} not found in tracker")
501
539
  return
502
540
 
503
- chunk = self.chunks[chunk_id]
541
+ chunk_state = self.chunks[chunk_id]
504
542
 
505
- # Convert absolute indices to chunk-relative
506
- relative_start = start_idx - chunk.start_index
507
- relative_end = end_idx - chunk.start_index
543
+ # Convert absolute indices to chunk-relative indices
544
+ relative_start = start_idx - chunk_state.start_index
545
+ relative_end = end_idx - chunk_state.start_index
508
546
 
509
- # Validate boundaries
510
- if relative_start < 0 or relative_end >= chunk.chunk_size:
511
- logger.error(
512
- f"Invalid indices for chunk {chunk_id}: "
513
- f"absolute {start_idx}-{end_idx} (relative {relative_start}-{relative_end}) "
514
- f"outside chunk bounds [{chunk.start_index}, {chunk.start_index + chunk.chunk_size - 1}]"
515
- )
516
- return
547
+ # Ensure indices are within chunk bounds
548
+ relative_start = max(0, relative_start)
549
+ relative_end = min(chunk_state.chunk_size - 1, relative_end)
517
550
 
518
- # Add the relative range
519
- chunk.add_processed_range(relative_start, relative_end)
551
+ # Add to processed ranges
552
+ chunk_state.processed_ranges.append((relative_start, relative_end))
520
553
 
521
- # If chunk is now complete, increment counter
522
- if chunk.status == "completed":
523
- self._completed_count += 1
554
+ # Merge overlapping ranges
555
+ chunk_state.processed_ranges = chunk_state._merge_ranges(chunk_state.processed_ranges)
524
556
 
525
- self.save()
526
557
  logger.debug(
527
- f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} "
528
- f"(relative indices: {relative_start}-{relative_end})"
558
+ f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} (relative indices: {relative_start}-{relative_end})"
529
559
  )
530
560
 
561
+ # Check if chunk is now complete
562
+ if chunk_state.get_unprocessed_ranges() == []:
563
+ logger.info(f"Chunk {chunk_id} is now complete")
564
+ chunk_state.status = "completed"
565
+
566
+ # Save checkpoint after updating
567
+ self.save()
568
+
531
569
  def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
532
570
  """Get chunk info with unprocessed item ranges."""
533
571
  chunk_state = self.chunks.get(chunk_id)
@@ -89,8 +89,13 @@ class BaseWorker(ABC):
89
89
  async def _connect_and_run(self):
90
90
  """Connect to orchestrator and run main loop."""
91
91
  logger.info(f"Connecting to {self.server_url}")
92
-
93
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
92
+ async with websockets.connect(
93
+ self.server_url,
94
+ ssl=self.ssl_context,
95
+ ping_interval=20,
96
+ ping_timeout=60,
97
+ close_timeout=10,
98
+ ) as websocket:
94
99
  self.websocket = websocket
95
100
  self.connected.set()
96
101
 
@@ -248,7 +248,13 @@ class CaptionWorker(BaseWorker):
248
248
  async def _initial_connect_for_config(self):
249
249
  """Connect initially just to get configuration."""
250
250
  logger.info(f"Connecting to {self.server_url}")
251
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
251
+ async with websockets.connect(
252
+ self.server_url,
253
+ ssl=self.ssl_context,
254
+ ping_interval=20,
255
+ ping_timeout=60,
256
+ close_timeout=10,
257
+ ) as websocket:
252
258
  await websocket.send(json.dumps(self._get_auth_data()))
253
259
 
254
260
  welcome = await websocket.recv()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: caption-flow
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: Self-contained distributed community captioning system
5
5
  Author-email: bghira <bghira@users.github.com>
6
6
  License: MIT
@@ -1,14 +1,14 @@
1
- caption_flow/__init__.py,sha256=09Vyr0RqKrKe1caUhXq9beficJkmclryjT6BNiASUxQ,303
2
- caption_flow/cli.py,sha256=t_cYCxJE7f5UtB3br2Es51JjO5KPsWM1JTdDXAxM_Lw,41371
1
+ caption_flow/__init__.py,sha256=2M1VLvkVjUmTHXuJFMLnZKqVYni5A0HJfxcnjz53K7c,303
2
+ caption_flow/cli.py,sha256=K3lML3WIYjD7OluGltHGP4N98S5w-KyhDUlQZudDQXE,41464
3
3
  caption_flow/models.py,sha256=2n6iphTEL62xK2FFcJM6axMsaE8KwsUv5Ak_cCF-TdQ,5652
4
- caption_flow/monitor.py,sha256=bAt9EJqfPgT_KdbknGdCxwBRH002pRDgyUmYIj6Dyso,7885
5
- caption_flow/orchestrator.py,sha256=34gZvaW14YZ7a7LagYOO3VKKwlbuS4aw0yoP1L8gwf0,36192
4
+ caption_flow/monitor.py,sha256=z2HakZSG799HvTJgjgG7u_MHvhq9-JL1LXzxBwP3WQc,7998
5
+ caption_flow/orchestrator.py,sha256=3XKZXFE1Aw1kCqb_Vw9loYpkmJ5LTLyZZf9pj4k6ldA,37175
6
6
  caption_flow/viewer.py,sha256=HxO98eHR1xtivG0dEdYC2U9T_RgeRfJqqTK-37u9bNM,20471
7
7
  caption_flow/processors/__init__.py,sha256=hvq-OuAJWQe6hFglKe7QmkS8473k20FmxZDSxfXpCrg,423
8
8
  caption_flow/processors/base.py,sha256=IAEr0pqHRuSkXunvDWk1vf2IKeYQ-2YERqej9iSQm94,6931
9
- caption_flow/processors/huggingface.py,sha256=w0j7PRosXYyJXZ0A0Y-J6_n-aHCGVW8tbt8lcvguO_Y,41237
9
+ caption_flow/processors/huggingface.py,sha256=t_dklhmNRAyk2jISu4FqmNecjg9hfY47omOiRVkbhvA,41215
10
10
  caption_flow/processors/local_filesystem.py,sha256=OuNNDemy0sdtpBBC_5GbI-c1vMqp8OIz983Cq85gdb8,27964
11
- caption_flow/processors/webdataset.py,sha256=TkC6xZO6m2FcwiBQGJsSQcrshBKcLdr4edFVtnBOd3U,28999
11
+ caption_flow/processors/webdataset.py,sha256=tUBCUKunqooHibTWtQ1wljuRI55Wc6M1WrI2hOZgt7g,33858
12
12
  caption_flow/storage/__init__.py,sha256=IVnzcSCPpPuyp-QLlgJirRZ9Sb3tR0F4sfuF5u2cNMk,36
13
13
  caption_flow/storage/exporter.py,sha256=mFJqMDQ61cP-qcXe118_-oL1TUqULdQZ8LdjSTym44I,19697
14
14
  caption_flow/storage/manager.py,sha256=KPExcKPuFVQSsBnfCBdne5PO4PwN4NTfd-EJQk13OY0,47459
@@ -16,18 +16,18 @@ caption_flow/utils/__init__.py,sha256=bDcO5uR455TKCQ2hX-_XcdTnRXDBaT8Yn4jWqWzfFs
16
16
  caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
17
17
  caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
18
18
  caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
19
- caption_flow/utils/checkpoint_tracker.py,sha256=-nN5gLvXyMdKOCT2SNNL2Km6UYm2Hii9wuXeezWhwx4,3339
20
- caption_flow/utils/chunk_tracker.py,sha256=HntWeINTbJmIERsW21p4q4FK8D9-4xKbZQUsj24DIqo,19975
19
+ caption_flow/utils/checkpoint_tracker.py,sha256=nOZIIGsXTRUj09tFSnWtRgj_zoa8Og_-rutkr2GFz8Y,4417
20
+ caption_flow/utils/chunk_tracker.py,sha256=JZIFvaHS5AYaVOzsSJKrnNlS4E3BdzV64cRkQa_65g0,21508
21
21
  caption_flow/utils/image_processor.py,sha256=wmOExkVfM7OeuLfX3AwMefsH-TxL8TNcn22gp0NmJKY,1541
22
22
  caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
23
23
  caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
24
24
  caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
25
- caption_flow/workers/base.py,sha256=2AGWERC5hbmO-0V_A1MUbgRVvRNN3blqGPyDokvvzmM,7575
26
- caption_flow/workers/caption.py,sha256=X4BEmb6C1c73hvgJDMsHtgCUlCuECtnloWSVolVpa4s,39353
25
+ caption_flow/workers/base.py,sha256=nEWohozFZ0Bw3_8U8xirnKLeZsGR5k69rSu4j-oDitc,7698
26
+ caption_flow/workers/caption.py,sha256=swE4pYg4ZYAAtMxvyvlETa3wv4yKWUPXXulCAwPhPiQ,39477
27
27
  caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
28
- caption_flow-0.3.2.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
29
- caption_flow-0.3.2.dist-info/METADATA,sha256=8bHECzNi4R6_FlbHWSHMx9TDo4uTVKWWgVbqAe5cCIs,9708
30
- caption_flow-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- caption_flow-0.3.2.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
32
- caption_flow-0.3.2.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
33
- caption_flow-0.3.2.dist-info/RECORD,,
28
+ caption_flow-0.3.4.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
29
+ caption_flow-0.3.4.dist-info/METADATA,sha256=dfB40EF_Zgz2Ux8qvdBbfLdhzY85_MUFRX-904I-qb4,9708
30
+ caption_flow-0.3.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ caption_flow-0.3.4.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
32
+ caption_flow-0.3.4.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
33
+ caption_flow-0.3.4.dist-info/RECORD,,