caption-flow 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +9 -3
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +357 -84
- caption_flow/utils/dataset_loader.py +179 -4
- {caption_flow-0.2.0.dist-info → caption_flow-0.2.1.dist-info}/METADATA +2 -1
- {caption_flow-0.2.0.dist-info → caption_flow-0.2.1.dist-info}/RECORD +10 -10
- {caption_flow-0.2.0.dist-info → caption_flow-0.2.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.0.dist-info → caption_flow-0.2.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.0.dist-info → caption_flow-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.0.dist-info → caption_flow-0.2.1.dist-info}/top_level.txt +0 -0
caption_flow/cli.py
CHANGED
@@ -120,13 +120,19 @@ class ConfigManager:
|
|
120
120
|
|
121
121
|
|
122
122
|
def setup_logging(verbose: bool = False):
|
123
|
-
"""Configure logging with rich handler."""
|
123
|
+
"""Configure logging with rich handler, including timestamp."""
|
124
124
|
level = logging.DEBUG if verbose else logging.INFO
|
125
125
|
logging.basicConfig(
|
126
126
|
level=level,
|
127
|
-
format="%(message)s",
|
127
|
+
format="%(asctime)s %(message)s",
|
128
|
+
datefmt="[%Y-%m-%d %H:%M:%S]",
|
128
129
|
handlers=[
|
129
|
-
RichHandler(
|
130
|
+
RichHandler(
|
131
|
+
console=console,
|
132
|
+
rich_tracebacks=True,
|
133
|
+
show_path=False,
|
134
|
+
show_time=True, # Enables timestamp in RichHandler output
|
135
|
+
)
|
130
136
|
],
|
131
137
|
)
|
132
138
|
|
caption_flow/monitor.py
CHANGED
@@ -107,7 +107,7 @@ class Monitor:
|
|
107
107
|
"""Main display update loop."""
|
108
108
|
layout = self._create_layout()
|
109
109
|
|
110
|
-
with Live(layout, console=self.console, refresh_per_second=
|
110
|
+
with Live(layout, console=self.console, refresh_per_second=1, screen=True) as live:
|
111
111
|
while self.running:
|
112
112
|
self._update_layout(layout)
|
113
113
|
await asyncio.sleep(0.25)
|
caption_flow/orchestrator.py
CHANGED
@@ -363,6 +363,7 @@ class Orchestrator:
|
|
363
363
|
self.ssl_context = self._setup_ssl()
|
364
364
|
|
365
365
|
# Statistics
|
366
|
+
self.is_generating_stats = False
|
366
367
|
self.stats = {
|
367
368
|
"total_chunks": 0,
|
368
369
|
"completed_chunks": 0,
|
@@ -1409,28 +1410,36 @@ class Orchestrator:
|
|
1409
1410
|
finally:
|
1410
1411
|
del self.data_workers[worker_id]
|
1411
1412
|
|
1412
|
-
async def
|
1413
|
-
"""
|
1414
|
-
|
1415
|
-
logger.info("Monitor connected")
|
1416
|
-
|
1413
|
+
async def _send_leaderboard_to_monitor(self, websocket: WebSocketServerProtocol):
|
1414
|
+
"""Send leaderboard data to a specific monitor."""
|
1415
|
+
total_start = time.time()
|
1417
1416
|
try:
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
# Send chunk stats
|
1422
|
-
chunk_stats = self.chunk_manager.get_stats()
|
1423
|
-
await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
|
1417
|
+
if websocket not in self.monitors:
|
1418
|
+
return
|
1424
1419
|
|
1425
|
-
#
|
1420
|
+
# Get contributors asynchronously
|
1421
|
+
contributors_start = time.time()
|
1426
1422
|
contributors = await self.storage.get_top_contributors(10)
|
1423
|
+
logger.debug(
|
1424
|
+
f"Contributors retrieved in {(time.time() - contributors_start)*1000:.1f}ms"
|
1425
|
+
)
|
1427
1426
|
|
1428
|
-
#
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1427
|
+
# Get worker counts in thread pool
|
1428
|
+
worker_counts_start = time.time()
|
1429
|
+
loop = asyncio.get_event_loop()
|
1430
|
+
worker_counts = await loop.run_in_executor(
|
1431
|
+
None,
|
1432
|
+
lambda: (
|
1433
|
+
self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
|
1434
|
+
),
|
1435
|
+
)
|
1436
|
+
logger.debug(
|
1437
|
+
f"Worker counts retrieved in {(time.time() - worker_counts_start)*1000:.1f}ms"
|
1432
1438
|
)
|
1433
1439
|
|
1440
|
+
# Build enhanced contributors list
|
1441
|
+
build_start = time.time()
|
1442
|
+
enhanced_contributors = []
|
1434
1443
|
for contributor in contributors:
|
1435
1444
|
contrib_dict = {
|
1436
1445
|
"contributor_id": contributor.contributor_id,
|
@@ -1442,40 +1451,157 @@ class Orchestrator:
|
|
1442
1451
|
),
|
1443
1452
|
}
|
1444
1453
|
enhanced_contributors.append(contrib_dict)
|
1454
|
+
logger.debug(f"Enhanced contributors built in {(time.time() - build_start)*1000:.1f}ms")
|
1445
1455
|
|
1446
|
-
|
1447
|
-
|
1456
|
+
# Cache for future monitors
|
1457
|
+
self._cached_leaderboard = enhanced_contributors
|
1458
|
+
|
1459
|
+
# Send if still connected
|
1460
|
+
if websocket in self.monitors:
|
1461
|
+
send_start = time.time()
|
1462
|
+
await websocket.send(
|
1463
|
+
safe_json_dumps({"type": "leaderboard", "data": enhanced_contributors})
|
1464
|
+
)
|
1465
|
+
logger.debug(
|
1466
|
+
f"Leaderboard sent to monitor in {(time.time() - send_start)*1000:.1f}ms"
|
1467
|
+
)
|
1468
|
+
|
1469
|
+
logger.debug(
|
1470
|
+
f"Leaderboard send to monitor completed in {(time.time() - total_start)*1000:.1f}ms"
|
1471
|
+
)
|
1472
|
+
|
1473
|
+
except websockets.exceptions.ConnectionClosed:
|
1474
|
+
logger.debug("Monitor disconnected during leaderboard send")
|
1475
|
+
except Exception as e:
|
1476
|
+
logger.error(f"Error sending leaderboard to monitor: {e}")
|
1477
|
+
|
1478
|
+
async def _send_initial_monitor_data(self, websocket: WebSocketServerProtocol):
|
1479
|
+
"""Send initial data to monitor in a separate task to avoid blocking."""
|
1480
|
+
total_start = time.time()
|
1481
|
+
try:
|
1482
|
+
# Check if websocket is still in monitors set
|
1483
|
+
if websocket not in self.monitors:
|
1484
|
+
logger.debug("Monitor disconnected before initial data send")
|
1485
|
+
return
|
1486
|
+
|
1487
|
+
# Send current stats (already in memory)
|
1488
|
+
stats_start = time.time()
|
1489
|
+
await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
|
1490
|
+
logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
|
1491
|
+
|
1492
|
+
# Get chunk stats asynchronously
|
1493
|
+
chunk_stats_start = time.time()
|
1494
|
+
loop = asyncio.get_event_loop()
|
1495
|
+
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1496
|
+
logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
|
1497
|
+
|
1498
|
+
if websocket not in self.monitors:
|
1499
|
+
return
|
1500
|
+
|
1501
|
+
chunk_send_start = time.time()
|
1502
|
+
await websocket.send(safe_json_dumps({"type": "chunk_stats", "data": chunk_stats}))
|
1503
|
+
logger.debug(f"Chunk stats sent in {(time.time() - chunk_send_start)*1000:.1f}ms")
|
1504
|
+
|
1505
|
+
# For leaderboard, check if we have a cached version first
|
1506
|
+
if hasattr(self, "_cached_leaderboard") and self._cached_leaderboard:
|
1507
|
+
# Use cached leaderboard if available
|
1508
|
+
cache_send_start = time.time()
|
1509
|
+
await websocket.send(
|
1510
|
+
safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
|
1511
|
+
)
|
1512
|
+
logger.debug(
|
1513
|
+
f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
|
1514
|
+
)
|
1515
|
+
else:
|
1516
|
+
# Schedule leaderboard update separately
|
1517
|
+
leaderboard_task_start = time.time()
|
1518
|
+
asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
|
1519
|
+
logger.debug(
|
1520
|
+
f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
1521
|
+
)
|
1522
|
+
|
1523
|
+
logger.debug(
|
1524
|
+
f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
|
1525
|
+
)
|
1526
|
+
|
1527
|
+
except websockets.exceptions.ConnectionClosed:
|
1528
|
+
logger.debug("Monitor disconnected during initial data send")
|
1529
|
+
except Exception as e:
|
1530
|
+
logger.error(f"Error sending initial monitor data: {e}")
|
1531
|
+
|
1532
|
+
async def _handle_monitor(self, websocket: WebSocketServerProtocol):
|
1533
|
+
"""Handle monitor connection - truly non-blocking version."""
|
1534
|
+
monitor_start = time.time()
|
1535
|
+
self.monitors.add(websocket)
|
1536
|
+
logger.info(f"Monitor connected (total monitors: {len(self.monitors)})")
|
1537
|
+
|
1538
|
+
try:
|
1539
|
+
# Send welcome message immediately
|
1540
|
+
welcome_start = time.time()
|
1541
|
+
await websocket.send(safe_json_dumps({"type": "welcome", "role": "monitor"}))
|
1542
|
+
logger.debug(f"Monitor welcome sent in {(time.time() - welcome_start)*1000:.1f}ms")
|
1543
|
+
|
1544
|
+
# Schedule initial data send as a separate task to avoid blocking
|
1545
|
+
task_create_start = time.time()
|
1546
|
+
asyncio.create_task(self._send_initial_monitor_data(websocket))
|
1547
|
+
logger.debug(
|
1548
|
+
f"Monitor initial data task created in {(time.time() - task_create_start)*1000:.1f}ms"
|
1448
1549
|
)
|
1449
1550
|
|
1450
|
-
#
|
1451
|
-
|
1452
|
-
|
1551
|
+
# Just keep the connection alive - no blocking work here
|
1552
|
+
try:
|
1553
|
+
async for message in websocket:
|
1554
|
+
# Handle any incoming messages from monitor if needed
|
1555
|
+
# For now, just ignore them
|
1556
|
+
pass
|
1557
|
+
except websockets.exceptions.ConnectionClosed:
|
1558
|
+
pass # Normal disconnection
|
1453
1559
|
|
1454
1560
|
except websockets.exceptions.ConnectionClosed:
|
1455
1561
|
logger.info("Monitor disconnected")
|
1562
|
+
except Exception as e:
|
1563
|
+
logger.error(f"Error in monitor handler: {e}")
|
1456
1564
|
finally:
|
1457
1565
|
self.monitors.discard(websocket)
|
1566
|
+
logger.debug(f"Monitor handler completed in {(time.time() - monitor_start)*1000:.1f}ms")
|
1458
1567
|
|
1459
1568
|
async def _broadcast_stats(self):
|
1460
|
-
"""Broadcast statistics to all monitors -
|
1569
|
+
"""Broadcast statistics to all monitors - truly non-blocking version."""
|
1461
1570
|
if not self.monitors:
|
1462
1571
|
return
|
1572
|
+
if self.is_generating_stats:
|
1573
|
+
return # Already generating stats, skip this call
|
1574
|
+
self.is_generating_stats = True
|
1575
|
+
total_start = time.time()
|
1576
|
+
|
1577
|
+
# Prepare all the data first
|
1578
|
+
data_prep_start = time.time()
|
1579
|
+
loop = asyncio.get_event_loop()
|
1463
1580
|
|
1464
|
-
# Get storage stats
|
1581
|
+
# Get storage stats (already async)
|
1582
|
+
storage_stats_start = time.time()
|
1465
1583
|
storage_stats = await self.storage.get_storage_stats()
|
1584
|
+
logger.debug(f"Storage stats retrieved in {(time.time() - storage_stats_start)*1000:.1f}ms")
|
1585
|
+
|
1586
|
+
caption_stats_start = time.time()
|
1466
1587
|
caption_stats = await self.storage.get_caption_stats()
|
1588
|
+
logger.debug(f"Caption stats retrieved in {(time.time() - caption_stats_start)*1000:.1f}ms")
|
1467
1589
|
|
1468
|
-
#
|
1469
|
-
|
1470
|
-
|
1590
|
+
# Get chunk stats in thread pool
|
1591
|
+
chunk_stats_start = time.time()
|
1592
|
+
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1593
|
+
logger.debug(f"Chunk stats retrieved in {(time.time() - chunk_stats_start)*1000:.1f}ms")
|
1471
1594
|
|
1472
|
-
#
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1595
|
+
# Build stats dict
|
1596
|
+
build_stats_start = time.time()
|
1597
|
+
stats_update = self.stats.copy()
|
1598
|
+
stats_update.update({f"chunks_{k}": v for k, v in chunk_stats.items()})
|
1599
|
+
stats_update.update(storage_stats)
|
1600
|
+
stats_update["field_breakdown"] = caption_stats.get("field_stats", {})
|
1601
|
+
stats_update["output_fields_list"] = caption_stats.get("output_fields", [])
|
1476
1602
|
|
1477
1603
|
# Add rate information
|
1478
|
-
|
1604
|
+
stats_update.update(
|
1479
1605
|
{
|
1480
1606
|
"current_rate": self.rate_tracker["current_rate"],
|
1481
1607
|
"average_rate": self.rate_tracker["average_rate"],
|
@@ -1483,41 +1609,106 @@ class Orchestrator:
|
|
1483
1609
|
}
|
1484
1610
|
)
|
1485
1611
|
|
1486
|
-
# Add vLLM info
|
1487
|
-
|
1488
|
-
|
1612
|
+
# Add vLLM info
|
1613
|
+
stats_update["vllm_model"] = self.vllm_config.get("model", "unknown")
|
1614
|
+
stats_update["vllm_batch_size"] = self.vllm_config.get("batch_size", 0)
|
1489
1615
|
|
1490
|
-
#
|
1616
|
+
# Add stage information
|
1491
1617
|
stages = self.vllm_config.get("stages", [])
|
1492
1618
|
if stages:
|
1493
|
-
|
1494
|
-
|
1619
|
+
stats_update["stage_count"] = len(stages)
|
1620
|
+
stats_update["stage_names"] = [s.get("name", "unnamed") for s in stages]
|
1495
1621
|
else:
|
1496
|
-
|
1497
|
-
|
1622
|
+
stats_update["stage_count"] = 1
|
1623
|
+
stats_update["stage_names"] = ["default"]
|
1498
1624
|
|
1625
|
+
# Get field stats
|
1626
|
+
field_stats_start = time.time()
|
1499
1627
|
field_stats = await self.storage.get_output_field_stats()
|
1500
|
-
|
1628
|
+
stats_update["output_fields"] = field_stats
|
1629
|
+
logger.debug(f"Field stats retrieved in {(time.time() - field_stats_start)*1000:.1f}ms")
|
1501
1630
|
|
1502
|
-
|
1631
|
+
# Update our internal stats
|
1632
|
+
self.stats = stats_update
|
1633
|
+
logger.debug(f"Stats prepared in {(time.time() - build_stats_start)*1000:.1f}ms")
|
1503
1634
|
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1635
|
+
logger.debug(f"Total data preparation took {(time.time() - data_prep_start)*1000:.1f}ms")
|
1636
|
+
|
1637
|
+
# Create message once
|
1638
|
+
message_create_start = time.time()
|
1639
|
+
stats_message = safe_json_dumps({"type": "stats", "data": self.stats})
|
1640
|
+
logger.debug(f"Stats message created in {(time.time() - message_create_start)*1000:.1f}ms")
|
1641
|
+
|
1642
|
+
# Send to all monitors asynchronously in parallel
|
1643
|
+
send_start = time.time()
|
1644
|
+
|
1645
|
+
async def send_to_monitor(monitor):
|
1508
1646
|
try:
|
1509
|
-
await monitor.send(
|
1647
|
+
await monitor.send(stats_message)
|
1510
1648
|
except websockets.exceptions.ConnectionClosed:
|
1511
|
-
|
1649
|
+
return monitor # Return for removal
|
1650
|
+
except Exception as e:
|
1651
|
+
logger.debug(f"Error sending stats to monitor: {e}")
|
1652
|
+
return monitor # Return for removal
|
1653
|
+
return None
|
1512
1654
|
|
1513
|
-
#
|
1655
|
+
# Send to all monitors in parallel
|
1656
|
+
monitors_copy = self.monitors.copy()
|
1657
|
+
results = await asyncio.gather(
|
1658
|
+
*[send_to_monitor(m) for m in monitors_copy], return_exceptions=True
|
1659
|
+
)
|
1660
|
+
|
1661
|
+
# Remove disconnected monitors
|
1662
|
+
disconnected = {
|
1663
|
+
m
|
1664
|
+
for m, r in zip(monitors_copy, results)
|
1665
|
+
if r is not None and not isinstance(r, Exception)
|
1666
|
+
}
|
1667
|
+
self.monitors -= disconnected
|
1668
|
+
|
1669
|
+
logger.debug(
|
1670
|
+
f"Stats sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
|
1671
|
+
)
|
1672
|
+
|
1673
|
+
# Send leaderboard update in a separate task to avoid blocking
|
1674
|
+
leaderboard_task_start = time.time()
|
1675
|
+
asyncio.create_task(self._broadcast_leaderboard())
|
1676
|
+
self.is_generating_stats = False
|
1677
|
+
logger.debug(
|
1678
|
+
f"Leaderboard broadcast task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
|
1679
|
+
)
|
1680
|
+
logger.debug(f"Stats broadcast completed in {(time.time() - total_start)*1000:.1f}ms")
|
1681
|
+
|
1682
|
+
async def _broadcast_leaderboard(self):
|
1683
|
+
"""Send leaderboard updates to monitors - separate from stats to avoid blocking."""
|
1684
|
+
if not self.monitors:
|
1685
|
+
return
|
1686
|
+
|
1687
|
+
total_start = time.time()
|
1514
1688
|
try:
|
1689
|
+
# Get contributors
|
1690
|
+
contributors_start = time.time()
|
1515
1691
|
contributors = await self.storage.get_top_contributors(10)
|
1516
|
-
|
1517
|
-
|
1518
|
-
|
1692
|
+
logger.debug(
|
1693
|
+
f"Contributors retrieved for broadcast in {(time.time() - contributors_start)*1000:.1f}ms"
|
1694
|
+
)
|
1695
|
+
|
1696
|
+
# Get worker counts
|
1697
|
+
worker_counts_start = time.time()
|
1698
|
+
loop = asyncio.get_event_loop()
|
1699
|
+
worker_counts = await loop.run_in_executor(
|
1700
|
+
None,
|
1701
|
+
lambda: (
|
1702
|
+
self.get_workers_by_user_stats() if hasattr(self, "workers_by_user") else {}
|
1703
|
+
),
|
1704
|
+
)
|
1705
|
+
logger.debug(
|
1706
|
+
f"Worker counts retrieved for broadcast in {(time.time() - worker_counts_start)*1000:.1f}ms"
|
1519
1707
|
)
|
1520
1708
|
|
1709
|
+
# Build enhanced contributors list
|
1710
|
+
build_start = time.time()
|
1711
|
+
enhanced_contributors = []
|
1521
1712
|
for contributor in contributors:
|
1522
1713
|
contrib_dict = {
|
1523
1714
|
"contributor_id": contributor.contributor_id,
|
@@ -1529,26 +1720,64 @@ class Orchestrator:
|
|
1529
1720
|
),
|
1530
1721
|
}
|
1531
1722
|
enhanced_contributors.append(contrib_dict)
|
1723
|
+
logger.debug(
|
1724
|
+
f"Enhanced contributors built for broadcast in {(time.time() - build_start)*1000:.1f}ms"
|
1725
|
+
)
|
1726
|
+
|
1727
|
+
# Cache it
|
1728
|
+
self._cached_leaderboard = enhanced_contributors
|
1532
1729
|
|
1730
|
+
# Create message once
|
1731
|
+
message_create_start = time.time()
|
1533
1732
|
leaderboard_message = safe_json_dumps(
|
1534
1733
|
{"type": "leaderboard", "data": enhanced_contributors}
|
1535
1734
|
)
|
1735
|
+
logger.debug(
|
1736
|
+
f"Leaderboard message created in {(time.time() - message_create_start)*1000:.1f}ms"
|
1737
|
+
)
|
1536
1738
|
|
1537
|
-
# Send to all monitors
|
1538
|
-
|
1539
|
-
|
1739
|
+
# Send to all monitors in parallel
|
1740
|
+
send_start = time.time()
|
1741
|
+
|
1742
|
+
async def send_leaderboard(monitor):
|
1540
1743
|
try:
|
1541
1744
|
await monitor.send(leaderboard_message)
|
1542
|
-
except
|
1543
|
-
|
1745
|
+
except:
|
1746
|
+
return monitor # Mark for removal
|
1747
|
+
return None
|
1748
|
+
|
1749
|
+
monitors_copy = self.monitors.copy()
|
1750
|
+
results = await asyncio.gather(
|
1751
|
+
*[send_leaderboard(m) for m in monitors_copy], return_exceptions=True
|
1752
|
+
)
|
1544
1753
|
|
1754
|
+
# Remove disconnected
|
1755
|
+
disconnected = {
|
1756
|
+
m
|
1757
|
+
for m, r in zip(monitors_copy, results)
|
1758
|
+
if r is not None and not isinstance(r, Exception)
|
1759
|
+
}
|
1545
1760
|
self.monitors -= disconnected
|
1546
1761
|
|
1547
|
-
|
1548
|
-
|
1762
|
+
logger.debug(
|
1763
|
+
f"Leaderboard sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
|
1764
|
+
)
|
1765
|
+
logger.debug(
|
1766
|
+
f"Leaderboard broadcast completed in {(time.time() - total_start)*1000:.1f}ms"
|
1767
|
+
)
|
1549
1768
|
|
1550
|
-
|
1551
|
-
|
1769
|
+
except Exception as e:
|
1770
|
+
logger.error(f"Error broadcasting leaderboard: {e}")
|
1771
|
+
|
1772
|
+
def _get_queue_stats(self) -> Dict[str, int]:
|
1773
|
+
"""Get queue statistics - synchronous helper for thread pool."""
|
1774
|
+
with self.chunk_manager.lock:
|
1775
|
+
return {
|
1776
|
+
"pending_chunks": len(self.chunk_manager.pending_chunks),
|
1777
|
+
"assigned_chunks": sum(
|
1778
|
+
len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
|
1779
|
+
),
|
1780
|
+
}
|
1552
1781
|
|
1553
1782
|
async def _flush_processed_items(self):
|
1554
1783
|
"""Flush batched processed items to chunk tracker."""
|
@@ -1591,12 +1820,14 @@ class Orchestrator:
|
|
1591
1820
|
self.last_item_batch_flush = time.time()
|
1592
1821
|
|
1593
1822
|
def get_workers_by_user_stats(self) -> Dict[str, Any]:
|
1594
|
-
"""Get statistics about workers grouped by user/token."""
|
1823
|
+
"""Get statistics about workers grouped by user/token - thread-safe version."""
|
1595
1824
|
if not hasattr(self, "workers_by_user"):
|
1596
1825
|
return {}
|
1597
1826
|
|
1827
|
+
# Create a copy to avoid issues with concurrent modification
|
1598
1828
|
stats = {}
|
1599
|
-
|
1829
|
+
workers_snapshot = dict(self.workers_by_user)
|
1830
|
+
for user, worker_ids in workers_snapshot.items():
|
1600
1831
|
stats[user] = {"worker_count": len(worker_ids), "worker_ids": list(worker_ids)}
|
1601
1832
|
return stats
|
1602
1833
|
|
@@ -1621,21 +1852,63 @@ class Orchestrator:
|
|
1621
1852
|
async def _heartbeat_loop(self):
|
1622
1853
|
"""Send periodic heartbeats to maintain connections."""
|
1623
1854
|
while True:
|
1624
|
-
|
1855
|
+
try:
|
1856
|
+
await asyncio.sleep(30)
|
1625
1857
|
|
1626
|
-
|
1627
|
-
|
1628
|
-
|
1629
|
-
try:
|
1630
|
-
await ws.ping()
|
1631
|
-
except:
|
1632
|
-
disconnected.append(worker_id)
|
1858
|
+
# Create a copy of worker items to avoid modification during iteration
|
1859
|
+
worker_items = list(self.workers.items())
|
1860
|
+
disconnected = []
|
1633
1861
|
|
1634
|
-
|
1635
|
-
|
1636
|
-
|
1637
|
-
|
1638
|
-
|
1862
|
+
for worker_id, ws in worker_items:
|
1863
|
+
try:
|
1864
|
+
# Check if worker still exists before pinging
|
1865
|
+
if worker_id not in self.workers:
|
1866
|
+
continue
|
1867
|
+
|
1868
|
+
# Send ping with timeout
|
1869
|
+
pong_waiter = await ws.ping()
|
1870
|
+
try:
|
1871
|
+
await asyncio.wait_for(pong_waiter, timeout=10)
|
1872
|
+
except asyncio.TimeoutError:
|
1873
|
+
logger.warning(f"Worker {worker_id} failed to respond to ping")
|
1874
|
+
disconnected.append(worker_id)
|
1875
|
+
except websockets.exceptions.ConnectionClosed:
|
1876
|
+
logger.info(f"Worker {worker_id} connection already closed")
|
1877
|
+
disconnected.append(worker_id)
|
1878
|
+
except Exception as e:
|
1879
|
+
logger.error(f"Error pinging worker {worker_id}: {e}")
|
1880
|
+
disconnected.append(worker_id)
|
1881
|
+
|
1882
|
+
# Clean up disconnected workers
|
1883
|
+
for worker_id in disconnected:
|
1884
|
+
if worker_id in self.workers:
|
1885
|
+
logger.info(f"Removing unresponsive worker {worker_id}")
|
1886
|
+
del self.workers[worker_id]
|
1887
|
+
self.chunk_manager.release_worker_chunks(worker_id)
|
1888
|
+
|
1889
|
+
# Update stats
|
1890
|
+
self.stats["connected_workers"] = len(self.workers)
|
1891
|
+
|
1892
|
+
# Also clean up from workers_by_user if it exists
|
1893
|
+
if hasattr(self, "workers_by_user"):
|
1894
|
+
worker_user = (
|
1895
|
+
worker_id.rsplit("_", 1)[0] if "_" in worker_id else worker_id
|
1896
|
+
)
|
1897
|
+
if worker_user in self.workers_by_user:
|
1898
|
+
self.workers_by_user[worker_user].discard(worker_id)
|
1899
|
+
if not self.workers_by_user[worker_user]:
|
1900
|
+
del self.workers_by_user[worker_user]
|
1901
|
+
|
1902
|
+
# Notify monitors
|
1903
|
+
await self._broadcast_stats()
|
1904
|
+
await self._send_activity(
|
1905
|
+
f"Worker {worker_id} removed due to heartbeat timeout"
|
1906
|
+
)
|
1907
|
+
|
1908
|
+
except Exception as e:
|
1909
|
+
logger.error(f"Error in heartbeat loop: {e}", exc_info=True)
|
1910
|
+
# Continue the loop even if there's an error
|
1911
|
+
await asyncio.sleep(5)
|
1639
1912
|
|
1640
1913
|
async def _checkpoint_loop(self):
|
1641
1914
|
"""Periodically checkpoint storage."""
|
@@ -1663,7 +1936,10 @@ class Orchestrator:
|
|
1663
1936
|
)
|
1664
1937
|
|
1665
1938
|
async def _stats_update_loop(self):
|
1666
|
-
"""Periodically update and broadcast stats."""
|
1939
|
+
"""Periodically update and broadcast stats - non-blocking version."""
|
1940
|
+
# Get the event loop for running blocking operations
|
1941
|
+
loop = asyncio.get_event_loop()
|
1942
|
+
|
1667
1943
|
# Track session start values
|
1668
1944
|
storage_stats = await self.storage.get_storage_stats()
|
1669
1945
|
session_start_outputs = storage_stats["total_captions"] # This now counts ALL outputs
|
@@ -1675,8 +1951,8 @@ class Orchestrator:
|
|
1675
1951
|
while True:
|
1676
1952
|
await asyncio.sleep(10)
|
1677
1953
|
|
1678
|
-
# Update chunk stats
|
1679
|
-
chunk_stats = self.chunk_manager.get_stats
|
1954
|
+
# Update chunk stats in thread pool to avoid blocking
|
1955
|
+
chunk_stats = await loop.run_in_executor(None, self.chunk_manager.get_stats)
|
1680
1956
|
storage_stats = await self.storage.get_storage_stats()
|
1681
1957
|
current_total_outputs = storage_stats["total_captions"] # ALL outputs
|
1682
1958
|
if self.chunk_tracker:
|
@@ -1690,12 +1966,9 @@ class Orchestrator:
|
|
1690
1966
|
self.stats["total_outputs"] = current_total_outputs
|
1691
1967
|
self.stats["total_captions"] = current_total_outputs # Keep for backward compatibility
|
1692
1968
|
|
1693
|
-
#
|
1694
|
-
|
1695
|
-
|
1696
|
-
self.stats["assigned_chunks"] = sum(
|
1697
|
-
len(chunks) for chunks in self.chunk_manager.assigned_chunks.values()
|
1698
|
-
)
|
1969
|
+
# Get queue stats in thread pool to avoid blocking
|
1970
|
+
queue_stats = await loop.run_in_executor(None, self._get_queue_stats)
|
1971
|
+
self.stats.update(queue_stats)
|
1699
1972
|
|
1700
1973
|
# Calculate if we need more chunks
|
1701
1974
|
worker_count = self.stats.get("connected_workers", 0)
|
@@ -234,6 +234,54 @@ class DatasetLoader:
|
|
234
234
|
for key, url, image_data in ds:
|
235
235
|
yield key, url, image_data
|
236
236
|
|
237
|
+
def _create_dataset_at_position(self, dataset_path: str, split: str, start_idx: int):
|
238
|
+
"""Create a dataset iterator positioned at start_idx using state_dict if available."""
|
239
|
+
try:
|
240
|
+
# Load dataset in streaming mode
|
241
|
+
dataset = load_dataset(
|
242
|
+
dataset_path,
|
243
|
+
split=split,
|
244
|
+
streaming=True,
|
245
|
+
token=self.token,
|
246
|
+
)
|
247
|
+
|
248
|
+
# Check if the dataset supports state_dict (newer versions of datasets library)
|
249
|
+
if hasattr(dataset, "load_state_dict") and hasattr(dataset, "state_dict"):
|
250
|
+
# Try to use the dataset's native state management
|
251
|
+
try:
|
252
|
+
# Get current state
|
253
|
+
state = dataset.state_dict()
|
254
|
+
|
255
|
+
# Modify the state to skip to start_idx
|
256
|
+
if "epoch" in state:
|
257
|
+
state["epoch"] = 0
|
258
|
+
if "num_examples_since_previous_state" in state:
|
259
|
+
state["num_examples_since_previous_state"] = start_idx
|
260
|
+
|
261
|
+
# For newer datasets with examples_iterable state
|
262
|
+
if "examples_iterable" in state:
|
263
|
+
if isinstance(state["examples_iterable"], dict):
|
264
|
+
if "shard_example_idx" in state["examples_iterable"]:
|
265
|
+
state["examples_iterable"]["shard_example_idx"] = start_idx
|
266
|
+
|
267
|
+
# Load the modified state
|
268
|
+
dataset.load_state_dict(state)
|
269
|
+
logger.info(f"Positioned dataset at index {start_idx} using state_dict")
|
270
|
+
return dataset
|
271
|
+
except Exception as e:
|
272
|
+
logger.debug(f"Could not use state_dict approach: {e}")
|
273
|
+
|
274
|
+
# Fall back to skip() for large skips
|
275
|
+
if start_idx > 0:
|
276
|
+
logger.info(f"Using skip() to position dataset at index {start_idx}")
|
277
|
+
dataset = dataset.skip(start_idx)
|
278
|
+
|
279
|
+
return dataset
|
280
|
+
|
281
|
+
except Exception as e:
|
282
|
+
logger.warning(f"Error creating positioned dataset: {e}")
|
283
|
+
return None
|
284
|
+
|
237
285
|
def _iterate_hf_dataset_shard_with_metadata(
|
238
286
|
self, shard_url: str, processed_keys: Optional[set] = None
|
239
287
|
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
@@ -248,7 +296,65 @@ class DatasetLoader:
|
|
248
296
|
)
|
249
297
|
|
250
298
|
try:
|
251
|
-
#
|
299
|
+
# Try optimized approach for large skips
|
300
|
+
if start_idx > 100:
|
301
|
+
dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
|
302
|
+
if dataset:
|
303
|
+
items_processed = 0
|
304
|
+
|
305
|
+
for item in dataset:
|
306
|
+
# Stop after processing chunk_size items
|
307
|
+
if items_processed >= chunk_size:
|
308
|
+
break
|
309
|
+
|
310
|
+
# Generate a unique key for this item
|
311
|
+
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
312
|
+
|
313
|
+
if key in processed_keys:
|
314
|
+
items_processed += 1
|
315
|
+
continue
|
316
|
+
|
317
|
+
try:
|
318
|
+
# Extract image data
|
319
|
+
if self.image_column in item:
|
320
|
+
img_data = item[self.image_column]
|
321
|
+
|
322
|
+
# Process image to bytes
|
323
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
324
|
+
|
325
|
+
if image_bytes:
|
326
|
+
# Extract all metadata (excluding the image column)
|
327
|
+
metadata = {
|
328
|
+
k: v for k, v in item.items() if k != self.image_column
|
329
|
+
}
|
330
|
+
|
331
|
+
# URL is virtual for HF datasets
|
332
|
+
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
333
|
+
items_processed += 1
|
334
|
+
yield key, url, image_bytes, metadata
|
335
|
+
else:
|
336
|
+
logger.warning(
|
337
|
+
f"Failed to process image for item at index {start_idx + items_processed}"
|
338
|
+
)
|
339
|
+
items_processed += 1
|
340
|
+
continue
|
341
|
+
else:
|
342
|
+
logger.warning(
|
343
|
+
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
344
|
+
f"Available columns: {list(item.keys())}"
|
345
|
+
)
|
346
|
+
items_processed += 1
|
347
|
+
|
348
|
+
except Exception as e:
|
349
|
+
logger.error(
|
350
|
+
f"Error processing item at index {start_idx + items_processed}: {e}"
|
351
|
+
)
|
352
|
+
items_processed += 1
|
353
|
+
continue
|
354
|
+
|
355
|
+
return
|
356
|
+
|
357
|
+
# Fall back to regular approach for small skips or if StatefulDataLoader not available
|
252
358
|
dataset = load_dataset(
|
253
359
|
dataset_path,
|
254
360
|
split=self.split,
|
@@ -256,7 +362,7 @@ class DatasetLoader:
|
|
256
362
|
token=self.token,
|
257
363
|
)
|
258
364
|
|
259
|
-
# Skip to start index if needed
|
365
|
+
# Skip to start index if needed
|
260
366
|
if start_idx > 0:
|
261
367
|
dataset = dataset.skip(start_idx)
|
262
368
|
|
@@ -267,7 +373,7 @@ class DatasetLoader:
|
|
267
373
|
if items_processed >= chunk_size:
|
268
374
|
break
|
269
375
|
|
270
|
-
# Generate a unique key for this item
|
376
|
+
# Generate a unique key for this item
|
271
377
|
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
272
378
|
|
273
379
|
if key in processed_keys:
|
@@ -337,7 +443,76 @@ class DatasetLoader:
|
|
337
443
|
)
|
338
444
|
|
339
445
|
try:
|
340
|
-
#
|
446
|
+
# Try optimized approach for large skips
|
447
|
+
if start_idx > 100:
|
448
|
+
dataset = self._create_dataset_at_position(dataset_path, self.split, start_idx)
|
449
|
+
if dataset:
|
450
|
+
items_processed = 0
|
451
|
+
|
452
|
+
for item in dataset:
|
453
|
+
# Stop after processing chunk_size items
|
454
|
+
if items_processed >= chunk_size:
|
455
|
+
logger.info(f"Completed chunk: processed {items_processed} items")
|
456
|
+
break
|
457
|
+
|
458
|
+
# Also stop if we've reached the dataset end
|
459
|
+
if (
|
460
|
+
self._hf_total_items
|
461
|
+
and (start_idx + items_processed) >= self._hf_total_items
|
462
|
+
):
|
463
|
+
logger.info(
|
464
|
+
f"Reached dataset end at item {start_idx + items_processed} "
|
465
|
+
f"(total: {self._hf_total_items})"
|
466
|
+
)
|
467
|
+
break
|
468
|
+
|
469
|
+
# Generate a unique key for this item
|
470
|
+
key = f"{dataset_path.replace('/', '_')}_{start_idx + items_processed:08d}"
|
471
|
+
|
472
|
+
if key in processed_keys:
|
473
|
+
items_processed += 1
|
474
|
+
continue
|
475
|
+
|
476
|
+
try:
|
477
|
+
# Extract image data
|
478
|
+
if self.image_column in item:
|
479
|
+
img_data = item[self.image_column]
|
480
|
+
|
481
|
+
# Delegate image processing to ImageProcessor
|
482
|
+
image_bytes = ImageProcessor.process_image_data(img_data)
|
483
|
+
|
484
|
+
if image_bytes:
|
485
|
+
# URL is virtual for HF datasets
|
486
|
+
url = f"hf://{dataset_path}#{start_idx + items_processed}"
|
487
|
+
items_processed += 1
|
488
|
+
yield key, url, image_bytes
|
489
|
+
else:
|
490
|
+
logger.warning(
|
491
|
+
f"Failed to process image for item at index {start_idx + items_processed}"
|
492
|
+
)
|
493
|
+
items_processed += 1
|
494
|
+
continue
|
495
|
+
else:
|
496
|
+
logger.warning(
|
497
|
+
f"No image column '{self.image_column}' found in item at index {start_idx + items_processed}. "
|
498
|
+
f"Available columns: {list(item.keys())}"
|
499
|
+
)
|
500
|
+
items_processed += 1
|
501
|
+
|
502
|
+
except Exception as e:
|
503
|
+
logger.error(
|
504
|
+
f"Error processing item at index {start_idx + items_processed}: {e}"
|
505
|
+
)
|
506
|
+
items_processed += 1
|
507
|
+
continue
|
508
|
+
|
509
|
+
logger.info(
|
510
|
+
f"Virtual shard complete: processed {items_processed} items "
|
511
|
+
f"(start_idx: {start_idx})"
|
512
|
+
)
|
513
|
+
return
|
514
|
+
|
515
|
+
# Fall back to regular approach for small skips or if StatefulDataLoader not available
|
341
516
|
dataset = load_dataset(
|
342
517
|
dataset_path,
|
343
518
|
split=self.split,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: caption-flow
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.1
|
4
4
|
Summary: Self-contained distributed community captioning system
|
5
5
|
Author-email: bghira <bghira@users.github.com>
|
6
6
|
License: MIT
|
@@ -32,6 +32,7 @@ Requires-Dist: pandas<3.0.0,>=2.3.1
|
|
32
32
|
Requires-Dist: arrow<2.0.0,>=1.3.0
|
33
33
|
Requires-Dist: datasets<5.0.0,>=4.0.0
|
34
34
|
Requires-Dist: boto3<2.0.0,>=1.40.11
|
35
|
+
Requires-Dist: torchdata<0.12.0,>=0.11.0
|
35
36
|
Provides-Extra: dev
|
36
37
|
Requires-Dist: pytest>=7.4.0; extra == "dev"
|
37
38
|
Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
|
@@ -1,8 +1,8 @@
|
|
1
1
|
caption_flow/__init__.py,sha256=NLPJ25lRN7xHqncXweINDNwbt0q8lgjZ30G21zlPdRs,303
|
2
|
-
caption_flow/cli.py,sha256=
|
2
|
+
caption_flow/cli.py,sha256=bHxx66CPsCmSieaH3pw8NZBojIIbniRTdU9mEBHMmWA,28832
|
3
3
|
caption_flow/models.py,sha256=qo6lQiO10UISbaBVr6Cs-fSW_pmjwE6kmiTmmU_l3Wk,2140
|
4
|
-
caption_flow/monitor.py,sha256=
|
5
|
-
caption_flow/orchestrator.py,sha256=
|
4
|
+
caption_flow/monitor.py,sha256=ZZCSasYLKJ-UzA3-RoAtytv-tbNA-m3h5YjlZg_vukg,7870
|
5
|
+
caption_flow/orchestrator.py,sha256=bZ8NnGdqoXSmu7Nq-_7cOSH1DLHkBT88cne0uDyPeNY,89112
|
6
6
|
caption_flow/storage.py,sha256=hC6ZHT_PHFoUVjqD5JUwy3_79oAD1e1H30neA_xsz7s,40748
|
7
7
|
caption_flow/utils/__init__.py,sha256=F1BChVoCsj9zn1GJRBOLHET1kLW6xrAmsbzcR7hHy6Y,202
|
8
8
|
caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
|
@@ -10,7 +10,7 @@ caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-
|
|
10
10
|
caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
|
11
11
|
caption_flow/utils/checkpoint_tracker.py,sha256=8tsTFF-HcygitK92YcS-QWzeg-qRm9AuCpQoQRfC8M0,3335
|
12
12
|
caption_flow/utils/chunk_tracker.py,sha256=hKn8CN6ubErc9kuCWZMj12ZCZKxVlqXqAEocbzjfa-k,17296
|
13
|
-
caption_flow/utils/dataset_loader.py,sha256=
|
13
|
+
caption_flow/utils/dataset_loader.py,sha256=ZplJv655ZMyUbaZC4BBiL5II18sBy4JSJhxGZtK_VmA,29107
|
14
14
|
caption_flow/utils/image_processor.py,sha256=Zl8TAv9gYPdAYat3UiTuuNdIb2fXNfZ35AxsxuovJTs,5650
|
15
15
|
caption_flow/utils/job_queue.py,sha256=itdfXcrkvGjmXn4qtpgMF63k1ufRBaejDe4V6WcxzgU,1104
|
16
16
|
caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
|
@@ -21,9 +21,9 @@ caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn0
|
|
21
21
|
caption_flow/workers/base.py,sha256=jPm_Xw4Lxd0cnrPs-biBqKRQKkTOJLvHLolmp0Gb1CI,7530
|
22
22
|
caption_flow/workers/caption.py,sha256=NZ9kTjk2uOoNwyyNSkB_arYk213vLr5mowHN-OjiFkk,54631
|
23
23
|
caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
|
24
|
-
caption_flow-0.2.
|
25
|
-
caption_flow-0.2.
|
26
|
-
caption_flow-0.2.
|
27
|
-
caption_flow-0.2.
|
28
|
-
caption_flow-0.2.
|
29
|
-
caption_flow-0.2.
|
24
|
+
caption_flow-0.2.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
25
|
+
caption_flow-0.2.1.dist-info/METADATA,sha256=fxNfSOqkCklb96aq3ZFU7SvRuXEBUQ11xbjkQn7Yzuo,11941
|
26
|
+
caption_flow-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
27
|
+
caption_flow-0.2.1.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
|
28
|
+
caption_flow-0.2.1.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
|
29
|
+
caption_flow-0.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|