caption-flow 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +937 -416
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +5 -3
  5. caption_flow/orchestrator.py +186 -116
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +440 -68
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +66 -25
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +41 -19
  18. caption_flow/utils/chunk_tracker.py +200 -65
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +12 -6
  25. caption_flow/workers/caption.py +272 -91
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
  28. caption_flow-0.4.0.dist-info/RECORD +33 -0
  29. caption_flow-0.3.3.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
caption_flow/cli.py CHANGED
@@ -1,21 +1,22 @@
1
1
  """Command-line interface for CaptionFlow with smart configuration handling."""
2
2
 
3
3
  import asyncio
4
+ import datetime as _datetime
4
5
  import json
5
6
  import logging
6
7
  import os
7
8
  import sys
9
+ from datetime import datetime
8
10
  from pathlib import Path
9
- from typing import Optional, Dict, Any, List
11
+ from typing import Any, Dict, List, Optional
10
12
 
11
13
  import click
12
14
  import yaml
13
15
  from rich.console import Console
14
16
  from rich.logging import RichHandler
15
- from datetime import datetime
16
17
 
17
- from .orchestrator import Orchestrator
18
18
  from .monitor import Monitor
19
+ from .orchestrator import Orchestrator
19
20
  from .utils.certificates import CertificateManager
20
21
 
21
22
  console = Console()
@@ -48,8 +49,7 @@ class ConfigManager:
48
49
  def find_config(
49
50
  cls, component: str, explicit_path: Optional[str] = None
50
51
  ) -> Optional[Dict[str, Any]]:
51
- """
52
- Find and load configuration for a component.
52
+ """Find and load configuration for a component.
53
53
 
54
54
  Search order:
55
55
  1. Explicit path if provided
@@ -120,22 +120,76 @@ class ConfigManager:
120
120
 
121
121
 
122
122
  def setup_logging(verbose: bool = False):
123
- """Configure logging with rich handler, including timestamp."""
123
+ """Configure logging with rich handler and file output to XDG state directory."""
124
124
  level = logging.DEBUG if verbose else logging.INFO
125
- logging.basicConfig(
126
- level=level,
127
- format="%(message)s",
128
- datefmt="[%Y-%m-%d %H:%M:%S]",
129
- handlers=[
125
+
126
+ # Determine log directory based on environment or XDG spec
127
+ log_dir_env = os.environ.get("CAPTIONFLOW_LOG_DIR")
128
+ if log_dir_env:
129
+ log_dir = Path(log_dir_env)
130
+ else:
131
+ # Use XDG_STATE_HOME for logs, with platform-specific fallbacks
132
+ xdg_state_home = os.environ.get("XDG_STATE_HOME")
133
+ if xdg_state_home:
134
+ base_dir = Path(xdg_state_home)
135
+ elif sys.platform == "darwin":
136
+ base_dir = Path.home() / "Library" / "Logs"
137
+ else:
138
+ # Default to ~/.local/state on Linux and other systems
139
+ base_dir = Path.home() / ".local" / "state"
140
+ log_dir = base_dir / "caption-flow"
141
+
142
+ try:
143
+ # Ensure log directory exists
144
+ log_dir.mkdir(parents=True, exist_ok=True)
145
+ log_file_path = log_dir / "caption_flow.log"
146
+
147
+ # Set up handlers
148
+ handlers: List[logging.Handler] = [
149
+ RichHandler(
150
+ console=console,
151
+ rich_tracebacks=True,
152
+ show_path=False,
153
+ show_time=True,
154
+ )
155
+ ]
156
+
157
+ # Add file handler
158
+ file_handler = logging.FileHandler(log_file_path, mode="a")
159
+ file_handler.setFormatter(
160
+ logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
161
+ )
162
+ handlers.append(file_handler)
163
+ log_msg = f"Logging to {log_file_path}"
164
+
165
+ except (OSError, PermissionError) as e:
166
+ # Fallback to only console logging if file logging fails
167
+ handlers = [
130
168
  RichHandler(
131
169
  console=console,
132
170
  rich_tracebacks=True,
133
171
  show_path=False,
134
- show_time=True, # Enables timestamp in RichHandler output
172
+ show_time=True,
135
173
  )
136
- ],
174
+ ]
175
+ log_file = log_dir / "caption_flow.log"
176
+ log_msg = f"[yellow]Warning: Could not write to log file {log_file}: {e}[/yellow]"
177
+
178
+ logging.basicConfig(
179
+ level=level,
180
+ format="%(message)s", # RichHandler overrides this format for console
181
+ datefmt="[%Y-%m-%d %H:%M:%S]",
182
+ handlers=handlers,
137
183
  )
138
184
 
185
+ # Suppress noisy libraries
186
+ logging.getLogger("websockets").setLevel(logging.WARNING)
187
+ logging.getLogger("pyarrow").setLevel(logging.WARNING)
188
+
189
+ # Use a dedicated logger to print the log file path to avoid format issues
190
+ if "log_msg" in locals():
191
+ logging.getLogger("setup").info(log_msg)
192
+
139
193
 
140
194
  def apply_cli_overrides(config: Dict[str, Any], **kwargs) -> Dict[str, Any]:
141
195
  """Apply CLI arguments as overrides to config, filtering out None values."""
@@ -189,9 +243,11 @@ def orchestrator(ctx, config: Optional[str], **kwargs):
189
243
  config_data["ssl"]["cert"] = kwargs["cert"]
190
244
  config_data["ssl"]["key"] = kwargs["key"]
191
245
  elif not config_data.get("ssl"):
192
- console.print(
193
- "[yellow]Warning: Running without SSL. Use --cert and --key for production.[/yellow]"
246
+ warning_msg = (
247
+ "[yellow]Warning: Running without SSL. "
248
+ "Use --cert and --key for production.[/yellow]"
194
249
  )
250
+ console.print(warning_msg)
195
251
 
196
252
  if kwargs.get("vllm") and "vllm" not in config_data:
197
253
  raise ValueError("Must provide vLLM config.")
@@ -259,33 +315,11 @@ def worker(ctx, config: Optional[str], **kwargs):
259
315
  asyncio.run(worker_instance.shutdown())
260
316
 
261
317
 
262
- @main.command()
263
- @click.option("--config", type=click.Path(exists=True), help="Configuration file")
264
- @click.option("--server", help="Orchestrator WebSocket URL")
265
- @click.option("--token", help="Authentication token")
266
- @click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
267
- @click.option("--debug", is_flag=True, help="Enable debug output")
268
- @click.pass_context
269
- def monitor(
270
- ctx,
271
- config: Optional[str],
272
- server: Optional[str],
273
- token: Optional[str],
274
- no_verify_ssl: bool,
275
- debug: bool,
276
- ):
277
- """Start the monitoring TUI."""
278
-
279
- # Enable debug logging if requested
280
- if debug:
281
- setup_logging(verbose=True)
282
- console.print("[yellow]Debug mode enabled[/yellow]")
283
-
284
- # Load configuration
318
+ def _load_monitor_config(config, server, token):
319
+ """Load monitor configuration from file or fallback to orchestrator config."""
285
320
  base_config = ConfigManager.find_config("monitor", config)
286
321
 
287
322
  if not base_config:
288
- # Try to find monitor config in orchestrator config as fallback
289
323
  orch_config = ConfigManager.find_config("orchestrator")
290
324
  if orch_config and "monitor" in orch_config:
291
325
  base_config = {"monitor": orch_config["monitor"]}
@@ -295,15 +329,11 @@ def monitor(
295
329
  if not server or not token:
296
330
  console.print("[yellow]No monitor config found, using CLI args[/yellow]")
297
331
 
298
- # Handle different config structures
299
- # Case 1: Config has top-level 'monitor' section
300
- if "monitor" in base_config:
301
- config_data = base_config["monitor"]
302
- # Case 2: Config IS the monitor config (no wrapper)
303
- else:
304
- config_data = base_config
332
+ return base_config.get("monitor", base_config)
333
+
305
334
 
306
- # Apply CLI overrides (CLI always wins)
335
+ def _apply_monitor_overrides(config_data, server, token, no_verify_ssl):
336
+ """Apply CLI overrides to monitor configuration."""
307
337
  if server:
308
338
  config_data["server"] = server
309
339
  if token:
@@ -311,17 +341,20 @@ def monitor(
311
341
  if no_verify_ssl:
312
342
  config_data["verify_ssl"] = False
313
343
 
314
- # Debug output
315
- if debug:
316
- console.print("\n[cyan]Final monitor configuration:[/cyan]")
317
- console.print(f" Server: {config_data.get('server', 'NOT SET')}")
318
- console.print(
319
- f" Token: {'***' + config_data.get('token', '')[-4:] if config_data.get('token') else 'NOT SET'}"
320
- )
321
- console.print(f" Verify SSL: {config_data.get('verify_ssl', True)}")
322
- console.print()
323
344
 
324
- # Validate required fields
345
+ def _debug_monitor_config(config_data):
346
+ """Print debug information about monitor configuration."""
347
+ console.print("\n[cyan]Final monitor configuration:[/cyan]")
348
+ console.print(f" Server: {config_data.get('server', 'NOT SET')}")
349
+ console.print(
350
+ f" Token: {'***' + config_data.get('token', '')[-4:] if config_data.get('token') else 'NOT SET'}"
351
+ )
352
+ console.print(f" Verify SSL: {config_data.get('verify_ssl', True)}")
353
+ console.print()
354
+
355
+
356
+ def _validate_monitor_config(config_data):
357
+ """Validate required monitor configuration fields."""
325
358
  if not config_data.get("server"):
326
359
  console.print("[red]Error: --server required (or set 'server' in monitor.yaml)[/red]")
327
360
  console.print("\n[dim]Example monitor.yaml:[/dim]")
@@ -336,12 +369,43 @@ def monitor(
336
369
  console.print("token: your-token-here")
337
370
  sys.exit(1)
338
371
 
339
- # Set defaults for optional settings
372
+
373
+ def _set_monitor_defaults(config_data):
374
+ """Set default values for optional monitor settings."""
340
375
  config_data.setdefault("refresh_interval", 1.0)
341
376
  config_data.setdefault("show_inactive_workers", False)
342
377
  config_data.setdefault("max_log_lines", 100)
343
378
 
344
- # Create and start monitor
379
+
380
+ @main.command()
381
+ @click.option("--config", type=click.Path(exists=True), help="Configuration file")
382
+ @click.option("--server", help="Orchestrator WebSocket URL")
383
+ @click.option("--token", help="Authentication token")
384
+ @click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
385
+ @click.option("--debug", is_flag=True, help="Enable debug output")
386
+ @click.pass_context
387
+ def monitor(
388
+ ctx,
389
+ config: Optional[str],
390
+ server: Optional[str],
391
+ token: Optional[str],
392
+ no_verify_ssl: bool,
393
+ debug: bool,
394
+ ):
395
+ """Start the monitoring TUI."""
396
+ if debug:
397
+ setup_logging(verbose=True)
398
+ console.print("[yellow]Debug mode enabled[/yellow]")
399
+
400
+ config_data = _load_monitor_config(config, server, token)
401
+ _apply_monitor_overrides(config_data, server, token, no_verify_ssl)
402
+
403
+ if debug:
404
+ _debug_monitor_config(config_data)
405
+
406
+ _validate_monitor_config(config_data)
407
+ _set_monitor_defaults(config_data)
408
+
345
409
  try:
346
410
  monitor_instance = Monitor(config_data)
347
411
 
@@ -406,7 +470,7 @@ def view(ctx, data_dir: str, refresh_rate: int, no_images: bool):
406
470
  viewer.disable_images = True
407
471
  viewer.refresh_rate = refresh_rate
408
472
 
409
- console.print(f"[cyan]Starting dataset viewer...[/cyan]")
473
+ console.print("[cyan]Starting dataset viewer...[/cyan]")
410
474
  console.print(f"[dim]Data directory: {data_path}[/dim]")
411
475
 
412
476
  asyncio.run(viewer.run())
@@ -424,6 +488,400 @@ def view(ctx, data_dir: str, refresh_rate: int, no_images: bool):
424
488
  sys.exit(1)
425
489
 
426
490
 
491
+ def _load_admin_credentials(config, server, token):
492
+ """Load admin server and token from config if not provided."""
493
+ if server and token:
494
+ return server, token
495
+
496
+ base_config = ConfigManager.find_config("orchestrator", config) or {}
497
+ admin_config = base_config.get("admin", {})
498
+ admin_tokens = base_config.get("orchestrator", {}).get("auth", {}).get("admin_tokens", [])
499
+
500
+ final_server = server or admin_config.get("server", "ws://localhost:8765")
501
+ final_token = token or admin_config.get("token")
502
+
503
+ if not final_token and admin_tokens:
504
+ console.print("Using first admin token.")
505
+ final_token = admin_tokens[0].get("token")
506
+
507
+ return final_server, final_token
508
+
509
+
510
+ def _setup_ssl_context(server, no_verify_ssl):
511
+ """Setup SSL context for websocket connection."""
512
+ import ssl
513
+
514
+ ssl_context = None
515
+ if server.startswith("wss://"):
516
+ ssl_context = ssl.create_default_context()
517
+ if no_verify_ssl:
518
+ ssl_context.check_hostname = False
519
+ ssl_context.verify_mode = ssl.CERT_NONE
520
+
521
+ return ssl_context
522
+
523
+
524
+ async def _authenticate_admin(websocket, token):
525
+ """Authenticate as admin with the websocket."""
526
+ await websocket.send(json.dumps({"token": token, "role": "admin"}))
527
+
528
+ response = await websocket.recv()
529
+ auth_response = json.loads(response)
530
+
531
+ if "error" in auth_response:
532
+ console.print(f"[red]Authentication failed: {auth_response['error']}[/red]")
533
+ return False
534
+
535
+ console.print("[green]✓ Authenticated as admin[/green]")
536
+ return True
537
+
538
+
539
+ async def _send_reload_command(websocket, new_cfg):
540
+ """Send reload command and handle response."""
541
+ await websocket.send(json.dumps({"type": "reload_config", "config": new_cfg}))
542
+
543
+ response = await websocket.recv()
544
+ reload_response = json.loads(response)
545
+
546
+ if reload_response.get("type") == "reload_complete":
547
+ if "message" in reload_response and "No changes" in reload_response["message"]:
548
+ console.print(f"[yellow]{reload_response['message']}[/yellow]")
549
+ else:
550
+ console.print("[green]✓ Configuration reloaded successfully![/green]")
551
+
552
+ if "updated" in reload_response and reload_response["updated"]:
553
+ console.print("\n[cyan]Updated sections:[/cyan]")
554
+ for section in reload_response["updated"]:
555
+ console.print(f" • {section}")
556
+
557
+ if "warnings" in reload_response and reload_response["warnings"]:
558
+ console.print("\n[yellow]Warnings:[/yellow]")
559
+ for warning in reload_response["warnings"]:
560
+ console.print(f" ⚠ {warning}")
561
+
562
+ return True
563
+ else:
564
+ error = reload_response.get("error", "Unknown error")
565
+ console.print(f"[red]Reload failed: {error} ({reload_response=})[/red]")
566
+ return False
567
+
568
+
569
+ def _add_token_to_config(config_data: Dict[str, Any], role: str, name: str, token: str) -> bool:
570
+ """Add a new token to the config data."""
571
+ # Ensure the auth section exists
572
+ if "orchestrator" not in config_data:
573
+ config_data["orchestrator"] = {}
574
+ if "auth" not in config_data["orchestrator"]:
575
+ config_data["orchestrator"]["auth"] = {}
576
+
577
+ auth_config = config_data["orchestrator"]["auth"]
578
+ token_key = f"{role}_tokens"
579
+
580
+ # Initialize token list if it doesn't exist
581
+ if token_key not in auth_config:
582
+ auth_config[token_key] = []
583
+
584
+ # Check if token already exists
585
+ for existing_token in auth_config[token_key]:
586
+ if existing_token.get("token") == token:
587
+ console.print(f"[yellow]Token already exists for {role}: {name}[/yellow]")
588
+ return False
589
+ if existing_token.get("name") == name:
590
+ console.print(f"[yellow]Name already exists for {role}: {name}[/yellow]")
591
+ return False
592
+
593
+ # Add the new token
594
+ auth_config[token_key].append({"name": name, "token": token})
595
+ console.print(f"[green]✓ Added {role} token for {name}[/green]")
596
+ return True
597
+
598
+
599
+ def _remove_token_from_config(config_data: Dict[str, Any], role: str, identifier: str) -> bool:
600
+ """Remove a token from the config data by name or token."""
601
+ auth_config = config_data.get("orchestrator", {}).get("auth", {})
602
+ token_key = f"{role}_tokens"
603
+
604
+ if token_key not in auth_config:
605
+ console.print(f"[red]No {role} tokens found in config[/red]")
606
+ return False
607
+
608
+ tokens = auth_config[token_key]
609
+ removed = False
610
+
611
+ for i, token_entry in enumerate(tokens):
612
+ if token_entry.get("name") == identifier or token_entry.get("token") == identifier:
613
+ removed_entry = tokens.pop(i)
614
+ console.print(f"[green]✓ Removed {role} token: {removed_entry['name']}[/green]")
615
+ removed = True
616
+ break
617
+
618
+ if not removed:
619
+ console.print(f"[red]Token not found for {role}: {identifier}[/red]")
620
+
621
+ return removed
622
+
623
+
624
+ def _list_tokens_in_config(config_data: Dict[str, Any], role: Optional[str] = None):
625
+ """List tokens in the config data."""
626
+ auth_config = config_data.get("orchestrator", {}).get("auth", {})
627
+
628
+ if not auth_config:
629
+ console.print("[yellow]No auth configuration found[/yellow]")
630
+ return
631
+
632
+ roles_to_show = [role] if role else ["worker", "admin", "monitor"]
633
+
634
+ for token_role in roles_to_show:
635
+ token_key = f"{token_role}_tokens"
636
+ tokens = auth_config.get(token_key, [])
637
+
638
+ if tokens:
639
+ console.print(f"\n[cyan]{token_role.title()} tokens:[/cyan]")
640
+ for token_entry in tokens:
641
+ name = token_entry.get("name", "Unknown")
642
+ token = token_entry.get("token", "")
643
+ masked_token = f"***{token[-4:]}" if len(token) > 4 else "***"
644
+ console.print(f" • {name}: {masked_token}")
645
+ else:
646
+ console.print(f"\n[dim]No {token_role} tokens configured[/dim]")
647
+
648
+
649
+ def _save_config_file(config_data: Dict[str, Any], config_path: Path) -> bool:
650
+ """Save the config data to a file."""
651
+ try:
652
+ with open(config_path, "w") as f:
653
+ yaml.safe_dump(config_data, f, default_flow_style=False, sort_keys=False)
654
+ console.print(f"[green]✓ Configuration saved to {config_path}[/green]")
655
+ return True
656
+ except Exception as e:
657
+ console.print(f"[red]Error saving config: {e}[/red]")
658
+ return False
659
+
660
+
661
+ async def _reload_orchestrator_config(
662
+ server: str, token: str, config_data: Dict[str, Any], no_verify_ssl: bool
663
+ ) -> bool:
664
+ """Reload the orchestrator configuration."""
665
+ import websockets
666
+
667
+ ssl_context = _setup_ssl_context(server, no_verify_ssl)
668
+
669
+ try:
670
+ async with websockets.connect(
671
+ server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
672
+ ) as websocket:
673
+ if not await _authenticate_admin(websocket, token):
674
+ return False
675
+
676
+ return await _send_reload_command(websocket, config_data)
677
+ except Exception as e:
678
+ console.print(f"[red]Error connecting to orchestrator: {e}[/red]")
679
+ return False
680
+
681
+
682
+ @main.group()
683
+ @click.option("--config", type=click.Path(exists=True), help="Configuration file")
684
+ @click.option("--server", help="Orchestrator WebSocket URL")
685
+ @click.option("--token", help="Admin authentication token")
686
+ @click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
687
+ @click.pass_context
688
+ def auth(
689
+ ctx, config: Optional[str], server: Optional[str], token: Optional[str], no_verify_ssl: bool
690
+ ):
691
+ """Manage authentication tokens for the orchestrator."""
692
+ ctx.ensure_object(dict)
693
+ ctx.obj.update(
694
+ {"config": config, "server": server, "token": token, "no_verify_ssl": no_verify_ssl}
695
+ )
696
+
697
+
698
+ @auth.command()
699
+ @click.argument("role", type=click.Choice(["worker", "admin", "monitor"]))
700
+ @click.argument("name")
701
+ @click.argument("token_value")
702
+ @click.option(
703
+ "--no-reload", is_flag=True, help="Don't reload orchestrator config after adding token"
704
+ )
705
+ @click.pass_context
706
+ def add(ctx, role: str, name: str, token_value: str, no_reload: bool):
707
+ """Add a new authentication token.
708
+
709
+ ROLE: Type of token (worker, admin, monitor)
710
+ NAME: Display name for the token
711
+ TOKEN_VALUE: The actual token string
712
+ """
713
+ config_file = ctx.obj.get("config")
714
+ server = ctx.obj.get("server")
715
+ admin_token = ctx.obj.get("token")
716
+ no_verify_ssl = ctx.obj.get("no_verify_ssl", False)
717
+
718
+ # Load config
719
+ config_data = ConfigManager.find_config("orchestrator", config_file)
720
+ if not config_data:
721
+ console.print("[red]No orchestrator config found[/red]")
722
+ console.print("[dim]Use --config to specify config file path[/dim]")
723
+ sys.exit(1)
724
+
725
+ # Find config file path for saving
726
+ config_path = None
727
+ if config_file:
728
+ config_path = Path(config_file)
729
+ else:
730
+ # Try to find the config file that was loaded
731
+ for search_path in [
732
+ Path.cwd() / "orchestrator.yaml",
733
+ Path.cwd() / "config" / "orchestrator.yaml",
734
+ Path.home() / ".caption-flow" / "orchestrator.yaml",
735
+ ConfigManager.get_xdg_config_home() / "caption-flow" / "orchestrator.yaml",
736
+ ]:
737
+ if search_path.exists():
738
+ config_path = search_path
739
+ break
740
+
741
+ if not config_path:
742
+ console.print("[red]Could not determine config file to save to[/red]")
743
+ console.print("[dim]Use --config to specify config file path[/dim]")
744
+ sys.exit(1)
745
+
746
+ # Add token to config
747
+ if not _add_token_to_config(config_data, role, name, token_value):
748
+ sys.exit(1)
749
+
750
+ # Save config file
751
+ if not _save_config_file(config_data, config_path):
752
+ sys.exit(1)
753
+
754
+ # Reload orchestrator if requested
755
+ if not no_reload:
756
+ server, admin_token = _load_admin_credentials(config_file, server, admin_token)
757
+
758
+ if not server:
759
+ console.print("[yellow]No server specified, skipping orchestrator reload[/yellow]")
760
+ console.print("[dim]Use --server to reload orchestrator config[/dim]")
761
+ elif not admin_token:
762
+ console.print("[yellow]No admin token specified, skipping orchestrator reload[/yellow]")
763
+ console.print("[dim]Use --token to reload orchestrator config[/dim]")
764
+ else:
765
+ console.print(f"[cyan]Reloading orchestrator config...[/cyan]")
766
+ success = asyncio.run(
767
+ _reload_orchestrator_config(server, admin_token, config_data, no_verify_ssl)
768
+ )
769
+ if not success:
770
+ console.print("[yellow]Config file updated but orchestrator reload failed[/yellow]")
771
+ console.print("[dim]You may need to restart the orchestrator manually[/dim]")
772
+
773
+
774
+ @auth.command()
775
+ @click.argument("role", type=click.Choice(["worker", "admin", "monitor"]))
776
+ @click.argument("identifier")
777
+ @click.option(
778
+ "--no-reload", is_flag=True, help="Don't reload orchestrator config after removing token"
779
+ )
780
+ @click.pass_context
781
+ def remove(ctx, role: str, identifier: str, no_reload: bool):
782
+ """Remove an authentication token.
783
+
784
+ ROLE: Type of token (worker, admin, monitor)
785
+ IDENTIFIER: Name or token value to remove
786
+ """
787
+ config_file = ctx.obj.get("config")
788
+ server = ctx.obj.get("server")
789
+ admin_token = ctx.obj.get("token")
790
+ no_verify_ssl = ctx.obj.get("no_verify_ssl", False)
791
+
792
+ # Load config
793
+ config_data = ConfigManager.find_config("orchestrator", config_file)
794
+ if not config_data:
795
+ console.print("[red]No orchestrator config found[/red]")
796
+ sys.exit(1)
797
+
798
+ # Find config file path for saving
799
+ config_path = None
800
+ if config_file:
801
+ config_path = Path(config_file)
802
+ else:
803
+ # Try to find the config file that was loaded
804
+ for search_path in [
805
+ Path.cwd() / "orchestrator.yaml",
806
+ Path.cwd() / "config" / "orchestrator.yaml",
807
+ Path.home() / ".caption-flow" / "orchestrator.yaml",
808
+ ConfigManager.get_xdg_config_home() / "caption-flow" / "orchestrator.yaml",
809
+ ]:
810
+ if search_path.exists():
811
+ config_path = search_path
812
+ break
813
+
814
+ if not config_path:
815
+ console.print("[red]Could not determine config file to save to[/red]")
816
+ sys.exit(1)
817
+
818
+ # Remove token from config
819
+ if not _remove_token_from_config(config_data, role, identifier):
820
+ sys.exit(1)
821
+
822
+ # Save config file
823
+ if not _save_config_file(config_data, config_path):
824
+ sys.exit(1)
825
+
826
+ # Reload orchestrator if requested
827
+ if not no_reload:
828
+ server, admin_token = _load_admin_credentials(config_file, server, admin_token)
829
+
830
+ if not server:
831
+ console.print("[yellow]No server specified, skipping orchestrator reload[/yellow]")
832
+ elif not admin_token:
833
+ console.print("[yellow]No admin token specified, skipping orchestrator reload[/yellow]")
834
+ else:
835
+ console.print(f"[cyan]Reloading orchestrator config...[/cyan]")
836
+ success = asyncio.run(
837
+ _reload_orchestrator_config(server, admin_token, config_data, no_verify_ssl)
838
+ )
839
+ if not success:
840
+ console.print("[yellow]Config file updated but orchestrator reload failed[/yellow]")
841
+
842
+
843
+ @auth.command()
844
+ @click.argument("role", type=click.Choice(["worker", "admin", "monitor", "all"]), required=False)
845
+ @click.pass_context
846
+ def list(ctx, role: Optional[str]):
847
+ """List authentication tokens.
848
+
849
+ ROLE: Type of tokens to list (worker, admin, monitor, all). Default: all
850
+ """
851
+ config_file = ctx.obj.get("config")
852
+
853
+ # Load config
854
+ config_data = ConfigManager.find_config("orchestrator", config_file)
855
+ if not config_data:
856
+ console.print("[red]No orchestrator config found[/red]")
857
+ sys.exit(1)
858
+
859
+ # Show tokens
860
+ if role == "all" or role is None:
861
+ _list_tokens_in_config(config_data)
862
+ else:
863
+ _list_tokens_in_config(config_data, role)
864
+
865
+
866
+ @auth.command()
867
+ @click.option("--length", default=32, help="Token length (default: 32)")
868
+ @click.option("--count", default=1, help="Number of tokens to generate (default: 1)")
869
+ def generate(length: int, count: int):
870
+ """Generate random authentication tokens."""
871
+ import secrets
872
+ import string
873
+
874
+ alphabet = string.ascii_letters + string.digits + "-_"
875
+
876
+ console.print(
877
+ f"[cyan]Generated {count} token{'s' if count > 1 else ''} ({length} characters each):[/cyan]\n"
878
+ )
879
+
880
+ for i in range(count):
881
+ token = "".join(secrets.choice(alphabet) for _ in range(length))
882
+ console.print(f" {i + 1}: {token}")
883
+
884
+
427
885
  @main.command()
428
886
  @click.option("--config", type=click.Path(exists=True), help="Configuration file")
429
887
  @click.option("--server", help="Orchestrator WebSocket URL")
@@ -441,27 +899,8 @@ def reload_config(
441
899
  ):
442
900
  """Reload orchestrator configuration via admin connection."""
443
901
  import websockets
444
- import ssl
445
902
 
446
- # Load base config to get server/token if not provided via CLI
447
- if not server or not token:
448
- base_config = ConfigManager.find_config("orchestrator", config) or {}
449
- admin_config = base_config.get("admin", {})
450
- admin_tokens = base_config.get("orchestrator", {}).get("auth", {}).get("admin_tokens", [])
451
- has_admin_tokens = False
452
- if len(admin_tokens) > 0:
453
- has_admin_tokens = True
454
- first_admin_token = admin_tokens[0].get("token", None)
455
- # Do not print sensitive admin token to console.
456
-
457
- if not server:
458
- server = admin_config.get("server", "ws://localhost:8765")
459
- if not token:
460
- token = admin_config.get("token", None)
461
- if token is None and has_admin_tokens:
462
- # grab the first one, we'll just assume we're localhost.
463
- console.print("Using first admin token.")
464
- token = first_admin_token
903
+ server, token = _load_admin_credentials(config, server, token)
465
904
 
466
905
  if not server:
467
906
  console.print("[red]Error: --server required (or set in config)[/red]")
@@ -472,64 +911,22 @@ def reload_config(
472
911
 
473
912
  console.print(f"[cyan]Loading configuration from {new_config}...[/cyan]")
474
913
 
475
- # Load the new configuration
476
914
  new_cfg = ConfigManager.load_yaml(Path(new_config))
477
915
  if not new_cfg:
478
916
  console.print("[red]Failed to load configuration[/red]")
479
917
  sys.exit(1)
480
918
 
481
- # Setup SSL
482
- ssl_context = None
483
- if server.startswith("wss://"):
484
- if no_verify_ssl:
485
- ssl_context = ssl.create_default_context()
486
- ssl_context.check_hostname = False
487
- ssl_context.verify_mode = ssl.CERT_NONE
488
- else:
489
- ssl_context = ssl.create_default_context()
919
+ ssl_context = _setup_ssl_context(server, no_verify_ssl)
490
920
 
491
921
  async def send_reload():
492
922
  try:
493
- async with websockets.connect(server, ssl=ssl_context) as websocket:
494
- # Authenticate as admin
495
- await websocket.send(json.dumps({"token": token, "role": "admin"}))
496
-
497
- response = await websocket.recv()
498
- auth_response = json.loads(response)
499
-
500
- if "error" in auth_response:
501
- console.print(f"[red]Authentication failed: {auth_response['error']}[/red]")
923
+ async with websockets.connect(
924
+ server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
925
+ ) as websocket:
926
+ if not await _authenticate_admin(websocket, token):
502
927
  return False
503
928
 
504
- console.print("[green]✓ Authenticated as admin[/green]")
505
-
506
- # Send reload command
507
- await websocket.send(json.dumps({"type": "reload_config", "config": new_cfg}))
508
-
509
- response = await websocket.recv()
510
- reload_response = json.loads(response)
511
-
512
- if reload_response.get("type") == "reload_complete":
513
- if "message" in reload_response and "No changes" in reload_response["message"]:
514
- console.print(f"[yellow]{reload_response['message']}[/yellow]")
515
- else:
516
- console.print("[green]✓ Configuration reloaded successfully![/green]")
517
-
518
- if "updated" in reload_response and reload_response["updated"]:
519
- console.print("\n[cyan]Updated sections:[/cyan]")
520
- for section in reload_response["updated"]:
521
- console.print(f" • {section}")
522
-
523
- if "warnings" in reload_response and reload_response["warnings"]:
524
- console.print("\n[yellow]Warnings:[/yellow]")
525
- for warning in reload_response["warnings"]:
526
- console.print(f" ⚠ {warning}")
527
-
528
- return True
529
- else:
530
- error = reload_response.get("error", "Unknown error")
531
- console.print(f"[red]Reload failed: {error} ({reload_response=})[/red]")
532
- return False
929
+ return await _send_reload_command(websocket, new_cfg)
533
930
 
534
931
  except Exception as e:
535
932
  console.print(f"[red]Error: {e}[/red]")
@@ -540,39 +937,20 @@ def reload_config(
540
937
  sys.exit(1)
541
938
 
542
939
 
543
- @main.command()
544
- @click.option("--data-dir", default="./caption_data", help="Storage directory")
545
- @click.option("--checkpoint-dir", default="./checkpoints", help="Checkpoint directory")
546
- @click.option("--fix", is_flag=True, help="Fix issues by resetting abandoned chunks")
547
- @click.option("--verbose", is_flag=True, help="Show detailed information")
548
- def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
549
- """Scan for sparse or abandoned chunks and optionally fix them."""
550
- from .utils.chunk_tracker import ChunkTracker
551
- from .storage import StorageManager
552
- import pyarrow.parquet as pq
553
-
554
- console.print("[bold cyan]Scanning for sparse/abandoned chunks...[/bold cyan]\n")
555
-
556
- checkpoint_path = Path(checkpoint_dir) / "chunks.json"
557
- if not checkpoint_path.exists():
558
- console.print("[red]No chunk checkpoint found![/red]")
559
- return
560
-
561
- tracker = ChunkTracker(checkpoint_path)
562
- storage = StorageManager(Path(data_dir))
563
-
564
- # Get and display stats
565
- stats = tracker.get_stats()
940
+ def _display_chunk_stats(stats):
941
+ """Display chunk statistics."""
566
942
  console.print(f"[green]Total chunks:[/green] {stats['total']}")
567
943
  console.print(f"[green]Completed:[/green] {stats['completed']}")
568
944
  console.print(f"[yellow]Pending:[/yellow] {stats['pending']}")
569
945
  console.print(f"[yellow]Assigned:[/yellow] {stats['assigned']}")
570
946
  console.print(f"[red]Failed:[/red] {stats['failed']}\n")
571
947
 
572
- # Find abandoned chunks
948
+
949
+ def _find_abandoned_chunks(tracker):
950
+ """Find chunks that have been assigned for too long."""
573
951
  abandoned_chunks = []
574
952
  stale_threshold = 3600 # 1 hour
575
- current_time = datetime.utcnow()
953
+ current_time = datetime.now(_datetime.UTC)
576
954
 
577
955
  for chunk_id, chunk_state in tracker.chunks.items():
578
956
  if chunk_state.status == "assigned" and chunk_state.assigned_at:
@@ -580,24 +958,31 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
580
958
  if age > stale_threshold:
581
959
  abandoned_chunks.append((chunk_id, chunk_state, age))
582
960
 
583
- if abandoned_chunks:
584
- console.print(f"[red]Found {len(abandoned_chunks)} abandoned chunks:[/red]")
585
- for chunk_id, chunk_state, age in abandoned_chunks[:10]:
586
- age_str = f"{age/3600:.1f} hours" if age > 3600 else f"{age/60:.1f} minutes"
587
- console.print(f" • {chunk_id} (assigned to {chunk_state.assigned_to} {age_str} ago)")
961
+ return abandoned_chunks
588
962
 
589
- if len(abandoned_chunks) > 10:
590
- console.print(f" ... and {len(abandoned_chunks) - 10} more")
591
963
 
592
- if fix:
593
- console.print("\n[yellow]Resetting abandoned chunks to pending...[/yellow]")
594
- for chunk_id, _, _ in abandoned_chunks:
595
- tracker.mark_failed(chunk_id)
596
- console.print(f"[green]✓ Reset {len(abandoned_chunks)} chunks[/green]")
964
+ def _display_abandoned_chunks(abandoned_chunks, fix, tracker):
965
+ """Display abandoned chunks and optionally fix them."""
966
+ if not abandoned_chunks:
967
+ return
597
968
 
598
- # Check for sparse shards
599
- console.print("\n[bold cyan]Checking for sparse shards...[/bold cyan]")
969
+ console.print(f"[red]Found {len(abandoned_chunks)} abandoned chunks:[/red]")
970
+ for chunk_id, chunk_state, age in abandoned_chunks[:10]:
971
+ age_str = f"{age / 3600:.1f} hours" if age > 3600 else f"{age / 60:.1f} minutes"
972
+ console.print(f" • {chunk_id} (assigned to {chunk_state.assigned_to} {age_str} ago)")
973
+
974
+ if len(abandoned_chunks) > 10:
975
+ console.print(f" ... and {len(abandoned_chunks) - 10} more")
976
+
977
+ if fix:
978
+ console.print("\n[yellow]Resetting abandoned chunks to pending...[/yellow]")
979
+ for chunk_id, _, _ in abandoned_chunks:
980
+ tracker.mark_failed(chunk_id)
981
+ console.print(f"[green]✓ Reset {len(abandoned_chunks)} chunks[/green]")
600
982
 
983
+
984
+ def _find_sparse_shards(tracker):
985
+ """Find shards with gaps or issues."""
601
986
  shards_summary = tracker.get_shards_summary()
602
987
  sparse_shards = []
603
988
 
@@ -616,60 +1001,108 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
616
1001
  if has_gaps or shard_info["failed_chunks"] > 0:
617
1002
  sparse_shards.append((shard_name, shard_info, has_gaps))
618
1003
 
619
- if sparse_shards:
620
- console.print(f"\n[yellow]Found {len(sparse_shards)} sparse/incomplete shards:[/yellow]")
621
- for shard_name, shard_info, has_gaps in sparse_shards[:5]:
622
- status = []
623
- if shard_info["pending_chunks"] > 0:
624
- status.append(f"{shard_info['pending_chunks']} pending")
625
- if shard_info["assigned_chunks"] > 0:
626
- status.append(f"{shard_info['assigned_chunks']} assigned")
627
- if shard_info["failed_chunks"] > 0:
628
- status.append(f"{shard_info['failed_chunks']} failed")
629
- if has_gaps:
630
- status.append("has gaps")
631
-
632
- console.print(f" {shard_name}: {', '.join(status)}")
1004
+ return sparse_shards
1005
+
1006
+
1007
+ def _display_sparse_shards(sparse_shards):
1008
+ """Display sparse/incomplete shards."""
1009
+ if not sparse_shards:
1010
+ return
1011
+
1012
+ console.print(f"\n[yellow]Found {len(sparse_shards)} sparse/incomplete shards:[/yellow]")
1013
+ for shard_name, shard_info, has_gaps in sparse_shards[:5]:
1014
+ status = []
1015
+ if shard_info["pending_chunks"] > 0:
1016
+ status.append(f"{shard_info['pending_chunks']} pending")
1017
+ if shard_info["assigned_chunks"] > 0:
1018
+ status.append(f"{shard_info['assigned_chunks']} assigned")
1019
+ if shard_info["failed_chunks"] > 0:
1020
+ status.append(f"{shard_info['failed_chunks']} failed")
1021
+ if has_gaps:
1022
+ status.append("has gaps")
1023
+
1024
+ console.print(f" • {shard_name}: {', '.join(status)}")
1025
+ console.print(
1026
+ f" Progress: {shard_info['completed_chunks']}/{shard_info['total_chunks']} chunks"
1027
+ )
1028
+
1029
+ if len(sparse_shards) > 5:
1030
+ console.print(f" ... and {len(sparse_shards) - 5} more")
1031
+
1032
+
1033
+ def _cross_check_storage(storage, tracker, fix):
1034
+ """Cross-check chunk tracker against storage."""
1035
+ import pyarrow.parquet as pq
1036
+
1037
+ console.print("\n[bold cyan]Cross-checking with stored captions...[/bold cyan]")
1038
+
1039
+ try:
1040
+ table = pq.read_table(storage.captions_path, columns=["chunk_id"])
1041
+ stored_chunk_ids = set(c for c in table["chunk_id"].to_pylist() if c)
1042
+
1043
+ tracker_completed = set(c for c, s in tracker.chunks.items() if s.status == "completed")
1044
+
1045
+ missing_in_storage = tracker_completed - stored_chunk_ids
1046
+ missing_in_tracker = stored_chunk_ids - set(tracker.chunks.keys())
1047
+
1048
+ if missing_in_storage:
633
1049
  console.print(
634
- f" Progress: {shard_info['completed_chunks']}/{shard_info['total_chunks']} chunks"
1050
+ f"\n[red]Chunks marked complete but missing from storage:[/red] {len(missing_in_storage)}"
635
1051
  )
1052
+ for chunk_id in list(missing_in_storage)[:5]:
1053
+ console.print(f" • {chunk_id}")
636
1054
 
637
- if len(sparse_shards) > 5:
638
- console.print(f" ... and {len(sparse_shards) - 5} more")
1055
+ if fix:
1056
+ console.print("[yellow]Resetting these chunks to pending...[/yellow]")
1057
+ for chunk_id in missing_in_storage:
1058
+ tracker.mark_failed(chunk_id)
1059
+ console.print(f"[green]✓ Reset {len(missing_in_storage)} chunks[/green]")
639
1060
 
640
- # Cross-check with storage if verbose
641
- if storage.captions_path.exists() and verbose:
642
- console.print("\n[bold cyan]Cross-checking with stored captions...[/bold cyan]")
1061
+ if missing_in_tracker:
1062
+ console.print(
1063
+ f"\n[yellow]Chunks in storage but not tracked:[/yellow] {len(missing_in_tracker)}"
1064
+ )
643
1065
 
644
- try:
645
- table = pq.read_table(storage.captions_path, columns=["chunk_id"])
646
- stored_chunk_ids = set(c for c in table["chunk_id"].to_pylist() if c)
1066
+ except Exception as e:
1067
+ console.print(f"[red]Error reading storage: {e}[/red]")
647
1068
 
648
- tracker_completed = set(c for c, s in tracker.chunks.items() if s.status == "completed")
649
1069
 
650
- missing_in_storage = tracker_completed - stored_chunk_ids
651
- missing_in_tracker = stored_chunk_ids - set(tracker.chunks.keys())
1070
+ @main.command()
1071
+ @click.option("--data-dir", default="./caption_data", help="Storage directory")
1072
+ @click.option("--checkpoint-dir", default="./checkpoints", help="Checkpoint directory")
1073
+ @click.option("--fix", is_flag=True, help="Fix issues by resetting abandoned chunks")
1074
+ @click.option("--verbose", is_flag=True, help="Show detailed information")
1075
+ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
1076
+ """Scan for sparse or abandoned chunks and optionally fix them."""
1077
+ from .storage import StorageManager
1078
+ from .utils.chunk_tracker import ChunkTracker
652
1079
 
653
- if missing_in_storage:
654
- console.print(
655
- f"\n[red]Chunks marked complete but missing from storage:[/red] {len(missing_in_storage)}"
656
- )
657
- for chunk_id in list(missing_in_storage)[:5]:
658
- console.print(f" • {chunk_id}")
1080
+ console.print("[bold cyan]Scanning for sparse/abandoned chunks...[/bold cyan]\n")
659
1081
 
660
- if fix:
661
- console.print("[yellow]Resetting these chunks to pending...[/yellow]")
662
- for chunk_id in missing_in_storage:
663
- tracker.mark_failed(chunk_id)
664
- console.print(f"[green]✓ Reset {len(missing_in_storage)} chunks[/green]")
1082
+ checkpoint_path = Path(checkpoint_dir) / "chunks.json"
1083
+ if not checkpoint_path.exists():
1084
+ console.print("[red]No chunk checkpoint found![/red]")
1085
+ return
665
1086
 
666
- if missing_in_tracker:
667
- console.print(
668
- f"\n[yellow]Chunks in storage but not tracked:[/yellow] {len(missing_in_tracker)}"
669
- )
1087
+ tracker = ChunkTracker(checkpoint_path)
1088
+ storage = StorageManager(Path(data_dir))
670
1089
 
671
- except Exception as e:
672
- console.print(f"[red]Error reading storage: {e}[/red]")
1090
+ # Get and display stats
1091
+ stats = tracker.get_stats()
1092
+ _display_chunk_stats(stats)
1093
+
1094
+ # Find and handle abandoned chunks
1095
+ abandoned_chunks = _find_abandoned_chunks(tracker)
1096
+ _display_abandoned_chunks(abandoned_chunks, fix, tracker)
1097
+
1098
+ # Check for sparse shards
1099
+ console.print("\n[bold cyan]Checking for sparse shards...[/bold cyan]")
1100
+ sparse_shards = _find_sparse_shards(tracker)
1101
+ _display_sparse_shards(sparse_shards)
1102
+
1103
+ # Cross-check with storage if verbose
1104
+ if storage.captions_path.exists() and verbose:
1105
+ _cross_check_storage(storage, tracker, fix)
673
1106
 
674
1107
  # Summary
675
1108
  console.print("\n[bold cyan]Summary:[/bold cyan]")
@@ -693,12 +1126,163 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
693
1126
  tracker.save_checkpoint()
694
1127
 
695
1128
 
1129
+ def _display_export_stats(stats):
1130
+ """Display storage statistics."""
1131
+ console.print("\n[bold cyan]Storage Statistics:[/bold cyan]")
1132
+ console.print(f"[green]Total rows:[/green] {stats['total_rows']:,}")
1133
+ console.print(f"[green]Total outputs:[/green] {stats['total_outputs']:,}")
1134
+ console.print(f"[green]Shards:[/green] {stats['shard_count']} ({', '.join(stats['shards'])})")
1135
+ console.print(f"[green]Output fields:[/green] {', '.join(stats['output_fields'])}")
1136
+
1137
+ if stats.get("field_stats"):
1138
+ console.print("\n[cyan]Field breakdown:[/cyan]")
1139
+ for field, count in stats["field_stats"].items():
1140
+ console.print(f" • {field}: {count['total_items']:,} items")
1141
+
1142
+
1143
+ def _prepare_export_params(shard, shards, columns):
1144
+ """Prepare shard filter and column list."""
1145
+ shard_filter = None
1146
+ if shard:
1147
+ shard_filter = [shard]
1148
+ elif shards:
1149
+ shard_filter = [s.strip() for s in shards.split(",")]
1150
+
1151
+ column_list = None
1152
+ if columns:
1153
+ column_list = [col.strip() for col in columns.split(",")]
1154
+ console.print(f"\n[cyan]Exporting columns:[/cyan] {', '.join(column_list)}")
1155
+
1156
+ return shard_filter, column_list
1157
+
1158
+
1159
+ async def _export_all_formats(
1160
+ exporter, output, shard_filter, column_list, limit, filename_column, export_column
1161
+ ):
1162
+ """Export to all formats."""
1163
+ base_name = output or "caption_export"
1164
+ base_path = Path(base_name)
1165
+ results = {}
1166
+
1167
+ for export_format in ["jsonl", "csv", "parquet", "json", "txt"]:
1168
+ console.print(f"\n[cyan]Exporting to {export_format.upper()}...[/cyan]")
1169
+ try:
1170
+ format_results = await exporter.export_all_shards(
1171
+ export_format,
1172
+ base_path,
1173
+ columns=column_list,
1174
+ limit_per_shard=limit,
1175
+ shard_filter=shard_filter,
1176
+ filename_column=filename_column,
1177
+ export_column=export_column,
1178
+ )
1179
+ results[export_format] = sum(format_results.values())
1180
+ except Exception as e:
1181
+ console.print(f"[yellow]Skipping {export_format}: {e}[/yellow]")
1182
+ results[export_format] = 0
1183
+
1184
+ console.print("\n[green]✓ Export complete![/green]")
1185
+ for fmt, count in results.items():
1186
+ if count > 0:
1187
+ console.print(f" • {fmt.upper()}: {count:,} items")
1188
+
1189
+
1190
+ async def _export_to_lance(exporter, output, column_list, shard_filter):
1191
+ """Export to Lance dataset."""
1192
+ output_path = output or "exported_captions.lance"
1193
+ console.print(f"\n[cyan]Exporting to Lance dataset:[/cyan] {output_path}")
1194
+ total_rows = await exporter.export_to_lance(
1195
+ output_path, columns=column_list, shard_filter=shard_filter
1196
+ )
1197
+ console.print(f"[green]✓ Exported {total_rows:,} rows to Lance dataset[/green]")
1198
+
1199
+
1200
+ async def _export_to_huggingface(exporter, hf_dataset, license, private, nsfw, tags, shard_filter):
1201
+ """Export to Hugging Face Hub."""
1202
+ if not hf_dataset:
1203
+ console.print("[red]Error: --hf-dataset required for huggingface_hub format[/red]")
1204
+ console.print("[dim]Example: --hf-dataset username/my-caption-dataset[/dim]")
1205
+ sys.exit(1)
1206
+
1207
+ tag_list = None
1208
+ if tags:
1209
+ tag_list = [tag.strip() for tag in tags.split(",")]
1210
+
1211
+ console.print(f"\n[cyan]Uploading to Hugging Face Hub:[/cyan] {hf_dataset}")
1212
+ if private:
1213
+ console.print("[dim]Privacy: Private dataset[/dim]")
1214
+ if nsfw:
1215
+ console.print("[dim]Content: Not for all audiences[/dim]")
1216
+ if tag_list:
1217
+ console.print(f"[dim]Tags: {', '.join(tag_list)}[/dim]")
1218
+ if shard_filter:
1219
+ console.print(f"[dim]Shards: {', '.join(shard_filter)}[/dim]")
1220
+
1221
+ url = await exporter.export_to_huggingface_hub(
1222
+ dataset_name=hf_dataset,
1223
+ license=license,
1224
+ private=private,
1225
+ nsfw=nsfw,
1226
+ tags=tag_list,
1227
+ shard_filter=shard_filter,
1228
+ )
1229
+ console.print(f"[green]✓ Dataset uploaded to: {url}[/green]")
1230
+
1231
+
1232
+ async def _export_single_format(
1233
+ exporter,
1234
+ format,
1235
+ output,
1236
+ shard_filter,
1237
+ column_list,
1238
+ limit,
1239
+ filename_column,
1240
+ export_column,
1241
+ verbose,
1242
+ ):
1243
+ """Export to a single format."""
1244
+ output_path = output or "export"
1245
+
1246
+ if shard_filter and len(shard_filter) == 1:
1247
+ console.print(f"\n[cyan]Exporting shard {shard_filter[0]} to {format.upper()}...[/cyan]")
1248
+ count = await exporter.export_shard(
1249
+ shard_filter[0],
1250
+ format,
1251
+ output_path,
1252
+ columns=column_list,
1253
+ limit=limit,
1254
+ filename_column=filename_column,
1255
+ export_column=export_column,
1256
+ )
1257
+ console.print(f"[green]✓ Exported {count:,} items[/green]")
1258
+ else:
1259
+ console.print(f"\n[cyan]Exporting to {format.upper()}...[/cyan]")
1260
+ results = await exporter.export_all_shards(
1261
+ format,
1262
+ output_path,
1263
+ columns=column_list,
1264
+ limit_per_shard=limit,
1265
+ shard_filter=shard_filter,
1266
+ filename_column=filename_column,
1267
+ export_column=export_column,
1268
+ )
1269
+
1270
+ total = sum(results.values())
1271
+ console.print(f"[green]✓ Exported {total:,} items total[/green]")
1272
+
1273
+ if verbose and len(results) > 1:
1274
+ console.print("\n[dim]Per-shard breakdown:[/dim]")
1275
+ for shard_name, count in sorted(results.items()):
1276
+ console.print(f" • {shard_name}: {count:,} items")
1277
+
1278
+
696
1279
  @main.command()
697
1280
  @click.option("--data-dir", default="./caption_data", help="Storage directory")
698
1281
  @click.option(
699
1282
  "--format",
700
1283
  type=click.Choice(
701
- ["jsonl", "json", "csv", "txt", "huggingface_hub", "all"], case_sensitive=False
1284
+ ["jsonl", "json", "csv", "txt", "parquet", "lance", "huggingface_hub", "all"],
1285
+ case_sensitive=False,
702
1286
  ),
703
1287
  default="jsonl",
704
1288
  help="Export format (default: jsonl)",
@@ -708,17 +1292,117 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
708
1292
  @click.option("--columns", help="Comma-separated list of columns to export (default: all)")
709
1293
  @click.option("--export-column", default="captions", help="Column to export for txt format")
710
1294
  @click.option("--filename-column", default="filename", help="Column containing filenames")
1295
+ @click.option("--shard", help="Specific shard to export (e.g., data-0001)")
1296
+ @click.option("--shards", help="Comma-separated list of shards to export")
711
1297
  @click.option("--include-empty", is_flag=True, help="Include rows with empty export column")
712
1298
  @click.option("--stats-only", is_flag=True, help="Show statistics without exporting")
713
- @click.option(
714
- "--optimize", is_flag=True, help="Optimize storage before export (remove empty columns)"
715
- )
1299
+ @click.option("--optimize", is_flag=True, help="Optimize storage before export")
716
1300
  @click.option("--verbose", is_flag=True, help="Show detailed export progress")
717
1301
  @click.option("--hf-dataset", help="Dataset name on HF Hub (e.g., username/dataset-name)")
718
- @click.option("--license", help="License for the dataset (required for new HF datasets)")
1302
+ @click.option("--license", default="apache-2.0", help="License for the dataset")
719
1303
  @click.option("--private", is_flag=True, help="Make HF dataset private")
720
1304
  @click.option("--nsfw", is_flag=True, help="Add not-for-all-audiences tag")
721
1305
  @click.option("--tags", help="Comma-separated tags for HF dataset")
1306
+ def _validate_export_setup(data_dir):
1307
+ """Validate export setup and create storage manager."""
1308
+ from .storage import StorageManager
1309
+
1310
+ storage_path = Path(data_dir)
1311
+ if not storage_path.exists():
1312
+ console.print(f"[red]Storage directory not found: {data_dir}[/red]")
1313
+ sys.exit(1)
1314
+
1315
+ return StorageManager(storage_path)
1316
+
1317
+
1318
+ async def _run_export_process(
1319
+ storage,
1320
+ format,
1321
+ output,
1322
+ shard,
1323
+ shards,
1324
+ columns,
1325
+ limit,
1326
+ filename_column,
1327
+ export_column,
1328
+ verbose,
1329
+ hf_dataset,
1330
+ license,
1331
+ private,
1332
+ nsfw,
1333
+ tags,
1334
+ stats_only,
1335
+ optimize,
1336
+ ):
1337
+ """Execute the main export process."""
1338
+ from .storage.exporter import LanceStorageExporter
1339
+
1340
+ await storage.initialize()
1341
+
1342
+ stats = await storage.get_caption_stats()
1343
+ _display_export_stats(stats)
1344
+
1345
+ if stats_only:
1346
+ return
1347
+
1348
+ if optimize:
1349
+ console.print("\n[yellow]Optimizing storage...[/yellow]")
1350
+ await storage.optimize_storage()
1351
+
1352
+ shard_filter, column_list = _prepare_export_params(shard, shards, columns)
1353
+ exporter = LanceStorageExporter(storage)
1354
+
1355
+ if format == "all":
1356
+ await _export_all_formats(
1357
+ exporter, output, shard_filter, column_list, limit, filename_column, export_column
1358
+ )
1359
+ elif format == "lance":
1360
+ await _export_to_lance(exporter, output, column_list, shard_filter)
1361
+ elif format == "huggingface_hub":
1362
+ await _export_to_huggingface(
1363
+ exporter, hf_dataset, license, private, nsfw, tags, shard_filter
1364
+ )
1365
+ else:
1366
+ await _export_single_format(
1367
+ exporter,
1368
+ format,
1369
+ output,
1370
+ shard_filter,
1371
+ column_list,
1372
+ limit,
1373
+ filename_column,
1374
+ export_column,
1375
+ verbose,
1376
+ )
1377
+
1378
+
1379
+ @main.command()
1380
+ @click.option("--data-dir", default="./caption_data", help="Storage directory")
1381
+ @click.option(
1382
+ "--format",
1383
+ type=click.Choice(
1384
+ ["jsonl", "json", "csv", "txt", "parquet", "lance", "huggingface_hub", "all"],
1385
+ case_sensitive=False,
1386
+ ),
1387
+ default="jsonl",
1388
+ help="Export format (default: jsonl)",
1389
+ )
1390
+ @click.option("--output", help="Output filename or directory")
1391
+ @click.option("--limit", type=int, help="Maximum number of items to export")
1392
+ @click.option("--columns", help="Comma-separated list of columns to include")
1393
+ @click.option("--export-column", default="captions", help="Column to export (default: captions)")
1394
+ @click.option("--filename-column", default="filename", help="Filename column (default: filename)")
1395
+ @click.option("--shard", help="Export only specific shard (e.g., 'data-001')")
1396
+ @click.option("--shards", help="Comma-separated list of shards to export")
1397
+ @click.option("--include-empty", is_flag=True, help="Include items with empty/null export column")
1398
+ @click.option("--stats-only", is_flag=True, help="Show statistics only, don't export")
1399
+ @click.option("--optimize", is_flag=True, help="Optimize storage before export")
1400
+ @click.option("--verbose", is_flag=True, help="Verbose output")
1401
+ @click.option("--hf-dataset", help="HuggingFace Hub dataset name (for huggingface_hub format)")
1402
+ @click.option("--license", default="MIT", help="Dataset license (default: MIT)")
1403
+ @click.option("--private", is_flag=True, help="Make HuggingFace dataset private")
1404
+ @click.option("--nsfw", is_flag=True, help="Mark dataset as NSFW")
1405
+ @click.option("--tags", help="Comma-separated tags for HuggingFace dataset")
722
1406
  def export(
723
1407
  data_dir: str,
724
1408
  format: str,
@@ -727,219 +1411,56 @@ def export(
727
1411
  columns: Optional[str],
728
1412
  export_column: str,
729
1413
  filename_column: str,
1414
+ shard: Optional[str],
1415
+ shards: Optional[str],
730
1416
  include_empty: bool,
731
1417
  stats_only: bool,
732
1418
  optimize: bool,
733
1419
  verbose: bool,
734
1420
  hf_dataset: Optional[str],
735
- license: Optional[str],
1421
+ license: str,
736
1422
  private: bool,
737
1423
  nsfw: bool,
738
1424
  tags: Optional[str],
739
1425
  ):
740
- """Export caption data to various formats."""
741
- from .storage import StorageManager
742
- from .storage.exporter import StorageExporter, ExportError
1426
+ """Export caption data to various formats with per-shard support."""
1427
+ from .storage.exporter import ExportError
743
1428
 
744
- # Initialize storage manager
745
- storage_path = Path(data_dir)
746
- if not storage_path.exists():
747
- console.print(f"[red]Storage directory not found: {data_dir}[/red]")
748
- sys.exit(1)
749
-
750
- storage = StorageManager(storage_path)
751
-
752
- async def run_export():
753
- await storage.initialize()
754
-
755
- # Show statistics first
756
- stats = await storage.get_caption_stats()
757
- console.print("\n[bold cyan]Storage Statistics:[/bold cyan]")
758
- console.print(f"[green]Total rows:[/green] {stats['total_rows']:,}")
759
- console.print(f"[green]Total outputs:[/green] {stats['total_outputs']:,}")
760
- console.print(f"[green]Output fields:[/green] {', '.join(stats['output_fields'])}")
761
-
762
- if stats.get("field_stats"):
763
- console.print("\n[cyan]Field breakdown:[/cyan]")
764
- for field, field_stat in stats["field_stats"].items():
765
- console.print(
766
- f" • {field}: {field_stat['total_items']:,} items "
767
- f"in {field_stat['rows_with_data']:,} rows"
768
- )
769
-
770
- if stats_only:
771
- return
772
-
773
- # Optimize storage if requested
774
- if optimize:
775
- console.print("\n[yellow]Optimizing storage (removing empty columns)...[/yellow]")
776
- await storage.optimize_storage()
777
-
778
- # Prepare columns list
779
- column_list = None
780
- if columns:
781
- column_list = [col.strip() for col in columns.split(",")]
782
- console.print(f"\n[cyan]Exporting columns:[/cyan] {', '.join(column_list)}")
783
-
784
- # Get storage contents
785
- console.print("\n[yellow]Loading data...[/yellow]")
786
- try:
787
- contents = await storage.get_storage_contents(
788
- limit=limit, columns=column_list, include_metadata=True
789
- )
790
- except ValueError as e:
791
- console.print(f"[red]Error: {e}[/red]")
792
- sys.exit(1)
1429
+ storage = _validate_export_setup(data_dir)
793
1430
 
794
- if not contents.rows:
795
- console.print("[yellow]No data to export![/yellow]")
796
- return
797
-
798
- # Filter out empty rows if not including empty
799
- if not include_empty and format in ["txt", "json"]:
800
- original_count = len(contents.rows)
801
- contents.rows = [
802
- row
803
- for row in contents.rows
804
- if row.get(export_column)
805
- and (not isinstance(row[export_column], list) or len(row[export_column]) > 0)
806
- ]
807
- filtered_count = original_count - len(contents.rows)
808
- if filtered_count > 0:
809
- console.print(f"[dim]Filtered {filtered_count} empty rows[/dim]")
810
-
811
- # Create exporter
812
- exporter = StorageExporter(contents)
813
-
814
- # Determine output paths
815
- if format == "all":
816
- # Export to all formats
817
- base_name = output or "caption_export"
818
- base_path = Path(base_name)
819
-
820
- formats_exported = []
821
-
822
- # JSONL
823
- jsonl_path = base_path.with_suffix(".jsonl")
824
- console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {jsonl_path}")
825
- rows = exporter.to_jsonl(jsonl_path)
826
- formats_exported.append(f"JSONL: {rows:,} rows")
827
-
828
- # CSV
829
- csv_path = base_path.with_suffix(".csv")
830
- console.print(f"[cyan]Exporting to CSV:[/cyan] {csv_path}")
831
- try:
832
- rows = exporter.to_csv(csv_path)
833
- formats_exported.append(f"CSV: {rows:,} rows")
834
- except ExportError as e:
835
- console.print(f"[yellow]Skipping CSV: {e}[/yellow]")
836
-
837
- # JSON files
838
- json_dir = base_path.parent / f"{base_path.stem}_json"
839
- console.print(f"[cyan]Exporting to JSON files:[/cyan] {json_dir}/")
840
- try:
841
- files = exporter.to_json(json_dir, filename_column)
842
- formats_exported.append(f"JSON: {files:,} files")
843
- except ExportError as e:
844
- console.print(f"[yellow]Skipping JSON files: {e}[/yellow]")
845
-
846
- # Text files
847
- txt_dir = base_path.parent / f"{base_path.stem}_txt"
848
- console.print(f"[cyan]Exporting to text files:[/cyan] {txt_dir}/")
849
- try:
850
- files = exporter.to_txt(txt_dir, filename_column, export_column)
851
- formats_exported.append(f"Text: {files:,} files")
852
- except ExportError as e:
853
- console.print(f"[yellow]Skipping text files: {e}[/yellow]")
854
-
855
- console.print(f"\n[green]✓ Export complete![/green]")
856
- for fmt in formats_exported:
857
- console.print(f" • {fmt}")
858
-
859
- else:
860
- # Single format export
861
- try:
862
- if format == "jsonl":
863
- output_path = output or "captions.jsonl"
864
- console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {output_path}")
865
- rows = exporter.to_jsonl(output_path)
866
- console.print(f"[green]✓ Exported {rows:,} rows[/green]")
867
-
868
- elif format == "csv":
869
- output_path = output or "captions.csv"
870
- console.print(f"\n[cyan]Exporting to CSV:[/cyan] {output_path}")
871
- rows = exporter.to_csv(output_path)
872
- console.print(f"[green]✓ Exported {rows:,} rows[/green]")
873
-
874
- elif format == "json":
875
- output_dir = output or "./json_output"
876
- console.print(f"\n[cyan]Exporting to JSON files:[/cyan] {output_dir}/")
877
- files = exporter.to_json(output_dir, filename_column)
878
- console.print(f"[green]✓ Created {files:,} JSON files[/green]")
879
-
880
- elif format == "txt":
881
- output_dir = output or "./txt_output"
882
- console.print(f"\n[cyan]Exporting to text files:[/cyan] {output_dir}/")
883
- console.print(f"[dim]Export column: {export_column}[/dim]")
884
- files = exporter.to_txt(output_dir, filename_column, export_column)
885
- console.print(f"[green]✓ Created {files:,} text files[/green]")
886
-
887
- elif format == "huggingface_hub":
888
- # Validate required parameters
889
- if not hf_dataset:
890
- console.print(
891
- "[red]Error: --hf-dataset required for huggingface_hub format[/red]"
892
- )
893
- console.print(
894
- "[dim]Example: --hf-dataset username/my-caption-dataset[/dim]"
895
- )
896
- sys.exit(1)
897
-
898
- # Parse tags
899
- tag_list = None
900
- if tags:
901
- tag_list = [tag.strip() for tag in tags.split(",")]
902
-
903
- console.print(f"\n[cyan]Uploading to Hugging Face Hub:[/cyan] {hf_dataset}")
904
- if private:
905
- console.print("[dim]Privacy: Private dataset[/dim]")
906
- if nsfw:
907
- console.print("[dim]Content: Not for all audiences[/dim]")
908
- if tag_list:
909
- console.print(f"[dim]Tags: {', '.join(tag_list)}[/dim]")
910
-
911
- url = exporter.to_huggingface_hub(
912
- dataset_name=hf_dataset,
913
- license=license,
914
- private=private,
915
- nsfw=nsfw,
916
- tags=tag_list,
917
- )
918
- console.print(f"[green]✓ Dataset uploaded to: {url}[/green]")
919
-
920
- except ExportError as e:
921
- console.print(f"[red]Export error: {e}[/red]")
922
- sys.exit(1)
923
-
924
- # Show export metadata
925
- if verbose and contents.metadata:
926
- console.print("\n[dim]Export metadata:[/dim]")
927
- console.print(f" Timestamp: {contents.metadata.get('export_timestamp')}")
928
- console.print(f" Total available: {contents.metadata.get('total_available_rows'):,}")
929
- console.print(f" Rows exported: {contents.metadata.get('rows_exported'):,}")
930
-
931
- # Run the async export
932
1431
  try:
933
- asyncio.run(run_export())
1432
+ asyncio.run(
1433
+ _run_export_process(
1434
+ storage,
1435
+ format,
1436
+ output,
1437
+ shard,
1438
+ shards,
1439
+ columns,
1440
+ limit,
1441
+ filename_column,
1442
+ export_column,
1443
+ verbose,
1444
+ hf_dataset,
1445
+ license,
1446
+ private,
1447
+ nsfw,
1448
+ tags,
1449
+ stats_only,
1450
+ optimize,
1451
+ )
1452
+ )
1453
+ except ExportError as e:
1454
+ console.print(f"[red]Export error: {e}[/red]")
1455
+ sys.exit(1)
934
1456
  except KeyboardInterrupt:
935
1457
  console.print("\n[yellow]Export cancelled[/yellow]")
936
1458
  sys.exit(1)
937
1459
  except Exception as e:
938
1460
  console.print(f"[red]Unexpected error: {e}[/red]")
939
- if verbose:
940
- import traceback
1461
+ import traceback
941
1462
 
942
- traceback.print_exc()
1463
+ traceback.print_exc()
943
1464
  sys.exit(1)
944
1465
 
945
1466
 
@@ -961,7 +1482,7 @@ def generate_cert(
961
1482
  cert_path, key_path = cert_manager.generate_self_signed(Path(output_dir), cert_domain)
962
1483
  console.print(f"[green]✓[/green] Certificate: {cert_path}")
963
1484
  console.print(f"[green]✓[/green] Key: {key_path}")
964
- console.print(f"\n[cyan]Use these paths in your config or CLI:[/cyan]")
1485
+ console.print("\n[cyan]Use these paths in your config or CLI:[/cyan]")
965
1486
  console.print(f" --cert {cert_path}")
966
1487
  console.print(f" --key {key_path}")
967
1488
  elif domain and email:
@@ -978,7 +1499,7 @@ def generate_cert(
978
1499
  )
979
1500
  console.print(f"[green]✓[/green] Certificate: {cert_path}")
980
1501
  console.print(f"[green]✓[/green] Key: {key_path}")
981
- console.print(f"\n[cyan]Use these paths in your config or CLI:[/cyan]")
1502
+ console.print("\n[cyan]Use these paths in your config or CLI:[/cyan]")
982
1503
  console.print(f" --cert {cert_path}")
983
1504
  console.print(f" --key {key_path}")
984
1505
 
@@ -1022,13 +1543,13 @@ def inspect_cert(cert_path: str):
1022
1543
 
1023
1544
  from datetime import datetime
1024
1545
 
1025
- if info["not_after"] < datetime.utcnow():
1546
+ if info["not_after"] < datetime.now(_datetime.UTC):
1026
1547
  console.print("[red]✗ Certificate has expired![/red]")
1027
- elif (info["not_after"] - datetime.utcnow()).days < 30:
1028
- days_left = (info["not_after"] - datetime.utcnow()).days
1548
+ elif (info["not_after"] - datetime.now(_datetime.UTC)).days < 30:
1549
+ days_left = (info["not_after"] - datetime.now(_datetime.UTC)).days
1029
1550
  console.print(f"[yellow]⚠ Certificate expires in {days_left} days[/yellow]")
1030
1551
  else:
1031
- days_left = (info["not_after"] - datetime.utcnow()).days
1552
+ days_left = (info["not_after"] - datetime.now(_datetime.UTC)).days
1032
1553
  console.print(f"[green]✓ Certificate valid for {days_left} more days[/green]")
1033
1554
 
1034
1555
  except Exception as e: