caption-flow 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +921 -427
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +2 -3
- caption_flow/orchestrator.py +153 -104
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +463 -68
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +28 -22
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +30 -28
- caption_flow/utils/chunk_tracker.py +153 -56
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +5 -4
- caption_flow/workers/caption.py +303 -92
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/METADATA +9 -4
- caption_flow-0.4.1.dist-info/RECORD +33 -0
- caption_flow-0.3.4.dist-info/RECORD +0 -33
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/top_level.txt +0 -0
caption_flow/cli.py
CHANGED
@@ -1,21 +1,22 @@
|
|
1
1
|
"""Command-line interface for CaptionFlow with smart configuration handling."""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import datetime as _datetime
|
4
5
|
import json
|
5
6
|
import logging
|
6
7
|
import os
|
7
8
|
import sys
|
9
|
+
from datetime import datetime
|
8
10
|
from pathlib import Path
|
9
|
-
from typing import
|
11
|
+
from typing import Any, Dict, List, Optional
|
10
12
|
|
11
13
|
import click
|
12
14
|
import yaml
|
13
15
|
from rich.console import Console
|
14
16
|
from rich.logging import RichHandler
|
15
|
-
from datetime import datetime
|
16
17
|
|
17
|
-
from .orchestrator import Orchestrator
|
18
18
|
from .monitor import Monitor
|
19
|
+
from .orchestrator import Orchestrator
|
19
20
|
from .utils.certificates import CertificateManager
|
20
21
|
|
21
22
|
console = Console()
|
@@ -48,8 +49,7 @@ class ConfigManager:
|
|
48
49
|
def find_config(
|
49
50
|
cls, component: str, explicit_path: Optional[str] = None
|
50
51
|
) -> Optional[Dict[str, Any]]:
|
51
|
-
"""
|
52
|
-
Find and load configuration for a component.
|
52
|
+
"""Find and load configuration for a component.
|
53
53
|
|
54
54
|
Search order:
|
55
55
|
1. Explicit path if provided
|
@@ -120,22 +120,76 @@ class ConfigManager:
|
|
120
120
|
|
121
121
|
|
122
122
|
def setup_logging(verbose: bool = False):
|
123
|
-
"""Configure logging with rich handler
|
123
|
+
"""Configure logging with rich handler and file output to XDG state directory."""
|
124
124
|
level = logging.DEBUG if verbose else logging.INFO
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
125
|
+
|
126
|
+
# Determine log directory based on environment or XDG spec
|
127
|
+
log_dir_env = os.environ.get("CAPTIONFLOW_LOG_DIR")
|
128
|
+
if log_dir_env:
|
129
|
+
log_dir = Path(log_dir_env)
|
130
|
+
else:
|
131
|
+
# Use XDG_STATE_HOME for logs, with platform-specific fallbacks
|
132
|
+
xdg_state_home = os.environ.get("XDG_STATE_HOME")
|
133
|
+
if xdg_state_home:
|
134
|
+
base_dir = Path(xdg_state_home)
|
135
|
+
elif sys.platform == "darwin":
|
136
|
+
base_dir = Path.home() / "Library" / "Logs"
|
137
|
+
else:
|
138
|
+
# Default to ~/.local/state on Linux and other systems
|
139
|
+
base_dir = Path.home() / ".local" / "state"
|
140
|
+
log_dir = base_dir / "caption-flow"
|
141
|
+
|
142
|
+
try:
|
143
|
+
# Ensure log directory exists
|
144
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
145
|
+
log_file_path = log_dir / "caption_flow.log"
|
146
|
+
|
147
|
+
# Set up handlers
|
148
|
+
handlers: List[logging.Handler] = [
|
149
|
+
RichHandler(
|
150
|
+
console=console,
|
151
|
+
rich_tracebacks=True,
|
152
|
+
show_path=False,
|
153
|
+
show_time=True,
|
154
|
+
)
|
155
|
+
]
|
156
|
+
|
157
|
+
# Add file handler
|
158
|
+
file_handler = logging.FileHandler(log_file_path, mode="a")
|
159
|
+
file_handler.setFormatter(
|
160
|
+
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
161
|
+
)
|
162
|
+
handlers.append(file_handler)
|
163
|
+
log_msg = f"Logging to {log_file_path}"
|
164
|
+
|
165
|
+
except (OSError, PermissionError) as e:
|
166
|
+
# Fallback to only console logging if file logging fails
|
167
|
+
handlers = [
|
130
168
|
RichHandler(
|
131
169
|
console=console,
|
132
170
|
rich_tracebacks=True,
|
133
171
|
show_path=False,
|
134
|
-
show_time=True,
|
172
|
+
show_time=True,
|
135
173
|
)
|
136
|
-
]
|
174
|
+
]
|
175
|
+
log_file = log_dir / "caption_flow.log"
|
176
|
+
log_msg = f"[yellow]Warning: Could not write to log file {log_file}: {e}[/yellow]"
|
177
|
+
|
178
|
+
logging.basicConfig(
|
179
|
+
level=level,
|
180
|
+
format="%(message)s", # RichHandler overrides this format for console
|
181
|
+
datefmt="[%Y-%m-%d %H:%M:%S]",
|
182
|
+
handlers=handlers,
|
137
183
|
)
|
138
184
|
|
185
|
+
# Suppress noisy libraries
|
186
|
+
logging.getLogger("websockets").setLevel(logging.WARNING)
|
187
|
+
logging.getLogger("pyarrow").setLevel(logging.WARNING)
|
188
|
+
|
189
|
+
# Use a dedicated logger to print the log file path to avoid format issues
|
190
|
+
if "log_msg" in locals():
|
191
|
+
logging.getLogger("setup").info(log_msg)
|
192
|
+
|
139
193
|
|
140
194
|
def apply_cli_overrides(config: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
141
195
|
"""Apply CLI arguments as overrides to config, filtering out None values."""
|
@@ -189,9 +243,11 @@ def orchestrator(ctx, config: Optional[str], **kwargs):
|
|
189
243
|
config_data["ssl"]["cert"] = kwargs["cert"]
|
190
244
|
config_data["ssl"]["key"] = kwargs["key"]
|
191
245
|
elif not config_data.get("ssl"):
|
192
|
-
|
193
|
-
"[yellow]Warning: Running without SSL.
|
246
|
+
warning_msg = (
|
247
|
+
"[yellow]Warning: Running without SSL. "
|
248
|
+
"Use --cert and --key for production.[/yellow]"
|
194
249
|
)
|
250
|
+
console.print(warning_msg)
|
195
251
|
|
196
252
|
if kwargs.get("vllm") and "vllm" not in config_data:
|
197
253
|
raise ValueError("Must provide vLLM config.")
|
@@ -259,33 +315,11 @@ def worker(ctx, config: Optional[str], **kwargs):
|
|
259
315
|
asyncio.run(worker_instance.shutdown())
|
260
316
|
|
261
317
|
|
262
|
-
|
263
|
-
|
264
|
-
@click.option("--server", help="Orchestrator WebSocket URL")
|
265
|
-
@click.option("--token", help="Authentication token")
|
266
|
-
@click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
|
267
|
-
@click.option("--debug", is_flag=True, help="Enable debug output")
|
268
|
-
@click.pass_context
|
269
|
-
def monitor(
|
270
|
-
ctx,
|
271
|
-
config: Optional[str],
|
272
|
-
server: Optional[str],
|
273
|
-
token: Optional[str],
|
274
|
-
no_verify_ssl: bool,
|
275
|
-
debug: bool,
|
276
|
-
):
|
277
|
-
"""Start the monitoring TUI."""
|
278
|
-
|
279
|
-
# Enable debug logging if requested
|
280
|
-
if debug:
|
281
|
-
setup_logging(verbose=True)
|
282
|
-
console.print("[yellow]Debug mode enabled[/yellow]")
|
283
|
-
|
284
|
-
# Load configuration
|
318
|
+
def _load_monitor_config(config, server, token):
|
319
|
+
"""Load monitor configuration from file or fallback to orchestrator config."""
|
285
320
|
base_config = ConfigManager.find_config("monitor", config)
|
286
321
|
|
287
322
|
if not base_config:
|
288
|
-
# Try to find monitor config in orchestrator config as fallback
|
289
323
|
orch_config = ConfigManager.find_config("orchestrator")
|
290
324
|
if orch_config and "monitor" in orch_config:
|
291
325
|
base_config = {"monitor": orch_config["monitor"]}
|
@@ -295,15 +329,11 @@ def monitor(
|
|
295
329
|
if not server or not token:
|
296
330
|
console.print("[yellow]No monitor config found, using CLI args[/yellow]")
|
297
331
|
|
298
|
-
|
299
|
-
|
300
|
-
if "monitor" in base_config:
|
301
|
-
config_data = base_config["monitor"]
|
302
|
-
# Case 2: Config IS the monitor config (no wrapper)
|
303
|
-
else:
|
304
|
-
config_data = base_config
|
332
|
+
return base_config.get("monitor", base_config)
|
333
|
+
|
305
334
|
|
306
|
-
|
335
|
+
def _apply_monitor_overrides(config_data, server, token, no_verify_ssl):
|
336
|
+
"""Apply CLI overrides to monitor configuration."""
|
307
337
|
if server:
|
308
338
|
config_data["server"] = server
|
309
339
|
if token:
|
@@ -311,17 +341,20 @@ def monitor(
|
|
311
341
|
if no_verify_ssl:
|
312
342
|
config_data["verify_ssl"] = False
|
313
343
|
|
314
|
-
# Debug output
|
315
|
-
if debug:
|
316
|
-
console.print("\n[cyan]Final monitor configuration:[/cyan]")
|
317
|
-
console.print(f" Server: {config_data.get('server', 'NOT SET')}")
|
318
|
-
console.print(
|
319
|
-
f" Token: {'***' + config_data.get('token', '')[-4:] if config_data.get('token') else 'NOT SET'}"
|
320
|
-
)
|
321
|
-
console.print(f" Verify SSL: {config_data.get('verify_ssl', True)}")
|
322
|
-
console.print()
|
323
344
|
|
324
|
-
|
345
|
+
def _debug_monitor_config(config_data):
|
346
|
+
"""Print debug information about monitor configuration."""
|
347
|
+
console.print("\n[cyan]Final monitor configuration:[/cyan]")
|
348
|
+
console.print(f" Server: {config_data.get('server', 'NOT SET')}")
|
349
|
+
console.print(
|
350
|
+
f" Token: {'***' + config_data.get('token', '')[-4:] if config_data.get('token') else 'NOT SET'}"
|
351
|
+
)
|
352
|
+
console.print(f" Verify SSL: {config_data.get('verify_ssl', True)}")
|
353
|
+
console.print()
|
354
|
+
|
355
|
+
|
356
|
+
def _validate_monitor_config(config_data):
|
357
|
+
"""Validate required monitor configuration fields."""
|
325
358
|
if not config_data.get("server"):
|
326
359
|
console.print("[red]Error: --server required (or set 'server' in monitor.yaml)[/red]")
|
327
360
|
console.print("\n[dim]Example monitor.yaml:[/dim]")
|
@@ -336,12 +369,43 @@ def monitor(
|
|
336
369
|
console.print("token: your-token-here")
|
337
370
|
sys.exit(1)
|
338
371
|
|
339
|
-
|
372
|
+
|
373
|
+
def _set_monitor_defaults(config_data):
|
374
|
+
"""Set default values for optional monitor settings."""
|
340
375
|
config_data.setdefault("refresh_interval", 1.0)
|
341
376
|
config_data.setdefault("show_inactive_workers", False)
|
342
377
|
config_data.setdefault("max_log_lines", 100)
|
343
378
|
|
344
|
-
|
379
|
+
|
380
|
+
@main.command()
|
381
|
+
@click.option("--config", type=click.Path(exists=True), help="Configuration file")
|
382
|
+
@click.option("--server", help="Orchestrator WebSocket URL")
|
383
|
+
@click.option("--token", help="Authentication token")
|
384
|
+
@click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
|
385
|
+
@click.option("--debug", is_flag=True, help="Enable debug output")
|
386
|
+
@click.pass_context
|
387
|
+
def monitor(
|
388
|
+
ctx,
|
389
|
+
config: Optional[str],
|
390
|
+
server: Optional[str],
|
391
|
+
token: Optional[str],
|
392
|
+
no_verify_ssl: bool,
|
393
|
+
debug: bool,
|
394
|
+
):
|
395
|
+
"""Start the monitoring TUI."""
|
396
|
+
if debug:
|
397
|
+
setup_logging(verbose=True)
|
398
|
+
console.print("[yellow]Debug mode enabled[/yellow]")
|
399
|
+
|
400
|
+
config_data = _load_monitor_config(config, server, token)
|
401
|
+
_apply_monitor_overrides(config_data, server, token, no_verify_ssl)
|
402
|
+
|
403
|
+
if debug:
|
404
|
+
_debug_monitor_config(config_data)
|
405
|
+
|
406
|
+
_validate_monitor_config(config_data)
|
407
|
+
_set_monitor_defaults(config_data)
|
408
|
+
|
345
409
|
try:
|
346
410
|
monitor_instance = Monitor(config_data)
|
347
411
|
|
@@ -406,7 +470,7 @@ def view(ctx, data_dir: str, refresh_rate: int, no_images: bool):
|
|
406
470
|
viewer.disable_images = True
|
407
471
|
viewer.refresh_rate = refresh_rate
|
408
472
|
|
409
|
-
console.print(
|
473
|
+
console.print("[cyan]Starting dataset viewer...[/cyan]")
|
410
474
|
console.print(f"[dim]Data directory: {data_path}[/dim]")
|
411
475
|
|
412
476
|
asyncio.run(viewer.run())
|
@@ -424,6 +488,400 @@ def view(ctx, data_dir: str, refresh_rate: int, no_images: bool):
|
|
424
488
|
sys.exit(1)
|
425
489
|
|
426
490
|
|
491
|
+
def _load_admin_credentials(config, server, token):
|
492
|
+
"""Load admin server and token from config if not provided."""
|
493
|
+
if server and token:
|
494
|
+
return server, token
|
495
|
+
|
496
|
+
base_config = ConfigManager.find_config("orchestrator", config) or {}
|
497
|
+
admin_config = base_config.get("admin", {})
|
498
|
+
admin_tokens = base_config.get("orchestrator", {}).get("auth", {}).get("admin_tokens", [])
|
499
|
+
|
500
|
+
final_server = server or admin_config.get("server", "ws://localhost:8765")
|
501
|
+
final_token = token or admin_config.get("token")
|
502
|
+
|
503
|
+
if not final_token and admin_tokens:
|
504
|
+
console.print("Using first admin token.")
|
505
|
+
final_token = admin_tokens[0].get("token")
|
506
|
+
|
507
|
+
return final_server, final_token
|
508
|
+
|
509
|
+
|
510
|
+
def _setup_ssl_context(server, no_verify_ssl):
|
511
|
+
"""Setup SSL context for websocket connection."""
|
512
|
+
import ssl
|
513
|
+
|
514
|
+
ssl_context = None
|
515
|
+
if server.startswith("wss://"):
|
516
|
+
ssl_context = ssl.create_default_context()
|
517
|
+
if no_verify_ssl:
|
518
|
+
ssl_context.check_hostname = False
|
519
|
+
ssl_context.verify_mode = ssl.CERT_NONE
|
520
|
+
|
521
|
+
return ssl_context
|
522
|
+
|
523
|
+
|
524
|
+
async def _authenticate_admin(websocket, token):
|
525
|
+
"""Authenticate as admin with the websocket."""
|
526
|
+
await websocket.send(json.dumps({"token": token, "role": "admin"}))
|
527
|
+
|
528
|
+
response = await websocket.recv()
|
529
|
+
auth_response = json.loads(response)
|
530
|
+
|
531
|
+
if "error" in auth_response:
|
532
|
+
console.print(f"[red]Authentication failed: {auth_response['error']}[/red]")
|
533
|
+
return False
|
534
|
+
|
535
|
+
console.print("[green]✓ Authenticated as admin[/green]")
|
536
|
+
return True
|
537
|
+
|
538
|
+
|
539
|
+
async def _send_reload_command(websocket, new_cfg):
|
540
|
+
"""Send reload command and handle response."""
|
541
|
+
await websocket.send(json.dumps({"type": "reload_config", "config": new_cfg}))
|
542
|
+
|
543
|
+
response = await websocket.recv()
|
544
|
+
reload_response = json.loads(response)
|
545
|
+
|
546
|
+
if reload_response.get("type") == "reload_complete":
|
547
|
+
if "message" in reload_response and "No changes" in reload_response["message"]:
|
548
|
+
console.print(f"[yellow]{reload_response['message']}[/yellow]")
|
549
|
+
else:
|
550
|
+
console.print("[green]✓ Configuration reloaded successfully![/green]")
|
551
|
+
|
552
|
+
if "updated" in reload_response and reload_response["updated"]:
|
553
|
+
console.print("\n[cyan]Updated sections:[/cyan]")
|
554
|
+
for section in reload_response["updated"]:
|
555
|
+
console.print(f" • {section}")
|
556
|
+
|
557
|
+
if "warnings" in reload_response and reload_response["warnings"]:
|
558
|
+
console.print("\n[yellow]Warnings:[/yellow]")
|
559
|
+
for warning in reload_response["warnings"]:
|
560
|
+
console.print(f" ⚠ {warning}")
|
561
|
+
|
562
|
+
return True
|
563
|
+
else:
|
564
|
+
error = reload_response.get("error", "Unknown error")
|
565
|
+
console.print(f"[red]Reload failed: {error} ({reload_response=})[/red]")
|
566
|
+
return False
|
567
|
+
|
568
|
+
|
569
|
+
def _add_token_to_config(config_data: Dict[str, Any], role: str, name: str, token: str) -> bool:
|
570
|
+
"""Add a new token to the config data."""
|
571
|
+
# Ensure the auth section exists
|
572
|
+
if "orchestrator" not in config_data:
|
573
|
+
config_data["orchestrator"] = {}
|
574
|
+
if "auth" not in config_data["orchestrator"]:
|
575
|
+
config_data["orchestrator"]["auth"] = {}
|
576
|
+
|
577
|
+
auth_config = config_data["orchestrator"]["auth"]
|
578
|
+
token_key = f"{role}_tokens"
|
579
|
+
|
580
|
+
# Initialize token list if it doesn't exist
|
581
|
+
if token_key not in auth_config:
|
582
|
+
auth_config[token_key] = []
|
583
|
+
|
584
|
+
# Check if token already exists
|
585
|
+
for existing_token in auth_config[token_key]:
|
586
|
+
if existing_token.get("token") == token:
|
587
|
+
console.print(f"[yellow]Token already exists for {role}: {name}[/yellow]")
|
588
|
+
return False
|
589
|
+
if existing_token.get("name") == name:
|
590
|
+
console.print(f"[yellow]Name already exists for {role}: {name}[/yellow]")
|
591
|
+
return False
|
592
|
+
|
593
|
+
# Add the new token
|
594
|
+
auth_config[token_key].append({"name": name, "token": token})
|
595
|
+
console.print(f"[green]✓ Added {role} token for {name}[/green]")
|
596
|
+
return True
|
597
|
+
|
598
|
+
|
599
|
+
def _remove_token_from_config(config_data: Dict[str, Any], role: str, identifier: str) -> bool:
|
600
|
+
"""Remove a token from the config data by name or token."""
|
601
|
+
auth_config = config_data.get("orchestrator", {}).get("auth", {})
|
602
|
+
token_key = f"{role}_tokens"
|
603
|
+
|
604
|
+
if token_key not in auth_config:
|
605
|
+
console.print(f"[red]No {role} tokens found in config[/red]")
|
606
|
+
return False
|
607
|
+
|
608
|
+
tokens = auth_config[token_key]
|
609
|
+
removed = False
|
610
|
+
|
611
|
+
for i, token_entry in enumerate(tokens):
|
612
|
+
if token_entry.get("name") == identifier or token_entry.get("token") == identifier:
|
613
|
+
removed_entry = tokens.pop(i)
|
614
|
+
console.print(f"[green]✓ Removed {role} token: {removed_entry['name']}[/green]")
|
615
|
+
removed = True
|
616
|
+
break
|
617
|
+
|
618
|
+
if not removed:
|
619
|
+
console.print(f"[red]Token not found for {role}: {identifier}[/red]")
|
620
|
+
|
621
|
+
return removed
|
622
|
+
|
623
|
+
|
624
|
+
def _list_tokens_in_config(config_data: Dict[str, Any], role: Optional[str] = None):
|
625
|
+
"""List tokens in the config data."""
|
626
|
+
auth_config = config_data.get("orchestrator", {}).get("auth", {})
|
627
|
+
|
628
|
+
if not auth_config:
|
629
|
+
console.print("[yellow]No auth configuration found[/yellow]")
|
630
|
+
return
|
631
|
+
|
632
|
+
roles_to_show = [role] if role else ["worker", "admin", "monitor"]
|
633
|
+
|
634
|
+
for token_role in roles_to_show:
|
635
|
+
token_key = f"{token_role}_tokens"
|
636
|
+
tokens = auth_config.get(token_key, [])
|
637
|
+
|
638
|
+
if tokens:
|
639
|
+
console.print(f"\n[cyan]{token_role.title()} tokens:[/cyan]")
|
640
|
+
for token_entry in tokens:
|
641
|
+
name = token_entry.get("name", "Unknown")
|
642
|
+
token = token_entry.get("token", "")
|
643
|
+
masked_token = f"***{token[-4:]}" if len(token) > 4 else "***"
|
644
|
+
console.print(f" • {name}: {masked_token}")
|
645
|
+
else:
|
646
|
+
console.print(f"\n[dim]No {token_role} tokens configured[/dim]")
|
647
|
+
|
648
|
+
|
649
|
+
def _save_config_file(config_data: Dict[str, Any], config_path: Path) -> bool:
|
650
|
+
"""Save the config data to a file."""
|
651
|
+
try:
|
652
|
+
with open(config_path, "w") as f:
|
653
|
+
yaml.safe_dump(config_data, f, default_flow_style=False, sort_keys=False)
|
654
|
+
console.print(f"[green]✓ Configuration saved to {config_path}[/green]")
|
655
|
+
return True
|
656
|
+
except Exception as e:
|
657
|
+
console.print(f"[red]Error saving config: {e}[/red]")
|
658
|
+
return False
|
659
|
+
|
660
|
+
|
661
|
+
async def _reload_orchestrator_config(
|
662
|
+
server: str, token: str, config_data: Dict[str, Any], no_verify_ssl: bool
|
663
|
+
) -> bool:
|
664
|
+
"""Reload the orchestrator configuration."""
|
665
|
+
import websockets
|
666
|
+
|
667
|
+
ssl_context = _setup_ssl_context(server, no_verify_ssl)
|
668
|
+
|
669
|
+
try:
|
670
|
+
async with websockets.connect(
|
671
|
+
server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
|
672
|
+
) as websocket:
|
673
|
+
if not await _authenticate_admin(websocket, token):
|
674
|
+
return False
|
675
|
+
|
676
|
+
return await _send_reload_command(websocket, config_data)
|
677
|
+
except Exception as e:
|
678
|
+
console.print(f"[red]Error connecting to orchestrator: {e}[/red]")
|
679
|
+
return False
|
680
|
+
|
681
|
+
|
682
|
+
@main.group()
|
683
|
+
@click.option("--config", type=click.Path(exists=True), help="Configuration file")
|
684
|
+
@click.option("--server", help="Orchestrator WebSocket URL")
|
685
|
+
@click.option("--token", help="Admin authentication token")
|
686
|
+
@click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
|
687
|
+
@click.pass_context
|
688
|
+
def auth(
|
689
|
+
ctx, config: Optional[str], server: Optional[str], token: Optional[str], no_verify_ssl: bool
|
690
|
+
):
|
691
|
+
"""Manage authentication tokens for the orchestrator."""
|
692
|
+
ctx.ensure_object(dict)
|
693
|
+
ctx.obj.update(
|
694
|
+
{"config": config, "server": server, "token": token, "no_verify_ssl": no_verify_ssl}
|
695
|
+
)
|
696
|
+
|
697
|
+
|
698
|
+
@auth.command()
|
699
|
+
@click.argument("role", type=click.Choice(["worker", "admin", "monitor"]))
|
700
|
+
@click.argument("name")
|
701
|
+
@click.argument("token_value")
|
702
|
+
@click.option(
|
703
|
+
"--no-reload", is_flag=True, help="Don't reload orchestrator config after adding token"
|
704
|
+
)
|
705
|
+
@click.pass_context
|
706
|
+
def add(ctx, role: str, name: str, token_value: str, no_reload: bool):
|
707
|
+
"""Add a new authentication token.
|
708
|
+
|
709
|
+
ROLE: Type of token (worker, admin, monitor)
|
710
|
+
NAME: Display name for the token
|
711
|
+
TOKEN_VALUE: The actual token string
|
712
|
+
"""
|
713
|
+
config_file = ctx.obj.get("config")
|
714
|
+
server = ctx.obj.get("server")
|
715
|
+
admin_token = ctx.obj.get("token")
|
716
|
+
no_verify_ssl = ctx.obj.get("no_verify_ssl", False)
|
717
|
+
|
718
|
+
# Load config
|
719
|
+
config_data = ConfigManager.find_config("orchestrator", config_file)
|
720
|
+
if not config_data:
|
721
|
+
console.print("[red]No orchestrator config found[/red]")
|
722
|
+
console.print("[dim]Use --config to specify config file path[/dim]")
|
723
|
+
sys.exit(1)
|
724
|
+
|
725
|
+
# Find config file path for saving
|
726
|
+
config_path = None
|
727
|
+
if config_file:
|
728
|
+
config_path = Path(config_file)
|
729
|
+
else:
|
730
|
+
# Try to find the config file that was loaded
|
731
|
+
for search_path in [
|
732
|
+
Path.cwd() / "orchestrator.yaml",
|
733
|
+
Path.cwd() / "config" / "orchestrator.yaml",
|
734
|
+
Path.home() / ".caption-flow" / "orchestrator.yaml",
|
735
|
+
ConfigManager.get_xdg_config_home() / "caption-flow" / "orchestrator.yaml",
|
736
|
+
]:
|
737
|
+
if search_path.exists():
|
738
|
+
config_path = search_path
|
739
|
+
break
|
740
|
+
|
741
|
+
if not config_path:
|
742
|
+
console.print("[red]Could not determine config file to save to[/red]")
|
743
|
+
console.print("[dim]Use --config to specify config file path[/dim]")
|
744
|
+
sys.exit(1)
|
745
|
+
|
746
|
+
# Add token to config
|
747
|
+
if not _add_token_to_config(config_data, role, name, token_value):
|
748
|
+
sys.exit(1)
|
749
|
+
|
750
|
+
# Save config file
|
751
|
+
if not _save_config_file(config_data, config_path):
|
752
|
+
sys.exit(1)
|
753
|
+
|
754
|
+
# Reload orchestrator if requested
|
755
|
+
if not no_reload:
|
756
|
+
server, admin_token = _load_admin_credentials(config_file, server, admin_token)
|
757
|
+
|
758
|
+
if not server:
|
759
|
+
console.print("[yellow]No server specified, skipping orchestrator reload[/yellow]")
|
760
|
+
console.print("[dim]Use --server to reload orchestrator config[/dim]")
|
761
|
+
elif not admin_token:
|
762
|
+
console.print("[yellow]No admin token specified, skipping orchestrator reload[/yellow]")
|
763
|
+
console.print("[dim]Use --token to reload orchestrator config[/dim]")
|
764
|
+
else:
|
765
|
+
console.print(f"[cyan]Reloading orchestrator config...[/cyan]")
|
766
|
+
success = asyncio.run(
|
767
|
+
_reload_orchestrator_config(server, admin_token, config_data, no_verify_ssl)
|
768
|
+
)
|
769
|
+
if not success:
|
770
|
+
console.print("[yellow]Config file updated but orchestrator reload failed[/yellow]")
|
771
|
+
console.print("[dim]You may need to restart the orchestrator manually[/dim]")
|
772
|
+
|
773
|
+
|
774
|
+
@auth.command()
|
775
|
+
@click.argument("role", type=click.Choice(["worker", "admin", "monitor"]))
|
776
|
+
@click.argument("identifier")
|
777
|
+
@click.option(
|
778
|
+
"--no-reload", is_flag=True, help="Don't reload orchestrator config after removing token"
|
779
|
+
)
|
780
|
+
@click.pass_context
|
781
|
+
def remove(ctx, role: str, identifier: str, no_reload: bool):
|
782
|
+
"""Remove an authentication token.
|
783
|
+
|
784
|
+
ROLE: Type of token (worker, admin, monitor)
|
785
|
+
IDENTIFIER: Name or token value to remove
|
786
|
+
"""
|
787
|
+
config_file = ctx.obj.get("config")
|
788
|
+
server = ctx.obj.get("server")
|
789
|
+
admin_token = ctx.obj.get("token")
|
790
|
+
no_verify_ssl = ctx.obj.get("no_verify_ssl", False)
|
791
|
+
|
792
|
+
# Load config
|
793
|
+
config_data = ConfigManager.find_config("orchestrator", config_file)
|
794
|
+
if not config_data:
|
795
|
+
console.print("[red]No orchestrator config found[/red]")
|
796
|
+
sys.exit(1)
|
797
|
+
|
798
|
+
# Find config file path for saving
|
799
|
+
config_path = None
|
800
|
+
if config_file:
|
801
|
+
config_path = Path(config_file)
|
802
|
+
else:
|
803
|
+
# Try to find the config file that was loaded
|
804
|
+
for search_path in [
|
805
|
+
Path.cwd() / "orchestrator.yaml",
|
806
|
+
Path.cwd() / "config" / "orchestrator.yaml",
|
807
|
+
Path.home() / ".caption-flow" / "orchestrator.yaml",
|
808
|
+
ConfigManager.get_xdg_config_home() / "caption-flow" / "orchestrator.yaml",
|
809
|
+
]:
|
810
|
+
if search_path.exists():
|
811
|
+
config_path = search_path
|
812
|
+
break
|
813
|
+
|
814
|
+
if not config_path:
|
815
|
+
console.print("[red]Could not determine config file to save to[/red]")
|
816
|
+
sys.exit(1)
|
817
|
+
|
818
|
+
# Remove token from config
|
819
|
+
if not _remove_token_from_config(config_data, role, identifier):
|
820
|
+
sys.exit(1)
|
821
|
+
|
822
|
+
# Save config file
|
823
|
+
if not _save_config_file(config_data, config_path):
|
824
|
+
sys.exit(1)
|
825
|
+
|
826
|
+
# Reload orchestrator if requested
|
827
|
+
if not no_reload:
|
828
|
+
server, admin_token = _load_admin_credentials(config_file, server, admin_token)
|
829
|
+
|
830
|
+
if not server:
|
831
|
+
console.print("[yellow]No server specified, skipping orchestrator reload[/yellow]")
|
832
|
+
elif not admin_token:
|
833
|
+
console.print("[yellow]No admin token specified, skipping orchestrator reload[/yellow]")
|
834
|
+
else:
|
835
|
+
console.print(f"[cyan]Reloading orchestrator config...[/cyan]")
|
836
|
+
success = asyncio.run(
|
837
|
+
_reload_orchestrator_config(server, admin_token, config_data, no_verify_ssl)
|
838
|
+
)
|
839
|
+
if not success:
|
840
|
+
console.print("[yellow]Config file updated but orchestrator reload failed[/yellow]")
|
841
|
+
|
842
|
+
|
843
|
+
@auth.command()
|
844
|
+
@click.argument("role", type=click.Choice(["worker", "admin", "monitor", "all"]), required=False)
|
845
|
+
@click.pass_context
|
846
|
+
def list(ctx, role: Optional[str]):
|
847
|
+
"""List authentication tokens.
|
848
|
+
|
849
|
+
ROLE: Type of tokens to list (worker, admin, monitor, all). Default: all
|
850
|
+
"""
|
851
|
+
config_file = ctx.obj.get("config")
|
852
|
+
|
853
|
+
# Load config
|
854
|
+
config_data = ConfigManager.find_config("orchestrator", config_file)
|
855
|
+
if not config_data:
|
856
|
+
console.print("[red]No orchestrator config found[/red]")
|
857
|
+
sys.exit(1)
|
858
|
+
|
859
|
+
# Show tokens
|
860
|
+
if role == "all" or role is None:
|
861
|
+
_list_tokens_in_config(config_data)
|
862
|
+
else:
|
863
|
+
_list_tokens_in_config(config_data, role)
|
864
|
+
|
865
|
+
|
866
|
+
@auth.command()
|
867
|
+
@click.option("--length", default=32, help="Token length (default: 32)")
|
868
|
+
@click.option("--count", default=1, help="Number of tokens to generate (default: 1)")
|
869
|
+
def generate(length: int, count: int):
|
870
|
+
"""Generate random authentication tokens."""
|
871
|
+
import secrets
|
872
|
+
import string
|
873
|
+
|
874
|
+
alphabet = string.ascii_letters + string.digits + "-_"
|
875
|
+
|
876
|
+
console.print(
|
877
|
+
f"[cyan]Generated {count} token{'s' if count > 1 else ''} ({length} characters each):[/cyan]\n"
|
878
|
+
)
|
879
|
+
|
880
|
+
for i in range(count):
|
881
|
+
token = "".join(secrets.choice(alphabet) for _ in range(length))
|
882
|
+
console.print(f" {i + 1}: {token}")
|
883
|
+
|
884
|
+
|
427
885
|
@main.command()
|
428
886
|
@click.option("--config", type=click.Path(exists=True), help="Configuration file")
|
429
887
|
@click.option("--server", help="Orchestrator WebSocket URL")
|
@@ -441,27 +899,8 @@ def reload_config(
|
|
441
899
|
):
|
442
900
|
"""Reload orchestrator configuration via admin connection."""
|
443
901
|
import websockets
|
444
|
-
import ssl
|
445
|
-
|
446
|
-
# Load base config to get server/token if not provided via CLI
|
447
|
-
if not server or not token:
|
448
|
-
base_config = ConfigManager.find_config("orchestrator", config) or {}
|
449
|
-
admin_config = base_config.get("admin", {})
|
450
|
-
admin_tokens = base_config.get("orchestrator", {}).get("auth", {}).get("admin_tokens", [])
|
451
|
-
has_admin_tokens = False
|
452
|
-
if len(admin_tokens) > 0:
|
453
|
-
has_admin_tokens = True
|
454
|
-
first_admin_token = admin_tokens[0].get("token", None)
|
455
|
-
# Do not print sensitive admin token to console.
|
456
902
|
|
457
|
-
|
458
|
-
server = admin_config.get("server", "ws://localhost:8765")
|
459
|
-
if not token:
|
460
|
-
token = admin_config.get("token", None)
|
461
|
-
if token is None and has_admin_tokens:
|
462
|
-
# grab the first one, we'll just assume we're localhost.
|
463
|
-
console.print("Using first admin token.")
|
464
|
-
token = first_admin_token
|
903
|
+
server, token = _load_admin_credentials(config, server, token)
|
465
904
|
|
466
905
|
if not server:
|
467
906
|
console.print("[red]Error: --server required (or set in config)[/red]")
|
@@ -472,66 +911,22 @@ def reload_config(
|
|
472
911
|
|
473
912
|
console.print(f"[cyan]Loading configuration from {new_config}...[/cyan]")
|
474
913
|
|
475
|
-
# Load the new configuration
|
476
914
|
new_cfg = ConfigManager.load_yaml(Path(new_config))
|
477
915
|
if not new_cfg:
|
478
916
|
console.print("[red]Failed to load configuration[/red]")
|
479
917
|
sys.exit(1)
|
480
918
|
|
481
|
-
|
482
|
-
ssl_context = None
|
483
|
-
if server.startswith("wss://"):
|
484
|
-
if no_verify_ssl:
|
485
|
-
ssl_context = ssl.create_default_context()
|
486
|
-
ssl_context.check_hostname = False
|
487
|
-
ssl_context.verify_mode = ssl.CERT_NONE
|
488
|
-
else:
|
489
|
-
ssl_context = ssl.create_default_context()
|
919
|
+
ssl_context = _setup_ssl_context(server, no_verify_ssl)
|
490
920
|
|
491
921
|
async def send_reload():
|
492
922
|
try:
|
493
923
|
async with websockets.connect(
|
494
924
|
server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
|
495
925
|
) as websocket:
|
496
|
-
|
497
|
-
await websocket.send(json.dumps({"token": token, "role": "admin"}))
|
498
|
-
|
499
|
-
response = await websocket.recv()
|
500
|
-
auth_response = json.loads(response)
|
501
|
-
|
502
|
-
if "error" in auth_response:
|
503
|
-
console.print(f"[red]Authentication failed: {auth_response['error']}[/red]")
|
926
|
+
if not await _authenticate_admin(websocket, token):
|
504
927
|
return False
|
505
928
|
|
506
|
-
|
507
|
-
|
508
|
-
# Send reload command
|
509
|
-
await websocket.send(json.dumps({"type": "reload_config", "config": new_cfg}))
|
510
|
-
|
511
|
-
response = await websocket.recv()
|
512
|
-
reload_response = json.loads(response)
|
513
|
-
|
514
|
-
if reload_response.get("type") == "reload_complete":
|
515
|
-
if "message" in reload_response and "No changes" in reload_response["message"]:
|
516
|
-
console.print(f"[yellow]{reload_response['message']}[/yellow]")
|
517
|
-
else:
|
518
|
-
console.print("[green]✓ Configuration reloaded successfully![/green]")
|
519
|
-
|
520
|
-
if "updated" in reload_response and reload_response["updated"]:
|
521
|
-
console.print("\n[cyan]Updated sections:[/cyan]")
|
522
|
-
for section in reload_response["updated"]:
|
523
|
-
console.print(f" • {section}")
|
524
|
-
|
525
|
-
if "warnings" in reload_response and reload_response["warnings"]:
|
526
|
-
console.print("\n[yellow]Warnings:[/yellow]")
|
527
|
-
for warning in reload_response["warnings"]:
|
528
|
-
console.print(f" ⚠ {warning}")
|
529
|
-
|
530
|
-
return True
|
531
|
-
else:
|
532
|
-
error = reload_response.get("error", "Unknown error")
|
533
|
-
console.print(f"[red]Reload failed: {error} ({reload_response=})[/red]")
|
534
|
-
return False
|
929
|
+
return await _send_reload_command(websocket, new_cfg)
|
535
930
|
|
536
931
|
except Exception as e:
|
537
932
|
console.print(f"[red]Error: {e}[/red]")
|
@@ -542,39 +937,20 @@ def reload_config(
|
|
542
937
|
sys.exit(1)
|
543
938
|
|
544
939
|
|
545
|
-
|
546
|
-
|
547
|
-
@click.option("--checkpoint-dir", default="./checkpoints", help="Checkpoint directory")
|
548
|
-
@click.option("--fix", is_flag=True, help="Fix issues by resetting abandoned chunks")
|
549
|
-
@click.option("--verbose", is_flag=True, help="Show detailed information")
|
550
|
-
def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
551
|
-
"""Scan for sparse or abandoned chunks and optionally fix them."""
|
552
|
-
from .utils.chunk_tracker import ChunkTracker
|
553
|
-
from .storage import StorageManager
|
554
|
-
import pyarrow.parquet as pq
|
555
|
-
|
556
|
-
console.print("[bold cyan]Scanning for sparse/abandoned chunks...[/bold cyan]\n")
|
557
|
-
|
558
|
-
checkpoint_path = Path(checkpoint_dir) / "chunks.json"
|
559
|
-
if not checkpoint_path.exists():
|
560
|
-
console.print("[red]No chunk checkpoint found![/red]")
|
561
|
-
return
|
562
|
-
|
563
|
-
tracker = ChunkTracker(checkpoint_path)
|
564
|
-
storage = StorageManager(Path(data_dir))
|
565
|
-
|
566
|
-
# Get and display stats
|
567
|
-
stats = tracker.get_stats()
|
940
|
+
def _display_chunk_stats(stats):
|
941
|
+
"""Display chunk statistics."""
|
568
942
|
console.print(f"[green]Total chunks:[/green] {stats['total']}")
|
569
943
|
console.print(f"[green]Completed:[/green] {stats['completed']}")
|
570
944
|
console.print(f"[yellow]Pending:[/yellow] {stats['pending']}")
|
571
945
|
console.print(f"[yellow]Assigned:[/yellow] {stats['assigned']}")
|
572
946
|
console.print(f"[red]Failed:[/red] {stats['failed']}\n")
|
573
947
|
|
574
|
-
|
948
|
+
|
949
|
+
def _find_abandoned_chunks(tracker):
|
950
|
+
"""Find chunks that have been assigned for too long."""
|
575
951
|
abandoned_chunks = []
|
576
952
|
stale_threshold = 3600 # 1 hour
|
577
|
-
current_time = datetime.
|
953
|
+
current_time = datetime.now(_datetime.UTC)
|
578
954
|
|
579
955
|
for chunk_id, chunk_state in tracker.chunks.items():
|
580
956
|
if chunk_state.status == "assigned" and chunk_state.assigned_at:
|
@@ -582,24 +958,31 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
|
582
958
|
if age > stale_threshold:
|
583
959
|
abandoned_chunks.append((chunk_id, chunk_state, age))
|
584
960
|
|
585
|
-
|
586
|
-
console.print(f"[red]Found {len(abandoned_chunks)} abandoned chunks:[/red]")
|
587
|
-
for chunk_id, chunk_state, age in abandoned_chunks[:10]:
|
588
|
-
age_str = f"{age/3600:.1f} hours" if age > 3600 else f"{age/60:.1f} minutes"
|
589
|
-
console.print(f" • {chunk_id} (assigned to {chunk_state.assigned_to} {age_str} ago)")
|
961
|
+
return abandoned_chunks
|
590
962
|
|
591
|
-
if len(abandoned_chunks) > 10:
|
592
|
-
console.print(f" ... and {len(abandoned_chunks) - 10} more")
|
593
963
|
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
console.print(f"[green]✓ Reset {len(abandoned_chunks)} chunks[/green]")
|
964
|
+
def _display_abandoned_chunks(abandoned_chunks, fix, tracker):
|
965
|
+
"""Display abandoned chunks and optionally fix them."""
|
966
|
+
if not abandoned_chunks:
|
967
|
+
return
|
599
968
|
|
600
|
-
|
601
|
-
|
969
|
+
console.print(f"[red]Found {len(abandoned_chunks)} abandoned chunks:[/red]")
|
970
|
+
for chunk_id, chunk_state, age in abandoned_chunks[:10]:
|
971
|
+
age_str = f"{age / 3600:.1f} hours" if age > 3600 else f"{age / 60:.1f} minutes"
|
972
|
+
console.print(f" • {chunk_id} (assigned to {chunk_state.assigned_to} {age_str} ago)")
|
973
|
+
|
974
|
+
if len(abandoned_chunks) > 10:
|
975
|
+
console.print(f" ... and {len(abandoned_chunks) - 10} more")
|
976
|
+
|
977
|
+
if fix:
|
978
|
+
console.print("\n[yellow]Resetting abandoned chunks to pending...[/yellow]")
|
979
|
+
for chunk_id, _, _ in abandoned_chunks:
|
980
|
+
tracker.mark_failed(chunk_id)
|
981
|
+
console.print(f"[green]✓ Reset {len(abandoned_chunks)} chunks[/green]")
|
602
982
|
|
983
|
+
|
984
|
+
def _find_sparse_shards(tracker):
|
985
|
+
"""Find shards with gaps or issues."""
|
603
986
|
shards_summary = tracker.get_shards_summary()
|
604
987
|
sparse_shards = []
|
605
988
|
|
@@ -618,60 +1001,108 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
|
618
1001
|
if has_gaps or shard_info["failed_chunks"] > 0:
|
619
1002
|
sparse_shards.append((shard_name, shard_info, has_gaps))
|
620
1003
|
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
1004
|
+
return sparse_shards
|
1005
|
+
|
1006
|
+
|
1007
|
+
def _display_sparse_shards(sparse_shards):
|
1008
|
+
"""Display sparse/incomplete shards."""
|
1009
|
+
if not sparse_shards:
|
1010
|
+
return
|
1011
|
+
|
1012
|
+
console.print(f"\n[yellow]Found {len(sparse_shards)} sparse/incomplete shards:[/yellow]")
|
1013
|
+
for shard_name, shard_info, has_gaps in sparse_shards[:5]:
|
1014
|
+
status = []
|
1015
|
+
if shard_info["pending_chunks"] > 0:
|
1016
|
+
status.append(f"{shard_info['pending_chunks']} pending")
|
1017
|
+
if shard_info["assigned_chunks"] > 0:
|
1018
|
+
status.append(f"{shard_info['assigned_chunks']} assigned")
|
1019
|
+
if shard_info["failed_chunks"] > 0:
|
1020
|
+
status.append(f"{shard_info['failed_chunks']} failed")
|
1021
|
+
if has_gaps:
|
1022
|
+
status.append("has gaps")
|
1023
|
+
|
1024
|
+
console.print(f" • {shard_name}: {', '.join(status)}")
|
1025
|
+
console.print(
|
1026
|
+
f" Progress: {shard_info['completed_chunks']}/{shard_info['total_chunks']} chunks"
|
1027
|
+
)
|
1028
|
+
|
1029
|
+
if len(sparse_shards) > 5:
|
1030
|
+
console.print(f" ... and {len(sparse_shards) - 5} more")
|
1031
|
+
|
1032
|
+
|
1033
|
+
def _cross_check_storage(storage, tracker, fix):
|
1034
|
+
"""Cross-check chunk tracker against storage."""
|
1035
|
+
import pyarrow.parquet as pq
|
1036
|
+
|
1037
|
+
console.print("\n[bold cyan]Cross-checking with stored captions...[/bold cyan]")
|
1038
|
+
|
1039
|
+
try:
|
1040
|
+
table = pq.read_table(storage.captions_path, columns=["chunk_id"])
|
1041
|
+
stored_chunk_ids = set(c for c in table["chunk_id"].to_pylist() if c)
|
1042
|
+
|
1043
|
+
tracker_completed = set(c for c, s in tracker.chunks.items() if s.status == "completed")
|
1044
|
+
|
1045
|
+
missing_in_storage = tracker_completed - stored_chunk_ids
|
1046
|
+
missing_in_tracker = stored_chunk_ids - set(tracker.chunks.keys())
|
1047
|
+
|
1048
|
+
if missing_in_storage:
|
635
1049
|
console.print(
|
636
|
-
f"
|
1050
|
+
f"\n[red]Chunks marked complete but missing from storage:[/red] {len(missing_in_storage)}"
|
637
1051
|
)
|
1052
|
+
for chunk_id in list(missing_in_storage)[:5]:
|
1053
|
+
console.print(f" • {chunk_id}")
|
638
1054
|
|
639
|
-
|
640
|
-
|
1055
|
+
if fix:
|
1056
|
+
console.print("[yellow]Resetting these chunks to pending...[/yellow]")
|
1057
|
+
for chunk_id in missing_in_storage:
|
1058
|
+
tracker.mark_failed(chunk_id)
|
1059
|
+
console.print(f"[green]✓ Reset {len(missing_in_storage)} chunks[/green]")
|
641
1060
|
|
642
|
-
|
643
|
-
|
644
|
-
|
1061
|
+
if missing_in_tracker:
|
1062
|
+
console.print(
|
1063
|
+
f"\n[yellow]Chunks in storage but not tracked:[/yellow] {len(missing_in_tracker)}"
|
1064
|
+
)
|
645
1065
|
|
646
|
-
|
647
|
-
|
648
|
-
stored_chunk_ids = set(c for c in table["chunk_id"].to_pylist() if c)
|
1066
|
+
except Exception as e:
|
1067
|
+
console.print(f"[red]Error reading storage: {e}[/red]")
|
649
1068
|
|
650
|
-
tracker_completed = set(c for c, s in tracker.chunks.items() if s.status == "completed")
|
651
1069
|
|
652
|
-
|
653
|
-
|
1070
|
+
@main.command()
|
1071
|
+
@click.option("--data-dir", default="./caption_data", help="Storage directory")
|
1072
|
+
@click.option("--checkpoint-dir", default="./checkpoints", help="Checkpoint directory")
|
1073
|
+
@click.option("--fix", is_flag=True, help="Fix issues by resetting abandoned chunks")
|
1074
|
+
@click.option("--verbose", is_flag=True, help="Show detailed information")
|
1075
|
+
def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
1076
|
+
"""Scan for sparse or abandoned chunks and optionally fix them."""
|
1077
|
+
from .storage import StorageManager
|
1078
|
+
from .utils.chunk_tracker import ChunkTracker
|
654
1079
|
|
655
|
-
|
656
|
-
console.print(
|
657
|
-
f"\n[red]Chunks marked complete but missing from storage:[/red] {len(missing_in_storage)}"
|
658
|
-
)
|
659
|
-
for chunk_id in list(missing_in_storage)[:5]:
|
660
|
-
console.print(f" • {chunk_id}")
|
1080
|
+
console.print("[bold cyan]Scanning for sparse/abandoned chunks...[/bold cyan]\n")
|
661
1081
|
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
console.print(f"[green]✓ Reset {len(missing_in_storage)} chunks[/green]")
|
1082
|
+
checkpoint_path = Path(checkpoint_dir) / "chunks.json"
|
1083
|
+
if not checkpoint_path.exists():
|
1084
|
+
console.print("[red]No chunk checkpoint found![/red]")
|
1085
|
+
return
|
667
1086
|
|
668
|
-
|
669
|
-
|
670
|
-
f"\n[yellow]Chunks in storage but not tracked:[/yellow] {len(missing_in_tracker)}"
|
671
|
-
)
|
1087
|
+
tracker = ChunkTracker(checkpoint_path)
|
1088
|
+
storage = StorageManager(Path(data_dir))
|
672
1089
|
|
673
|
-
|
674
|
-
|
1090
|
+
# Get and display stats
|
1091
|
+
stats = tracker.get_stats()
|
1092
|
+
_display_chunk_stats(stats)
|
1093
|
+
|
1094
|
+
# Find and handle abandoned chunks
|
1095
|
+
abandoned_chunks = _find_abandoned_chunks(tracker)
|
1096
|
+
_display_abandoned_chunks(abandoned_chunks, fix, tracker)
|
1097
|
+
|
1098
|
+
# Check for sparse shards
|
1099
|
+
console.print("\n[bold cyan]Checking for sparse shards...[/bold cyan]")
|
1100
|
+
sparse_shards = _find_sparse_shards(tracker)
|
1101
|
+
_display_sparse_shards(sparse_shards)
|
1102
|
+
|
1103
|
+
# Cross-check with storage if verbose
|
1104
|
+
if storage.captions_path.exists() and verbose:
|
1105
|
+
_cross_check_storage(storage, tracker, fix)
|
675
1106
|
|
676
1107
|
# Summary
|
677
1108
|
console.print("\n[bold cyan]Summary:[/bold cyan]")
|
@@ -695,32 +1126,257 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
|
|
695
1126
|
tracker.save_checkpoint()
|
696
1127
|
|
697
1128
|
|
1129
|
+
def _display_export_stats(stats):
|
1130
|
+
"""Display storage statistics."""
|
1131
|
+
console.print("\n[bold cyan]Storage Statistics:[/bold cyan]")
|
1132
|
+
console.print(f"[green]Total rows:[/green] {stats['total_rows']:,}")
|
1133
|
+
console.print(f"[green]Total outputs:[/green] {stats['total_outputs']:,}")
|
1134
|
+
console.print(f"[green]Shards:[/green] {stats['shard_count']} ({', '.join(stats['shards'])})")
|
1135
|
+
console.print(f"[green]Output fields:[/green] {', '.join(stats['output_fields'])}")
|
1136
|
+
|
1137
|
+
if stats.get("field_stats"):
|
1138
|
+
console.print("\n[cyan]Field breakdown:[/cyan]")
|
1139
|
+
for field, count in stats["field_stats"].items():
|
1140
|
+
console.print(f" • {field}: {count['total_items']:,} items")
|
1141
|
+
|
1142
|
+
|
1143
|
+
def _prepare_export_params(shard, shards, columns):
|
1144
|
+
"""Prepare shard filter and column list."""
|
1145
|
+
shard_filter = None
|
1146
|
+
if shard:
|
1147
|
+
shard_filter = [shard]
|
1148
|
+
elif shards:
|
1149
|
+
shard_filter = [s.strip() for s in shards.split(",")]
|
1150
|
+
|
1151
|
+
column_list = None
|
1152
|
+
if columns:
|
1153
|
+
column_list = [col.strip() for col in columns.split(",")]
|
1154
|
+
console.print(f"\n[cyan]Exporting columns:[/cyan] {', '.join(column_list)}")
|
1155
|
+
|
1156
|
+
return shard_filter, column_list
|
1157
|
+
|
1158
|
+
|
1159
|
+
async def _export_all_formats(
|
1160
|
+
exporter, output, shard_filter, column_list, limit, filename_column, export_column
|
1161
|
+
):
|
1162
|
+
"""Export to all formats."""
|
1163
|
+
base_name = output or "caption_export"
|
1164
|
+
base_path = Path(base_name)
|
1165
|
+
results = {}
|
1166
|
+
|
1167
|
+
for export_format in ["jsonl", "csv", "parquet", "json", "txt"]:
|
1168
|
+
console.print(f"\n[cyan]Exporting to {export_format.upper()}...[/cyan]")
|
1169
|
+
try:
|
1170
|
+
format_results = await exporter.export_all_shards(
|
1171
|
+
export_format,
|
1172
|
+
base_path,
|
1173
|
+
columns=column_list,
|
1174
|
+
limit_per_shard=limit,
|
1175
|
+
shard_filter=shard_filter,
|
1176
|
+
filename_column=filename_column,
|
1177
|
+
export_column=export_column,
|
1178
|
+
)
|
1179
|
+
results[export_format] = sum(format_results.values())
|
1180
|
+
except Exception as e:
|
1181
|
+
console.print(f"[yellow]Skipping {export_format}: {e}[/yellow]")
|
1182
|
+
results[export_format] = 0
|
1183
|
+
|
1184
|
+
console.print("\n[green]✓ Export complete![/green]")
|
1185
|
+
for fmt, count in results.items():
|
1186
|
+
if count > 0:
|
1187
|
+
console.print(f" • {fmt.upper()}: {count:,} items")
|
1188
|
+
|
1189
|
+
|
1190
|
+
async def _export_to_lance(exporter, output, column_list, shard_filter):
|
1191
|
+
"""Export to Lance dataset."""
|
1192
|
+
output_path = output or "exported_captions.lance"
|
1193
|
+
console.print(f"\n[cyan]Exporting to Lance dataset:[/cyan] {output_path}")
|
1194
|
+
total_rows = await exporter.export_to_lance(
|
1195
|
+
output_path, columns=column_list, shard_filter=shard_filter
|
1196
|
+
)
|
1197
|
+
console.print(f"[green]✓ Exported {total_rows:,} rows to Lance dataset[/green]")
|
1198
|
+
|
1199
|
+
|
1200
|
+
async def _export_to_huggingface(exporter, hf_dataset, license, private, nsfw, tags, shard_filter):
|
1201
|
+
"""Export to Hugging Face Hub."""
|
1202
|
+
if not hf_dataset:
|
1203
|
+
console.print("[red]Error: --hf-dataset required for huggingface_hub format[/red]")
|
1204
|
+
console.print("[dim]Example: --hf-dataset username/my-caption-dataset[/dim]")
|
1205
|
+
sys.exit(1)
|
1206
|
+
|
1207
|
+
tag_list = None
|
1208
|
+
if tags:
|
1209
|
+
tag_list = [tag.strip() for tag in tags.split(",")]
|
1210
|
+
|
1211
|
+
console.print(f"\n[cyan]Uploading to Hugging Face Hub:[/cyan] {hf_dataset}")
|
1212
|
+
if private:
|
1213
|
+
console.print("[dim]Privacy: Private dataset[/dim]")
|
1214
|
+
if nsfw:
|
1215
|
+
console.print("[dim]Content: Not for all audiences[/dim]")
|
1216
|
+
if tag_list:
|
1217
|
+
console.print(f"[dim]Tags: {', '.join(tag_list)}[/dim]")
|
1218
|
+
if shard_filter:
|
1219
|
+
console.print(f"[dim]Shards: {', '.join(shard_filter)}[/dim]")
|
1220
|
+
|
1221
|
+
url = await exporter.export_to_huggingface_hub(
|
1222
|
+
dataset_name=hf_dataset,
|
1223
|
+
license=license,
|
1224
|
+
private=private,
|
1225
|
+
nsfw=nsfw,
|
1226
|
+
tags=tag_list,
|
1227
|
+
shard_filter=shard_filter,
|
1228
|
+
)
|
1229
|
+
console.print(f"[green]✓ Dataset uploaded to: {url}[/green]")
|
1230
|
+
|
1231
|
+
|
1232
|
+
async def _export_single_format(
|
1233
|
+
exporter,
|
1234
|
+
format,
|
1235
|
+
output,
|
1236
|
+
shard_filter,
|
1237
|
+
column_list,
|
1238
|
+
limit,
|
1239
|
+
filename_column,
|
1240
|
+
export_column,
|
1241
|
+
verbose,
|
1242
|
+
):
|
1243
|
+
"""Export to a single format."""
|
1244
|
+
output_path = output or "export"
|
1245
|
+
|
1246
|
+
if shard_filter and len(shard_filter) == 1:
|
1247
|
+
console.print(f"\n[cyan]Exporting shard {shard_filter[0]} to {format.upper()}...[/cyan]")
|
1248
|
+
count = await exporter.export_shard(
|
1249
|
+
shard_filter[0],
|
1250
|
+
format,
|
1251
|
+
output_path,
|
1252
|
+
columns=column_list,
|
1253
|
+
limit=limit,
|
1254
|
+
filename_column=filename_column,
|
1255
|
+
export_column=export_column,
|
1256
|
+
)
|
1257
|
+
console.print(f"[green]✓ Exported {count:,} items[/green]")
|
1258
|
+
else:
|
1259
|
+
console.print(f"\n[cyan]Exporting to {format.upper()}...[/cyan]")
|
1260
|
+
results = await exporter.export_all_shards(
|
1261
|
+
format,
|
1262
|
+
output_path,
|
1263
|
+
columns=column_list,
|
1264
|
+
limit_per_shard=limit,
|
1265
|
+
shard_filter=shard_filter,
|
1266
|
+
filename_column=filename_column,
|
1267
|
+
export_column=export_column,
|
1268
|
+
)
|
1269
|
+
|
1270
|
+
total = sum(results.values())
|
1271
|
+
console.print(f"[green]✓ Exported {total:,} items total[/green]")
|
1272
|
+
|
1273
|
+
if verbose and len(results) > 1:
|
1274
|
+
console.print("\n[dim]Per-shard breakdown:[/dim]")
|
1275
|
+
for shard_name, count in sorted(results.items()):
|
1276
|
+
console.print(f" • {shard_name}: {count:,} items")
|
1277
|
+
|
1278
|
+
|
1279
|
+
def _validate_export_setup(data_dir):
|
1280
|
+
"""Validate export setup and create storage manager."""
|
1281
|
+
from .storage import StorageManager
|
1282
|
+
|
1283
|
+
storage_path = Path(data_dir)
|
1284
|
+
if not storage_path.exists():
|
1285
|
+
console.print(f"[red]Storage directory not found: {data_dir}[/red]")
|
1286
|
+
sys.exit(1)
|
1287
|
+
|
1288
|
+
return StorageManager(storage_path)
|
1289
|
+
|
1290
|
+
|
1291
|
+
async def _run_export_process(
|
1292
|
+
storage,
|
1293
|
+
format,
|
1294
|
+
output,
|
1295
|
+
shard,
|
1296
|
+
shards,
|
1297
|
+
columns,
|
1298
|
+
limit,
|
1299
|
+
filename_column,
|
1300
|
+
export_column,
|
1301
|
+
verbose,
|
1302
|
+
hf_dataset,
|
1303
|
+
license,
|
1304
|
+
private,
|
1305
|
+
nsfw,
|
1306
|
+
tags,
|
1307
|
+
stats_only,
|
1308
|
+
optimize,
|
1309
|
+
include_empty,
|
1310
|
+
):
|
1311
|
+
"""Execute the main export process."""
|
1312
|
+
from .storage.exporter import LanceStorageExporter
|
1313
|
+
|
1314
|
+
await storage.initialize()
|
1315
|
+
|
1316
|
+
stats = await storage.get_caption_stats()
|
1317
|
+
_display_export_stats(stats)
|
1318
|
+
|
1319
|
+
if stats_only:
|
1320
|
+
return
|
1321
|
+
|
1322
|
+
if optimize:
|
1323
|
+
console.print("\n[yellow]Optimizing storage...[/yellow]")
|
1324
|
+
await storage.optimize_storage()
|
1325
|
+
|
1326
|
+
shard_filter, column_list = _prepare_export_params(shard, shards, columns)
|
1327
|
+
exporter = LanceStorageExporter(storage)
|
1328
|
+
|
1329
|
+
if format == "all":
|
1330
|
+
await _export_all_formats(
|
1331
|
+
exporter, output, shard_filter, column_list, limit, filename_column, export_column
|
1332
|
+
)
|
1333
|
+
elif format == "lance":
|
1334
|
+
await _export_to_lance(exporter, output, column_list, shard_filter)
|
1335
|
+
elif format == "huggingface_hub":
|
1336
|
+
await _export_to_huggingface(
|
1337
|
+
exporter, hf_dataset, license, private, nsfw, tags, shard_filter
|
1338
|
+
)
|
1339
|
+
else:
|
1340
|
+
await _export_single_format(
|
1341
|
+
exporter,
|
1342
|
+
format,
|
1343
|
+
output,
|
1344
|
+
shard_filter,
|
1345
|
+
column_list,
|
1346
|
+
limit,
|
1347
|
+
filename_column,
|
1348
|
+
export_column,
|
1349
|
+
verbose,
|
1350
|
+
)
|
1351
|
+
|
1352
|
+
|
698
1353
|
@main.command()
|
699
1354
|
@click.option("--data-dir", default="./caption_data", help="Storage directory")
|
700
1355
|
@click.option(
|
701
1356
|
"--format",
|
702
1357
|
type=click.Choice(
|
703
|
-
["jsonl", "json", "csv", "txt", "huggingface_hub", "all"],
|
1358
|
+
["jsonl", "json", "csv", "txt", "parquet", "lance", "huggingface_hub", "all"],
|
1359
|
+
case_sensitive=False,
|
704
1360
|
),
|
705
1361
|
default="jsonl",
|
706
1362
|
help="Export format (default: jsonl)",
|
707
1363
|
)
|
708
|
-
@click.option("--output",
|
709
|
-
@click.option("--limit", type=int, help="
|
710
|
-
@click.option("--columns", help="Comma-separated list of columns to
|
711
|
-
@click.option("--export-column", default="captions", help="Column to export
|
712
|
-
@click.option("--filename-column", default="filename", help="
|
713
|
-
@click.option("--
|
714
|
-
@click.option("--
|
715
|
-
@click.option(
|
716
|
-
|
717
|
-
)
|
718
|
-
@click.option("--verbose", is_flag=True, help="
|
719
|
-
@click.option("--hf-dataset", help="
|
720
|
-
@click.option("--license", help="
|
721
|
-
@click.option("--private", is_flag=True, help="Make
|
722
|
-
@click.option("--nsfw", is_flag=True, help="
|
723
|
-
@click.option("--tags", help="Comma-separated tags for
|
1364
|
+
@click.option("--output", help="Output filename or directory")
|
1365
|
+
@click.option("--limit", type=int, help="Maximum number of items to export")
|
1366
|
+
@click.option("--columns", help="Comma-separated list of columns to include")
|
1367
|
+
@click.option("--export-column", default="captions", help="Column to export (default: captions)")
|
1368
|
+
@click.option("--filename-column", default="filename", help="Filename column (default: filename)")
|
1369
|
+
@click.option("--shard", help="Export only specific shard (e.g., 'data-001')")
|
1370
|
+
@click.option("--shards", help="Comma-separated list of shards to export")
|
1371
|
+
@click.option("--include-empty", is_flag=True, help="Include items with empty/null export column")
|
1372
|
+
@click.option("--stats-only", is_flag=True, help="Show statistics only, don't export")
|
1373
|
+
@click.option("--optimize", is_flag=True, help="Optimize storage before export")
|
1374
|
+
@click.option("--verbose", is_flag=True, help="Verbose output")
|
1375
|
+
@click.option("--hf-dataset", help="HuggingFace Hub dataset name (for huggingface_hub format)")
|
1376
|
+
@click.option("--license", default="MIT", help="Dataset license (default: MIT)")
|
1377
|
+
@click.option("--private", is_flag=True, help="Make HuggingFace dataset private")
|
1378
|
+
@click.option("--nsfw", is_flag=True, help="Mark dataset as NSFW")
|
1379
|
+
@click.option("--tags", help="Comma-separated tags for HuggingFace dataset")
|
724
1380
|
def export(
|
725
1381
|
data_dir: str,
|
726
1382
|
format: str,
|
@@ -729,219 +1385,57 @@ def export(
|
|
729
1385
|
columns: Optional[str],
|
730
1386
|
export_column: str,
|
731
1387
|
filename_column: str,
|
1388
|
+
shard: Optional[str],
|
1389
|
+
shards: Optional[str],
|
732
1390
|
include_empty: bool,
|
733
1391
|
stats_only: bool,
|
734
1392
|
optimize: bool,
|
735
1393
|
verbose: bool,
|
736
1394
|
hf_dataset: Optional[str],
|
737
|
-
license:
|
1395
|
+
license: str,
|
738
1396
|
private: bool,
|
739
1397
|
nsfw: bool,
|
740
1398
|
tags: Optional[str],
|
741
1399
|
):
|
742
|
-
"""Export caption data to various formats."""
|
743
|
-
from .storage import
|
744
|
-
from .storage.exporter import StorageExporter, ExportError
|
745
|
-
|
746
|
-
# Initialize storage manager
|
747
|
-
storage_path = Path(data_dir)
|
748
|
-
if not storage_path.exists():
|
749
|
-
console.print(f"[red]Storage directory not found: {data_dir}[/red]")
|
750
|
-
sys.exit(1)
|
751
|
-
|
752
|
-
storage = StorageManager(storage_path)
|
753
|
-
|
754
|
-
async def run_export():
|
755
|
-
await storage.initialize()
|
756
|
-
|
757
|
-
# Show statistics first
|
758
|
-
stats = await storage.get_caption_stats()
|
759
|
-
console.print("\n[bold cyan]Storage Statistics:[/bold cyan]")
|
760
|
-
console.print(f"[green]Total rows:[/green] {stats['total_rows']:,}")
|
761
|
-
console.print(f"[green]Total outputs:[/green] {stats['total_outputs']:,}")
|
762
|
-
console.print(f"[green]Output fields:[/green] {', '.join(stats['output_fields'])}")
|
763
|
-
|
764
|
-
if stats.get("field_stats"):
|
765
|
-
console.print("\n[cyan]Field breakdown:[/cyan]")
|
766
|
-
for field, field_stat in stats["field_stats"].items():
|
767
|
-
console.print(
|
768
|
-
f" • {field}: {field_stat['total_items']:,} items "
|
769
|
-
f"in {field_stat['rows_with_data']:,} rows"
|
770
|
-
)
|
771
|
-
|
772
|
-
if stats_only:
|
773
|
-
return
|
774
|
-
|
775
|
-
# Optimize storage if requested
|
776
|
-
if optimize:
|
777
|
-
console.print("\n[yellow]Optimizing storage (removing empty columns)...[/yellow]")
|
778
|
-
await storage.optimize_storage()
|
779
|
-
|
780
|
-
# Prepare columns list
|
781
|
-
column_list = None
|
782
|
-
if columns:
|
783
|
-
column_list = [col.strip() for col in columns.split(",")]
|
784
|
-
console.print(f"\n[cyan]Exporting columns:[/cyan] {', '.join(column_list)}")
|
785
|
-
|
786
|
-
# Get storage contents
|
787
|
-
console.print("\n[yellow]Loading data...[/yellow]")
|
788
|
-
try:
|
789
|
-
contents = await storage.get_storage_contents(
|
790
|
-
limit=limit, columns=column_list, include_metadata=True
|
791
|
-
)
|
792
|
-
except ValueError as e:
|
793
|
-
console.print(f"[red]Error: {e}[/red]")
|
794
|
-
sys.exit(1)
|
1400
|
+
"""Export caption data to various formats with per-shard support."""
|
1401
|
+
from .storage.exporter import ExportError
|
795
1402
|
|
796
|
-
|
797
|
-
console.print("[yellow]No data to export![/yellow]")
|
798
|
-
return
|
799
|
-
|
800
|
-
# Filter out empty rows if not including empty
|
801
|
-
if not include_empty and format in ["txt", "json"]:
|
802
|
-
original_count = len(contents.rows)
|
803
|
-
contents.rows = [
|
804
|
-
row
|
805
|
-
for row in contents.rows
|
806
|
-
if row.get(export_column)
|
807
|
-
and (not isinstance(row[export_column], list) or len(row[export_column]) > 0)
|
808
|
-
]
|
809
|
-
filtered_count = original_count - len(contents.rows)
|
810
|
-
if filtered_count > 0:
|
811
|
-
console.print(f"[dim]Filtered {filtered_count} empty rows[/dim]")
|
812
|
-
|
813
|
-
# Create exporter
|
814
|
-
exporter = StorageExporter(contents)
|
815
|
-
|
816
|
-
# Determine output paths
|
817
|
-
if format == "all":
|
818
|
-
# Export to all formats
|
819
|
-
base_name = output or "caption_export"
|
820
|
-
base_path = Path(base_name)
|
821
|
-
|
822
|
-
formats_exported = []
|
823
|
-
|
824
|
-
# JSONL
|
825
|
-
jsonl_path = base_path.with_suffix(".jsonl")
|
826
|
-
console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {jsonl_path}")
|
827
|
-
rows = exporter.to_jsonl(jsonl_path)
|
828
|
-
formats_exported.append(f"JSONL: {rows:,} rows")
|
829
|
-
|
830
|
-
# CSV
|
831
|
-
csv_path = base_path.with_suffix(".csv")
|
832
|
-
console.print(f"[cyan]Exporting to CSV:[/cyan] {csv_path}")
|
833
|
-
try:
|
834
|
-
rows = exporter.to_csv(csv_path)
|
835
|
-
formats_exported.append(f"CSV: {rows:,} rows")
|
836
|
-
except ExportError as e:
|
837
|
-
console.print(f"[yellow]Skipping CSV: {e}[/yellow]")
|
838
|
-
|
839
|
-
# JSON files
|
840
|
-
json_dir = base_path.parent / f"{base_path.stem}_json"
|
841
|
-
console.print(f"[cyan]Exporting to JSON files:[/cyan] {json_dir}/")
|
842
|
-
try:
|
843
|
-
files = exporter.to_json(json_dir, filename_column)
|
844
|
-
formats_exported.append(f"JSON: {files:,} files")
|
845
|
-
except ExportError as e:
|
846
|
-
console.print(f"[yellow]Skipping JSON files: {e}[/yellow]")
|
847
|
-
|
848
|
-
# Text files
|
849
|
-
txt_dir = base_path.parent / f"{base_path.stem}_txt"
|
850
|
-
console.print(f"[cyan]Exporting to text files:[/cyan] {txt_dir}/")
|
851
|
-
try:
|
852
|
-
files = exporter.to_txt(txt_dir, filename_column, export_column)
|
853
|
-
formats_exported.append(f"Text: {files:,} files")
|
854
|
-
except ExportError as e:
|
855
|
-
console.print(f"[yellow]Skipping text files: {e}[/yellow]")
|
856
|
-
|
857
|
-
console.print(f"\n[green]✓ Export complete![/green]")
|
858
|
-
for fmt in formats_exported:
|
859
|
-
console.print(f" • {fmt}")
|
1403
|
+
storage = _validate_export_setup(data_dir)
|
860
1404
|
|
861
|
-
else:
|
862
|
-
# Single format export
|
863
|
-
try:
|
864
|
-
if format == "jsonl":
|
865
|
-
output_path = output or "captions.jsonl"
|
866
|
-
console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {output_path}")
|
867
|
-
rows = exporter.to_jsonl(output_path)
|
868
|
-
console.print(f"[green]✓ Exported {rows:,} rows[/green]")
|
869
|
-
|
870
|
-
elif format == "csv":
|
871
|
-
output_path = output or "captions.csv"
|
872
|
-
console.print(f"\n[cyan]Exporting to CSV:[/cyan] {output_path}")
|
873
|
-
rows = exporter.to_csv(output_path)
|
874
|
-
console.print(f"[green]✓ Exported {rows:,} rows[/green]")
|
875
|
-
|
876
|
-
elif format == "json":
|
877
|
-
output_dir = output or "./json_output"
|
878
|
-
console.print(f"\n[cyan]Exporting to JSON files:[/cyan] {output_dir}/")
|
879
|
-
files = exporter.to_json(output_dir, filename_column)
|
880
|
-
console.print(f"[green]✓ Created {files:,} JSON files[/green]")
|
881
|
-
|
882
|
-
elif format == "txt":
|
883
|
-
output_dir = output or "./txt_output"
|
884
|
-
console.print(f"\n[cyan]Exporting to text files:[/cyan] {output_dir}/")
|
885
|
-
console.print(f"[dim]Export column: {export_column}[/dim]")
|
886
|
-
files = exporter.to_txt(output_dir, filename_column, export_column)
|
887
|
-
console.print(f"[green]✓ Created {files:,} text files[/green]")
|
888
|
-
|
889
|
-
elif format == "huggingface_hub":
|
890
|
-
# Validate required parameters
|
891
|
-
if not hf_dataset:
|
892
|
-
console.print(
|
893
|
-
"[red]Error: --hf-dataset required for huggingface_hub format[/red]"
|
894
|
-
)
|
895
|
-
console.print(
|
896
|
-
"[dim]Example: --hf-dataset username/my-caption-dataset[/dim]"
|
897
|
-
)
|
898
|
-
sys.exit(1)
|
899
|
-
|
900
|
-
# Parse tags
|
901
|
-
tag_list = None
|
902
|
-
if tags:
|
903
|
-
tag_list = [tag.strip() for tag in tags.split(",")]
|
904
|
-
|
905
|
-
console.print(f"\n[cyan]Uploading to Hugging Face Hub:[/cyan] {hf_dataset}")
|
906
|
-
if private:
|
907
|
-
console.print("[dim]Privacy: Private dataset[/dim]")
|
908
|
-
if nsfw:
|
909
|
-
console.print("[dim]Content: Not for all audiences[/dim]")
|
910
|
-
if tag_list:
|
911
|
-
console.print(f"[dim]Tags: {', '.join(tag_list)}[/dim]")
|
912
|
-
|
913
|
-
url = exporter.to_huggingface_hub(
|
914
|
-
dataset_name=hf_dataset,
|
915
|
-
license=license,
|
916
|
-
private=private,
|
917
|
-
nsfw=nsfw,
|
918
|
-
tags=tag_list,
|
919
|
-
)
|
920
|
-
console.print(f"[green]✓ Dataset uploaded to: {url}[/green]")
|
921
|
-
|
922
|
-
except ExportError as e:
|
923
|
-
console.print(f"[red]Export error: {e}[/red]")
|
924
|
-
sys.exit(1)
|
925
|
-
|
926
|
-
# Show export metadata
|
927
|
-
if verbose and contents.metadata:
|
928
|
-
console.print("\n[dim]Export metadata:[/dim]")
|
929
|
-
console.print(f" Timestamp: {contents.metadata.get('export_timestamp')}")
|
930
|
-
console.print(f" Total available: {contents.metadata.get('total_available_rows'):,}")
|
931
|
-
console.print(f" Rows exported: {contents.metadata.get('rows_exported'):,}")
|
932
|
-
|
933
|
-
# Run the async export
|
934
1405
|
try:
|
935
|
-
asyncio.run(
|
1406
|
+
asyncio.run(
|
1407
|
+
_run_export_process(
|
1408
|
+
storage,
|
1409
|
+
format,
|
1410
|
+
output,
|
1411
|
+
shard,
|
1412
|
+
shards,
|
1413
|
+
columns,
|
1414
|
+
limit,
|
1415
|
+
filename_column,
|
1416
|
+
export_column,
|
1417
|
+
verbose,
|
1418
|
+
hf_dataset,
|
1419
|
+
license,
|
1420
|
+
private,
|
1421
|
+
nsfw,
|
1422
|
+
tags,
|
1423
|
+
stats_only,
|
1424
|
+
optimize,
|
1425
|
+
include_empty,
|
1426
|
+
)
|
1427
|
+
)
|
1428
|
+
except ExportError as e:
|
1429
|
+
console.print(f"[red]Export error: {e}[/red]")
|
1430
|
+
sys.exit(1)
|
936
1431
|
except KeyboardInterrupt:
|
937
1432
|
console.print("\n[yellow]Export cancelled[/yellow]")
|
938
1433
|
sys.exit(1)
|
939
1434
|
except Exception as e:
|
940
1435
|
console.print(f"[red]Unexpected error: {e}[/red]")
|
941
|
-
|
942
|
-
import traceback
|
1436
|
+
import traceback
|
943
1437
|
|
944
|
-
|
1438
|
+
traceback.print_exc()
|
945
1439
|
sys.exit(1)
|
946
1440
|
|
947
1441
|
|
@@ -963,7 +1457,7 @@ def generate_cert(
|
|
963
1457
|
cert_path, key_path = cert_manager.generate_self_signed(Path(output_dir), cert_domain)
|
964
1458
|
console.print(f"[green]✓[/green] Certificate: {cert_path}")
|
965
1459
|
console.print(f"[green]✓[/green] Key: {key_path}")
|
966
|
-
console.print(
|
1460
|
+
console.print("\n[cyan]Use these paths in your config or CLI:[/cyan]")
|
967
1461
|
console.print(f" --cert {cert_path}")
|
968
1462
|
console.print(f" --key {key_path}")
|
969
1463
|
elif domain and email:
|
@@ -980,7 +1474,7 @@ def generate_cert(
|
|
980
1474
|
)
|
981
1475
|
console.print(f"[green]✓[/green] Certificate: {cert_path}")
|
982
1476
|
console.print(f"[green]✓[/green] Key: {key_path}")
|
983
|
-
console.print(
|
1477
|
+
console.print("\n[cyan]Use these paths in your config or CLI:[/cyan]")
|
984
1478
|
console.print(f" --cert {cert_path}")
|
985
1479
|
console.print(f" --key {key_path}")
|
986
1480
|
|
@@ -1024,13 +1518,13 @@ def inspect_cert(cert_path: str):
|
|
1024
1518
|
|
1025
1519
|
from datetime import datetime
|
1026
1520
|
|
1027
|
-
if info["not_after"] < datetime.
|
1521
|
+
if info["not_after"] < datetime.now(_datetime.UTC):
|
1028
1522
|
console.print("[red]✗ Certificate has expired![/red]")
|
1029
|
-
elif (info["not_after"] - datetime.
|
1030
|
-
days_left = (info["not_after"] - datetime.
|
1523
|
+
elif (info["not_after"] - datetime.now(_datetime.UTC)).days < 30:
|
1524
|
+
days_left = (info["not_after"] - datetime.now(_datetime.UTC)).days
|
1031
1525
|
console.print(f"[yellow]⚠ Certificate expires in {days_left} days[/yellow]")
|
1032
1526
|
else:
|
1033
|
-
days_left = (info["not_after"] - datetime.
|
1527
|
+
days_left = (info["not_after"] - datetime.now(_datetime.UTC)).days
|
1034
1528
|
console.print(f"[green]✓ Certificate valid for {days_left} more days[/green]")
|
1035
1529
|
|
1036
1530
|
except Exception as e:
|