caption-flow 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +934 -415
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +2 -3
- caption_flow/orchestrator.py +153 -104
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +439 -67
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +28 -22
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +30 -28
- caption_flow/utils/chunk_tracker.py +153 -56
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +5 -4
- caption_flow/workers/caption.py +265 -90
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
- caption_flow-0.4.0.dist-info/RECORD +33 -0
- caption_flow-0.3.4.dist-info/RECORD +0 -33
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
caption_flow/cli.py
CHANGED
@@ -1,21 +1,22 @@
|
|
1
1
|
"""Command-line interface for CaptionFlow with smart configuration handling."""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import datetime as _datetime
|
4
5
|
import json
|
5
6
|
import logging
|
6
7
|
import os
|
7
8
|
import sys
|
9
|
+
from datetime import datetime
|
8
10
|
from pathlib import Path
|
9
|
-
from typing import
|
11
|
+
from typing import Any, Dict, List, Optional
|
10
12
|
|
11
13
|
import click
|
12
14
|
import yaml
|
13
15
|
from rich.console import Console
|
14
16
|
from rich.logging import RichHandler
|
15
|
-
from datetime import datetime
|
16
17
|
|
17
|
-
from .orchestrator import Orchestrator
|
18
18
|
from .monitor import Monitor
|
19
|
+
from .orchestrator import Orchestrator
|
19
20
|
from .utils.certificates import CertificateManager
|
20
21
|
|
21
22
|
console = Console()
|
@@ -48,8 +49,7 @@ class ConfigManager:
|
|
48
49
|
def find_config(
|
49
50
|
cls, component: str, explicit_path: Optional[str] = None
|
50
51
|
) -> Optional[Dict[str, Any]]:
|
51
|
-
"""
|
52
|
-
Find and load configuration for a component.
|
52
|
+
"""Find and load configuration for a component.
|
53
53
|
|
54
54
|
Search order:
|
55
55
|
1. Explicit path if provided
|
@@ -120,22 +120,76 @@ class ConfigManager:
|
|
120
120
|
|
121
121
|
|
122
122
|
def setup_logging(verbose: bool = False):
|
123
|
-
"""Configure logging with rich handler
|
123
|
+
"""Configure logging with rich handler and file output to XDG state directory."""
|
124
124
|
level = logging.DEBUG if verbose else logging.INFO
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
125
|
+
|
126
|
+
# Determine log directory based on environment or XDG spec
|
127
|
+
log_dir_env = os.environ.get("CAPTIONFLOW_LOG_DIR")
|
128
|
+
if log_dir_env:
|
129
|
+
log_dir = Path(log_dir_env)
|
130
|
+
else:
|
131
|
+
# Use XDG_STATE_HOME for logs, with platform-specific fallbacks
|
132
|
+
xdg_state_home = os.environ.get("XDG_STATE_HOME")
|
133
|
+
if xdg_state_home:
|
134
|
+
base_dir = Path(xdg_state_home)
|
135
|
+
elif sys.platform == "darwin":
|
136
|
+
base_dir = Path.home() / "Library" / "Logs"
|
137
|
+
else:
|
138
|
+
# Default to ~/.local/state on Linux and other systems
|
139
|
+
base_dir = Path.home() / ".local" / "state"
|
140
|
+
log_dir = base_dir / "caption-flow"
|
141
|
+
|
142
|
+
try:
|
143
|
+
# Ensure log directory exists
|
144
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
145
|
+
log_file_path = log_dir / "caption_flow.log"
|
146
|
+
|
147
|
+
# Set up handlers
|
148
|
+
handlers: List[logging.Handler] = [
|
149
|
+
RichHandler(
|
150
|
+
console=console,
|
151
|
+
rich_tracebacks=True,
|
152
|
+
show_path=False,
|
153
|
+
show_time=True,
|
154
|
+
)
|
155
|
+
]
|
156
|
+
|
157
|
+
# Add file handler
|
158
|
+
file_handler = logging.FileHandler(log_file_path, mode="a")
|
159
|
+
file_handler.setFormatter(
|
160
|
+
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
161
|
+
)
|
162
|
+
handlers.append(file_handler)
|
163
|
+
log_msg = f"Logging to {log_file_path}"
|
164
|
+
|
165
|
+
except (OSError, PermissionError) as e:
|
166
|
+
# Fallback to only console logging if file logging fails
|
167
|
+
handlers = [
|
130
168
|
RichHandler(
|
131
169
|
console=console,
|
132
170
|
rich_tracebacks=True,
|
133
171
|
show_path=False,
|
134
|
-
show_time=True,
|
172
|
+
show_time=True,
|
135
173
|
)
|
136
|
-
]
|
174
|
+
]
|
175
|
+
log_file = log_dir / "caption_flow.log"
|
176
|
+
log_msg = f"[yellow]Warning: Could not write to log file {log_file}: {e}[/yellow]"
|
177
|
+
|
178
|
+
logging.basicConfig(
|
179
|
+
level=level,
|
180
|
+
format="%(message)s", # RichHandler overrides this format for console
|
181
|
+
datefmt="[%Y-%m-%d %H:%M:%S]",
|
182
|
+
handlers=handlers,
|
137
183
|
)
|
138
184
|
|
185
|
+
# Suppress noisy libraries
|
186
|
+
logging.getLogger("websockets").setLevel(logging.WARNING)
|
187
|
+
logging.getLogger("pyarrow").setLevel(logging.WARNING)
|
188
|
+
|
189
|
+
# Use a dedicated logger to print the log file path to avoid format issues
|
190
|
+
if "log_msg" in locals():
|
191
|
+
logging.getLogger("setup").info(log_msg)
|
192
|
+
|
139
193
|
|
140
194
|
def apply_cli_overrides(config: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
141
195
|
"""Apply CLI arguments as overrides to config, filtering out None values."""
|
@@ -189,9 +243,11 @@ def orchestrator(ctx, config: Optional[str], **kwargs):
|
|
189
243
|
config_data["ssl"]["cert"] = kwargs["cert"]
|
190
244
|
config_data["ssl"]["key"] = kwargs["key"]
|
191
245
|
elif not config_data.get("ssl"):
|
192
|
-
|
193
|
-
"[yellow]Warning: Running without SSL.
|
246
|
+
warning_msg = (
|
247
|
+
"[yellow]Warning: Running without SSL. "
|
248
|
+
"Use --cert and --key for production.[/yellow]"
|
194
249
|
)
|
250
|
+
console.print(warning_msg)
|
195
251
|
|
196
252
|
if kwargs.get("vllm") and "vllm" not in config_data:
|
197
253
|
raise ValueError("Must provide vLLM config.")
|
@@ -259,33 +315,11 @@ def worker(ctx, config: Optional[str], **kwargs):
|
|
259
315
|
asyncio.run(worker_instance.shutdown())
|
260
316
|
|
261
317
|
|
262
|
-
|
263
|
-
|
264
|
-
@click.option("--server", help="Orchestrator WebSocket URL")
|
265
|
-
@click.option("--token", help="Authentication token")
|
266
|
-
@click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
|
267
|
-
@click.option("--debug", is_flag=True, help="Enable debug output")
|
268
|
-
@click.pass_context
|
269
|
-
def monitor(
|
270
|
-
ctx,
|
271
|
-
config: Optional[str],
|
272
|
-
server: Optional[str],
|
273
|
-
token: Optional[str],
|
274
|
-
no_verify_ssl: bool,
|
275
|
-
debug: bool,
|
276
|
-
):
|
277
|
-
"""Start the monitoring TUI."""
|
278
|
-
|
279
|
-
# Enable debug logging if requested
|
280
|
-
if debug:
|
281
|
-
setup_logging(verbose=True)
|
282
|
-
console.print("[yellow]Debug mode enabled[/yellow]")
|
283
|
-
|
284
|
-
# Load configuration
|
318
|
+
def _load_monitor_config(config, server, token):
|
319
|
+
"""Load monitor configuration from file or fallback to orchestrator config."""
|
285
320
|
base_config = ConfigManager.find_config("monitor", config)
|
286
321
|
|
287
322
|
if not base_config:
|
288
|
-
# Try to find monitor config in orchestrator config as fallback
|
289
323
|
orch_config = ConfigManager.find_config("orchestrator")
|
290
324
|
if orch_config and "monitor" in orch_config:
|
291
325
|
base_config = {"monitor": orch_config["monitor"]}
|
@@ -295,15 +329,11 @@ def monitor(
|
|
295
329
|
if not server or not token:
|
296
330
|
console.print("[yellow]No monitor config found, using CLI args[/yellow]")
|
297
331
|
|
298
|
-
|
299
|
-
|
300
|
-
if "monitor" in base_config:
|
301
|
-
config_data = base_config["monitor"]
|
302
|
-
# Case 2: Config IS the monitor config (no wrapper)
|
303
|
-
else:
|
304
|
-
config_data = base_config
|
332
|
+
return base_config.get("monitor", base_config)
|
333
|
+
|
305
334
|
|
306
|
-
|
335
|
+
def _apply_monitor_overrides(config_data, server, token, no_verify_ssl):
|
336
|
+
"""Apply CLI overrides to monitor configuration."""
|
307
337
|
if server:
|
308
338
|
config_data["server"] = server
|
309
339
|
if token:
|
@@ -311,17 +341,20 @@ def monitor(
|
|
311
341
|
if no_verify_ssl:
|
312
342
|
config_data["verify_ssl"] = False
|
313
343
|
|
314
|
-
# Debug output
|
315
|
-
if debug:
|
316
|
-
console.print("\n[cyan]Final monitor configuration:[/cyan]")
|
317
|
-
console.print(f" Server: {config_data.get('server', 'NOT SET')}")
|
318
|
-
console.print(
|
319
|
-
f" Token: {'***' + config_data.get('token', '')[-4:] if config_data.get('token') else 'NOT SET'}"
|
320
|
-
)
|
321
|
-
console.print(f" Verify SSL: {config_data.get('verify_ssl', True)}")
|
322
|
-
console.print()
|
323
344
|
|
324
|
-
|
345
|
+
def _debug_monitor_config(config_data):
|
346
|
+
"""Print debug information about monitor configuration."""
|
347
|
+
console.print("\n[cyan]Final monitor configuration:[/cyan]")
|
348
|
+
console.print(f" Server: {config_data.get('server', 'NOT SET')}")
|
349
|
+
console.print(
|
350
|
+
f" Token: {'***' + config_data.get('token', '')[-4:] if config_data.get('token') else 'NOT SET'}"
|
351
|
+
)
|
352
|
+
console.print(f" Verify SSL: {config_data.get('verify_ssl', True)}")
|
353
|
+
console.print()
|
354
|
+
|
355
|
+
|
356
|
+
def _validate_monitor_config(config_data):
|
357
|
+
"""Validate required monitor configuration fields."""
|
325
358
|
if not config_data.get("server"):
|
326
359
|
console.print("[red]Error: --server required (or set 'server' in monitor.yaml)[/red]")
|
327
360
|
console.print("\n[dim]Example monitor.yaml:[/dim]")
|
@@ -336,12 +369,43 @@ def monitor(
|
|
336
369
|
console.print("token: your-token-here")
|
337
370
|
sys.exit(1)
|
338
371
|
|
339
|
-
|
372
|
+
|
373
|
+
def _set_monitor_defaults(config_data):
|
374
|
+
"""Set default values for optional monitor settings."""
|
340
375
|
config_data.setdefault("refresh_interval", 1.0)
|
341
376
|
config_data.setdefault("show_inactive_workers", False)
|
342
377
|
config_data.setdefault("max_log_lines", 100)
|
343
378
|
|
344
|
-
|
379
|
+
|
380
|
+
@main.command()
|
381
|
+
@click.option("--config", type=click.Path(exists=True), help="Configuration file")
|
382
|
+
@click.option("--server", help="Orchestrator WebSocket URL")
|
383
|
+
@click.option("--token", help="Authentication token")
|
384
|
+
@click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
|
385
|
+
@click.option("--debug", is_flag=True, help="Enable debug output")
|
386
|
+
@click.pass_context
|
387
|
+
def monitor(
|
388
|
+
ctx,
|
389
|
+
config: Optional[str],
|
390
|
+
server: Optional[str],
|
391
|
+
token: Optional[str],
|
392
|
+
no_verify_ssl: bool,
|
393
|
+
debug: bool,
|
394
|
+
):
|
395
|
+
"""Start the monitoring TUI."""
|
396
|
+
if debug:
|
397
|
+
setup_logging(verbose=True)
|
398
|
+
console.print("[yellow]Debug mode enabled[/yellow]")
|
399
|
+
|
400
|
+
config_data = _load_monitor_config(config, server, token)
|
401
|
+
_apply_monitor_overrides(config_data, server, token, no_verify_ssl)
|
402
|
+
|
403
|
+
if debug:
|
404
|
+
_debug_monitor_config(config_data)
|
405
|
+
|
406
|
+
_validate_monitor_config(config_data)
|
407
|
+
_set_monitor_defaults(config_data)
|
408
|
+
|
345
409
|
try:
|
346
410
|
monitor_instance = Monitor(config_data)
|
347
411
|
|
@@ -406,7 +470,7 @@ def view(ctx, data_dir: str, refresh_rate: int, no_images: bool):
|
|
406
470
|
viewer.disable_images = True
|
407
471
|
viewer.refresh_rate = refresh_rate
|
408
472
|
|
409
|
-
console.print(
|
473
|
+
console.print("[cyan]Starting dataset viewer...[/cyan]")
|
410
474
|
console.print(f"[dim]Data directory: {data_path}[/dim]")
|
411
475
|
|
412
476
|
asyncio.run(viewer.run())
|
@@ -424,6 +488,400 @@ def view(ctx, data_dir: str, refresh_rate: int, no_images: bool):
|
|
424
488
|
sys.exit(1)
|
425
489
|
|
426
490
|
|
491
|
+
def _load_admin_credentials(config, server, token):
|
492
|
+
"""Load admin server and token from config if not provided."""
|
493
|
+
if server and token:
|
494
|
+
return server, token
|
495
|
+
|
496
|
+
base_config = ConfigManager.find_config("orchestrator", config) or {}
|
497
|
+
admin_config = base_config.get("admin", {})
|
498
|
+
admin_tokens = base_config.get("orchestrator", {}).get("auth", {}).get("admin_tokens", [])
|
499
|
+
|
500
|
+
final_server = server or admin_config.get("server", "ws://localhost:8765")
|
501
|
+
final_token = token or admin_config.get("token")
|
502
|
+
|
503
|
+
if not final_token and admin_tokens:
|
504
|
+
console.print("Using first admin token.")
|
505
|
+
final_token = admin_tokens[0].get("token")
|
506
|
+
|
507
|
+
return final_server, final_token
|
508
|
+
|
509
|
+
|
510
|
+
def _setup_ssl_context(server, no_verify_ssl):
|
511
|
+
"""Setup SSL context for websocket connection."""
|
512
|
+
import ssl
|
513
|
+
|
514
|
+
ssl_context = None
|
515
|
+
if server.startswith("wss://"):
|
516
|
+
ssl_context = ssl.create_default_context()
|
517
|
+
if no_verify_ssl:
|
518
|
+
ssl_context.check_hostname = False
|
519
|
+
ssl_context.verify_mode = ssl.CERT_NONE
|
520
|
+
|
521
|
+
return ssl_context
|
522
|
+
|
523
|
+
|
524
|
+
async def _authenticate_admin(websocket, token):
|
525
|
+
"""Authenticate as admin with the websocket."""
|
526
|
+
await websocket.send(json.dumps({"token": token, "role": "admin"}))
|
527
|
+
|
528
|
+
response = await websocket.recv()
|
529
|
+
auth_response = json.loads(response)
|
530
|
+
|
531
|
+
if "error" in auth_response:
|
532
|
+
console.print(f"[red]Authentication failed: {auth_response['error']}[/red]")
|
533
|
+
return False
|
534
|
+
|
535
|
+
console.print("[green]✓ Authenticated as admin[/green]")
|
536
|
+
return True
|
537
|
+
|
538
|
+
|
539
|
+
async def _send_reload_command(websocket, new_cfg):
|
540
|
+
"""Send reload command and handle response."""
|
541
|
+
await websocket.send(json.dumps({"type": "reload_config", "config": new_cfg}))
|
542
|
+
|
543
|
+
response = await websocket.recv()
|
544
|
+
reload_response = json.loads(response)
|
545
|
+
|
546
|
+
if reload_response.get("type") == "reload_complete":
|
547
|
+
if "message" in reload_response and "No changes" in reload_response["message"]:
|
548
|
+
console.print(f"[yellow]{reload_response['message']}[/yellow]")
|
549
|
+
else:
|
550
|
+
console.print("[green]✓ Configuration reloaded successfully![/green]")
|
551
|
+
|
552
|
+
if "updated" in reload_response and reload_response["updated"]:
|
553
|
+
console.print("\n[cyan]Updated sections:[/cyan]")
|
554
|
+
for section in reload_response["updated"]:
|
555
|
+
console.print(f" • {section}")
|
556
|
+
|
557
|
+
if "warnings" in reload_response and reload_response["warnings"]:
|
558
|
+
console.print("\n[yellow]Warnings:[/yellow]")
|
559
|
+
for warning in reload_response["warnings"]:
|
560
|
+
console.print(f" ⚠ {warning}")
|
561
|
+
|
562
|
+
return True
|
563
|
+
else:
|
564
|
+
error = reload_response.get("error", "Unknown error")
|
565
|
+
console.print(f"[red]Reload failed: {error} ({reload_response=})[/red]")
|
566
|
+
return False
|
567
|
+
|
568
|
+
|
569
|
+
def _add_token_to_config(config_data: Dict[str, Any], role: str, name: str, token: str) -> bool:
|
570
|
+
"""Add a new token to the config data."""
|
571
|
+
# Ensure the auth section exists
|
572
|
+
if "orchestrator" not in config_data:
|
573
|
+
config_data["orchestrator"] = {}
|
574
|
+
if "auth" not in config_data["orchestrator"]:
|
575
|
+
config_data["orchestrator"]["auth"] = {}
|
576
|
+
|
577
|
+
auth_config = config_data["orchestrator"]["auth"]
|
578
|
+
token_key = f"{role}_tokens"
|
579
|
+
|
580
|
+
# Initialize token list if it doesn't exist
|
581
|
+
if token_key not in auth_config:
|
582
|
+
auth_config[token_key] = []
|
583
|
+
|
584
|
+
# Check if token already exists
|
585
|
+
for existing_token in auth_config[token_key]:
|
586
|
+
if existing_token.get("token") == token:
|
587
|
+
console.print(f"[yellow]Token already exists for {role}: {name}[/yellow]")
|
588
|
+
return False
|
589
|
+
if existing_token.get("name") == name:
|
590
|
+
console.print(f"[yellow]Name already exists for {role}: {name}[/yellow]")
|
591
|
+
return False
|
592
|
+
|
593
|
+
# Add the new token
|
594
|
+
auth_config[token_key].append({"name": name, "token": token})
|
595
|
+
console.print(f"[green]✓ Added {role} token for {name}[/green]")
|
596
|
+
return True
|
597
|
+
|
598
|
+
|
599
|
+
def _remove_token_from_config(config_data: Dict[str, Any], role: str, identifier: str) -> bool:
|
600
|
+
"""Remove a token from the config data by name or token."""
|
601
|
+
auth_config = config_data.get("orchestrator", {}).get("auth", {})
|
602
|
+
token_key = f"{role}_tokens"
|
603
|
+
|
604
|
+
if token_key not in auth_config:
|
605
|
+
console.print(f"[red]No {role} tokens found in config[/red]")
|
606
|
+
return False
|
607
|
+
|
608
|
+
tokens = auth_config[token_key]
|
609
|
+
removed = False
|
610
|
+
|
611
|
+
for i, token_entry in enumerate(tokens):
|
612
|
+
if token_entry.get("name") == identifier or token_entry.get("token") == identifier:
|
613
|
+
removed_entry = tokens.pop(i)
|
614
|
+
console.print(f"[green]✓ Removed {role} token: {removed_entry['name']}[/green]")
|
615
|
+
removed = True
|
616
|
+
break
|
617
|
+
|
618
|
+
if not removed:
|
619
|
+
console.print(f"[red]Token not found for {role}: {identifier}[/red]")
|
620
|
+
|
621
|
+
return removed
|
622
|
+
|
623
|
+
|
624
|
+
def _list_tokens_in_config(config_data: Dict[str, Any], role: Optional[str] = None):
|
625
|
+
"""List tokens in the config data."""
|
626
|
+
auth_config = config_data.get("orchestrator", {}).get("auth", {})
|
627
|
+
|
628
|
+
if not auth_config:
|
629
|
+
console.print("[yellow]No auth configuration found[/yellow]")
|
630
|
+
return
|
631
|
+
|
632
|
+
roles_to_show = [role] if role else ["worker", "admin", "monitor"]
|
633
|
+
|
634
|
+
for token_role in roles_to_show:
|
635
|
+
token_key = f"{token_role}_tokens"
|
636
|
+
tokens = auth_config.get(token_key, [])
|
637
|
+
|
638
|
+
if tokens:
|
639
|
+
console.print(f"\n[cyan]{token_role.title()} tokens:[/cyan]")
|
640
|
+
for token_entry in tokens:
|
641
|
+
name = token_entry.get("name", "Unknown")
|
642
|
+
token = token_entry.get("token", "")
|
643
|
+
masked_token = f"***{token[-4:]}" if len(token) > 4 else "***"
|
644
|
+
console.print(f" • {name}: {masked_token}")
|
645
|
+
else:
|
646
|
+
console.print(f"\n[dim]No {token_role} tokens configured[/dim]")
|
647
|
+
|
648
|
+
|
649
|
+
def _save_config_file(config_data: Dict[str, Any], config_path: Path) -> bool:
|
650
|
+
"""Save the config data to a file."""
|
651
|
+
try:
|
652
|
+
with open(config_path, "w") as f:
|
653
|
+
yaml.safe_dump(config_data, f, default_flow_style=False, sort_keys=False)
|
654
|
+
console.print(f"[green]✓ Configuration saved to {config_path}[/green]")
|
655
|
+
return True
|
656
|
+
except Exception as e:
|
657
|
+
console.print(f"[red]Error saving config: {e}[/red]")
|
658
|
+
return False
|
659
|
+
|
660
|
+
|
661
|
+
async def _reload_orchestrator_config(
|
662
|
+
server: str, token: str, config_data: Dict[str, Any], no_verify_ssl: bool
|
663
|
+
) -> bool:
|
664
|
+
"""Reload the orchestrator configuration."""
|
665
|
+
import websockets
|
666
|
+
|
667
|
+
ssl_context = _setup_ssl_context(server, no_verify_ssl)
|
668
|
+
|
669
|
+
try:
|
670
|
+
async with websockets.connect(
|
671
|
+
server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
|
672
|
+
) as websocket:
|
673
|
+
if not await _authenticate_admin(websocket, token):
|
674
|
+
return False
|
675
|
+
|
676
|
+
return await _send_reload_command(websocket, config_data)
|
677
|
+
except Exception as e:
|
678
|
+
console.print(f"[red]Error connecting to orchestrator: {e}[/red]")
|
679
|
+
return False
|
680
|
+
|
681
|
+
|
682
|
+
@main.group()
|
683
|
+
@click.option("--config", type=click.Path(exists=True), help="Configuration file")
|
684
|
+
@click.option("--server", help="Orchestrator WebSocket URL")
|
685
|
+
@click.option("--token", help="Admin authentication token")
|
686
|
+
@click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
|
687
|
+
@click.pass_context
|
688
|
+
def auth(
|
689
|
+
ctx, config: Optional[str], server: Optional[str], token: Optional[str], no_verify_ssl: bool
|
690
|
+
):
|
691
|
+
"""Manage authentication tokens for the orchestrator."""
|
692
|
+
ctx.ensure_object(dict)
|
693
|
+
ctx.obj.update(
|
694
|
+
{"config": config, "server": server, "token": token, "no_verify_ssl": no_verify_ssl}
|
695
|
+
)
|
696
|
+
|
697
|
+
|
698
|
+
@auth.command()
|
699
|
+
@click.argument("role", type=click.Choice(["worker", "admin", "monitor"]))
|
700
|
+
@click.argument("name")
|
701
|
+
@click.argument("token_value")
|
702
|
+
@click.option(
|
703
|
+
"--no-reload", is_flag=True, help="Don't reload orchestrator config after adding token"
|
704
|
+
)
|
705
|
+
@click.pass_context
|
706
|
+
def add(ctx, role: str, name: str, token_value: str, no_reload: bool):
|
707
|
+
"""Add a new authentication token.
|
708
|
+
|
709
|
+
ROLE: Type of token (worker, admin, monitor)
|
710
|
+
NAME: Display name for the token
|
711
|
+
TOKEN_VALUE: The actual token string
|
712
|
+
"""
|
713
|
+
config_file = ctx.obj.get("config")
|
714
|
+
server = ctx.obj.get("server")
|
715
|
+
admin_token = ctx.obj.get("token")
|
716
|
+
no_verify_ssl = ctx.obj.get("no_verify_ssl", False)
|
717
|
+
|
718
|
+
# Load config
|
719
|
+
config_data = ConfigManager.find_config("orchestrator", config_file)
|
720
|
+
if not config_data:
|
721
|
+
console.print("[red]No orchestrator config found[/red]")
|
722
|
+
console.print("[dim]Use --config to specify config file path[/dim]")
|
723
|
+
sys.exit(1)
|
724
|
+
|
725
|
+
# Find config file path for saving
|
726
|
+
config_path = None
|
727
|
+
if config_file:
|
728
|
+
config_path = Path(config_file)
|
729
|
+
else:
|
730
|
+
# Try to find the config file that was loaded
|
731
|
+
for search_path in [
|
732
|
+
Path.cwd() / "orchestrator.yaml",
|
733
|
+
Path.cwd() / "config" / "orchestrator.yaml",
|
734
|
+
Path.home() / ".caption-flow" / "orchestrator.yaml",
|
735
|
+
ConfigManager.get_xdg_config_home() / "caption-flow" / "orchestrator.yaml",
|
736
|
+
]:
|
737
|
+
if search_path.exists():
|
738
|
+
config_path = search_path
|
739
|
+
break
|
740
|
+
|
741
|
+
if not config_path:
|
742
|
+
console.print("[red]Could not determine config file to save to[/red]")
|
743
|
+
console.print("[dim]Use --config to specify config file path[/dim]")
|
744
|
+
sys.exit(1)
|
745
|
+
|
746
|
+
# Add token to config
|
747
|
+
if not _add_token_to_config(config_data, role, name, token_value):
|
748
|
+
sys.exit(1)
|
749
|
+
|
750
|
+
# Save config file
|
751
|
+
if not _save_config_file(config_data, config_path):
|
752
|
+
sys.exit(1)
|
753
|
+
|
754
|
+
# Reload orchestrator if requested
|
755
|
+
if not no_reload:
|
756
|
+
server, admin_token = _load_admin_credentials(config_file, server, admin_token)
|
757
|
+
|
758
|
+
if not server:
|
759
|
+
console.print("[yellow]No server specified, skipping orchestrator reload[/yellow]")
|
760
|
+
console.print("[dim]Use --server to reload orchestrator config[/dim]")
|
761
|
+
elif not admin_token:
|
762
|
+
console.print("[yellow]No admin token specified, skipping orchestrator reload[/yellow]")
|
763
|
+
console.print("[dim]Use --token to reload orchestrator config[/dim]")
|
764
|
+
else:
|
765
|
+
console.print(f"[cyan]Reloading orchestrator config...[/cyan]")
|
766
|
+
success = asyncio.run(
|
767
|
+
_reload_orchestrator_config(server, admin_token, config_data, no_verify_ssl)
|
768
|
+
)
|
769
|
+
if not success:
|
770
|
+
console.print("[yellow]Config file updated but orchestrator reload failed[/yellow]")
|
771
|
+
console.print("[dim]You may need to restart the orchestrator manually[/dim]")
|
772
|
+
|
773
|
+
|
774
|
+
@auth.command()
|
775
|
+
@click.argument("role", type=click.Choice(["worker", "admin", "monitor"]))
|
776
|
+
@click.argument("identifier")
|
777
|
+
@click.option(
|
778
|
+
"--no-reload", is_flag=True, help="Don't reload orchestrator config after removing token"
|
779
|
+
)
|
780
|
+
@click.pass_context
|
781
|
+
def remove(ctx, role: str, identifier: str, no_reload: bool):
|
782
|
+
"""Remove an authentication token.
|
783
|
+
|
784
|
+
ROLE: Type of token (worker, admin, monitor)
|
785
|
+
IDENTIFIER: Name or token value to remove
|
786
|
+
"""
|
787
|
+
config_file = ctx.obj.get("config")
|
788
|
+
server = ctx.obj.get("server")
|
789
|
+
admin_token = ctx.obj.get("token")
|
790
|
+
no_verify_ssl = ctx.obj.get("no_verify_ssl", False)
|
791
|
+
|
792
|
+
# Load config
|
793
|
+
config_data = ConfigManager.find_config("orchestrator", config_file)
|
794
|
+
if not config_data:
|
795
|
+
console.print("[red]No orchestrator config found[/red]")
|
796
|
+
sys.exit(1)
|
797
|
+
|
798
|
+
# Find config file path for saving
|
799
|
+
config_path = None
|
800
|
+
if config_file:
|
801
|
+
config_path = Path(config_file)
|
802
|
+
else:
|
803
|
+
# Try to find the config file that was loaded
|
804
|
+
for search_path in [
|
805
|
+
Path.cwd() / "orchestrator.yaml",
|
806
|
+
Path.cwd() / "config" / "orchestrator.yaml",
|
807
|
+
Path.home() / ".caption-flow" / "orchestrator.yaml",
|
808
|
+
ConfigManager.get_xdg_config_home() / "caption-flow" / "orchestrator.yaml",
|
809
|
+
]:
|
810
|
+
if search_path.exists():
|
811
|
+
config_path = search_path
|
812
|
+
break
|
813
|
+
|
814
|
+
if not config_path:
|
815
|
+
console.print("[red]Could not determine config file to save to[/red]")
|
816
|
+
sys.exit(1)
|
817
|
+
|
818
|
+
# Remove token from config
|
819
|
+
if not _remove_token_from_config(config_data, role, identifier):
|
820
|
+
sys.exit(1)
|
821
|
+
|
822
|
+
# Save config file
|
823
|
+
if not _save_config_file(config_data, config_path):
|
824
|
+
sys.exit(1)
|
825
|
+
|
826
|
+
# Reload orchestrator if requested
|
827
|
+
if not no_reload:
|
828
|
+
server, admin_token = _load_admin_credentials(config_file, server, admin_token)
|
829
|
+
|
830
|
+
if not server:
|
831
|
+
console.print("[yellow]No server specified, skipping orchestrator reload[/yellow]")
|
832
|
+
elif not admin_token:
|
833
|
+
console.print("[yellow]No admin token specified, skipping orchestrator reload[/yellow]")
|
834
|
+
else:
|
835
|
+
console.print(f"[cyan]Reloading orchestrator config...[/cyan]")
|
836
|
+
success = asyncio.run(
|
837
|
+
_reload_orchestrator_config(server, admin_token, config_data, no_verify_ssl)
|
838
|
+
)
|
839
|
+
if not success:
|
840
|
+
console.print("[yellow]Config file updated but orchestrator reload failed[/yellow]")
|
841
|
+
|
842
|
+
|
843
|
+
@auth.command()
|
844
|
+
@click.argument("role", type=click.Choice(["worker", "admin", "monitor", "all"]), required=False)
|
845
|
+
@click.pass_context
|
846
|
+
def list(ctx, role: Optional[str]):
|
847
|
+
"""List authentication tokens.
|
848
|
+
|
849
|
+
ROLE: Type of tokens to list (worker, admin, monitor, all). Default: all
|
850
|
+
"""
|
851
|
+
config_file = ctx.obj.get("config")
|
852
|
+
|
853
|
+
# Load config
|
854
|
+
config_data = ConfigManager.find_config("orchestrator", config_file)
|
855
|
+
if not config_data:
|
856
|
+
console.print("[red]No orchestrator config found[/red]")
|
857
|
+
sys.exit(1)
|
858
|
+
|
859
|
+
# Show tokens
|
860
|
+
if role == "all" or role is None:
|
861
|
+
_list_tokens_in_config(config_data)
|
862
|
+
else:
|
863
|
+
_list_tokens_in_config(config_data, role)
|
864
|
+
|
865
|
+
|
866
|
+
@auth.command()
|
867
|
+
@click.option("--length", default=32, help="Token length (default: 32)")
|
868
|
+
@click.option("--count", default=1, help="Number of tokens to generate (default: 1)")
|
869
|
+
def generate(length: int, count: int):
|
870
|
+
"""Generate random authentication tokens."""
|
871
|
+
import secrets
|
872
|
+
import string
|
873
|
+
|
874
|
+
alphabet = string.ascii_letters + string.digits + "-_"
|
875
|
+
|
876
|
+
console.print(
|
877
|
+
f"[cyan]Generated {count} token{'s' if count > 1 else ''} ({length} characters each):[/cyan]\n"
|
878
|
+
)
|
879
|
+
|
880
|
+
for i in range(count):
|
881
|
+
token = "".join(secrets.choice(alphabet) for _ in range(length))
|
882
|
+
console.print(f" {i + 1}: {token}")
|
883
|
+
|
884
|
+
|
427
885
|
@main.command()
|
428
886
|
@click.option("--config", type=click.Path(exists=True), help="Configuration file")
|
429
887
|
@click.option("--server", help="Orchestrator WebSocket URL")
|
@@ -441,27 +899,8 @@ def reload_config(
|
|
441
899
|
):
|
442
900
|
"""Reload orchestrator configuration via admin connection."""
|
443
901
|
import websockets
|
444
|
-
import ssl
|
445
902
|
|
446
|
-
|
447
|
-
if not server or not token:
|
448
|
-
base_config = ConfigManager.find_config("orchestrator", config) or {}
|
449
|
-
admin_config = base_config.get("admin", {})
|
450
|
-
admin_tokens = base_config.get("orchestrator", {}).get("auth", {}).get("admin_tokens", [])
|
451
|
-
has_admin_tokens = False
|
452
|
-
if len(admin_tokens) > 0:
|
453
|
-
has_admin_tokens = True
|
454
|
-
first_admin_token = admin_tokens[0].get("token", None)
|
455
|
-
# Do not print sensitive admin token to console.
|
456
|
-
|
457
|
-
if not server:
|
458
|
-
server = admin_config.get("server", "ws://localhost:8765")
|
459
|
-
if not token:
|
460
|
-
token = admin_config.get("token", None)
|
461
|
-
if token is None and has_admin_tokens:
|
462
|
-
# grab the first one, we'll just assume we're localhost.
|
463
|
-
console.print("Using first admin token.")
|
464
|
-
token = first_admin_token
|
903
|
+
server, token = _load_admin_credentials(config, server, token)
|
465
904
|
|
466
905
|
if not server:
|
467
906
|
console.print("[red]Error: --server required (or set in config)[/red]")
|
@@ -472,66 +911,22 @@ def reload_config(
|
|
472
911
|
|
473
912
|
console.print(f"[cyan]Loading configuration from {new_config}...[/cyan]")
|
474
913
|
|
475
|
-
# Load the new configuration
|
476
914
|
new_cfg = ConfigManager.load_yaml(Path(new_config))
|
477
915
|
if not new_cfg:
|
478
916
|
console.print("[red]Failed to load configuration[/red]")
|
479
917
|
sys.exit(1)
|
480
918
|
|
481
|
-
|
482
|
-
ssl_context = None
|
483
|
-
if server.startswith("wss://"):
|
484
|
-
if no_verify_ssl:
|
485
|
-
ssl_context = ssl.create_default_context()
|
486
|
-
ssl_context.check_hostname = False
|
487
|
-
ssl_context.verify_mode = ssl.CERT_NONE
|
488
|
-
else:
|
489
|
-
ssl_context = ssl.create_default_context()
|
919
|
+
ssl_context = _setup_ssl_context(server, no_verify_ssl)
|
490
920
|
|
491
921
|
async def send_reload():
|
492
922
|
try:
|
493
923
|
async with websockets.connect(
|
494
924
|
server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
|
495
925
|
) as websocket:
|
496
|
-
|
497
|
-
await websocket.send(json.dumps({"token": token, "role": "admin"}))
|
498
|
-
|
499
|
-
response = await websocket.recv()
|
500
|
-
auth_response = json.loads(response)
|
501
|
-
|
502
|
-
if "error" in auth_response:
|
503
|
-
console.print(f"[red]Authentication failed: {auth_response['error']}[/red]")
|
926
|
+
if not await _authenticate_admin(websocket, token):
|
504
927
|
return False
|
505
928
|
|
506
|
-
|
507
|
-
|
508
|
-
# Send reload command
|
509
|
-
await websocket.send(json.dumps({"type": "reload_config", "config": new_cfg}))
|
510
|
-
|
511
|
-
response = await websocket.recv()
|
512
|
-
reload_response = json.loads(response)
|
513
|
-
|
514
|
-
if reload_response.get("type") == "reload_complete":
|
515
|
-
if "message" in reload_response and "No changes" in reload_response["message"]:
|
516
|
-
console.print(f"[yellow]{reload_response['message']}[/yellow]")
|
517
|
-
else:
|
518
|
-
console.print("[green]✓ Configuration reloaded successfully![/green]")
|
519
|
-
|
520
|
-
if "updated" in reload_response and reload_response["updated"]:
|
521
|
-
console.print("\n[cyan]Updated sections:[/cyan]")
|
522
|
-
for section in reload_response["updated"]:
|
523
|
-
console.print(f" • {section}")
|
524
|
-
|
525
|
-
if "warnings" in reload_response and reload_response["warnings"]:
|
526
|
-
console.print("\n[yellow]Warnings:[/yellow]")
|
527
|
-
for warning in reload_response["warnings"]:
|
528
|
-
console.print(f" ⚠ {warning}")
|
529
|
-
|
530
|
-
return True
|
531
|
-
else:
|
532
|
-
error = reload_response.get("error", "Unknown error")
|
533
|
-
console.print(f"[red]Reload failed: {error} ({reload_response=})[/red]")
|
534
|
-
return False
|
929
|
+
return await _send_reload_command(websocket, new_cfg)
|
535
930
|
|
536
931
|
except Exception as e:
|
537
932
|
console.print(f"[red]Error: {e}[/red]")
|
@@ -542,39 +937,20 @@ def reload_config(
|
|
542
937
|
sys.exit(1)
|
543
938
|
|
544
939
|
|
545
|
-
|
546
|
-
|
547
|
-
@click.option("--checkpoint-dir", default="./checkpoints", help="Checkpoint directory")
|
548
|
-
@click.option("--fix", is_flag=True, help="Fix issues by resetting abandoned chunks")
|
549
|
-
@click.option("--verbose", is_flag=True, help="Show detailed information")
|
550
|
-
def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
551
|
-
"""Scan for sparse or abandoned chunks and optionally fix them."""
|
552
|
-
from .utils.chunk_tracker import ChunkTracker
|
553
|
-
from .storage import StorageManager
|
554
|
-
import pyarrow.parquet as pq
|
555
|
-
|
556
|
-
console.print("[bold cyan]Scanning for sparse/abandoned chunks...[/bold cyan]\n")
|
557
|
-
|
558
|
-
checkpoint_path = Path(checkpoint_dir) / "chunks.json"
|
559
|
-
if not checkpoint_path.exists():
|
560
|
-
console.print("[red]No chunk checkpoint found![/red]")
|
561
|
-
return
|
562
|
-
|
563
|
-
tracker = ChunkTracker(checkpoint_path)
|
564
|
-
storage = StorageManager(Path(data_dir))
|
565
|
-
|
566
|
-
# Get and display stats
|
567
|
-
stats = tracker.get_stats()
|
940
|
+
def _display_chunk_stats(stats):
|
941
|
+
"""Display chunk statistics."""
|
568
942
|
console.print(f"[green]Total chunks:[/green] {stats['total']}")
|
569
943
|
console.print(f"[green]Completed:[/green] {stats['completed']}")
|
570
944
|
console.print(f"[yellow]Pending:[/yellow] {stats['pending']}")
|
571
945
|
console.print(f"[yellow]Assigned:[/yellow] {stats['assigned']}")
|
572
946
|
console.print(f"[red]Failed:[/red] {stats['failed']}\n")
|
573
947
|
|
574
|
-
|
948
|
+
|
949
|
+
def _find_abandoned_chunks(tracker):
|
950
|
+
"""Find chunks that have been assigned for too long."""
|
575
951
|
abandoned_chunks = []
|
576
952
|
stale_threshold = 3600 # 1 hour
|
577
|
-
current_time = datetime.
|
953
|
+
current_time = datetime.now(_datetime.UTC)
|
578
954
|
|
579
955
|
for chunk_id, chunk_state in tracker.chunks.items():
|
580
956
|
if chunk_state.status == "assigned" and chunk_state.assigned_at:
|
@@ -582,24 +958,31 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
|
582
958
|
if age > stale_threshold:
|
583
959
|
abandoned_chunks.append((chunk_id, chunk_state, age))
|
584
960
|
|
585
|
-
|
586
|
-
console.print(f"[red]Found {len(abandoned_chunks)} abandoned chunks:[/red]")
|
587
|
-
for chunk_id, chunk_state, age in abandoned_chunks[:10]:
|
588
|
-
age_str = f"{age/3600:.1f} hours" if age > 3600 else f"{age/60:.1f} minutes"
|
589
|
-
console.print(f" • {chunk_id} (assigned to {chunk_state.assigned_to} {age_str} ago)")
|
961
|
+
return abandoned_chunks
|
590
962
|
|
591
|
-
if len(abandoned_chunks) > 10:
|
592
|
-
console.print(f" ... and {len(abandoned_chunks) - 10} more")
|
593
963
|
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
console.print(f"[green]✓ Reset {len(abandoned_chunks)} chunks[/green]")
|
964
|
+
def _display_abandoned_chunks(abandoned_chunks, fix, tracker):
|
965
|
+
"""Display abandoned chunks and optionally fix them."""
|
966
|
+
if not abandoned_chunks:
|
967
|
+
return
|
599
968
|
|
600
|
-
|
601
|
-
|
969
|
+
console.print(f"[red]Found {len(abandoned_chunks)} abandoned chunks:[/red]")
|
970
|
+
for chunk_id, chunk_state, age in abandoned_chunks[:10]:
|
971
|
+
age_str = f"{age / 3600:.1f} hours" if age > 3600 else f"{age / 60:.1f} minutes"
|
972
|
+
console.print(f" • {chunk_id} (assigned to {chunk_state.assigned_to} {age_str} ago)")
|
973
|
+
|
974
|
+
if len(abandoned_chunks) > 10:
|
975
|
+
console.print(f" ... and {len(abandoned_chunks) - 10} more")
|
976
|
+
|
977
|
+
if fix:
|
978
|
+
console.print("\n[yellow]Resetting abandoned chunks to pending...[/yellow]")
|
979
|
+
for chunk_id, _, _ in abandoned_chunks:
|
980
|
+
tracker.mark_failed(chunk_id)
|
981
|
+
console.print(f"[green]✓ Reset {len(abandoned_chunks)} chunks[/green]")
|
602
982
|
|
983
|
+
|
984
|
+
def _find_sparse_shards(tracker):
|
985
|
+
"""Find shards with gaps or issues."""
|
603
986
|
shards_summary = tracker.get_shards_summary()
|
604
987
|
sparse_shards = []
|
605
988
|
|
@@ -618,60 +1001,108 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
|
618
1001
|
if has_gaps or shard_info["failed_chunks"] > 0:
|
619
1002
|
sparse_shards.append((shard_name, shard_info, has_gaps))
|
620
1003
|
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
1004
|
+
return sparse_shards
|
1005
|
+
|
1006
|
+
|
1007
|
+
def _display_sparse_shards(sparse_shards):
|
1008
|
+
"""Display sparse/incomplete shards."""
|
1009
|
+
if not sparse_shards:
|
1010
|
+
return
|
1011
|
+
|
1012
|
+
console.print(f"\n[yellow]Found {len(sparse_shards)} sparse/incomplete shards:[/yellow]")
|
1013
|
+
for shard_name, shard_info, has_gaps in sparse_shards[:5]:
|
1014
|
+
status = []
|
1015
|
+
if shard_info["pending_chunks"] > 0:
|
1016
|
+
status.append(f"{shard_info['pending_chunks']} pending")
|
1017
|
+
if shard_info["assigned_chunks"] > 0:
|
1018
|
+
status.append(f"{shard_info['assigned_chunks']} assigned")
|
1019
|
+
if shard_info["failed_chunks"] > 0:
|
1020
|
+
status.append(f"{shard_info['failed_chunks']} failed")
|
1021
|
+
if has_gaps:
|
1022
|
+
status.append("has gaps")
|
1023
|
+
|
1024
|
+
console.print(f" • {shard_name}: {', '.join(status)}")
|
1025
|
+
console.print(
|
1026
|
+
f" Progress: {shard_info['completed_chunks']}/{shard_info['total_chunks']} chunks"
|
1027
|
+
)
|
1028
|
+
|
1029
|
+
if len(sparse_shards) > 5:
|
1030
|
+
console.print(f" ... and {len(sparse_shards) - 5} more")
|
1031
|
+
|
1032
|
+
|
1033
|
+
def _cross_check_storage(storage, tracker, fix):
|
1034
|
+
"""Cross-check chunk tracker against storage."""
|
1035
|
+
import pyarrow.parquet as pq
|
1036
|
+
|
1037
|
+
console.print("\n[bold cyan]Cross-checking with stored captions...[/bold cyan]")
|
1038
|
+
|
1039
|
+
try:
|
1040
|
+
table = pq.read_table(storage.captions_path, columns=["chunk_id"])
|
1041
|
+
stored_chunk_ids = set(c for c in table["chunk_id"].to_pylist() if c)
|
1042
|
+
|
1043
|
+
tracker_completed = set(c for c, s in tracker.chunks.items() if s.status == "completed")
|
1044
|
+
|
1045
|
+
missing_in_storage = tracker_completed - stored_chunk_ids
|
1046
|
+
missing_in_tracker = stored_chunk_ids - set(tracker.chunks.keys())
|
1047
|
+
|
1048
|
+
if missing_in_storage:
|
635
1049
|
console.print(
|
636
|
-
f"
|
1050
|
+
f"\n[red]Chunks marked complete but missing from storage:[/red] {len(missing_in_storage)}"
|
637
1051
|
)
|
1052
|
+
for chunk_id in list(missing_in_storage)[:5]:
|
1053
|
+
console.print(f" • {chunk_id}")
|
638
1054
|
|
639
|
-
|
640
|
-
|
1055
|
+
if fix:
|
1056
|
+
console.print("[yellow]Resetting these chunks to pending...[/yellow]")
|
1057
|
+
for chunk_id in missing_in_storage:
|
1058
|
+
tracker.mark_failed(chunk_id)
|
1059
|
+
console.print(f"[green]✓ Reset {len(missing_in_storage)} chunks[/green]")
|
641
1060
|
|
642
|
-
|
643
|
-
|
644
|
-
|
1061
|
+
if missing_in_tracker:
|
1062
|
+
console.print(
|
1063
|
+
f"\n[yellow]Chunks in storage but not tracked:[/yellow] {len(missing_in_tracker)}"
|
1064
|
+
)
|
645
1065
|
|
646
|
-
|
647
|
-
|
648
|
-
stored_chunk_ids = set(c for c in table["chunk_id"].to_pylist() if c)
|
1066
|
+
except Exception as e:
|
1067
|
+
console.print(f"[red]Error reading storage: {e}[/red]")
|
649
1068
|
|
650
|
-
tracker_completed = set(c for c, s in tracker.chunks.items() if s.status == "completed")
|
651
1069
|
|
652
|
-
|
653
|
-
|
1070
|
+
@main.command()
|
1071
|
+
@click.option("--data-dir", default="./caption_data", help="Storage directory")
|
1072
|
+
@click.option("--checkpoint-dir", default="./checkpoints", help="Checkpoint directory")
|
1073
|
+
@click.option("--fix", is_flag=True, help="Fix issues by resetting abandoned chunks")
|
1074
|
+
@click.option("--verbose", is_flag=True, help="Show detailed information")
|
1075
|
+
def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
1076
|
+
"""Scan for sparse or abandoned chunks and optionally fix them."""
|
1077
|
+
from .storage import StorageManager
|
1078
|
+
from .utils.chunk_tracker import ChunkTracker
|
654
1079
|
|
655
|
-
|
656
|
-
console.print(
|
657
|
-
f"\n[red]Chunks marked complete but missing from storage:[/red] {len(missing_in_storage)}"
|
658
|
-
)
|
659
|
-
for chunk_id in list(missing_in_storage)[:5]:
|
660
|
-
console.print(f" • {chunk_id}")
|
1080
|
+
console.print("[bold cyan]Scanning for sparse/abandoned chunks...[/bold cyan]\n")
|
661
1081
|
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
console.print(f"[green]✓ Reset {len(missing_in_storage)} chunks[/green]")
|
1082
|
+
checkpoint_path = Path(checkpoint_dir) / "chunks.json"
|
1083
|
+
if not checkpoint_path.exists():
|
1084
|
+
console.print("[red]No chunk checkpoint found![/red]")
|
1085
|
+
return
|
667
1086
|
|
668
|
-
|
669
|
-
|
670
|
-
f"\n[yellow]Chunks in storage but not tracked:[/yellow] {len(missing_in_tracker)}"
|
671
|
-
)
|
1087
|
+
tracker = ChunkTracker(checkpoint_path)
|
1088
|
+
storage = StorageManager(Path(data_dir))
|
672
1089
|
|
673
|
-
|
674
|
-
|
1090
|
+
# Get and display stats
|
1091
|
+
stats = tracker.get_stats()
|
1092
|
+
_display_chunk_stats(stats)
|
1093
|
+
|
1094
|
+
# Find and handle abandoned chunks
|
1095
|
+
abandoned_chunks = _find_abandoned_chunks(tracker)
|
1096
|
+
_display_abandoned_chunks(abandoned_chunks, fix, tracker)
|
1097
|
+
|
1098
|
+
# Check for sparse shards
|
1099
|
+
console.print("\n[bold cyan]Checking for sparse shards...[/bold cyan]")
|
1100
|
+
sparse_shards = _find_sparse_shards(tracker)
|
1101
|
+
_display_sparse_shards(sparse_shards)
|
1102
|
+
|
1103
|
+
# Cross-check with storage if verbose
|
1104
|
+
if storage.captions_path.exists() and verbose:
|
1105
|
+
_cross_check_storage(storage, tracker, fix)
|
675
1106
|
|
676
1107
|
# Summary
|
677
1108
|
console.print("\n[bold cyan]Summary:[/bold cyan]")
|
@@ -695,12 +1126,163 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
|
695
1126
|
tracker.save_checkpoint()
|
696
1127
|
|
697
1128
|
|
1129
|
+
def _display_export_stats(stats):
|
1130
|
+
"""Display storage statistics."""
|
1131
|
+
console.print("\n[bold cyan]Storage Statistics:[/bold cyan]")
|
1132
|
+
console.print(f"[green]Total rows:[/green] {stats['total_rows']:,}")
|
1133
|
+
console.print(f"[green]Total outputs:[/green] {stats['total_outputs']:,}")
|
1134
|
+
console.print(f"[green]Shards:[/green] {stats['shard_count']} ({', '.join(stats['shards'])})")
|
1135
|
+
console.print(f"[green]Output fields:[/green] {', '.join(stats['output_fields'])}")
|
1136
|
+
|
1137
|
+
if stats.get("field_stats"):
|
1138
|
+
console.print("\n[cyan]Field breakdown:[/cyan]")
|
1139
|
+
for field, count in stats["field_stats"].items():
|
1140
|
+
console.print(f" • {field}: {count['total_items']:,} items")
|
1141
|
+
|
1142
|
+
|
1143
|
+
def _prepare_export_params(shard, shards, columns):
|
1144
|
+
"""Prepare shard filter and column list."""
|
1145
|
+
shard_filter = None
|
1146
|
+
if shard:
|
1147
|
+
shard_filter = [shard]
|
1148
|
+
elif shards:
|
1149
|
+
shard_filter = [s.strip() for s in shards.split(",")]
|
1150
|
+
|
1151
|
+
column_list = None
|
1152
|
+
if columns:
|
1153
|
+
column_list = [col.strip() for col in columns.split(",")]
|
1154
|
+
console.print(f"\n[cyan]Exporting columns:[/cyan] {', '.join(column_list)}")
|
1155
|
+
|
1156
|
+
return shard_filter, column_list
|
1157
|
+
|
1158
|
+
|
1159
|
+
async def _export_all_formats(
|
1160
|
+
exporter, output, shard_filter, column_list, limit, filename_column, export_column
|
1161
|
+
):
|
1162
|
+
"""Export to all formats."""
|
1163
|
+
base_name = output or "caption_export"
|
1164
|
+
base_path = Path(base_name)
|
1165
|
+
results = {}
|
1166
|
+
|
1167
|
+
for export_format in ["jsonl", "csv", "parquet", "json", "txt"]:
|
1168
|
+
console.print(f"\n[cyan]Exporting to {export_format.upper()}...[/cyan]")
|
1169
|
+
try:
|
1170
|
+
format_results = await exporter.export_all_shards(
|
1171
|
+
export_format,
|
1172
|
+
base_path,
|
1173
|
+
columns=column_list,
|
1174
|
+
limit_per_shard=limit,
|
1175
|
+
shard_filter=shard_filter,
|
1176
|
+
filename_column=filename_column,
|
1177
|
+
export_column=export_column,
|
1178
|
+
)
|
1179
|
+
results[export_format] = sum(format_results.values())
|
1180
|
+
except Exception as e:
|
1181
|
+
console.print(f"[yellow]Skipping {export_format}: {e}[/yellow]")
|
1182
|
+
results[export_format] = 0
|
1183
|
+
|
1184
|
+
console.print("\n[green]✓ Export complete![/green]")
|
1185
|
+
for fmt, count in results.items():
|
1186
|
+
if count > 0:
|
1187
|
+
console.print(f" • {fmt.upper()}: {count:,} items")
|
1188
|
+
|
1189
|
+
|
1190
|
+
async def _export_to_lance(exporter, output, column_list, shard_filter):
|
1191
|
+
"""Export to Lance dataset."""
|
1192
|
+
output_path = output or "exported_captions.lance"
|
1193
|
+
console.print(f"\n[cyan]Exporting to Lance dataset:[/cyan] {output_path}")
|
1194
|
+
total_rows = await exporter.export_to_lance(
|
1195
|
+
output_path, columns=column_list, shard_filter=shard_filter
|
1196
|
+
)
|
1197
|
+
console.print(f"[green]✓ Exported {total_rows:,} rows to Lance dataset[/green]")
|
1198
|
+
|
1199
|
+
|
1200
|
+
async def _export_to_huggingface(exporter, hf_dataset, license, private, nsfw, tags, shard_filter):
|
1201
|
+
"""Export to Hugging Face Hub."""
|
1202
|
+
if not hf_dataset:
|
1203
|
+
console.print("[red]Error: --hf-dataset required for huggingface_hub format[/red]")
|
1204
|
+
console.print("[dim]Example: --hf-dataset username/my-caption-dataset[/dim]")
|
1205
|
+
sys.exit(1)
|
1206
|
+
|
1207
|
+
tag_list = None
|
1208
|
+
if tags:
|
1209
|
+
tag_list = [tag.strip() for tag in tags.split(",")]
|
1210
|
+
|
1211
|
+
console.print(f"\n[cyan]Uploading to Hugging Face Hub:[/cyan] {hf_dataset}")
|
1212
|
+
if private:
|
1213
|
+
console.print("[dim]Privacy: Private dataset[/dim]")
|
1214
|
+
if nsfw:
|
1215
|
+
console.print("[dim]Content: Not for all audiences[/dim]")
|
1216
|
+
if tag_list:
|
1217
|
+
console.print(f"[dim]Tags: {', '.join(tag_list)}[/dim]")
|
1218
|
+
if shard_filter:
|
1219
|
+
console.print(f"[dim]Shards: {', '.join(shard_filter)}[/dim]")
|
1220
|
+
|
1221
|
+
url = await exporter.export_to_huggingface_hub(
|
1222
|
+
dataset_name=hf_dataset,
|
1223
|
+
license=license,
|
1224
|
+
private=private,
|
1225
|
+
nsfw=nsfw,
|
1226
|
+
tags=tag_list,
|
1227
|
+
shard_filter=shard_filter,
|
1228
|
+
)
|
1229
|
+
console.print(f"[green]✓ Dataset uploaded to: {url}[/green]")
|
1230
|
+
|
1231
|
+
|
1232
|
+
async def _export_single_format(
|
1233
|
+
exporter,
|
1234
|
+
format,
|
1235
|
+
output,
|
1236
|
+
shard_filter,
|
1237
|
+
column_list,
|
1238
|
+
limit,
|
1239
|
+
filename_column,
|
1240
|
+
export_column,
|
1241
|
+
verbose,
|
1242
|
+
):
|
1243
|
+
"""Export to a single format."""
|
1244
|
+
output_path = output or "export"
|
1245
|
+
|
1246
|
+
if shard_filter and len(shard_filter) == 1:
|
1247
|
+
console.print(f"\n[cyan]Exporting shard {shard_filter[0]} to {format.upper()}...[/cyan]")
|
1248
|
+
count = await exporter.export_shard(
|
1249
|
+
shard_filter[0],
|
1250
|
+
format,
|
1251
|
+
output_path,
|
1252
|
+
columns=column_list,
|
1253
|
+
limit=limit,
|
1254
|
+
filename_column=filename_column,
|
1255
|
+
export_column=export_column,
|
1256
|
+
)
|
1257
|
+
console.print(f"[green]✓ Exported {count:,} items[/green]")
|
1258
|
+
else:
|
1259
|
+
console.print(f"\n[cyan]Exporting to {format.upper()}...[/cyan]")
|
1260
|
+
results = await exporter.export_all_shards(
|
1261
|
+
format,
|
1262
|
+
output_path,
|
1263
|
+
columns=column_list,
|
1264
|
+
limit_per_shard=limit,
|
1265
|
+
shard_filter=shard_filter,
|
1266
|
+
filename_column=filename_column,
|
1267
|
+
export_column=export_column,
|
1268
|
+
)
|
1269
|
+
|
1270
|
+
total = sum(results.values())
|
1271
|
+
console.print(f"[green]✓ Exported {total:,} items total[/green]")
|
1272
|
+
|
1273
|
+
if verbose and len(results) > 1:
|
1274
|
+
console.print("\n[dim]Per-shard breakdown:[/dim]")
|
1275
|
+
for shard_name, count in sorted(results.items()):
|
1276
|
+
console.print(f" • {shard_name}: {count:,} items")
|
1277
|
+
|
1278
|
+
|
698
1279
|
@main.command()
|
699
1280
|
@click.option("--data-dir", default="./caption_data", help="Storage directory")
|
700
1281
|
@click.option(
|
701
1282
|
"--format",
|
702
1283
|
type=click.Choice(
|
703
|
-
["jsonl", "json", "csv", "txt", "huggingface_hub", "all"],
|
1284
|
+
["jsonl", "json", "csv", "txt", "parquet", "lance", "huggingface_hub", "all"],
|
1285
|
+
case_sensitive=False,
|
704
1286
|
),
|
705
1287
|
default="jsonl",
|
706
1288
|
help="Export format (default: jsonl)",
|
@@ -710,17 +1292,117 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
|
710
1292
|
@click.option("--columns", help="Comma-separated list of columns to export (default: all)")
|
711
1293
|
@click.option("--export-column", default="captions", help="Column to export for txt format")
|
712
1294
|
@click.option("--filename-column", default="filename", help="Column containing filenames")
|
1295
|
+
@click.option("--shard", help="Specific shard to export (e.g., data-0001)")
|
1296
|
+
@click.option("--shards", help="Comma-separated list of shards to export")
|
713
1297
|
@click.option("--include-empty", is_flag=True, help="Include rows with empty export column")
|
714
1298
|
@click.option("--stats-only", is_flag=True, help="Show statistics without exporting")
|
715
|
-
@click.option(
|
716
|
-
"--optimize", is_flag=True, help="Optimize storage before export (remove empty columns)"
|
717
|
-
)
|
1299
|
+
@click.option("--optimize", is_flag=True, help="Optimize storage before export")
|
718
1300
|
@click.option("--verbose", is_flag=True, help="Show detailed export progress")
|
719
1301
|
@click.option("--hf-dataset", help="Dataset name on HF Hub (e.g., username/dataset-name)")
|
720
|
-
@click.option("--license", help="License for the dataset
|
1302
|
+
@click.option("--license", default="apache-2.0", help="License for the dataset")
|
721
1303
|
@click.option("--private", is_flag=True, help="Make HF dataset private")
|
722
1304
|
@click.option("--nsfw", is_flag=True, help="Add not-for-all-audiences tag")
|
723
1305
|
@click.option("--tags", help="Comma-separated tags for HF dataset")
|
1306
|
+
def _validate_export_setup(data_dir):
|
1307
|
+
"""Validate export setup and create storage manager."""
|
1308
|
+
from .storage import StorageManager
|
1309
|
+
|
1310
|
+
storage_path = Path(data_dir)
|
1311
|
+
if not storage_path.exists():
|
1312
|
+
console.print(f"[red]Storage directory not found: {data_dir}[/red]")
|
1313
|
+
sys.exit(1)
|
1314
|
+
|
1315
|
+
return StorageManager(storage_path)
|
1316
|
+
|
1317
|
+
|
1318
|
+
async def _run_export_process(
|
1319
|
+
storage,
|
1320
|
+
format,
|
1321
|
+
output,
|
1322
|
+
shard,
|
1323
|
+
shards,
|
1324
|
+
columns,
|
1325
|
+
limit,
|
1326
|
+
filename_column,
|
1327
|
+
export_column,
|
1328
|
+
verbose,
|
1329
|
+
hf_dataset,
|
1330
|
+
license,
|
1331
|
+
private,
|
1332
|
+
nsfw,
|
1333
|
+
tags,
|
1334
|
+
stats_only,
|
1335
|
+
optimize,
|
1336
|
+
):
|
1337
|
+
"""Execute the main export process."""
|
1338
|
+
from .storage.exporter import LanceStorageExporter
|
1339
|
+
|
1340
|
+
await storage.initialize()
|
1341
|
+
|
1342
|
+
stats = await storage.get_caption_stats()
|
1343
|
+
_display_export_stats(stats)
|
1344
|
+
|
1345
|
+
if stats_only:
|
1346
|
+
return
|
1347
|
+
|
1348
|
+
if optimize:
|
1349
|
+
console.print("\n[yellow]Optimizing storage...[/yellow]")
|
1350
|
+
await storage.optimize_storage()
|
1351
|
+
|
1352
|
+
shard_filter, column_list = _prepare_export_params(shard, shards, columns)
|
1353
|
+
exporter = LanceStorageExporter(storage)
|
1354
|
+
|
1355
|
+
if format == "all":
|
1356
|
+
await _export_all_formats(
|
1357
|
+
exporter, output, shard_filter, column_list, limit, filename_column, export_column
|
1358
|
+
)
|
1359
|
+
elif format == "lance":
|
1360
|
+
await _export_to_lance(exporter, output, column_list, shard_filter)
|
1361
|
+
elif format == "huggingface_hub":
|
1362
|
+
await _export_to_huggingface(
|
1363
|
+
exporter, hf_dataset, license, private, nsfw, tags, shard_filter
|
1364
|
+
)
|
1365
|
+
else:
|
1366
|
+
await _export_single_format(
|
1367
|
+
exporter,
|
1368
|
+
format,
|
1369
|
+
output,
|
1370
|
+
shard_filter,
|
1371
|
+
column_list,
|
1372
|
+
limit,
|
1373
|
+
filename_column,
|
1374
|
+
export_column,
|
1375
|
+
verbose,
|
1376
|
+
)
|
1377
|
+
|
1378
|
+
|
1379
|
+
@main.command()
|
1380
|
+
@click.option("--data-dir", default="./caption_data", help="Storage directory")
|
1381
|
+
@click.option(
|
1382
|
+
"--format",
|
1383
|
+
type=click.Choice(
|
1384
|
+
["jsonl", "json", "csv", "txt", "parquet", "lance", "huggingface_hub", "all"],
|
1385
|
+
case_sensitive=False,
|
1386
|
+
),
|
1387
|
+
default="jsonl",
|
1388
|
+
help="Export format (default: jsonl)",
|
1389
|
+
)
|
1390
|
+
@click.option("--output", help="Output filename or directory")
|
1391
|
+
@click.option("--limit", type=int, help="Maximum number of items to export")
|
1392
|
+
@click.option("--columns", help="Comma-separated list of columns to include")
|
1393
|
+
@click.option("--export-column", default="captions", help="Column to export (default: captions)")
|
1394
|
+
@click.option("--filename-column", default="filename", help="Filename column (default: filename)")
|
1395
|
+
@click.option("--shard", help="Export only specific shard (e.g., 'data-001')")
|
1396
|
+
@click.option("--shards", help="Comma-separated list of shards to export")
|
1397
|
+
@click.option("--include-empty", is_flag=True, help="Include items with empty/null export column")
|
1398
|
+
@click.option("--stats-only", is_flag=True, help="Show statistics only, don't export")
|
1399
|
+
@click.option("--optimize", is_flag=True, help="Optimize storage before export")
|
1400
|
+
@click.option("--verbose", is_flag=True, help="Verbose output")
|
1401
|
+
@click.option("--hf-dataset", help="HuggingFace Hub dataset name (for huggingface_hub format)")
|
1402
|
+
@click.option("--license", default="MIT", help="Dataset license (default: MIT)")
|
1403
|
+
@click.option("--private", is_flag=True, help="Make HuggingFace dataset private")
|
1404
|
+
@click.option("--nsfw", is_flag=True, help="Mark dataset as NSFW")
|
1405
|
+
@click.option("--tags", help="Comma-separated tags for HuggingFace dataset")
|
724
1406
|
def export(
|
725
1407
|
data_dir: str,
|
726
1408
|
format: str,
|
@@ -729,219 +1411,56 @@ def export(
|
|
729
1411
|
columns: Optional[str],
|
730
1412
|
export_column: str,
|
731
1413
|
filename_column: str,
|
1414
|
+
shard: Optional[str],
|
1415
|
+
shards: Optional[str],
|
732
1416
|
include_empty: bool,
|
733
1417
|
stats_only: bool,
|
734
1418
|
optimize: bool,
|
735
1419
|
verbose: bool,
|
736
1420
|
hf_dataset: Optional[str],
|
737
|
-
license:
|
1421
|
+
license: str,
|
738
1422
|
private: bool,
|
739
1423
|
nsfw: bool,
|
740
1424
|
tags: Optional[str],
|
741
1425
|
):
|
742
|
-
"""Export caption data to various formats."""
|
743
|
-
from .storage import
|
744
|
-
from .storage.exporter import StorageExporter, ExportError
|
1426
|
+
"""Export caption data to various formats with per-shard support."""
|
1427
|
+
from .storage.exporter import ExportError
|
745
1428
|
|
746
|
-
|
747
|
-
storage_path = Path(data_dir)
|
748
|
-
if not storage_path.exists():
|
749
|
-
console.print(f"[red]Storage directory not found: {data_dir}[/red]")
|
750
|
-
sys.exit(1)
|
751
|
-
|
752
|
-
storage = StorageManager(storage_path)
|
753
|
-
|
754
|
-
async def run_export():
|
755
|
-
await storage.initialize()
|
756
|
-
|
757
|
-
# Show statistics first
|
758
|
-
stats = await storage.get_caption_stats()
|
759
|
-
console.print("\n[bold cyan]Storage Statistics:[/bold cyan]")
|
760
|
-
console.print(f"[green]Total rows:[/green] {stats['total_rows']:,}")
|
761
|
-
console.print(f"[green]Total outputs:[/green] {stats['total_outputs']:,}")
|
762
|
-
console.print(f"[green]Output fields:[/green] {', '.join(stats['output_fields'])}")
|
763
|
-
|
764
|
-
if stats.get("field_stats"):
|
765
|
-
console.print("\n[cyan]Field breakdown:[/cyan]")
|
766
|
-
for field, field_stat in stats["field_stats"].items():
|
767
|
-
console.print(
|
768
|
-
f" • {field}: {field_stat['total_items']:,} items "
|
769
|
-
f"in {field_stat['rows_with_data']:,} rows"
|
770
|
-
)
|
771
|
-
|
772
|
-
if stats_only:
|
773
|
-
return
|
774
|
-
|
775
|
-
# Optimize storage if requested
|
776
|
-
if optimize:
|
777
|
-
console.print("\n[yellow]Optimizing storage (removing empty columns)...[/yellow]")
|
778
|
-
await storage.optimize_storage()
|
779
|
-
|
780
|
-
# Prepare columns list
|
781
|
-
column_list = None
|
782
|
-
if columns:
|
783
|
-
column_list = [col.strip() for col in columns.split(",")]
|
784
|
-
console.print(f"\n[cyan]Exporting columns:[/cyan] {', '.join(column_list)}")
|
785
|
-
|
786
|
-
# Get storage contents
|
787
|
-
console.print("\n[yellow]Loading data...[/yellow]")
|
788
|
-
try:
|
789
|
-
contents = await storage.get_storage_contents(
|
790
|
-
limit=limit, columns=column_list, include_metadata=True
|
791
|
-
)
|
792
|
-
except ValueError as e:
|
793
|
-
console.print(f"[red]Error: {e}[/red]")
|
794
|
-
sys.exit(1)
|
1429
|
+
storage = _validate_export_setup(data_dir)
|
795
1430
|
|
796
|
-
if not contents.rows:
|
797
|
-
console.print("[yellow]No data to export![/yellow]")
|
798
|
-
return
|
799
|
-
|
800
|
-
# Filter out empty rows if not including empty
|
801
|
-
if not include_empty and format in ["txt", "json"]:
|
802
|
-
original_count = len(contents.rows)
|
803
|
-
contents.rows = [
|
804
|
-
row
|
805
|
-
for row in contents.rows
|
806
|
-
if row.get(export_column)
|
807
|
-
and (not isinstance(row[export_column], list) or len(row[export_column]) > 0)
|
808
|
-
]
|
809
|
-
filtered_count = original_count - len(contents.rows)
|
810
|
-
if filtered_count > 0:
|
811
|
-
console.print(f"[dim]Filtered {filtered_count} empty rows[/dim]")
|
812
|
-
|
813
|
-
# Create exporter
|
814
|
-
exporter = StorageExporter(contents)
|
815
|
-
|
816
|
-
# Determine output paths
|
817
|
-
if format == "all":
|
818
|
-
# Export to all formats
|
819
|
-
base_name = output or "caption_export"
|
820
|
-
base_path = Path(base_name)
|
821
|
-
|
822
|
-
formats_exported = []
|
823
|
-
|
824
|
-
# JSONL
|
825
|
-
jsonl_path = base_path.with_suffix(".jsonl")
|
826
|
-
console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {jsonl_path}")
|
827
|
-
rows = exporter.to_jsonl(jsonl_path)
|
828
|
-
formats_exported.append(f"JSONL: {rows:,} rows")
|
829
|
-
|
830
|
-
# CSV
|
831
|
-
csv_path = base_path.with_suffix(".csv")
|
832
|
-
console.print(f"[cyan]Exporting to CSV:[/cyan] {csv_path}")
|
833
|
-
try:
|
834
|
-
rows = exporter.to_csv(csv_path)
|
835
|
-
formats_exported.append(f"CSV: {rows:,} rows")
|
836
|
-
except ExportError as e:
|
837
|
-
console.print(f"[yellow]Skipping CSV: {e}[/yellow]")
|
838
|
-
|
839
|
-
# JSON files
|
840
|
-
json_dir = base_path.parent / f"{base_path.stem}_json"
|
841
|
-
console.print(f"[cyan]Exporting to JSON files:[/cyan] {json_dir}/")
|
842
|
-
try:
|
843
|
-
files = exporter.to_json(json_dir, filename_column)
|
844
|
-
formats_exported.append(f"JSON: {files:,} files")
|
845
|
-
except ExportError as e:
|
846
|
-
console.print(f"[yellow]Skipping JSON files: {e}[/yellow]")
|
847
|
-
|
848
|
-
# Text files
|
849
|
-
txt_dir = base_path.parent / f"{base_path.stem}_txt"
|
850
|
-
console.print(f"[cyan]Exporting to text files:[/cyan] {txt_dir}/")
|
851
|
-
try:
|
852
|
-
files = exporter.to_txt(txt_dir, filename_column, export_column)
|
853
|
-
formats_exported.append(f"Text: {files:,} files")
|
854
|
-
except ExportError as e:
|
855
|
-
console.print(f"[yellow]Skipping text files: {e}[/yellow]")
|
856
|
-
|
857
|
-
console.print(f"\n[green]✓ Export complete![/green]")
|
858
|
-
for fmt in formats_exported:
|
859
|
-
console.print(f" • {fmt}")
|
860
|
-
|
861
|
-
else:
|
862
|
-
# Single format export
|
863
|
-
try:
|
864
|
-
if format == "jsonl":
|
865
|
-
output_path = output or "captions.jsonl"
|
866
|
-
console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {output_path}")
|
867
|
-
rows = exporter.to_jsonl(output_path)
|
868
|
-
console.print(f"[green]✓ Exported {rows:,} rows[/green]")
|
869
|
-
|
870
|
-
elif format == "csv":
|
871
|
-
output_path = output or "captions.csv"
|
872
|
-
console.print(f"\n[cyan]Exporting to CSV:[/cyan] {output_path}")
|
873
|
-
rows = exporter.to_csv(output_path)
|
874
|
-
console.print(f"[green]✓ Exported {rows:,} rows[/green]")
|
875
|
-
|
876
|
-
elif format == "json":
|
877
|
-
output_dir = output or "./json_output"
|
878
|
-
console.print(f"\n[cyan]Exporting to JSON files:[/cyan] {output_dir}/")
|
879
|
-
files = exporter.to_json(output_dir, filename_column)
|
880
|
-
console.print(f"[green]✓ Created {files:,} JSON files[/green]")
|
881
|
-
|
882
|
-
elif format == "txt":
|
883
|
-
output_dir = output or "./txt_output"
|
884
|
-
console.print(f"\n[cyan]Exporting to text files:[/cyan] {output_dir}/")
|
885
|
-
console.print(f"[dim]Export column: {export_column}[/dim]")
|
886
|
-
files = exporter.to_txt(output_dir, filename_column, export_column)
|
887
|
-
console.print(f"[green]✓ Created {files:,} text files[/green]")
|
888
|
-
|
889
|
-
elif format == "huggingface_hub":
|
890
|
-
# Validate required parameters
|
891
|
-
if not hf_dataset:
|
892
|
-
console.print(
|
893
|
-
"[red]Error: --hf-dataset required for huggingface_hub format[/red]"
|
894
|
-
)
|
895
|
-
console.print(
|
896
|
-
"[dim]Example: --hf-dataset username/my-caption-dataset[/dim]"
|
897
|
-
)
|
898
|
-
sys.exit(1)
|
899
|
-
|
900
|
-
# Parse tags
|
901
|
-
tag_list = None
|
902
|
-
if tags:
|
903
|
-
tag_list = [tag.strip() for tag in tags.split(",")]
|
904
|
-
|
905
|
-
console.print(f"\n[cyan]Uploading to Hugging Face Hub:[/cyan] {hf_dataset}")
|
906
|
-
if private:
|
907
|
-
console.print("[dim]Privacy: Private dataset[/dim]")
|
908
|
-
if nsfw:
|
909
|
-
console.print("[dim]Content: Not for all audiences[/dim]")
|
910
|
-
if tag_list:
|
911
|
-
console.print(f"[dim]Tags: {', '.join(tag_list)}[/dim]")
|
912
|
-
|
913
|
-
url = exporter.to_huggingface_hub(
|
914
|
-
dataset_name=hf_dataset,
|
915
|
-
license=license,
|
916
|
-
private=private,
|
917
|
-
nsfw=nsfw,
|
918
|
-
tags=tag_list,
|
919
|
-
)
|
920
|
-
console.print(f"[green]✓ Dataset uploaded to: {url}[/green]")
|
921
|
-
|
922
|
-
except ExportError as e:
|
923
|
-
console.print(f"[red]Export error: {e}[/red]")
|
924
|
-
sys.exit(1)
|
925
|
-
|
926
|
-
# Show export metadata
|
927
|
-
if verbose and contents.metadata:
|
928
|
-
console.print("\n[dim]Export metadata:[/dim]")
|
929
|
-
console.print(f" Timestamp: {contents.metadata.get('export_timestamp')}")
|
930
|
-
console.print(f" Total available: {contents.metadata.get('total_available_rows'):,}")
|
931
|
-
console.print(f" Rows exported: {contents.metadata.get('rows_exported'):,}")
|
932
|
-
|
933
|
-
# Run the async export
|
934
1431
|
try:
|
935
|
-
asyncio.run(
|
1432
|
+
asyncio.run(
|
1433
|
+
_run_export_process(
|
1434
|
+
storage,
|
1435
|
+
format,
|
1436
|
+
output,
|
1437
|
+
shard,
|
1438
|
+
shards,
|
1439
|
+
columns,
|
1440
|
+
limit,
|
1441
|
+
filename_column,
|
1442
|
+
export_column,
|
1443
|
+
verbose,
|
1444
|
+
hf_dataset,
|
1445
|
+
license,
|
1446
|
+
private,
|
1447
|
+
nsfw,
|
1448
|
+
tags,
|
1449
|
+
stats_only,
|
1450
|
+
optimize,
|
1451
|
+
)
|
1452
|
+
)
|
1453
|
+
except ExportError as e:
|
1454
|
+
console.print(f"[red]Export error: {e}[/red]")
|
1455
|
+
sys.exit(1)
|
936
1456
|
except KeyboardInterrupt:
|
937
1457
|
console.print("\n[yellow]Export cancelled[/yellow]")
|
938
1458
|
sys.exit(1)
|
939
1459
|
except Exception as e:
|
940
1460
|
console.print(f"[red]Unexpected error: {e}[/red]")
|
941
|
-
|
942
|
-
import traceback
|
1461
|
+
import traceback
|
943
1462
|
|
944
|
-
|
1463
|
+
traceback.print_exc()
|
945
1464
|
sys.exit(1)
|
946
1465
|
|
947
1466
|
|
@@ -963,7 +1482,7 @@ def generate_cert(
|
|
963
1482
|
cert_path, key_path = cert_manager.generate_self_signed(Path(output_dir), cert_domain)
|
964
1483
|
console.print(f"[green]✓[/green] Certificate: {cert_path}")
|
965
1484
|
console.print(f"[green]✓[/green] Key: {key_path}")
|
966
|
-
console.print(
|
1485
|
+
console.print("\n[cyan]Use these paths in your config or CLI:[/cyan]")
|
967
1486
|
console.print(f" --cert {cert_path}")
|
968
1487
|
console.print(f" --key {key_path}")
|
969
1488
|
elif domain and email:
|
@@ -980,7 +1499,7 @@ def generate_cert(
|
|
980
1499
|
)
|
981
1500
|
console.print(f"[green]✓[/green] Certificate: {cert_path}")
|
982
1501
|
console.print(f"[green]✓[/green] Key: {key_path}")
|
983
|
-
console.print(
|
1502
|
+
console.print("\n[cyan]Use these paths in your config or CLI:[/cyan]")
|
984
1503
|
console.print(f" --cert {cert_path}")
|
985
1504
|
console.print(f" --key {key_path}")
|
986
1505
|
|
@@ -1024,13 +1543,13 @@ def inspect_cert(cert_path: str):
|
|
1024
1543
|
|
1025
1544
|
from datetime import datetime
|
1026
1545
|
|
1027
|
-
if info["not_after"] < datetime.
|
1546
|
+
if info["not_after"] < datetime.now(_datetime.UTC):
|
1028
1547
|
console.print("[red]✗ Certificate has expired![/red]")
|
1029
|
-
elif (info["not_after"] - datetime.
|
1030
|
-
days_left = (info["not_after"] - datetime.
|
1548
|
+
elif (info["not_after"] - datetime.now(_datetime.UTC)).days < 30:
|
1549
|
+
days_left = (info["not_after"] - datetime.now(_datetime.UTC)).days
|
1031
1550
|
console.print(f"[yellow]⚠ Certificate expires in {days_left} days[/yellow]")
|
1032
1551
|
else:
|
1033
|
-
days_left = (info["not_after"] - datetime.
|
1552
|
+
days_left = (info["not_after"] - datetime.now(_datetime.UTC)).days
|
1034
1553
|
console.print(f"[green]✓ Certificate valid for {days_left} more days[/green]")
|
1035
1554
|
|
1036
1555
|
except Exception as e:
|