caption-flow 0.3.3__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {caption_flow-0.3.3/src/caption_flow.egg-info → caption_flow-0.3.4}/PKG-INFO +1 -1
  2. {caption_flow-0.3.3 → caption_flow-0.3.4}/pyproject.toml +1 -1
  3. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/__init__.py +1 -1
  4. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/cli.py +4 -2
  5. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/monitor.py +3 -0
  6. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/orchestrator.py +53 -32
  7. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/processors/huggingface.py +2 -2
  8. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/processors/webdataset.py +39 -4
  9. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/checkpoint_tracker.py +41 -21
  10. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/chunk_tracker.py +85 -47
  11. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/workers/base.py +7 -2
  12. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/workers/caption.py +7 -1
  13. {caption_flow-0.3.3 → caption_flow-0.3.4/src/caption_flow.egg-info}/PKG-INFO +1 -1
  14. {caption_flow-0.3.3 → caption_flow-0.3.4}/LICENSE +0 -0
  15. {caption_flow-0.3.3 → caption_flow-0.3.4}/README.md +0 -0
  16. {caption_flow-0.3.3 → caption_flow-0.3.4}/setup.cfg +0 -0
  17. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/models.py +0 -0
  18. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/processors/__init__.py +0 -0
  19. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/processors/base.py +0 -0
  20. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/processors/local_filesystem.py +0 -0
  21. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/storage/__init__.py +0 -0
  22. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/storage/exporter.py +0 -0
  23. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/storage/manager.py +0 -0
  24. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/__init__.py +0 -0
  25. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/auth.py +0 -0
  26. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/caption_utils.py +0 -0
  27. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/certificates.py +0 -0
  28. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/image_processor.py +0 -0
  29. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/json_utils.py +0 -0
  30. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/prompt_template.py +0 -0
  31. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/utils/vllm_config.py +0 -0
  32. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/viewer.py +0 -0
  33. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow/workers/data.py +0 -0
  34. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow.egg-info/SOURCES.txt +0 -0
  35. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow.egg-info/dependency_links.txt +0 -0
  36. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow.egg-info/entry_points.txt +0 -0
  37. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow.egg-info/requires.txt +0 -0
  38. {caption_flow-0.3.3 → caption_flow-0.3.4}/src/caption_flow.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: caption-flow
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: Self-contained distributed community captioning system
5
5
  Author-email: bghira <bghira@users.github.com>
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "caption-flow"
3
- version = "0.3.3"
3
+ version = "0.3.4"
4
4
  description = "Self-contained distributed community captioning system"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10,<3.13"
@@ -1,6 +1,6 @@
1
1
  """CaptionFlow - Distributed community captioning system."""
2
2
 
3
- __version__ = "0.3.3"
3
+ __version__ = "0.3.4"
4
4
 
5
5
  from .orchestrator import Orchestrator
6
6
  from .workers.data import DataWorker
@@ -124,7 +124,7 @@ def setup_logging(verbose: bool = False):
124
124
  level = logging.DEBUG if verbose else logging.INFO
125
125
  logging.basicConfig(
126
126
  level=level,
127
- format="%(message)s",
127
+ format="%(name)s: %(message)s",
128
128
  datefmt="[%Y-%m-%d %H:%M:%S]",
129
129
  handlers=[
130
130
  RichHandler(
@@ -490,7 +490,9 @@ def reload_config(
490
490
 
491
491
  async def send_reload():
492
492
  try:
493
- async with websockets.connect(server, ssl=ssl_context) as websocket:
493
+ async with websockets.connect(
494
+ server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
495
+ ) as websocket:
494
496
  # Authenticate as admin
495
497
  await websocket.send(json.dumps({"token": token, "role": "admin"}))
496
498
 
@@ -73,6 +73,9 @@ class Monitor:
73
73
  async with websockets.connect(
74
74
  self.server_url,
75
75
  ssl=self.ssl_context if self.server_url.startswith("wss://") else None,
76
+ ping_interval=20,
77
+ ping_timeout=60,
78
+ close_timeout=10,
76
79
  ) as websocket:
77
80
  # Authenticate
78
81
  await websocket.send(json.dumps({"token": self.token}))
@@ -377,16 +377,17 @@ class Orchestrator:
377
377
  """Process results submission from worker."""
378
378
  # Extract user from worker_id
379
379
  worker_user = worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
380
+
380
381
  # Create work result
381
382
  _job_id = data.get("job_id")
382
383
  job_id = JobId.from_str(_job_id)
383
- shard_name = job_id.shard_id # >data-0000<
384
- chunk_name = job_id.chunk_id # data-0000:chunk:>0<
385
- # logger.debug(f"({job_id}) Worker result: {data}")
384
+ shard_name = job_id.shard_id
385
+ chunk_name = job_id.chunk_id
386
+
386
387
  result = WorkResult(
387
388
  unit_id=data["unit_id"],
388
389
  source_id=shard_name,
389
- chunk_id=job_id.get_chunk_str(), # we want the full string here
390
+ chunk_id=job_id.get_chunk_str(),
390
391
  sample_id=data["sample_id"],
391
392
  dataset=data["dataset"],
392
393
  outputs=data["outputs"],
@@ -394,7 +395,9 @@ class Orchestrator:
394
395
  processing_time_ms=data.get("processing_time_ms", 0),
395
396
  )
396
397
 
397
- # Let processor handle any custom processing
398
+ # Let processor handle any custom processing - this updates chunk tracker
399
+ # IMPORTANT: Call this BEFORE saving to storage so chunk tracker is updated
400
+ # regardless of whether the item is a duplicate
398
401
  processed = self.processor.handle_result(result)
399
402
 
400
403
  # Create caption record for storage
@@ -412,6 +415,7 @@ class Orchestrator:
412
415
  for key in to_delete_metadata_keys:
413
416
  if key in result.metadata:
414
417
  del result.metadata[key]
418
+
415
419
  caption = Caption(
416
420
  job_id=job_id,
417
421
  dataset=result.dataset,
@@ -433,14 +437,15 @@ class Orchestrator:
433
437
  image_format=image_format,
434
438
  )
435
439
 
436
- # Save to storage
437
- await self.storage.save_caption(caption)
440
+ # Save to storage (might skip if duplicate)
441
+ saved = await self.storage.save_caption(caption)
438
442
 
439
- # Update contributor stats
440
- contributor = await self.storage.get_contributor(worker_user)
441
- if contributor:
442
- contributor.total_captions += total_outputs
443
- await self.storage.save_contributor(contributor)
443
+ # Update contributor stats only if actually saved
444
+ if saved:
445
+ contributor = await self.storage.get_contributor(worker_user)
446
+ if contributor:
447
+ contributor.total_captions += total_outputs
448
+ await self.storage.save_contributor(contributor)
444
449
 
445
450
  async def _handle_monitor(self, websocket: WebSocketServerProtocol):
446
451
  """Handle monitor connection."""
@@ -840,39 +845,55 @@ class Orchestrator:
840
845
  self.monitors -= disconnected
841
846
 
842
847
  async def _heartbeat_loop(self):
843
- """Send periodic heartbeats to maintain connections."""
848
+ """Collect and log worker status periodically."""
844
849
  while True:
845
850
  await asyncio.sleep(30)
846
851
 
847
- disconnected = []
852
+ # Just collect status - no ping/pong
853
+ active_workers = []
848
854
  for worker_id, ws in list(self.workers.items()):
849
- try:
850
- pong_waiter = await ws.ping()
851
- await asyncio.wait_for(pong_waiter, timeout=10)
852
- except:
853
- disconnected.append(worker_id)
854
-
855
- # Clean up disconnected workers
856
- for worker_id in disconnected:
857
- logger.warning(f"Worker {worker_id} did not respond to ping, disconnecting")
858
- if worker_id in self.workers:
855
+ # Check if WebSocket is still open (don't ping)
856
+ if ws.state == websockets.protocol.State.OPEN:
857
+ active_workers.append(worker_id)
858
+ else:
859
+ # Clean up closed connections
860
+ logger.info(f"Worker {worker_id} connection closed")
859
861
  del self.workers[worker_id]
860
- logger.warning(
861
- f"Releasing assignments for worker {worker_id} because it did not respond to ping"
862
- )
863
862
  self.processor.release_assignments(worker_id)
864
- self.stats["connected_workers"] = len(self.workers)
863
+
864
+ # Log status
865
+ if active_workers:
866
+ logger.debug(
867
+ f"Active workers: {len(active_workers)} - {', '.join(active_workers[:5])}"
868
+ )
869
+ logger.debug(f"Inactive workers: {len(self.workers) - len(active_workers)}")
870
+ # add to self.stats
871
+ self.stats["active_workers"] = len(active_workers)
872
+ self.stats["inactive_workers"] = len(self.workers) - len(active_workers)
865
873
 
866
874
  async def _checkpoint_loop(self):
867
- """Periodically checkpoint storage."""
875
+ """Periodically checkpoint storage and chunk tracker."""
868
876
  interval = self.config.get("storage", {}).get("checkpoint_interval", 60)
869
877
 
870
878
  while True:
871
879
  await asyncio.sleep(interval)
872
880
 
873
- await self.storage.checkpoint()
874
- self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
875
- logger.info("Storage checkpoint complete")
881
+ try:
882
+ # Checkpoint storage
883
+ await self.storage.checkpoint()
884
+
885
+ # Also checkpoint the chunk tracker if using webdataset processor
886
+ if hasattr(self.processor, "chunk_tracker") and self.processor.chunk_tracker:
887
+ # Save checkpoint in thread pool to avoid blocking
888
+ await asyncio.get_event_loop().run_in_executor(
889
+ None, self.processor.chunk_tracker.save
890
+ )
891
+ logger.debug("Saved chunk tracker checkpoint")
892
+
893
+ self.stats["last_checkpoint"] = datetime.utcnow().isoformat()
894
+ logger.info("Storage and chunk tracker checkpoint complete")
895
+ except Exception as e:
896
+ logger.error(f"Error during checkpoint: {e}", exc_info=True)
876
897
 
877
898
  async def _stats_update_loop(self):
878
899
  """Periodically update and broadcast stats."""
@@ -551,7 +551,7 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
551
551
 
552
552
  # Force checkpoint save if needed
553
553
  if self.chunk_tracker:
554
- self.chunk_tracker.save_checkpoint()
554
+ self.chunk_tracker.save()
555
555
 
556
556
  def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
557
557
  """Get available work units for a worker."""
@@ -717,7 +717,7 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
717
717
 
718
718
  # Save final state
719
719
  if self.chunk_tracker:
720
- self.chunk_tracker.save_checkpoint()
720
+ self.chunk_tracker.save()
721
721
 
722
722
 
723
723
  class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
@@ -306,8 +306,15 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
306
306
  assigned = []
307
307
 
308
308
  with self.lock:
309
- while len(assigned) < count and self.pending_units:
309
+ units_checked = 0
310
+ max_units_to_check = len(self.pending_units)
311
+
312
+ while len(assigned) < count and units_checked < max_units_to_check:
313
+ if not self.pending_units:
314
+ break
315
+
310
316
  unit_id = self.pending_units.popleft()
317
+ units_checked += 1
311
318
  unit = self.work_units.get(unit_id)
312
319
 
313
320
  if unit:
@@ -316,6 +323,16 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
316
323
  chunk_state = self.chunk_tracker.chunks[unit_id]
317
324
  relative_unprocessed = chunk_state.get_unprocessed_ranges()
318
325
 
326
+ # If no unprocessed ranges, mark as completed and skip
327
+ if not relative_unprocessed:
328
+ logger.info(
329
+ f"Chunk {unit_id} has no unprocessed ranges, marking as completed"
330
+ )
331
+ self.chunk_tracker.mark_completed(unit_id)
332
+ # Remove from work units
333
+ del self.work_units[unit_id]
334
+ continue
335
+
319
336
  # Convert relative to absolute indices
320
337
  absolute_ranges = []
321
338
  for start, end in relative_unprocessed:
@@ -335,6 +352,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
335
352
 
336
353
  if self.chunk_tracker:
337
354
  self.chunk_tracker.mark_assigned(unit_id, worker_id)
355
+ else:
356
+ # Put it back if we couldn't get the unit
357
+ self.pending_units.append(unit_id)
338
358
 
339
359
  logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
340
360
  return assigned
@@ -394,8 +414,20 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
394
414
  logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
395
415
 
396
416
  def handle_result(self, result: WorkResult) -> Dict[str, Any]:
397
- """Handle result from worker."""
398
- # Track processed items if we have chunk tracker
417
+ """Handle result from worker and update chunk tracker."""
418
+ # Extract the actual item index from the metadata
419
+ item_index = result.metadata.get("_item_index", None)
420
+
421
+ # If we have an item index, mark it as processed in the chunk tracker
422
+ if self.chunk_tracker and item_index is not None and result.chunk_id:
423
+ try:
424
+ # Mark single item as processed
425
+ self.chunk_tracker.mark_items_processed(result.chunk_id, item_index, item_index)
426
+ # logger.debug(f"Marked item {item_index} as processed in chunk {result.chunk_id}")
427
+ except Exception as e:
428
+ logger.error(f"Error marking item {item_index} as processed: {e}")
429
+
430
+ # Also handle batch results if present (backward compatibility)
399
431
  if self.chunk_tracker and "item_indices" in result.metadata:
400
432
  indices = result.metadata["item_indices"]
401
433
 
@@ -419,6 +451,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
419
451
  # Mark ranges as processed
420
452
  for start_idx, end_idx in ranges:
421
453
  self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
454
+ logger.debug(
455
+ f"Marked range {start_idx}-{end_idx} as processed in chunk {result.chunk_id}"
456
+ )
422
457
 
423
458
  return {
424
459
  "source_id": result.source_id,
@@ -539,7 +574,7 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
539
574
 
540
575
  # Save checkpoint
541
576
  if self.chunk_tracker:
542
- self.chunk_tracker.save_checkpoint()
577
+ self.chunk_tracker.save()
543
578
 
544
579
 
545
580
  class WebDatasetWorkerProcessor(WorkerProcessor):
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
6
6
  from pathlib import Path
7
7
  from typing import Dict, Any, Optional
8
8
  from datetime import datetime
9
+ from concurrent.futures import ThreadPoolExecutor
9
10
 
10
11
  logger = logging.getLogger(__name__)
11
12
 
@@ -52,35 +53,54 @@ class CheckpointTracker(ABC):
52
53
 
53
54
  def save(self) -> None:
54
55
  """Save checkpoint to disk atomically."""
55
- try:
56
- # Prepare data with metadata
57
- data = self._serialize_state()
58
- data["updated_at"] = datetime.utcnow().isoformat()
56
+ with self.lock:
57
+ try:
58
+ # Prepare data with metadata
59
+ data = self._serialize_state()
60
+ data["updated_at"] = datetime.utcnow().isoformat()
61
+
62
+ # Write atomically using temp file
63
+ tmp_file = self.checkpoint_path.with_suffix(".tmp")
64
+ # If a save is already in progress, let it finish.
65
+ # This prevents race conditions if save() is called rapidly.
66
+ if (
67
+ hasattr(self, "_save_future")
68
+ and self._save_future
69
+ and not self._save_future.done()
70
+ ):
71
+ self._save_future.result() # Wait for the previous save to complete
72
+
73
+ # Use an executor to run the save operation in a background thread.
74
+ # This makes the save call non-blocking.
75
+ with ThreadPoolExecutor(max_workers=1) as executor:
76
+ data_to_save = data.copy()
77
+ self._save_future = executor.submit(self._write_to_disk, data_to_save, tmp_file)
78
+ except Exception as e:
79
+ logger.error(f"Failed to submit save task: {e}", exc_info=True)
59
80
 
60
- # Write atomically using temp file
61
- tmp_file = self.checkpoint_path.with_suffix(".tmp")
81
+ def _write_to_disk(self, data: Dict[str, Any]) -> None:
82
+ """Write checkpoint data to disk atomically."""
83
+ # Create a temporary file in the same directory as the checkpoint
84
+ tmp_file = self.checkpoint_path.with_suffix(".tmp")
85
+
86
+ try:
87
+ # Ensure the parent directory exists
88
+ self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
62
89
 
63
90
  with open(tmp_file, "w") as f:
64
91
  json.dump(data, f, indent=2)
65
92
 
66
- # Ensure temp file was created
67
- if not tmp_file.exists():
68
- raise IOError(f"Failed to create temporary file: {tmp_file}")
69
-
70
- # Move atomically
93
+ # Atomically replace the checkpoint file
71
94
  tmp_file.replace(self.checkpoint_path)
72
-
73
95
  logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
74
-
75
96
  except Exception as e:
76
- # logger.error(f"Error saving checkpoint: {e}", exc_info=True)
77
- # Try direct write as fallback
78
- try:
79
- with open(self.checkpoint_path, "w") as f:
80
- json.dump(data, f, indent=2)
81
- # logger.info("Saved checkpoint using fallback direct write")
82
- except Exception as fallback_error:
83
- logger.error(f"Fallback save also failed: {fallback_error}")
97
+ logger.error(f"Failed to save checkpoint atomically: {e}", exc_info=True)
98
+ # Try to clean up the temp file if it exists
99
+ if tmp_file.exists():
100
+ try:
101
+ tmp_file.unlink()
102
+ except:
103
+ pass
84
104
 
85
105
  def get_stats(self) -> Dict[str, Any]:
86
106
  """Get statistics about tracked items. Override for custom stats."""
@@ -8,6 +8,7 @@ from datetime import datetime, timedelta
8
8
  from dataclasses import dataclass, asdict, field
9
9
 
10
10
  from .checkpoint_tracker import CheckpointTracker
11
+ from threading import Lock
11
12
 
12
13
  logger = logging.getLogger(__name__)
13
14
  logger.setLevel(logging.DEBUG)
@@ -60,12 +61,12 @@ class ChunkState:
60
61
  self.status = "completed"
61
62
  self.completed_at = datetime.utcnow()
62
63
  # Clear processed_ranges since we don't need them after completion
63
- self.processed_ranges = []
64
- self.assigned_to = None
65
- self.assigned_at = None
64
+ # self.processed_ranges = []
65
+ # self.assigned_to = None
66
+ # self.assigned_at = None
66
67
 
67
68
  def get_unprocessed_ranges(self) -> List[Tuple[int, int]]:
68
- """Get ranges that haven't been processed yet."""
69
+ """Get ranges of unprocessed items within the chunk (relative indices)."""
69
70
  if self.status == "completed":
70
71
  return []
71
72
 
@@ -73,22 +74,57 @@ class ChunkState:
73
74
  logger.info(f"Chunk {self.chunk_id} has no processed ranges, returning full range")
74
75
  return [(0, self.chunk_size - 1)]
75
76
 
77
+ # Merge ranges first to ensure no overlaps
78
+ merged_ranges = self._merge_ranges(self.processed_ranges)
79
+
76
80
  unprocessed = []
77
- current = 0
81
+ current_pos = 0
78
82
 
79
- logger.info(
80
- f"Processing {len(self.processed_ranges)} processed ranges for chunk {self.chunk_id}"
81
- )
82
- for start, end in self.processed_ranges:
83
- if current < start:
84
- unprocessed.append((current, start - 1))
85
- current = max(current, end + 1)
83
+ for start, end in merged_ranges:
84
+ if current_pos < start:
85
+ unprocessed.append((current_pos, start - 1))
86
+ current_pos = max(current_pos, end + 1)
86
87
 
87
- if current < self.chunk_size:
88
- unprocessed.append((current, self.chunk_size - 1))
88
+ # Add any remaining range
89
+ if current_pos < self.chunk_size:
90
+ unprocessed.append((current_pos, self.chunk_size - 1))
91
+
92
+ # Log for debugging
93
+ if not unprocessed:
94
+ logger.info(
95
+ f"Chunk {self.chunk_id} has processed ranges {merged_ranges} covering entire chunk size {self.chunk_size}"
96
+ )
97
+ else:
98
+ total_processed = sum(end - start + 1 for start, end in merged_ranges)
99
+ total_unprocessed = sum(end - start + 1 for start, end in unprocessed)
100
+ logger.debug(
101
+ f"Chunk {self.chunk_id}: {total_processed} processed, {total_unprocessed} unprocessed"
102
+ )
89
103
 
90
104
  return unprocessed
91
105
 
106
+ def _merge_ranges(self, ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
107
+ """Merge overlapping or adjacent ranges."""
108
+ if not ranges:
109
+ return []
110
+
111
+ # Sort ranges by start index, ensuring all are tuples
112
+ sorted_ranges = sorted([tuple(r) for r in ranges])
113
+ merged = [sorted_ranges[0]]
114
+
115
+ for current_start, current_end in sorted_ranges[1:]:
116
+ last_start, last_end = merged[-1]
117
+
118
+ # Check if ranges overlap or are adjacent
119
+ if current_start <= last_end + 1:
120
+ # Merge the ranges
121
+ merged[-1] = (last_start, max(last_end, current_end))
122
+ else:
123
+ # Add as new range
124
+ merged.append((current_start, current_end))
125
+
126
+ return merged
127
+
92
128
  def to_dict(self):
93
129
  """Convert to dictionary for JSON serialization."""
94
130
  d = asdict(self)
@@ -124,6 +160,7 @@ class ChunkTracker(CheckpointTracker):
124
160
  self.max_completed_chunks_in_memory = max_completed_chunks_in_memory
125
161
  self.archive_after_hours = archive_after_hours
126
162
  self._completed_count = 0 # Track count without storing all IDs
163
+ self.lock = Lock()
127
164
  super().__init__(checkpoint_file)
128
165
 
129
166
  def _get_default_state(self) -> Dict[str, Any]:
@@ -132,16 +169,17 @@ class ChunkTracker(CheckpointTracker):
132
169
 
133
170
  def _deserialize_state(self, data: Dict[str, Any]) -> None:
134
171
  """Deserialize loaded data into instance state."""
135
- self.chunks = {}
136
- self._completed_count = data.get("completed_count", 0)
137
-
138
- # Load chunk states
139
- completed_chunks = 0
140
- for chunk_id, chunk_data in data.get("chunks", {}).items():
141
- chunk_state = ChunkState.from_dict(chunk_data)
142
- self.chunks[chunk_id] = chunk_state
143
- if chunk_state.status == "completed":
144
- completed_chunks += 1
172
+ with self.lock:
173
+ self.chunks = {}
174
+ self._completed_count = data.get("completed_count", 0)
175
+
176
+ # Load chunk states
177
+ completed_chunks = 0
178
+ for chunk_id, chunk_data in data.get("chunks", {}).items():
179
+ chunk_state = ChunkState.from_dict(chunk_data)
180
+ self.chunks[chunk_id] = chunk_state
181
+ if chunk_state.status == "completed":
182
+ completed_chunks += 1
145
183
 
146
184
  logger.info(
147
185
  f"Loaded {len(self.chunks)} chunks from checkpoint, "
@@ -494,40 +532,40 @@ class ChunkTracker(CheckpointTracker):
494
532
  for start_idx, end_idx in ranges:
495
533
  chunk.add_processed_range(start_idx, end_idx)
496
534
 
497
- def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int):
498
- """Mark a range of items as processed within a chunk (expects ABSOLUTE indices)."""
535
+ def mark_items_processed(self, chunk_id: str, start_idx: int, end_idx: int) -> None:
536
+ """Mark a range of items as processed within a chunk."""
499
537
  if chunk_id not in self.chunks:
500
- logger.error(f"Unknown chunk: {chunk_id}")
538
+ logger.warning(f"Chunk {chunk_id} not found in tracker")
501
539
  return
502
540
 
503
- chunk = self.chunks[chunk_id]
541
+ chunk_state = self.chunks[chunk_id]
504
542
 
505
- # Convert absolute indices to chunk-relative
506
- relative_start = start_idx - chunk.start_index
507
- relative_end = end_idx - chunk.start_index
543
+ # Convert absolute indices to chunk-relative indices
544
+ relative_start = start_idx - chunk_state.start_index
545
+ relative_end = end_idx - chunk_state.start_index
508
546
 
509
- # Validate boundaries
510
- if relative_start < 0 or relative_end >= chunk.chunk_size:
511
- logger.error(
512
- f"Invalid indices for chunk {chunk_id}: "
513
- f"absolute {start_idx}-{end_idx} (relative {relative_start}-{relative_end}) "
514
- f"outside chunk bounds [{chunk.start_index}, {chunk.start_index + chunk.chunk_size - 1}]"
515
- )
516
- return
547
+ # Ensure indices are within chunk bounds
548
+ relative_start = max(0, relative_start)
549
+ relative_end = min(chunk_state.chunk_size - 1, relative_end)
517
550
 
518
- # Add the relative range
519
- chunk.add_processed_range(relative_start, relative_end)
551
+ # Add to processed ranges
552
+ chunk_state.processed_ranges.append((relative_start, relative_end))
520
553
 
521
- # If chunk is now complete, increment counter
522
- if chunk.status == "completed":
523
- self._completed_count += 1
554
+ # Merge overlapping ranges
555
+ chunk_state.processed_ranges = chunk_state._merge_ranges(chunk_state.processed_ranges)
524
556
 
525
- self.save()
526
557
  logger.debug(
527
- f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} "
528
- f"(relative indices: {relative_start}-{relative_end})"
558
+ f"Marked items {start_idx}-{end_idx} as processed in chunk {chunk_id} (relative indices: {relative_start}-{relative_end})"
529
559
  )
530
560
 
561
+ # Check if chunk is now complete
562
+ if chunk_state.get_unprocessed_ranges() == []:
563
+ logger.info(f"Chunk {chunk_id} is now complete")
564
+ chunk_state.status = "completed"
565
+
566
+ # Save checkpoint after updating
567
+ self.save()
568
+
531
569
  def get_chunk_with_unprocessed_items(self, chunk_id: str) -> Optional[Dict[str, Any]]:
532
570
  """Get chunk info with unprocessed item ranges."""
533
571
  chunk_state = self.chunks.get(chunk_id)
@@ -89,8 +89,13 @@ class BaseWorker(ABC):
89
89
  async def _connect_and_run(self):
90
90
  """Connect to orchestrator and run main loop."""
91
91
  logger.info(f"Connecting to {self.server_url}")
92
-
93
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
92
+ async with websockets.connect(
93
+ self.server_url,
94
+ ssl=self.ssl_context,
95
+ ping_interval=20,
96
+ ping_timeout=60,
97
+ close_timeout=10,
98
+ ) as websocket:
94
99
  self.websocket = websocket
95
100
  self.connected.set()
96
101
 
@@ -248,7 +248,13 @@ class CaptionWorker(BaseWorker):
248
248
  async def _initial_connect_for_config(self):
249
249
  """Connect initially just to get configuration."""
250
250
  logger.info(f"Connecting to {self.server_url}")
251
- async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
251
+ async with websockets.connect(
252
+ self.server_url,
253
+ ssl=self.ssl_context,
254
+ ping_interval=20,
255
+ ping_timeout=60,
256
+ close_timeout=10,
257
+ ) as websocket:
252
258
  await websocket.send(json.dumps(self._get_auth_data()))
253
259
 
254
260
  welcome = await websocket.recv()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: caption-flow
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: Self-contained distributed community captioning system
5
5
  Author-email: bghira <bghira@users.github.com>
6
6
  License: MIT
File without changes
File without changes
File without changes