caption-flow 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {caption_flow-0.2.0/src/caption_flow.egg-info → caption_flow-0.2.1}/PKG-INFO +2 -1
  2. {caption_flow-0.2.0 → caption_flow-0.2.1}/pyproject.toml +2 -1
  3. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/cli.py +9 -3
  4. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/monitor.py +1 -1
  5. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/orchestrator.py +357 -84
  6. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/dataset_loader.py +179 -4
  7. {caption_flow-0.2.0 → caption_flow-0.2.1/src/caption_flow.egg-info}/PKG-INFO +2 -1
  8. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow.egg-info/requires.txt +1 -0
  9. {caption_flow-0.2.0 → caption_flow-0.2.1}/LICENSE +0 -0
  10. {caption_flow-0.2.0 → caption_flow-0.2.1}/README.md +0 -0
  11. {caption_flow-0.2.0 → caption_flow-0.2.1}/setup.cfg +0 -0
  12. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/__init__.py +0 -0
  13. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/models.py +0 -0
  14. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/storage.py +0 -0
  15. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/__init__.py +0 -0
  16. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/auth.py +0 -0
  17. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/caption_utils.py +0 -0
  18. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/certificates.py +0 -0
  19. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/checkpoint_tracker.py +0 -0
  20. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/chunk_tracker.py +0 -0
  21. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/image_processor.py +0 -0
  22. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/job_queue.py +0 -0
  23. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/json_utils.py +0 -0
  24. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/prompt_template.py +0 -0
  25. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/shard_processor.py +0 -0
  26. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/shard_tracker.py +0 -0
  27. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/utils/vllm_config.py +0 -0
  28. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/workers/base.py +0 -0
  29. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/workers/caption.py +0 -0
  30. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow/workers/data.py +0 -0
  31. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow.egg-info/SOURCES.txt +0 -0
  32. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow.egg-info/dependency_links.txt +0 -0
  33. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow.egg-info/entry_points.txt +0 -0
  34. {caption_flow-0.2.0 → caption_flow-0.2.1}/src/caption_flow.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: caption-flow
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Self-contained distributed community captioning system
5
5
  Author-email: bghira <bghira@users.github.com>
6
6
  License: MIT
@@ -32,6 +32,7 @@ Requires-Dist: pandas<3.0.0,>=2.3.1
32
32
  Requires-Dist: arrow<2.0.0,>=1.3.0
33
33
  Requires-Dist: datasets<5.0.0,>=4.0.0
34
34
  Requires-Dist: boto3<2.0.0,>=1.40.11
35
+ Requires-Dist: torchdata<0.12.0,>=0.11.0
35
36
  Provides-Extra: dev
36
37
  Requires-Dist: pytest>=7.4.0; extra == "dev"
37
38
  Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "caption-flow"
3
- version = "0.2.0"
3
+ version = "0.2.1"
4
4
  description = "Self-contained distributed community captioning system"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10,<3.13"
@@ -37,6 +37,7 @@ dependencies = [
37
37
  "arrow (>=1.3.0,<2.0.0)",
38
38
  "datasets (>=4.0.0,<5.0.0)",
39
39
  "boto3 (>=1.40.11,<2.0.0)",
40
+ "torchdata (>=0.11.0,<0.12.0)",
40
41
  ]
41
42
 
42
43
  [project.optional-dependencies]
@@ -120,13 +120,19 @@ class ConfigManager:
120
120
 
121
121
 
122
122
  def setup_logging(verbose: bool = False):
123
- """Configure logging with rich handler."""
123
+ """Configure logging with rich handler, including timestamp."""
124
124
  level = logging.DEBUG if verbose else logging.INFO
125
125
  logging.basicConfig(
126
126
  level=level,
127
- format="%(message)s",
127
+ format="%(asctime)s %(message)s",
128
+ datefmt="[%Y-%m-%d %H:%M:%S]",
128
129
  handlers=[
129
- RichHandler(console=console, rich_tracebacks=True, show_path=False, show_time=False)
130
+ RichHandler(
131
+ console=console,
132
+ rich_tracebacks=True,
133
+ show_path=False,
134
+ show_time=True, # Enables timestamp in RichHandler output
135
+ )
130
136
  ],
131
137
  )
132
138
 
@@ -107,7 +107,7 @@ class Monitor:
107
107
  """Main display update loop."""
108
108
  layout = self._create_layout()
109
109
 
110
- with Live(layout, console=self.console, refresh_per_second=4, screen=True) as live:
110
+ with Live(layout, console=self.console, refresh_per_second=1, screen=True) as live:
111
111
  while self.running:
112
112
  self._update_layout(layout)
113
113
  await asyncio.sleep(0.25)
@@ -363,6 +363,7 @@ class Orchestrator:
363
363
  self.ssl_context = self._setup_ssl()
364
364
 
365
365
  # Statistics
366
+ self.is_generating_stats = False
366
367
  self.stats = {
367
368
  "total_chunks": 0,
368
369
  "completed_chunks": 0,
@@ -1409,28 +1410,36 @@ class Orchestrator:
1409
1410
  finally:
1410
1411
  del self.data_workers[worker_id]
1411
1412
 
1412
- async def _handle_monitor(self, websocket: WebSocketServerProtocol):
1413
- """Handle monitor connection."""
1414
- self.monitors.add(websocket)
1415
- logger.info("Monitor connected")
1416
-
1413
+ async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
1414
+ """Send leaderboard data to a specific monitor."""
1415
+ total_start = time.time()
1417
1416
  try:
1418
- # Send initial stats
1419
- await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
1420
-
1421
- # Send chunk stats
1422
- chunk_stats = self.chunk_manager.get_stats()
1423
- await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
1417
+ if websocket not in self.monitors:
1418
+ return
1424
1419
 
1425
- # Send contributor leaderboard with active worker counts
1420
+ # Get contributors asynchronously
1421
+ contributors_start = time.time()
1426
1422
  contributors = await self.storage.get_top_contributors(10)
1423
+ logger.debug(
1424
+ f"Contributors retrieved in {(time.time() - contributors_start)*1000:.1f}ms"
1425
+ )
1427
1426
 
1428
- # Enhance contributor data with active worker counts
1429
- enhanced_contributors = []
1430
- worker_counts = (
1431
- self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
1427
+ # Get worker counts in thread pool
1428
+ worker_counts_start = time.time()
1429
+ loop = asyncio.get_event_loop()
1430
+ worker_counts = await loop.run_in_executor(
1431
+ None,
1432
+ lambda: (
1433
+ self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
1434
+ ),
1435
+ )
1436
+ logger.debug(
1437
+ f"Worker counts retrieved in {(time.time() - worker_counts_start)*1000:.1f}ms"
1432
1438
  )
1433
1439
 
1440
+ # Build enhanced contributors list
1441
+ build_start = time.time()
1442
+ enhanced_contributors = []
1434
1443
  for contributor in contributors:
1435
1444
  contrib_dict = {
1436
1445
  "contributor_id": contributor.contributor_id,
@@ -1442,40 +1451,157 @@ class Orchestrator:
1442
1451
  ),
1443
1452
  }
1444
1453
  enhanced_contributors.append(contrib_dict)
1454
+ logger.debug(f"Enhanced contributors built in {(time.time() - build_start)*1000:.1f}ms")
1445
1455
 
1446
- await websocket.send(
1447
- safe_json_dumps({"type": "leaderboard", "data": enhanced_contributors})
1456
+ # Cache for future monitors
1457
+ self._cached_leaderboard = enhanced_contributors
1458
+
1459
+ # Send if still connected
1460
+ if websocket in self.monitors:
1461
+ send_start = time.time()
1462
+ await websocket.send(
1463
+ safe_json_dumps({"type": "leaderboard", "data": enhanced_contributors})
1464
+ )
1465
+ logger.debug(
1466
+ f"Leaderboard sent to monitor in {(time.time() - send_start)*1000:.1f}ms"
1467
+ )
1468
+
1469
+ logger.debug(
1470
+ f"Leaderboard send to monitor completed in {(time.time() - total_start)*1000:.1f}ms"
1471
+ )
1472
+
1473
+ except websockets.exceptions.ConnectionClosed:
1474
+ logger.debug("Monitor disconnected during leaderboard send")
1475
+ except Exception as e:
1476
+ logger.error(f"Error sending leaderboard to monitor: {e}")
1477
+
1478
+ async def _send_initial_monitor_data(self, websocket: WebSocketServerProtocol):
1479
+ """Send initial data to monitor in a separate task to avoid blocking."""
1480
+ total_start = time.time()
1481
+ try:
1482
+ # Check if websocket is still in monitors set
1483
+ if websocket not in self.monitors:
1484
+ logger.debug("Monitor disconnected before initial data send")
1485
+ return
1486
+
1487
+ # Send current stats (already in memory)
1488
+ stats_start = time.time()
1489
+ await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
1490
+ logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
1491
+
1492
+ # Get chunk stats asynchronously
1493
+ chunk_stats_start = time.time()
1494
+ loop = asyncio.get_event_loop()
1495
+ chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1496
+ logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
1497
+
1498
+ if websocket not in self.monitors:
1499
+ return
1500
+
1501
+ chunk_send_start = time.time()
1502
+ await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
1503
+ logger.debug(f"Chunk stats sent in {(time.time() - chunk_send_start)*1000:.1f}ms")
1504
+
1505
+ # For leaderboard, check if we have a cached version first
1506
+ if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
1507
+ # Use cached leaderboard if available
1508
+ cache_send_start = time.time()
1509
+ await websocket.send(
1510
+ safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
1511
+ )
1512
+ logger.debug(
1513
+ f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
1514
+ )
1515
+ else:
1516
+ # Schedule leaderboard update separately
1517
+ leaderboard_task_start = time.time()
1518
+ asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
1519
+ logger.debug(
1520
+ f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
1521
+ )
1522
+
1523
+ logger.debug(
1524
+ f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
1525
+ )
1526
+
1527
+ except websockets.exceptions.ConnectionClosed:
1528
+ logger.debug("Monitor disconnected during initial data send")
1529
+ except Exception as e:
1530
+ logger.error(f"Error sending initial monitor data: {e}")
1531
+
1532
+ async def _handle_monitor(self, websocket: WebSocketServerProtocol):
1533
+ """Handle monitor connection - truly non-blocking version."""
1534
+ monitor_start = time.time()
1535
+ self.monitors.add(websocket)
1536
+ logger.info(f"Monitor connected (total monitors: {len(self.monitors)})")
1537
+
1538
+ try:
1539
+ # Send welcome message immediately
1540
+ welcome_start = time.time()
1541
+ await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
1542
+ logger.debug(f"Monitor welcome sent in {(time.time() - welcome_start)*1000:.1f}ms")
1543
+
1544
+ # Schedule initial data send as a separate task to avoid blocking
1545
+ task_create_start = time.time()
1546
+ asyncio.create_task(self._send_initial_monitor_data(websocket))
1547
+ logger.debug(
1548
+ f"Monitor initial data task created in {(time.time() - task_create_start)*1000:.1f}ms"
1448
1549
  )
1449
1550
 
1450
- # Keep connection alive
1451
- async for _ in websocket:
1452
- pass
1551
+ # Just keep the connection alive - no blocking work here
1552
+ try:
1553
+ async for message in websocket:
1554
+ # Handle any incoming messages from monitor if needed
1555
+ # For now, just ignore them
1556
+ pass
1557
+ except websockets.exceptions.ConnectionClosed:
1558
+ pass # Normal disconnection
1453
1559
 
1454
1560
  except websockets.exceptions.ConnectionClosed:
1455
1561
  logger.info("Monitor disconnected")
1562
+ except Exception as e:
1563
+ logger.error(f"Error in monitor handler: {e}")
1456
1564
  finally:
1457
1565
  self.monitors.discard(websocket)
1566
+ logger.debug(f"Monitor handler completed in {(time.time() - monitor_start)*1000:.1f}ms")
1458
1567
 
1459
1568
  async def _broadcast_stats(self):
1460
- """Broadcast statistics to all monitors - enhanced for multi-stage."""
1569
+ """Broadcast statistics to all monitors - truly non-blocking version."""
1461
1570
  if not self.monitors:
1462
1571
  return
1572
+ if self.is_generating_stats:
1573
+ return # Already generating stats, skip this call
1574
+ self.is_generating_stats = True
1575
+ total_start = time.time()
1576
+
1577
+ # Prepare all the data first
1578
+ data_prep_start = time.time()
1579
+ loop = asyncio.get_event_loop()
1463
1580
 
1464
- # Get storage stats
1581
+ # Get storage stats (already async)
1582
+ storage_stats_start = time.time()
1465
1583
  storage_stats = await self.storage.get_storage_stats()
1584
+ logger.debug(f"Storage stats retrieved in {(time.time() - storage_stats_start)*1000:.1f}ms")
1585
+
1586
+ caption_stats_start = time.time()
1466
1587
  caption_stats = await self.storage.get_caption_stats()
1588
+ logger.debug(f"Caption stats retrieved in {(time.time() - caption_stats_start)*1000:.1f}ms")
1467
1589
 
1468
- # Include chunk stats
1469
- chunk_stats = self.chunk_manager.get_stats()
1470
- self.stats.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
1590
+ # Get chunk stats in thread pool
1591
+ chunk_stats_start = time.time()
1592
+ chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1593
+ logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
1471
1594
 
1472
- # Merge storage stats
1473
- self.stats.update(storage_stats)
1474
- self.stats["field_breakdown"] = caption_stats.get("field_stats", {})
1475
- self.stats["output_fields_list"] = caption_stats.get("output_fields", [])
1595
+ # Build stats dict
1596
+ build_stats_start = time.time()
1597
+ stats_update = self.stats.copy()
1598
+ stats_update.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
1599
+ stats_update.update(storage_stats)
1600
+ stats_update["field_breakdown"] = caption_stats.get("field_stats", {})
1601
+ stats_update["output_fields_list"] = caption_stats.get("output_fields", [])
1476
1602
 
1477
1603
  # Add rate information
1478
- self.stats.update(
1604
+ stats_update.update(
1479
1605
  {
1480
1606
  "current_rate": self.rate_tracker["current_rate"],
1481
1607
  "average_rate": self.rate_tracker["average_rate"],
@@ -1483,41 +1609,106 @@ class Orchestrator:
1483
1609
  }
1484
1610
  )
1485
1611
 
1486
- # Add vLLM info - now includes stage count
1487
- self.stats["vllm_model"] = self.vllm_config.get("model", "unknown")
1488
- self.stats["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
1612
+ # Add vLLM info
1613
+ stats_update["vllm_model"] = self.vllm_config.get("model", "unknown")
1614
+ stats_update["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
1489
1615
 
1490
- # NEW: Add stage information
1616
+ # Add stage information
1491
1617
  stages = self.vllm_config.get("stages", [])
1492
1618
  if stages:
1493
- self.stats["stage_count"] = len(stages)
1494
- self.stats["stage_names"] = [s.get("name", "unnamed") for s in stages]
1619
+ stats_update["stage_count"] = len(stages)
1620
+ stats_update["stage_names"] = [s.get("name", "unnamed") for s in stages]
1495
1621
  else:
1496
- self.stats["stage_count"] = 1 # Backward compatibility
1497
- self.stats["stage_names"] = ["default"]
1622
+ stats_update["stage_count"] = 1
1623
+ stats_update["stage_names"] = ["default"]
1498
1624
 
1625
+ # Get field stats
1626
+ field_stats_start = time.time()
1499
1627
  field_stats = await self.storage.get_output_field_stats()
1500
- self.stats["output_fields"] = field_stats
1628
+ stats_update["output_fields"] = field_stats
1629
+ logger.debug(f"Field stats retrieved in {(time.time() - field_stats_start)*1000:.1f}ms")
1501
1630
 
1502
- message = safe_json_dumps({"type": "stats", "data": self.stats})
1631
+ # Update our internal stats
1632
+ self.stats = stats_update
1633
+ logger.debug(f"Stats prepared in {(time.time() - build_stats_start)*1000:.1f}ms")
1503
1634
 
1504
- # Send to all monitors
1505
- disconnected = set()
1506
- _monitors = self.monitors.copy()
1507
- for monitor in _monitors:
1635
+ logger.debug(f"Total data preparation took {(time.time() - data_prep_start)*1000:.1f}ms")
1636
+
1637
+ # Create message once
1638
+ message_create_start = time.time()
1639
+ stats_message = safe_json_dumps({"type": "stats", "data": self.stats})
1640
+ logger.debug(f"Stats message created in {(time.time() - message_create_start)*1000:.1f}ms")
1641
+
1642
+ # Send to all monitors asynchronously in parallel
1643
+ send_start = time.time()
1644
+
1645
+ async def send_to_monitor(monitor):
1508
1646
  try:
1509
- await monitor.send(message)
1647
+ await monitor.send(stats_message)
1510
1648
  except websockets.exceptions.ConnectionClosed:
1511
- disconnected.add(monitor)
1649
+ return monitor # Return for removal
1650
+ except Exception as e:
1651
+ logger.debug(f"Error sending stats to monitor: {e}")
1652
+ return monitor # Return for removal
1653
+ return None
1512
1654
 
1513
- # send updated leaderboard
1655
+ # Send to all monitors in parallel
1656
+ monitors_copy = self.monitors.copy()
1657
+ results = await asyncio.gather(
1658
+ *[send_to_monitor(m) for m in monitors_copy], return_exceptions=True
1659
+ )
1660
+
1661
+ # Remove disconnected monitors
1662
+ disconnected = {
1663
+ m
1664
+ for m, r in zip(monitors_copy, results)
1665
+ if r is not None and not isinstance(r, Exception)
1666
+ }
1667
+ self.monitors -= disconnected
1668
+
1669
+ logger.debug(
1670
+ f"Stats sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
1671
+ )
1672
+
1673
+ # Send leaderboard update in a separate task to avoid blocking
1674
+ leaderboard_task_start = time.time()
1675
+ asyncio.create_task(self._broadcast_leaderboard())
1676
+ self.is_generating_stats = False
1677
+ logger.debug(
1678
+ f"Leaderboard broadcast task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
1679
+ )
1680
+ logger.debug(f"Stats broadcast completed in {(time.time() - total_start)*1000:.1f}ms")
1681
+
1682
+ async def _broadcast_leaderboard(self):
1683
+ """Send leaderboard updates to monitors - separate from stats to avoid blocking."""
1684
+ if not self.monitors:
1685
+ return
1686
+
1687
+ total_start = time.time()
1514
1688
  try:
1689
+ # Get contributors
1690
+ contributors_start = time.time()
1515
1691
  contributors = await self.storage.get_top_contributors(10)
1516
- enhanced_contributors = []
1517
- worker_counts = (
1518
- self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
1692
+ logger.debug(
1693
+ f"Contributors retrieved for broadcast in {(time.time() - contributors_start)*1000:.1f}ms"
1694
+ )
1695
+
1696
+ # Get worker counts
1697
+ worker_counts_start = time.time()
1698
+ loop = asyncio.get_event_loop()
1699
+ worker_counts = await loop.run_in_executor(
1700
+ None,
1701
+ lambda: (
1702
+ self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
1703
+ ),
1704
+ )
1705
+ logger.debug(
1706
+ f"Worker counts retrieved for broadcast in {(time.time() - worker_counts_start)*1000:.1f}ms"
1519
1707
  )
1520
1708
 
1709
+ # Build enhanced contributors list
1710
+ build_start = time.time()
1711
+ enhanced_contributors = []
1521
1712
  for contributor in contributors:
1522
1713
  contrib_dict = {
1523
1714
  "contributor_id": contributor.contributor_id,
@@ -1529,26 +1720,64 @@ class Orchestrator:
1529
1720
  ),
1530
1721
  }
1531
1722
  enhanced_contributors.append(contrib_dict)
1723
+ logger.debug(
1724
+ f"Enhanced contributors built for broadcast in {(time.time() - build_start)*1000:.1f}ms"
1725
+ )
1726
+
1727
+ # Cache it
1728
+ self._cached_leaderboard = enhanced_contributors
1532
1729
 
1730
+ # Create message once
1731
+ message_create_start = time.time()
1533
1732
  leaderboard_message = safe_json_dumps(
1534
1733
  {"type": "leaderboard", "data": enhanced_contributors}
1535
1734
  )
1735
+ logger.debug(
1736
+ f"Leaderboard message created in {(time.time() - message_create_start)*1000:.1f}ms"
1737
+ )
1536
1738
 
1537
- # Send to all monitors
1538
- disconnected = set()
1539
- for monitor in self.monitors.copy():
1739
+ # Send to all monitors in parallel
1740
+ send_start = time.time()
1741
+
1742
+ async def send_leaderboard(monitor):
1540
1743
  try:
1541
1744
  await monitor.send(leaderboard_message)
1542
- except websockets.exceptions.ConnectionClosed:
1543
- disconnected.add(monitor)
1745
+ except:
1746
+ return monitor # Mark for removal
1747
+ return None
1748
+
1749
+ monitors_copy = self.monitors.copy()
1750
+ results = await asyncio.gather(
1751
+ *[send_leaderboard(m) for m in monitors_copy], return_exceptions=True
1752
+ )
1544
1753
 
1754
+ # Remove disconnected
1755
+ disconnected = {
1756
+ m
1757
+ for m, r in zip(monitors_copy, results)
1758
+ if r is not None and not isinstance(r, Exception)
1759
+ }
1545
1760
  self.monitors -= disconnected
1546
1761
 
1547
- except Exception as e:
1548
- logger.error(f"Error sending leaderboard update: {e}")
1762
+ logger.debug(
1763
+ f"Leaderboard sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
1764
+ )
1765
+ logger.debug(
1766
+ f"Leaderboard broadcast completed in {(time.time() - total_start)*1000:.1f}ms"
1767
+ )
1549
1768
 
1550
- # Clean up disconnected monitors
1551
- self.monitors -= disconnected
1769
+ except Exception as e:
1770
+ logger.error(f"Error broadcasting leaderboard: {e}")
1771
+
1772
+ def _get_queue_stats(self) -> Dict[str, int]:
1773
+ """Get queue statistics - synchronous helper for thread pool."""
1774
+ with self.chunk_manager.lock:
1775
+ return {
1776
+ "pending_chunks": len(self.chunk_manager.pending_chunks),
1777
+ "assigned_chunks": sum(
1778
+ len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
1779
+ ),
1780
+ }
1552
1781
 
1553
1782
  async def _flush_processed_items(self):
1554
1783
  """Flush batched processed items to chunk tracker."""
@@ -1591,12 +1820,14 @@ class Orchestrator:
1591
1820
  self.last_item_batch_flush = time.time()
1592
1821
 
1593
1822
  def get_workers_by_user_stats(self) -> Dict[str, Any]:
1594
- """Get statistics about workers grouped by user/token."""
1823
+ """Get statistics about workers grouped by user/token - thread-safe version."""
1595
1824
  if not hasattr(self, "workers_by_user"):
1596
1825
  return {}
1597
1826
 
1827
+ # Create a copy to avoid issues with concurrent modification
1598
1828
  stats = {}
1599
- for user, worker_ids in self.workers_by_user.items():
1829
+ workers_snapshot = dict(self.workers_by_user)
1830
+ for user, worker_ids in workers_snapshot.items():
1600
1831
  stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
1601
1832
  return stats
1602
1833
 
@@ -1621,21 +1852,63 @@ class Orchestrator:
1621
1852
  async def _heartbeat_loop(self):
1622
1853
  """Send periodic heartbeats to maintain connections."""
1623
1854
  while True:
1624
- await asyncio.sleep(30)
1855
+ try:
1856
+ await asyncio.sleep(30)
1625
1857
 
1626
- # Ping workers
1627
- disconnected = []
1628
- for worker_id, ws in self.workers.items():
1629
- try:
1630
- await ws.ping()
1631
- except:
1632
- disconnected.append(worker_id)
1858
+ # Create a copy of worker items to avoid modification during iteration
1859
+ worker_items = list(self.workers.items())
1860
+ disconnected = []
1633
1861
 
1634
- # Clean up disconnected workers
1635
- for worker_id in disconnected:
1636
- if worker_id in self.workers:
1637
- del self.workers[worker_id]
1638
- self.chunk_manager.release_worker_chunks(worker_id)
1862
+ for worker_id, ws in worker_items:
1863
+ try:
1864
+ # Check if worker still exists before pinging
1865
+ if worker_id not in self.workers:
1866
+ continue
1867
+
1868
+ # Send ping with timeout
1869
+ pong_waiter = await ws.ping()
1870
+ try:
1871
+ await asyncio.wait_for(pong_waiter, timeout=10)
1872
+ except asyncio.TimeoutError:
1873
+ logger.warning(f"Worker {worker_id} failed to respond to ping")
1874
+ disconnected.append(worker_id)
1875
+ except websockets.exceptions.ConnectionClosed:
1876
+ logger.info(f"Worker {worker_id} connection already closed")
1877
+ disconnected.append(worker_id)
1878
+ except Exception as e:
1879
+ logger.error(f"Error pinging worker {worker_id}: {e}")
1880
+ disconnected.append(worker_id)
1881
+
1882
+ # Clean up disconnected workers
1883
+ for worker_id in disconnected:
1884
+ if worker_id in self.workers:
1885
+ logger.info(f"Removing unresponsive worker {worker_id}")
1886
+ del self.workers[worker_id]
1887
+ self.chunk_manager.release_worker_chunks(worker_id)
1888
+
1889
+ # Update stats
1890
+ self.stats["connected_workers"] = len(self.workers)
1891
+
1892
+ # Also clean up from workers_by_user if it exists
1893
+ if hasattr(self, "workers_by_user"):
1894
+ worker_user = (
1895
+ worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
1896
+ )
1897
+ if worker_user in self.workers_by_user:
1898
+ self.workers_by_user[worker_user].discard(worker_id)
1899
+ if not self.workers_by_user[worker_user]:
1900
+ del self.workers_by_user[worker_user]
1901
+
1902
+ # Notify monitors
1903
+ await self._broadcast_stats()
1904
+ await self._send_activity(
1905
+ f"Worker {worker_id} removed due to heartbeat timeout"
1906
+ )
1907
+
1908
+ except Exception as e:
1909
+ logger.error(f"Error in heartbeat loop: {e}", exc_info=True)
1910
+ # Continue the loop even if there's an error
1911
+ await asyncio.sleep(5)
1639
1912
 
1640
1913
  async def _checkpoint_loop(self):
1641
1914
  """Periodically checkpoint storage."""
@@ -1663,7 +1936,10 @@ class Orchestrator:
1663
1936
  )
1664
1937
 
1665
1938
  async def _stats_update_loop(self):
1666
- """Periodically update and broadcast stats."""
1939
+ """Periodically update and broadcast stats - non-blocking version."""
1940
+ # Get the event loop for running blocking operations
1941
+ loop = asyncio.get_event_loop()
1942
+
1667
1943
  # Track session start values
1668
1944
  storage_stats = await self.storage.get_storage_stats()
1669
1945
  session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
@@ -1675,8 +1951,8 @@ class Orchestrator:
1675
1951
  while True:
1676
1952
  await asyncio.sleep(10)
1677
1953
 
1678
- # Update chunk stats
1679
- chunk_stats = self.chunk_manager.get_stats()
1954
+ # Update chunk stats in thread pool to avoid blocking
1955
+ chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
1680
1956
  storage_stats = await self.storage.get_storage_stats()
1681
1957
  current_total_outputs = storage_stats["total_captions"] # ALL outputs
1682
1958
  if self.chunk_tracker:
@@ -1690,12 +1966,9 @@ class Orchestrator:
1690
1966
  self.stats["total_outputs"] = current_total_outputs
1691
1967
  self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
1692
1968
 
1693
- # Add queue information
1694
- with self.chunk_manager.lock:
1695
- self.stats["pending_chunks"] = len(self.chunk_manager.pending_chunks)
1696
- self.stats["assigned_chunks"] = sum(
1697
- len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
1698
- )
1969
+ # Get queue stats in thread pool to avoid blocking
1970
+ queue_stats = await loop.run_in_executor(None, self._get_queue_stats)
1971
+ self.stats.update(queue_stats)
1699
1972
 
1700
1973
  # Calculate if we need more chunks
1701
1974
  worker_count = self.stats.get("connected_workers", 0)
@@ -234,6 +234,54 @@ class DatasetLoader:
234
234
  for key, url, image_data in ds:
235
235
  yield key, url, image_data
236
236
 
237
+ def _create_dataset_at_position(self, dataset_path: str, split: str, start_idx: int):
238
+ """Create a dataset iterator positioned at start_idx using state_dict if available."""
239
+ try:
240
+ # Load dataset in streaming mode
241
+ dataset = load_dataset(
242
+ dataset_path,
243
+ split=split,
244
+ streaming=True,
245
+ token=self.token,
246
+ )
247
+
248
+ # Check if the dataset supports state_dict (newer versions of datasets library)
249
+ if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
250
+ # Try to use the dataset's native state management
251
+ try:
252
+ # Get current state
253
+ state = dataset.state_dict()
254
+
255
+ # Modify the state to skip to start_idx
256
+ if "epoch" in state:
257
+ state["epoch"] = 0
258
+ if "num_examples_since_previous_state" in state:
259
+ state["num_examples_since_previous_state"] = start_idx
260
+
261
+ # For newer datasets with examples_iterable state
262
+ if "examples_iterable" in state:
263
+ if isinstance(state["examples_iterable"], dict):
264
+ if "shard_example_idx" in state["examples_iterable"]:
265
+ state["examples_iterable"]["shard_example_idx"] = start_idx
266
+
267
+ # Load the modified state
268
+ dataset.load_state_dict(state)
269
+ logger.info(f"Positioned dataset at index {start_idx} using state_dict")
270
+ return dataset
271
+ except Exception as e:
272
+ logger.debug(f"Could not use state_dict approach: {e}")
273
+
274
+ # Fall back to skip() for large skips
275
+ if start_idx > 0:
276
+ logger.info(f"Using skip() to position dataset at index {start_idx}")
277
+ dataset = dataset.skip(start_idx)
278
+
279
+ return dataset
280
+
281
+ except Exception as e:
282
+ logger.warning(f"Error creating positioned dataset: {e}")
283
+ return None
284
+
237
285
  def _iterate_hf_dataset_shard_with_metadata(
238
286
  self, shard_url: str, processed_keys: Optional[set] = None
239
287
  ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
@@ -248,7 +296,65 @@ class DatasetLoader:
248
296
  )
249
297
 
250
298
  try:
251
- # Load dataset in streaming mode
299
+ # Try optimized approach for large skips
300
+ if start_idx > 100:
301
+ dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
302
+ if dataset:
303
+ items_processed = 0
304
+
305
+ for item in dataset:
306
+ # Stop after processing chunk_size items
307
+ if items_processed >= chunk_size:
308
+ break
309
+
310
+ # Generate a unique key for this item
311
+ key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
312
+
313
+ if key in processed_keys:
314
+ items_processed += 1
315
+ continue
316
+
317
+ try:
318
+ # Extract image data
319
+ if self.image_column in item:
320
+ img_data = item[self.image_column]
321
+
322
+ # Process image to bytes
323
+ image_bytes = ImageProcessor.process_image_data(img_data)
324
+
325
+ if image_bytes:
326
+ # Extract all metadata (excluding the image column)
327
+ metadata = {
328
+ k: v for k, v in item.items() if k != self.image_column
329
+ }
330
+
331
+ # URL is virtual for HF datasets
332
+ url = f"hf://{dataset_path}#{start_idx + items_processed}"
333
+ items_processed += 1
334
+ yield key, url, image_bytes, metadata
335
+ else:
336
+ logger.warning(
337
+ f"Failed to process image for item at index {start_idx + items_processed}"
338
+ )
339
+ items_processed += 1
340
+ continue
341
+ else:
342
+ logger.warning(
343
+ f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
344
+ f"Available columns: {list(item.keys())}"
345
+ )
346
+ items_processed += 1
347
+
348
+ except Exception as e:
349
+ logger.error(
350
+ f"Error processing item at index {start_idx + items_processed}: {e}"
351
+ )
352
+ items_processed += 1
353
+ continue
354
+
355
+ return
356
+
357
+ # Fall back to regular approach for small skips or if StatefulDataLoader not available
252
358
  dataset = load_dataset(
253
359
  dataset_path,
254
360
  split=self.split,
@@ -256,7 +362,7 @@ class DatasetLoader:
256
362
  token=self.token,
257
363
  )
258
364
 
259
- # Skip to start index if needed - CONSISTENT WITH OTHER METHOD
365
+ # Skip to start index if needed
260
366
  if start_idx > 0:
261
367
  dataset = dataset.skip(start_idx)
262
368
 
@@ -267,7 +373,7 @@ class DatasetLoader:
267
373
  if items_processed >= chunk_size:
268
374
  break
269
375
 
270
- # Generate a unique key for this item - CONSISTENT FORMAT
376
+ # Generate a unique key for this item
271
377
  key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
272
378
 
273
379
  if key in processed_keys:
@@ -337,7 +443,76 @@ class DatasetLoader:
337
443
  )
338
444
 
339
445
  try:
340
- # Load dataset in streaming mode
446
+ # Try optimized approach for large skips
447
+ if start_idx > 100:
448
+ dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
449
+ if dataset:
450
+ items_processed = 0
451
+
452
+ for item in dataset:
453
+ # Stop after processing chunk_size items
454
+ if items_processed >= chunk_size:
455
+ logger.info(f"Completed chunk: processed {items_processed} items")
456
+ break
457
+
458
+ # Also stop if we've reached the dataset end
459
+ if (
460
+ self._hf_total_items
461
+ and (start_idx + items_processed) >= self._hf_total_items
462
+ ):
463
+ logger.info(
464
+ f"Reached dataset end at item {start_idx + items_processed} "
465
+ f"(total: {self._hf_total_items})"
466
+ )
467
+ break
468
+
469
+ # Generate a unique key for this item
470
+ key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
471
+
472
+ if key in processed_keys:
473
+ items_processed += 1
474
+ continue
475
+
476
+ try:
477
+ # Extract image data
478
+ if self.image_column in item:
479
+ img_data = item[self.image_column]
480
+
481
+ # Delegate image processing to ImageProcessor
482
+ image_bytes = ImageProcessor.process_image_data(img_data)
483
+
484
+ if image_bytes:
485
+ # URL is virtual for HF datasets
486
+ url = f"hf://{dataset_path}#{start_idx + items_processed}"
487
+ items_processed += 1
488
+ yield key, url, image_bytes
489
+ else:
490
+ logger.warning(
491
+ f"Failed to process image for item at index {start_idx + items_processed}"
492
+ )
493
+ items_processed += 1
494
+ continue
495
+ else:
496
+ logger.warning(
497
+ f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
498
+ f"Available columns: {list(item.keys())}"
499
+ )
500
+ items_processed += 1
501
+
502
+ except Exception as e:
503
+ logger.error(
504
+ f"Error processing item at index {start_idx + items_processed}: {e}"
505
+ )
506
+ items_processed += 1
507
+ continue
508
+
509
+ logger.info(
510
+ f"Virtual shard complete: processed {items_processed} items "
511
+ f"(start_idx: {start_idx})"
512
+ )
513
+ return
514
+
515
+ # Fall back to regular approach for small skips or if StatefulDataLoader not available
341
516
  dataset = load_dataset(
342
517
  dataset_path,
343
518
  split=self.split,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: caption-flow
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Self-contained distributed community captioning system
5
5
  Author-email: bghira <bghira@users.github.com>
6
6
  License: MIT
@@ -32,6 +32,7 @@ Requires-Dist: pandas<3.0.0,>=2.3.1
32
32
  Requires-Dist: arrow<2.0.0,>=1.3.0
33
33
  Requires-Dist: datasets<5.0.0,>=4.0.0
34
34
  Requires-Dist: boto3<2.0.0,>=1.40.11
35
+ Requires-Dist: torchdata<0.12.0,>=0.11.0
35
36
  Provides-Extra: dev
36
37
  Requires-Dist: pytest>=7.4.0; extra == "dev"
37
38
  Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
@@ -15,6 +15,7 @@ pandas<3.0.0,>=2.3.1
15
15
  arrow<2.0.0,>=1.3.0
16
16
  datasets<5.0.0,>=4.0.0
17
17
  boto3<2.0.0,>=1.40.11
18
+ torchdata<0.12.0,>=0.11.0
18
19
 
19
20
  [dev]
20
21
  pytest>=7.4.0
File without changes
File without changes
File without changes