caption-flow 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +934 -415
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +2 -3
  5. caption_flow/orchestrator.py +153 -104
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +439 -67
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +28 -22
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +30 -28
  18. caption_flow/utils/chunk_tracker.py +153 -56
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +5 -4
  25. caption_flow/workers/caption.py +265 -90
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
  28. caption_flow-0.4.0.dist-info/RECORD +33 -0
  29. caption_flow-0.3.4.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
caption_flow/cli.py CHANGED
@@ -1,21 +1,22 @@
1
1
  """Command-line interface for CaptionFlow with smart configuration handling."""
2
2
 
3
3
  import asyncio
4
+ import datetime as _datetime
4
5
  import json
5
6
  import logging
6
7
  import os
7
8
  import sys
9
+ from datetime import datetime
8
10
  from pathlib import Path
9
- from typing import Optional, Dict, Any, List
11
+ from typing import Any, Dict, List, Optional
10
12
 
11
13
  import click
12
14
  import yaml
13
15
  from rich.console import Console
14
16
  from rich.logging import RichHandler
15
- from datetime import datetime
16
17
 
17
- from .orchestrator import Orchestrator
18
18
  from .monitor import Monitor
19
+ from .orchestrator import Orchestrator
19
20
  from .utils.certificates import CertificateManager
20
21
 
21
22
  console = Console()
@@ -48,8 +49,7 @@ class ConfigManager:
48
49
  def find_config(
49
50
  cls, component: str, explicit_path: Optional[str] = None
50
51
  ) -> Optional[Dict[str, Any]]:
51
- """
52
- Find and load configuration for a component.
52
+ """Find and load configuration for a component.
53
53
 
54
54
  Search order:
55
55
  1. Explicit path if provided
@@ -120,22 +120,76 @@ class ConfigManager:
120
120
 
121
121
 
122
122
  def setup_logging(verbose: bool = False):
123
- """Configure logging with rich handler, including timestamp."""
123
+ """Configure logging with rich handler and file output to XDG state directory."""
124
124
  level = logging.DEBUG if verbose else logging.INFO
125
- logging.basicConfig(
126
- level=level,
127
- format="%(name)s: %(message)s",
128
- datefmt="[%Y-%m-%d %H:%M:%S]",
129
- handlers=[
125
+
126
+ # Determine log directory based on environment or XDG spec
127
+ log_dir_env = os.environ.get("CAPTIONFLOW_LOG_DIR")
128
+ if log_dir_env:
129
+ log_dir = Path(log_dir_env)
130
+ else:
131
+ # Use XDG_STATE_HOME for logs, with platform-specific fallbacks
132
+ xdg_state_home = os.environ.get("XDG_STATE_HOME")
133
+ if xdg_state_home:
134
+ base_dir = Path(xdg_state_home)
135
+ elif sys.platform == "darwin":
136
+ base_dir = Path.home() / "Library" / "Logs"
137
+ else:
138
+ # Default to ~/.local/state on Linux and other systems
139
+ base_dir = Path.home() / ".local" / "state"
140
+ log_dir = base_dir / "caption-flow"
141
+
142
+ try:
143
+ # Ensure log directory exists
144
+ log_dir.mkdir(parents=True, exist_ok=True)
145
+ log_file_path = log_dir / "caption_flow.log"
146
+
147
+ # Set up handlers
148
+ handlers: List[logging.Handler] = [
149
+ RichHandler(
150
+ console=console,
151
+ rich_tracebacks=True,
152
+ show_path=False,
153
+ show_time=True,
154
+ )
155
+ ]
156
+
157
+ # Add file handler
158
+ file_handler = logging.FileHandler(log_file_path, mode="a")
159
+ file_handler.setFormatter(
160
+ logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
161
+ )
162
+ handlers.append(file_handler)
163
+ log_msg = f"Logging to {log_file_path}"
164
+
165
+ except (OSError, PermissionError) as e:
166
+ # Fallback to only console logging if file logging fails
167
+ handlers = [
130
168
  RichHandler(
131
169
  console=console,
132
170
  rich_tracebacks=True,
133
171
  show_path=False,
134
- show_time=True, # Enables timestamp in RichHandler output
172
+ show_time=True,
135
173
  )
136
- ],
174
+ ]
175
+ log_file = log_dir / "caption_flow.log"
176
+ log_msg = f"[yellow]Warning: Could not write to log file {log_file}: {e}[/yellow]"
177
+
178
+ logging.basicConfig(
179
+ level=level,
180
+ format="%(message)s", # RichHandler overrides this format for console
181
+ datefmt="[%Y-%m-%d %H:%M:%S]",
182
+ handlers=handlers,
137
183
  )
138
184
 
185
+ # Suppress noisy libraries
186
+ logging.getLogger("websockets").setLevel(logging.WARNING)
187
+ logging.getLogger("pyarrow").setLevel(logging.WARNING)
188
+
189
+ # Use a dedicated logger to print the log file path to avoid format issues
190
+ if "log_msg" in locals():
191
+ logging.getLogger("setup").info(log_msg)
192
+
139
193
 
140
194
  def apply_cli_overrides(config: Dict[str, Any], **kwargs) -> Dict[str, Any]:
141
195
  """Apply CLI arguments as overrides to config, filtering out None values."""
@@ -189,9 +243,11 @@ def orchestrator(ctx, config: Optional[str], **kwargs):
189
243
  config_data["ssl"]["cert"] = kwargs["cert"]
190
244
  config_data["ssl"]["key"] = kwargs["key"]
191
245
  elif not config_data.get("ssl"):
192
- console.print(
193
- "[yellow]Warning: Running without SSL. Use --cert and --key for production.[/yellow]"
246
+ warning_msg = (
247
+ "[yellow]Warning: Running without SSL. "
248
+ "Use --cert and --key for production.[/yellow]"
194
249
  )
250
+ console.print(warning_msg)
195
251
 
196
252
  if kwargs.get("vllm") and "vllm" not in config_data:
197
253
  raise ValueError("Must provide vLLM config.")
@@ -259,33 +315,11 @@ def worker(ctx, config: Optional[str], **kwargs):
259
315
  asyncio.run(worker_instance.shutdown())
260
316
 
261
317
 
262
- @main.command()
263
- @click.option("--config", type=click.Path(exists=True), help="Configuration file")
264
- @click.option("--server", help="Orchestrator WebSocket URL")
265
- @click.option("--token", help="Authentication token")
266
- @click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
267
- @click.option("--debug", is_flag=True, help="Enable debug output")
268
- @click.pass_context
269
- def monitor(
270
- ctx,
271
- config: Optional[str],
272
- server: Optional[str],
273
- token: Optional[str],
274
- no_verify_ssl: bool,
275
- debug: bool,
276
- ):
277
- """Start the monitoring TUI."""
278
-
279
- # Enable debug logging if requested
280
- if debug:
281
- setup_logging(verbose=True)
282
- console.print("[yellow]Debug mode enabled[/yellow]")
283
-
284
- # Load configuration
318
+ def _load_monitor_config(config, server, token):
319
+ """Load monitor configuration from file or fallback to orchestrator config."""
285
320
  base_config = ConfigManager.find_config("monitor", config)
286
321
 
287
322
  if not base_config:
288
- # Try to find monitor config in orchestrator config as fallback
289
323
  orch_config = ConfigManager.find_config("orchestrator")
290
324
  if orch_config and "monitor" in orch_config:
291
325
  base_config = {"monitor": orch_config["monitor"]}
@@ -295,15 +329,11 @@ def monitor(
295
329
  if not server or not token:
296
330
  console.print("[yellow]No monitor config found, using CLI args[/yellow]")
297
331
 
298
- # Handle different config structures
299
- # Case 1: Config has top-level 'monitor' section
300
- if "monitor" in base_config:
301
- config_data = base_config["monitor"]
302
- # Case 2: Config IS the monitor config (no wrapper)
303
- else:
304
- config_data = base_config
332
+ return base_config.get("monitor", base_config)
333
+
305
334
 
306
- # Apply CLI overrides (CLI always wins)
335
+ def _apply_monitor_overrides(config_data, server, token, no_verify_ssl):
336
+ """Apply CLI overrides to monitor configuration."""
307
337
  if server:
308
338
  config_data["server"] = server
309
339
  if token:
@@ -311,17 +341,20 @@ def monitor(
311
341
  if no_verify_ssl:
312
342
  config_data["verify_ssl"] = False
313
343
 
314
- # Debug output
315
- if debug:
316
- console.print("\n[cyan]Final monitor configuration:[/cyan]")
317
- console.print(f" Server: {config_data.get('server', 'NOT SET')}")
318
- console.print(
319
- f" Token: {'***' + config_data.get('token', '')[-4:] if config_data.get('token') else 'NOT SET'}"
320
- )
321
- console.print(f" Verify SSL: {config_data.get('verify_ssl', True)}")
322
- console.print()
323
344
 
324
- # Validate required fields
345
+ def _debug_monitor_config(config_data):
346
+ """Print debug information about monitor configuration."""
347
+ console.print("\n[cyan]Final monitor configuration:[/cyan]")
348
+ console.print(f" Server: {config_data.get('server', 'NOT SET')}")
349
+ console.print(
350
+ f" Token: {'***' + config_data.get('token', '')[-4:] if config_data.get('token') else 'NOT SET'}"
351
+ )
352
+ console.print(f" Verify SSL: {config_data.get('verify_ssl', True)}")
353
+ console.print()
354
+
355
+
356
+ def _validate_monitor_config(config_data):
357
+ """Validate required monitor configuration fields."""
325
358
  if not config_data.get("server"):
326
359
  console.print("[red]Error: --server required (or set 'server' in monitor.yaml)[/red]")
327
360
  console.print("\n[dim]Example monitor.yaml:[/dim]")
@@ -336,12 +369,43 @@ def monitor(
336
369
  console.print("token: your-token-here")
337
370
  sys.exit(1)
338
371
 
339
- # Set defaults for optional settings
372
+
373
+ def _set_monitor_defaults(config_data):
374
+ """Set default values for optional monitor settings."""
340
375
  config_data.setdefault("refresh_interval", 1.0)
341
376
  config_data.setdefault("show_inactive_workers", False)
342
377
  config_data.setdefault("max_log_lines", 100)
343
378
 
344
- # Create and start monitor
379
+
380
+ @main.command()
381
+ @click.option("--config", type=click.Path(exists=True), help="Configuration file")
382
+ @click.option("--server", help="Orchestrator WebSocket URL")
383
+ @click.option("--token", help="Authentication token")
384
+ @click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
385
+ @click.option("--debug", is_flag=True, help="Enable debug output")
386
+ @click.pass_context
387
+ def monitor(
388
+ ctx,
389
+ config: Optional[str],
390
+ server: Optional[str],
391
+ token: Optional[str],
392
+ no_verify_ssl: bool,
393
+ debug: bool,
394
+ ):
395
+ """Start the monitoring TUI."""
396
+ if debug:
397
+ setup_logging(verbose=True)
398
+ console.print("[yellow]Debug mode enabled[/yellow]")
399
+
400
+ config_data = _load_monitor_config(config, server, token)
401
+ _apply_monitor_overrides(config_data, server, token, no_verify_ssl)
402
+
403
+ if debug:
404
+ _debug_monitor_config(config_data)
405
+
406
+ _validate_monitor_config(config_data)
407
+ _set_monitor_defaults(config_data)
408
+
345
409
  try:
346
410
  monitor_instance = Monitor(config_data)
347
411
 
@@ -406,7 +470,7 @@ def view(ctx, data_dir: str, refresh_rate: int, no_images: bool):
406
470
  viewer.disable_images = True
407
471
  viewer.refresh_rate = refresh_rate
408
472
 
409
- console.print(f"[cyan]Starting dataset viewer...[/cyan]")
473
+ console.print("[cyan]Starting dataset viewer...[/cyan]")
410
474
  console.print(f"[dim]Data directory: {data_path}[/dim]")
411
475
 
412
476
  asyncio.run(viewer.run())
@@ -424,6 +488,400 @@ def view(ctx, data_dir: str, refresh_rate: int, no_images: bool):
424
488
  sys.exit(1)
425
489
 
426
490
 
491
+ def _load_admin_credentials(config, server, token):
492
+ """Load admin server and token from config if not provided."""
493
+ if server and token:
494
+ return server, token
495
+
496
+ base_config = ConfigManager.find_config("orchestrator", config) or {}
497
+ admin_config = base_config.get("admin", {})
498
+ admin_tokens = base_config.get("orchestrator", {}).get("auth", {}).get("admin_tokens", [])
499
+
500
+ final_server = server or admin_config.get("server", "ws://localhost:8765")
501
+ final_token = token or admin_config.get("token")
502
+
503
+ if not final_token and admin_tokens:
504
+ console.print("Using first admin token.")
505
+ final_token = admin_tokens[0].get("token")
506
+
507
+ return final_server, final_token
508
+
509
+
510
+ def _setup_ssl_context(server, no_verify_ssl):
511
+ """Setup SSL context for websocket connection."""
512
+ import ssl
513
+
514
+ ssl_context = None
515
+ if server.startswith("wss://"):
516
+ ssl_context = ssl.create_default_context()
517
+ if no_verify_ssl:
518
+ ssl_context.check_hostname = False
519
+ ssl_context.verify_mode = ssl.CERT_NONE
520
+
521
+ return ssl_context
522
+
523
+
524
+ async def _authenticate_admin(websocket, token):
525
+ """Authenticate as admin with the websocket."""
526
+ await websocket.send(json.dumps({"token": token, "role": "admin"}))
527
+
528
+ response = await websocket.recv()
529
+ auth_response = json.loads(response)
530
+
531
+ if "error" in auth_response:
532
+ console.print(f"[red]Authentication failed: {auth_response['error']}[/red]")
533
+ return False
534
+
535
+ console.print("[green]✓ Authenticated as admin[/green]")
536
+ return True
537
+
538
+
539
+ async def _send_reload_command(websocket, new_cfg):
540
+ """Send reload command and handle response."""
541
+ await websocket.send(json.dumps({"type": "reload_config", "config": new_cfg}))
542
+
543
+ response = await websocket.recv()
544
+ reload_response = json.loads(response)
545
+
546
+ if reload_response.get("type") == "reload_complete":
547
+ if "message" in reload_response and "No changes" in reload_response["message"]:
548
+ console.print(f"[yellow]{reload_response['message']}[/yellow]")
549
+ else:
550
+ console.print("[green]✓ Configuration reloaded successfully![/green]")
551
+
552
+ if "updated" in reload_response and reload_response["updated"]:
553
+ console.print("\n[cyan]Updated sections:[/cyan]")
554
+ for section in reload_response["updated"]:
555
+ console.print(f" • {section}")
556
+
557
+ if "warnings" in reload_response and reload_response["warnings"]:
558
+ console.print("\n[yellow]Warnings:[/yellow]")
559
+ for warning in reload_response["warnings"]:
560
+ console.print(f" ⚠ {warning}")
561
+
562
+ return True
563
+ else:
564
+ error = reload_response.get("error", "Unknown error")
565
+ console.print(f"[red]Reload failed: {error} ({reload_response=})[/red]")
566
+ return False
567
+
568
+
569
+ def _add_token_to_config(config_data: Dict[str, Any], role: str, name: str, token: str) -> bool:
570
+ """Add a new token to the config data."""
571
+ # Ensure the auth section exists
572
+ if "orchestrator" not in config_data:
573
+ config_data["orchestrator"] = {}
574
+ if "auth" not in config_data["orchestrator"]:
575
+ config_data["orchestrator"]["auth"] = {}
576
+
577
+ auth_config = config_data["orchestrator"]["auth"]
578
+ token_key = f"{role}_tokens"
579
+
580
+ # Initialize token list if it doesn't exist
581
+ if token_key not in auth_config:
582
+ auth_config[token_key] = []
583
+
584
+ # Check if token already exists
585
+ for existing_token in auth_config[token_key]:
586
+ if existing_token.get("token") == token:
587
+ console.print(f"[yellow]Token already exists for {role}: {name}[/yellow]")
588
+ return False
589
+ if existing_token.get("name") == name:
590
+ console.print(f"[yellow]Name already exists for {role}: {name}[/yellow]")
591
+ return False
592
+
593
+ # Add the new token
594
+ auth_config[token_key].append({"name": name, "token": token})
595
+ console.print(f"[green]✓ Added {role} token for {name}[/green]")
596
+ return True
597
+
598
+
599
+ def _remove_token_from_config(config_data: Dict[str, Any], role: str, identifier: str) -> bool:
600
+ """Remove a token from the config data by name or token."""
601
+ auth_config = config_data.get("orchestrator", {}).get("auth", {})
602
+ token_key = f"{role}_tokens"
603
+
604
+ if token_key not in auth_config:
605
+ console.print(f"[red]No {role} tokens found in config[/red]")
606
+ return False
607
+
608
+ tokens = auth_config[token_key]
609
+ removed = False
610
+
611
+ for i, token_entry in enumerate(tokens):
612
+ if token_entry.get("name") == identifier or token_entry.get("token") == identifier:
613
+ removed_entry = tokens.pop(i)
614
+ console.print(f"[green]✓ Removed {role} token: {removed_entry['name']}[/green]")
615
+ removed = True
616
+ break
617
+
618
+ if not removed:
619
+ console.print(f"[red]Token not found for {role}: {identifier}[/red]")
620
+
621
+ return removed
622
+
623
+
624
+ def _list_tokens_in_config(config_data: Dict[str, Any], role: Optional[str] = None):
625
+ """List tokens in the config data."""
626
+ auth_config = config_data.get("orchestrator", {}).get("auth", {})
627
+
628
+ if not auth_config:
629
+ console.print("[yellow]No auth configuration found[/yellow]")
630
+ return
631
+
632
+ roles_to_show = [role] if role else ["worker", "admin", "monitor"]
633
+
634
+ for token_role in roles_to_show:
635
+ token_key = f"{token_role}_tokens"
636
+ tokens = auth_config.get(token_key, [])
637
+
638
+ if tokens:
639
+ console.print(f"\n[cyan]{token_role.title()} tokens:[/cyan]")
640
+ for token_entry in tokens:
641
+ name = token_entry.get("name", "Unknown")
642
+ token = token_entry.get("token", "")
643
+ masked_token = f"***{token[-4:]}" if len(token) > 4 else "***"
644
+ console.print(f" • {name}: {masked_token}")
645
+ else:
646
+ console.print(f"\n[dim]No {token_role} tokens configured[/dim]")
647
+
648
+
649
+ def _save_config_file(config_data: Dict[str, Any], config_path: Path) -> bool:
650
+ """Save the config data to a file."""
651
+ try:
652
+ with open(config_path, "w") as f:
653
+ yaml.safe_dump(config_data, f, default_flow_style=False, sort_keys=False)
654
+ console.print(f"[green]✓ Configuration saved to {config_path}[/green]")
655
+ return True
656
+ except Exception as e:
657
+ console.print(f"[red]Error saving config: {e}[/red]")
658
+ return False
659
+
660
+
661
+ async def _reload_orchestrator_config(
662
+ server: str, token: str, config_data: Dict[str, Any], no_verify_ssl: bool
663
+ ) -> bool:
664
+ """Reload the orchestrator configuration."""
665
+ import websockets
666
+
667
+ ssl_context = _setup_ssl_context(server, no_verify_ssl)
668
+
669
+ try:
670
+ async with websockets.connect(
671
+ server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
672
+ ) as websocket:
673
+ if not await _authenticate_admin(websocket, token):
674
+ return False
675
+
676
+ return await _send_reload_command(websocket, config_data)
677
+ except Exception as e:
678
+ console.print(f"[red]Error connecting to orchestrator: {e}[/red]")
679
+ return False
680
+
681
+
682
+ @main.group()
683
+ @click.option("--config", type=click.Path(exists=True), help="Configuration file")
684
+ @click.option("--server", help="Orchestrator WebSocket URL")
685
+ @click.option("--token", help="Admin authentication token")
686
+ @click.option("--no-verify-ssl", is_flag=True, help="Skip SSL verification")
687
+ @click.pass_context
688
+ def auth(
689
+ ctx, config: Optional[str], server: Optional[str], token: Optional[str], no_verify_ssl: bool
690
+ ):
691
+ """Manage authentication tokens for the orchestrator."""
692
+ ctx.ensure_object(dict)
693
+ ctx.obj.update(
694
+ {"config": config, "server": server, "token": token, "no_verify_ssl": no_verify_ssl}
695
+ )
696
+
697
+
698
+ @auth.command()
699
+ @click.argument("role", type=click.Choice(["worker", "admin", "monitor"]))
700
+ @click.argument("name")
701
+ @click.argument("token_value")
702
+ @click.option(
703
+ "--no-reload", is_flag=True, help="Don't reload orchestrator config after adding token"
704
+ )
705
+ @click.pass_context
706
+ def add(ctx, role: str, name: str, token_value: str, no_reload: bool):
707
+ """Add a new authentication token.
708
+
709
+ ROLE: Type of token (worker, admin, monitor)
710
+ NAME: Display name for the token
711
+ TOKEN_VALUE: The actual token string
712
+ """
713
+ config_file = ctx.obj.get("config")
714
+ server = ctx.obj.get("server")
715
+ admin_token = ctx.obj.get("token")
716
+ no_verify_ssl = ctx.obj.get("no_verify_ssl", False)
717
+
718
+ # Load config
719
+ config_data = ConfigManager.find_config("orchestrator", config_file)
720
+ if not config_data:
721
+ console.print("[red]No orchestrator config found[/red]")
722
+ console.print("[dim]Use --config to specify config file path[/dim]")
723
+ sys.exit(1)
724
+
725
+ # Find config file path for saving
726
+ config_path = None
727
+ if config_file:
728
+ config_path = Path(config_file)
729
+ else:
730
+ # Try to find the config file that was loaded
731
+ for search_path in [
732
+ Path.cwd() / "orchestrator.yaml",
733
+ Path.cwd() / "config" / "orchestrator.yaml",
734
+ Path.home() / ".caption-flow" / "orchestrator.yaml",
735
+ ConfigManager.get_xdg_config_home() / "caption-flow" / "orchestrator.yaml",
736
+ ]:
737
+ if search_path.exists():
738
+ config_path = search_path
739
+ break
740
+
741
+ if not config_path:
742
+ console.print("[red]Could not determine config file to save to[/red]")
743
+ console.print("[dim]Use --config to specify config file path[/dim]")
744
+ sys.exit(1)
745
+
746
+ # Add token to config
747
+ if not _add_token_to_config(config_data, role, name, token_value):
748
+ sys.exit(1)
749
+
750
+ # Save config file
751
+ if not _save_config_file(config_data, config_path):
752
+ sys.exit(1)
753
+
754
+ # Reload orchestrator if requested
755
+ if not no_reload:
756
+ server, admin_token = _load_admin_credentials(config_file, server, admin_token)
757
+
758
+ if not server:
759
+ console.print("[yellow]No server specified, skipping orchestrator reload[/yellow]")
760
+ console.print("[dim]Use --server to reload orchestrator config[/dim]")
761
+ elif not admin_token:
762
+ console.print("[yellow]No admin token specified, skipping orchestrator reload[/yellow]")
763
+ console.print("[dim]Use --token to reload orchestrator config[/dim]")
764
+ else:
765
+ console.print(f"[cyan]Reloading orchestrator config...[/cyan]")
766
+ success = asyncio.run(
767
+ _reload_orchestrator_config(server, admin_token, config_data, no_verify_ssl)
768
+ )
769
+ if not success:
770
+ console.print("[yellow]Config file updated but orchestrator reload failed[/yellow]")
771
+ console.print("[dim]You may need to restart the orchestrator manually[/dim]")
772
+
773
+
774
+ @auth.command()
775
+ @click.argument("role", type=click.Choice(["worker", "admin", "monitor"]))
776
+ @click.argument("identifier")
777
+ @click.option(
778
+ "--no-reload", is_flag=True, help="Don't reload orchestrator config after removing token"
779
+ )
780
+ @click.pass_context
781
+ def remove(ctx, role: str, identifier: str, no_reload: bool):
782
+ """Remove an authentication token.
783
+
784
+ ROLE: Type of token (worker, admin, monitor)
785
+ IDENTIFIER: Name or token value to remove
786
+ """
787
+ config_file = ctx.obj.get("config")
788
+ server = ctx.obj.get("server")
789
+ admin_token = ctx.obj.get("token")
790
+ no_verify_ssl = ctx.obj.get("no_verify_ssl", False)
791
+
792
+ # Load config
793
+ config_data = ConfigManager.find_config("orchestrator", config_file)
794
+ if not config_data:
795
+ console.print("[red]No orchestrator config found[/red]")
796
+ sys.exit(1)
797
+
798
+ # Find config file path for saving
799
+ config_path = None
800
+ if config_file:
801
+ config_path = Path(config_file)
802
+ else:
803
+ # Try to find the config file that was loaded
804
+ for search_path in [
805
+ Path.cwd() / "orchestrator.yaml",
806
+ Path.cwd() / "config" / "orchestrator.yaml",
807
+ Path.home() / ".caption-flow" / "orchestrator.yaml",
808
+ ConfigManager.get_xdg_config_home() / "caption-flow" / "orchestrator.yaml",
809
+ ]:
810
+ if search_path.exists():
811
+ config_path = search_path
812
+ break
813
+
814
+ if not config_path:
815
+ console.print("[red]Could not determine config file to save to[/red]")
816
+ sys.exit(1)
817
+
818
+ # Remove token from config
819
+ if not _remove_token_from_config(config_data, role, identifier):
820
+ sys.exit(1)
821
+
822
+ # Save config file
823
+ if not _save_config_file(config_data, config_path):
824
+ sys.exit(1)
825
+
826
+ # Reload orchestrator if requested
827
+ if not no_reload:
828
+ server, admin_token = _load_admin_credentials(config_file, server, admin_token)
829
+
830
+ if not server:
831
+ console.print("[yellow]No server specified, skipping orchestrator reload[/yellow]")
832
+ elif not admin_token:
833
+ console.print("[yellow]No admin token specified, skipping orchestrator reload[/yellow]")
834
+ else:
835
+ console.print(f"[cyan]Reloading orchestrator config...[/cyan]")
836
+ success = asyncio.run(
837
+ _reload_orchestrator_config(server, admin_token, config_data, no_verify_ssl)
838
+ )
839
+ if not success:
840
+ console.print("[yellow]Config file updated but orchestrator reload failed[/yellow]")
841
+
842
+
843
+ @auth.command()
844
+ @click.argument("role", type=click.Choice(["worker", "admin", "monitor", "all"]), required=False)
845
+ @click.pass_context
846
+ def list(ctx, role: Optional[str]):
847
+ """List authentication tokens.
848
+
849
+ ROLE: Type of tokens to list (worker, admin, monitor, all). Default: all
850
+ """
851
+ config_file = ctx.obj.get("config")
852
+
853
+ # Load config
854
+ config_data = ConfigManager.find_config("orchestrator", config_file)
855
+ if not config_data:
856
+ console.print("[red]No orchestrator config found[/red]")
857
+ sys.exit(1)
858
+
859
+ # Show tokens
860
+ if role == "all" or role is None:
861
+ _list_tokens_in_config(config_data)
862
+ else:
863
+ _list_tokens_in_config(config_data, role)
864
+
865
+
866
+ @auth.command()
867
+ @click.option("--length", default=32, help="Token length (default: 32)")
868
+ @click.option("--count", default=1, help="Number of tokens to generate (default: 1)")
869
+ def generate(length: int, count: int):
870
+ """Generate random authentication tokens."""
871
+ import secrets
872
+ import string
873
+
874
+ alphabet = string.ascii_letters + string.digits + "-_"
875
+
876
+ console.print(
877
+ f"[cyan]Generated {count} token{'s' if count > 1 else ''} ({length} characters each):[/cyan]\n"
878
+ )
879
+
880
+ for i in range(count):
881
+ token = "".join(secrets.choice(alphabet) for _ in range(length))
882
+ console.print(f" {i + 1}: {token}")
883
+
884
+
427
885
  @main.command()
428
886
  @click.option("--config", type=click.Path(exists=True), help="Configuration file")
429
887
  @click.option("--server", help="Orchestrator WebSocket URL")
@@ -441,27 +899,8 @@ def reload_config(
441
899
  ):
442
900
  """Reload orchestrator configuration via admin connection."""
443
901
  import websockets
444
- import ssl
445
902
 
446
- # Load base config to get server/token if not provided via CLI
447
- if not server or not token:
448
- base_config = ConfigManager.find_config("orchestrator", config) or {}
449
- admin_config = base_config.get("admin", {})
450
- admin_tokens = base_config.get("orchestrator", {}).get("auth", {}).get("admin_tokens", [])
451
- has_admin_tokens = False
452
- if len(admin_tokens) > 0:
453
- has_admin_tokens = True
454
- first_admin_token = admin_tokens[0].get("token", None)
455
- # Do not print sensitive admin token to console.
456
-
457
- if not server:
458
- server = admin_config.get("server", "ws://localhost:8765")
459
- if not token:
460
- token = admin_config.get("token", None)
461
- if token is None and has_admin_tokens:
462
- # grab the first one, we'll just assume we're localhost.
463
- console.print("Using first admin token.")
464
- token = first_admin_token
903
+ server, token = _load_admin_credentials(config, server, token)
465
904
 
466
905
  if not server:
467
906
  console.print("[red]Error: --server required (or set in config)[/red]")
@@ -472,66 +911,22 @@ def reload_config(
472
911
 
473
912
  console.print(f"[cyan]Loading configuration from {new_config}...[/cyan]")
474
913
 
475
- # Load the new configuration
476
914
  new_cfg = ConfigManager.load_yaml(Path(new_config))
477
915
  if not new_cfg:
478
916
  console.print("[red]Failed to load configuration[/red]")
479
917
  sys.exit(1)
480
918
 
481
- # Setup SSL
482
- ssl_context = None
483
- if server.startswith("wss://"):
484
- if no_verify_ssl:
485
- ssl_context = ssl.create_default_context()
486
- ssl_context.check_hostname = False
487
- ssl_context.verify_mode = ssl.CERT_NONE
488
- else:
489
- ssl_context = ssl.create_default_context()
919
+ ssl_context = _setup_ssl_context(server, no_verify_ssl)
490
920
 
491
921
  async def send_reload():
492
922
  try:
493
923
  async with websockets.connect(
494
924
  server, ssl=ssl_context, ping_interval=20, ping_timeout=60, close_timeout=10
495
925
  ) as websocket:
496
- # Authenticate as admin
497
- await websocket.send(json.dumps({"token": token, "role": "admin"}))
498
-
499
- response = await websocket.recv()
500
- auth_response = json.loads(response)
501
-
502
- if "error" in auth_response:
503
- console.print(f"[red]Authentication failed: {auth_response['error']}[/red]")
926
+ if not await _authenticate_admin(websocket, token):
504
927
  return False
505
928
 
506
- console.print("[green]✓ Authenticated as admin[/green]")
507
-
508
- # Send reload command
509
- await websocket.send(json.dumps({"type": "reload_config", "config": new_cfg}))
510
-
511
- response = await websocket.recv()
512
- reload_response = json.loads(response)
513
-
514
- if reload_response.get("type") == "reload_complete":
515
- if "message" in reload_response and "No changes" in reload_response["message"]:
516
- console.print(f"[yellow]{reload_response['message']}[/yellow]")
517
- else:
518
- console.print("[green]✓ Configuration reloaded successfully![/green]")
519
-
520
- if "updated" in reload_response and reload_response["updated"]:
521
- console.print("\n[cyan]Updated sections:[/cyan]")
522
- for section in reload_response["updated"]:
523
- console.print(f" • {section}")
524
-
525
- if "warnings" in reload_response and reload_response["warnings"]:
526
- console.print("\n[yellow]Warnings:[/yellow]")
527
- for warning in reload_response["warnings"]:
528
- console.print(f" ⚠ {warning}")
529
-
530
- return True
531
- else:
532
- error = reload_response.get("error", "Unknown error")
533
- console.print(f"[red]Reload failed: {error} ({reload_response=})[/red]")
534
- return False
929
+ return await _send_reload_command(websocket, new_cfg)
535
930
 
536
931
  except Exception as e:
537
932
  console.print(f"[red]Error: {e}[/red]")
@@ -542,39 +937,20 @@ def reload_config(
542
937
  sys.exit(1)
543
938
 
544
939
 
545
- @main.command()
546
- @click.option("--data-dir", default="./caption_data", help="Storage directory")
547
- @click.option("--checkpoint-dir", default="./checkpoints", help="Checkpoint directory")
548
- @click.option("--fix", is_flag=True, help="Fix issues by resetting abandoned chunks")
549
- @click.option("--verbose", is_flag=True, help="Show detailed information")
550
- def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
551
- """Scan for sparse or abandoned chunks and optionally fix them."""
552
- from .utils.chunk_tracker import ChunkTracker
553
- from .storage import StorageManager
554
- import pyarrow.parquet as pq
555
-
556
- console.print("[bold cyan]Scanning for sparse/abandoned chunks...[/bold cyan]\n")
557
-
558
- checkpoint_path = Path(checkpoint_dir) / "chunks.json"
559
- if not checkpoint_path.exists():
560
- console.print("[red]No chunk checkpoint found![/red]")
561
- return
562
-
563
- tracker = ChunkTracker(checkpoint_path)
564
- storage = StorageManager(Path(data_dir))
565
-
566
- # Get and display stats
567
- stats = tracker.get_stats()
940
+ def _display_chunk_stats(stats):
941
+ """Display chunk statistics."""
568
942
  console.print(f"[green]Total chunks:[/green] {stats['total']}")
569
943
  console.print(f"[green]Completed:[/green] {stats['completed']}")
570
944
  console.print(f"[yellow]Pending:[/yellow] {stats['pending']}")
571
945
  console.print(f"[yellow]Assigned:[/yellow] {stats['assigned']}")
572
946
  console.print(f"[red]Failed:[/red] {stats['failed']}\n")
573
947
 
574
- # Find abandoned chunks
948
+
949
+ def _find_abandoned_chunks(tracker):
950
+ """Find chunks that have been assigned for too long."""
575
951
  abandoned_chunks = []
576
952
  stale_threshold = 3600 # 1 hour
577
- current_time = datetime.utcnow()
953
+ current_time = datetime.now(_datetime.UTC)
578
954
 
579
955
  for chunk_id, chunk_state in tracker.chunks.items():
580
956
  if chunk_state.status == "assigned" and chunk_state.assigned_at:
@@ -582,24 +958,31 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
582
958
  if age > stale_threshold:
583
959
  abandoned_chunks.append((chunk_id, chunk_state, age))
584
960
 
585
- if abandoned_chunks:
586
- console.print(f"[red]Found {len(abandoned_chunks)} abandoned chunks:[/red]")
587
- for chunk_id, chunk_state, age in abandoned_chunks[:10]:
588
- age_str = f"{age/3600:.1f} hours" if age > 3600 else f"{age/60:.1f} minutes"
589
- console.print(f" • {chunk_id} (assigned to {chunk_state.assigned_to} {age_str} ago)")
961
+ return abandoned_chunks
590
962
 
591
- if len(abandoned_chunks) > 10:
592
- console.print(f" ... and {len(abandoned_chunks) - 10} more")
593
963
 
594
- if fix:
595
- console.print("\n[yellow]Resetting abandoned chunks to pending...[/yellow]")
596
- for chunk_id, _, _ in abandoned_chunks:
597
- tracker.mark_failed(chunk_id)
598
- console.print(f"[green]✓ Reset {len(abandoned_chunks)} chunks[/green]")
964
+ def _display_abandoned_chunks(abandoned_chunks, fix, tracker):
965
+ """Display abandoned chunks and optionally fix them."""
966
+ if not abandoned_chunks:
967
+ return
599
968
 
600
- # Check for sparse shards
601
- console.print("\n[bold cyan]Checking for sparse shards...[/bold cyan]")
969
+ console.print(f"[red]Found {len(abandoned_chunks)} abandoned chunks:[/red]")
970
+ for chunk_id, chunk_state, age in abandoned_chunks[:10]:
971
+ age_str = f"{age / 3600:.1f} hours" if age > 3600 else f"{age / 60:.1f} minutes"
972
+ console.print(f" • {chunk_id} (assigned to {chunk_state.assigned_to} {age_str} ago)")
973
+
974
+ if len(abandoned_chunks) > 10:
975
+ console.print(f" ... and {len(abandoned_chunks) - 10} more")
976
+
977
+ if fix:
978
+ console.print("\n[yellow]Resetting abandoned chunks to pending...[/yellow]")
979
+ for chunk_id, _, _ in abandoned_chunks:
980
+ tracker.mark_failed(chunk_id)
981
+ console.print(f"[green]✓ Reset {len(abandoned_chunks)} chunks[/green]")
602
982
 
983
+
984
+ def _find_sparse_shards(tracker):
985
+ """Find shards with gaps or issues."""
603
986
  shards_summary = tracker.get_shards_summary()
604
987
  sparse_shards = []
605
988
 
@@ -618,60 +1001,108 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
618
1001
  if has_gaps or shard_info["failed_chunks"] > 0:
619
1002
  sparse_shards.append((shard_name, shard_info, has_gaps))
620
1003
 
621
- if sparse_shards:
622
- console.print(f"\n[yellow]Found {len(sparse_shards)} sparse/incomplete shards:[/yellow]")
623
- for shard_name, shard_info, has_gaps in sparse_shards[:5]:
624
- status = []
625
- if shard_info["pending_chunks"] > 0:
626
- status.append(f"{shard_info['pending_chunks']} pending")
627
- if shard_info["assigned_chunks"] > 0:
628
- status.append(f"{shard_info['assigned_chunks']} assigned")
629
- if shard_info["failed_chunks"] > 0:
630
- status.append(f"{shard_info['failed_chunks']} failed")
631
- if has_gaps:
632
- status.append("has gaps")
633
-
634
- console.print(f" {shard_name}: {', '.join(status)}")
1004
+ return sparse_shards
1005
+
1006
+
1007
+ def _display_sparse_shards(sparse_shards):
1008
+ """Display sparse/incomplete shards."""
1009
+ if not sparse_shards:
1010
+ return
1011
+
1012
+ console.print(f"\n[yellow]Found {len(sparse_shards)} sparse/incomplete shards:[/yellow]")
1013
+ for shard_name, shard_info, has_gaps in sparse_shards[:5]:
1014
+ status = []
1015
+ if shard_info["pending_chunks"] > 0:
1016
+ status.append(f"{shard_info['pending_chunks']} pending")
1017
+ if shard_info["assigned_chunks"] > 0:
1018
+ status.append(f"{shard_info['assigned_chunks']} assigned")
1019
+ if shard_info["failed_chunks"] > 0:
1020
+ status.append(f"{shard_info['failed_chunks']} failed")
1021
+ if has_gaps:
1022
+ status.append("has gaps")
1023
+
1024
+ console.print(f" • {shard_name}: {', '.join(status)}")
1025
+ console.print(
1026
+ f" Progress: {shard_info['completed_chunks']}/{shard_info['total_chunks']} chunks"
1027
+ )
1028
+
1029
+ if len(sparse_shards) > 5:
1030
+ console.print(f" ... and {len(sparse_shards) - 5} more")
1031
+
1032
+
1033
+ def _cross_check_storage(storage, tracker, fix):
1034
+ """Cross-check chunk tracker against storage."""
1035
+ import pyarrow.parquet as pq
1036
+
1037
+ console.print("\n[bold cyan]Cross-checking with stored captions...[/bold cyan]")
1038
+
1039
+ try:
1040
+ table = pq.read_table(storage.captions_path, columns=["chunk_id"])
1041
+ stored_chunk_ids = set(c for c in table["chunk_id"].to_pylist() if c)
1042
+
1043
+ tracker_completed = set(c for c, s in tracker.chunks.items() if s.status == "completed")
1044
+
1045
+ missing_in_storage = tracker_completed - stored_chunk_ids
1046
+ missing_in_tracker = stored_chunk_ids - set(tracker.chunks.keys())
1047
+
1048
+ if missing_in_storage:
635
1049
  console.print(
636
- f" Progress: {shard_info['completed_chunks']}/{shard_info['total_chunks']} chunks"
1050
+ f"\n[red]Chunks marked complete but missing from storage:[/red] {len(missing_in_storage)}"
637
1051
  )
1052
+ for chunk_id in list(missing_in_storage)[:5]:
1053
+ console.print(f" • {chunk_id}")
638
1054
 
639
- if len(sparse_shards) > 5:
640
- console.print(f" ... and {len(sparse_shards) - 5} more")
1055
+ if fix:
1056
+ console.print("[yellow]Resetting these chunks to pending...[/yellow]")
1057
+ for chunk_id in missing_in_storage:
1058
+ tracker.mark_failed(chunk_id)
1059
+ console.print(f"[green]✓ Reset {len(missing_in_storage)} chunks[/green]")
641
1060
 
642
- # Cross-check with storage if verbose
643
- if storage.captions_path.exists() and verbose:
644
- console.print("\n[bold cyan]Cross-checking with stored captions...[/bold cyan]")
1061
+ if missing_in_tracker:
1062
+ console.print(
1063
+ f"\n[yellow]Chunks in storage but not tracked:[/yellow] {len(missing_in_tracker)}"
1064
+ )
645
1065
 
646
- try:
647
- table = pq.read_table(storage.captions_path, columns=["chunk_id"])
648
- stored_chunk_ids = set(c for c in table["chunk_id"].to_pylist() if c)
1066
+ except Exception as e:
1067
+ console.print(f"[red]Error reading storage: {e}[/red]")
649
1068
 
650
- tracker_completed = set(c for c, s in tracker.chunks.items() if s.status == "completed")
651
1069
 
652
- missing_in_storage = tracker_completed - stored_chunk_ids
653
- missing_in_tracker = stored_chunk_ids - set(tracker.chunks.keys())
1070
+ @main.command()
1071
+ @click.option("--data-dir", default="./caption_data", help="Storage directory")
1072
+ @click.option("--checkpoint-dir", default="./checkpoints", help="Checkpoint directory")
1073
+ @click.option("--fix", is_flag=True, help="Fix issues by resetting abandoned chunks")
1074
+ @click.option("--verbose", is_flag=True, help="Show detailed information")
1075
+ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
1076
+ """Scan for sparse or abandoned chunks and optionally fix them."""
1077
+ from .storage import StorageManager
1078
+ from .utils.chunk_tracker import ChunkTracker
654
1079
 
655
- if missing_in_storage:
656
- console.print(
657
- f"\n[red]Chunks marked complete but missing from storage:[/red] {len(missing_in_storage)}"
658
- )
659
- for chunk_id in list(missing_in_storage)[:5]:
660
- console.print(f" • {chunk_id}")
1080
+ console.print("[bold cyan]Scanning for sparse/abandoned chunks...[/bold cyan]\n")
661
1081
 
662
- if fix:
663
- console.print("[yellow]Resetting these chunks to pending...[/yellow]")
664
- for chunk_id in missing_in_storage:
665
- tracker.mark_failed(chunk_id)
666
- console.print(f"[green]✓ Reset {len(missing_in_storage)} chunks[/green]")
1082
+ checkpoint_path = Path(checkpoint_dir) / "chunks.json"
1083
+ if not checkpoint_path.exists():
1084
+ console.print("[red]No chunk checkpoint found![/red]")
1085
+ return
667
1086
 
668
- if missing_in_tracker:
669
- console.print(
670
- f"\n[yellow]Chunks in storage but not tracked:[/yellow] {len(missing_in_tracker)}"
671
- )
1087
+ tracker = ChunkTracker(checkpoint_path)
1088
+ storage = StorageManager(Path(data_dir))
672
1089
 
673
- except Exception as e:
674
- console.print(f"[red]Error reading storage: {e}[/red]")
1090
+ # Get and display stats
1091
+ stats = tracker.get_stats()
1092
+ _display_chunk_stats(stats)
1093
+
1094
+ # Find and handle abandoned chunks
1095
+ abandoned_chunks = _find_abandoned_chunks(tracker)
1096
+ _display_abandoned_chunks(abandoned_chunks, fix, tracker)
1097
+
1098
+ # Check for sparse shards
1099
+ console.print("\n[bold cyan]Checking for sparse shards...[/bold cyan]")
1100
+ sparse_shards = _find_sparse_shards(tracker)
1101
+ _display_sparse_shards(sparse_shards)
1102
+
1103
+ # Cross-check with storage if verbose
1104
+ if storage.captions_path.exists() and verbose:
1105
+ _cross_check_storage(storage, tracker, fix)
675
1106
 
676
1107
  # Summary
677
1108
  console.print("\n[bold cyan]Summary:[/bold cyan]")
@@ -695,12 +1126,163 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
695
1126
  tracker.save_checkpoint()
696
1127
 
697
1128
 
1129
+ def _display_export_stats(stats):
1130
+ """Display storage statistics."""
1131
+ console.print("\n[bold cyan]Storage Statistics:[/bold cyan]")
1132
+ console.print(f"[green]Total rows:[/green] {stats['total_rows']:,}")
1133
+ console.print(f"[green]Total outputs:[/green] {stats['total_outputs']:,}")
1134
+ console.print(f"[green]Shards:[/green] {stats['shard_count']} ({', '.join(stats['shards'])})")
1135
+ console.print(f"[green]Output fields:[/green] {', '.join(stats['output_fields'])}")
1136
+
1137
+ if stats.get("field_stats"):
1138
+ console.print("\n[cyan]Field breakdown:[/cyan]")
1139
+ for field, count in stats["field_stats"].items():
1140
+ console.print(f" • {field}: {count['total_items']:,} items")
1141
+
1142
+
1143
+ def _prepare_export_params(shard, shards, columns):
1144
+ """Prepare shard filter and column list."""
1145
+ shard_filter = None
1146
+ if shard:
1147
+ shard_filter = [shard]
1148
+ elif shards:
1149
+ shard_filter = [s.strip() for s in shards.split(",")]
1150
+
1151
+ column_list = None
1152
+ if columns:
1153
+ column_list = [col.strip() for col in columns.split(",")]
1154
+ console.print(f"\n[cyan]Exporting columns:[/cyan] {', '.join(column_list)}")
1155
+
1156
+ return shard_filter, column_list
1157
+
1158
+
1159
+ async def _export_all_formats(
1160
+ exporter, output, shard_filter, column_list, limit, filename_column, export_column
1161
+ ):
1162
+ """Export to all formats."""
1163
+ base_name = output or "caption_export"
1164
+ base_path = Path(base_name)
1165
+ results = {}
1166
+
1167
+ for export_format in ["jsonl", "csv", "parquet", "json", "txt"]:
1168
+ console.print(f"\n[cyan]Exporting to {export_format.upper()}...[/cyan]")
1169
+ try:
1170
+ format_results = await exporter.export_all_shards(
1171
+ export_format,
1172
+ base_path,
1173
+ columns=column_list,
1174
+ limit_per_shard=limit,
1175
+ shard_filter=shard_filter,
1176
+ filename_column=filename_column,
1177
+ export_column=export_column,
1178
+ )
1179
+ results[export_format] = sum(format_results.values())
1180
+ except Exception as e:
1181
+ console.print(f"[yellow]Skipping {export_format}: {e}[/yellow]")
1182
+ results[export_format] = 0
1183
+
1184
+ console.print("\n[green]✓ Export complete![/green]")
1185
+ for fmt, count in results.items():
1186
+ if count > 0:
1187
+ console.print(f" • {fmt.upper()}: {count:,} items")
1188
+
1189
+
1190
+ async def _export_to_lance(exporter, output, column_list, shard_filter):
1191
+ """Export to Lance dataset."""
1192
+ output_path = output or "exported_captions.lance"
1193
+ console.print(f"\n[cyan]Exporting to Lance dataset:[/cyan] {output_path}")
1194
+ total_rows = await exporter.export_to_lance(
1195
+ output_path, columns=column_list, shard_filter=shard_filter
1196
+ )
1197
+ console.print(f"[green]✓ Exported {total_rows:,} rows to Lance dataset[/green]")
1198
+
1199
+
1200
+ async def _export_to_huggingface(exporter, hf_dataset, license, private, nsfw, tags, shard_filter):
1201
+ """Export to Hugging Face Hub."""
1202
+ if not hf_dataset:
1203
+ console.print("[red]Error: --hf-dataset required for huggingface_hub format[/red]")
1204
+ console.print("[dim]Example: --hf-dataset username/my-caption-dataset[/dim]")
1205
+ sys.exit(1)
1206
+
1207
+ tag_list = None
1208
+ if tags:
1209
+ tag_list = [tag.strip() for tag in tags.split(",")]
1210
+
1211
+ console.print(f"\n[cyan]Uploading to Hugging Face Hub:[/cyan] {hf_dataset}")
1212
+ if private:
1213
+ console.print("[dim]Privacy: Private dataset[/dim]")
1214
+ if nsfw:
1215
+ console.print("[dim]Content: Not for all audiences[/dim]")
1216
+ if tag_list:
1217
+ console.print(f"[dim]Tags: {', '.join(tag_list)}[/dim]")
1218
+ if shard_filter:
1219
+ console.print(f"[dim]Shards: {', '.join(shard_filter)}[/dim]")
1220
+
1221
+ url = await exporter.export_to_huggingface_hub(
1222
+ dataset_name=hf_dataset,
1223
+ license=license,
1224
+ private=private,
1225
+ nsfw=nsfw,
1226
+ tags=tag_list,
1227
+ shard_filter=shard_filter,
1228
+ )
1229
+ console.print(f"[green]✓ Dataset uploaded to: {url}[/green]")
1230
+
1231
+
1232
+ async def _export_single_format(
1233
+ exporter,
1234
+ format,
1235
+ output,
1236
+ shard_filter,
1237
+ column_list,
1238
+ limit,
1239
+ filename_column,
1240
+ export_column,
1241
+ verbose,
1242
+ ):
1243
+ """Export to a single format."""
1244
+ output_path = output or "export"
1245
+
1246
+ if shard_filter and len(shard_filter) == 1:
1247
+ console.print(f"\n[cyan]Exporting shard {shard_filter[0]} to {format.upper()}...[/cyan]")
1248
+ count = await exporter.export_shard(
1249
+ shard_filter[0],
1250
+ format,
1251
+ output_path,
1252
+ columns=column_list,
1253
+ limit=limit,
1254
+ filename_column=filename_column,
1255
+ export_column=export_column,
1256
+ )
1257
+ console.print(f"[green]✓ Exported {count:,} items[/green]")
1258
+ else:
1259
+ console.print(f"\n[cyan]Exporting to {format.upper()}...[/cyan]")
1260
+ results = await exporter.export_all_shards(
1261
+ format,
1262
+ output_path,
1263
+ columns=column_list,
1264
+ limit_per_shard=limit,
1265
+ shard_filter=shard_filter,
1266
+ filename_column=filename_column,
1267
+ export_column=export_column,
1268
+ )
1269
+
1270
+ total = sum(results.values())
1271
+ console.print(f"[green]✓ Exported {total:,} items total[/green]")
1272
+
1273
+ if verbose and len(results) > 1:
1274
+ console.print("\n[dim]Per-shard breakdown:[/dim]")
1275
+ for shard_name, count in sorted(results.items()):
1276
+ console.print(f" • {shard_name}: {count:,} items")
1277
+
1278
+
698
1279
  @main.command()
699
1280
  @click.option("--data-dir", default="./caption_data", help="Storage directory")
700
1281
  @click.option(
701
1282
  "--format",
702
1283
  type=click.Choice(
703
- ["jsonl", "json", "csv", "txt", "huggingface_hub", "all"], case_sensitive=False
1284
+ ["jsonl", "json", "csv", "txt", "parquet", "lance", "huggingface_hub", "all"],
1285
+ case_sensitive=False,
704
1286
  ),
705
1287
  default="jsonl",
706
1288
  help="Export format (default: jsonl)",
@@ -710,17 +1292,117 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
710
1292
  @click.option("--columns", help="Comma-separated list of columns to export (default: all)")
711
1293
  @click.option("--export-column", default="captions", help="Column to export for txt format")
712
1294
  @click.option("--filename-column", default="filename", help="Column containing filenames")
1295
+ @click.option("--shard", help="Specific shard to export (e.g., data-0001)")
1296
+ @click.option("--shards", help="Comma-separated list of shards to export")
713
1297
  @click.option("--include-empty", is_flag=True, help="Include rows with empty export column")
714
1298
  @click.option("--stats-only", is_flag=True, help="Show statistics without exporting")
715
- @click.option(
716
- "--optimize", is_flag=True, help="Optimize storage before export (remove empty columns)"
717
- )
1299
+ @click.option("--optimize", is_flag=True, help="Optimize storage before export")
718
1300
  @click.option("--verbose", is_flag=True, help="Show detailed export progress")
719
1301
  @click.option("--hf-dataset", help="Dataset name on HF Hub (e.g., username/dataset-name)")
720
- @click.option("--license", help="License for the dataset (required for new HF datasets)")
1302
+ @click.option("--license", default="apache-2.0", help="License for the dataset")
721
1303
  @click.option("--private", is_flag=True, help="Make HF dataset private")
722
1304
  @click.option("--nsfw", is_flag=True, help="Add not-for-all-audiences tag")
723
1305
  @click.option("--tags", help="Comma-separated tags for HF dataset")
1306
+ def _validate_export_setup(data_dir):
1307
+ """Validate export setup and create storage manager."""
1308
+ from .storage import StorageManager
1309
+
1310
+ storage_path = Path(data_dir)
1311
+ if not storage_path.exists():
1312
+ console.print(f"[red]Storage directory not found: {data_dir}[/red]")
1313
+ sys.exit(1)
1314
+
1315
+ return StorageManager(storage_path)
1316
+
1317
+
1318
+ async def _run_export_process(
1319
+ storage,
1320
+ format,
1321
+ output,
1322
+ shard,
1323
+ shards,
1324
+ columns,
1325
+ limit,
1326
+ filename_column,
1327
+ export_column,
1328
+ verbose,
1329
+ hf_dataset,
1330
+ license,
1331
+ private,
1332
+ nsfw,
1333
+ tags,
1334
+ stats_only,
1335
+ optimize,
1336
+ ):
1337
+ """Execute the main export process."""
1338
+ from .storage.exporter import LanceStorageExporter
1339
+
1340
+ await storage.initialize()
1341
+
1342
+ stats = await storage.get_caption_stats()
1343
+ _display_export_stats(stats)
1344
+
1345
+ if stats_only:
1346
+ return
1347
+
1348
+ if optimize:
1349
+ console.print("\n[yellow]Optimizing storage...[/yellow]")
1350
+ await storage.optimize_storage()
1351
+
1352
+ shard_filter, column_list = _prepare_export_params(shard, shards, columns)
1353
+ exporter = LanceStorageExporter(storage)
1354
+
1355
+ if format == "all":
1356
+ await _export_all_formats(
1357
+ exporter, output, shard_filter, column_list, limit, filename_column, export_column
1358
+ )
1359
+ elif format == "lance":
1360
+ await _export_to_lance(exporter, output, column_list, shard_filter)
1361
+ elif format == "huggingface_hub":
1362
+ await _export_to_huggingface(
1363
+ exporter, hf_dataset, license, private, nsfw, tags, shard_filter
1364
+ )
1365
+ else:
1366
+ await _export_single_format(
1367
+ exporter,
1368
+ format,
1369
+ output,
1370
+ shard_filter,
1371
+ column_list,
1372
+ limit,
1373
+ filename_column,
1374
+ export_column,
1375
+ verbose,
1376
+ )
1377
+
1378
+
1379
+ @main.command()
1380
+ @click.option("--data-dir", default="./caption_data", help="Storage directory")
1381
+ @click.option(
1382
+ "--format",
1383
+ type=click.Choice(
1384
+ ["jsonl", "json", "csv", "txt", "parquet", "lance", "huggingface_hub", "all"],
1385
+ case_sensitive=False,
1386
+ ),
1387
+ default="jsonl",
1388
+ help="Export format (default: jsonl)",
1389
+ )
1390
+ @click.option("--output", help="Output filename or directory")
1391
+ @click.option("--limit", type=int, help="Maximum number of items to export")
1392
+ @click.option("--columns", help="Comma-separated list of columns to include")
1393
+ @click.option("--export-column", default="captions", help="Column to export (default: captions)")
1394
+ @click.option("--filename-column", default="filename", help="Filename column (default: filename)")
1395
+ @click.option("--shard", help="Export only specific shard (e.g., 'data-001')")
1396
+ @click.option("--shards", help="Comma-separated list of shards to export")
1397
+ @click.option("--include-empty", is_flag=True, help="Include items with empty/null export column")
1398
+ @click.option("--stats-only", is_flag=True, help="Show statistics only, don't export")
1399
+ @click.option("--optimize", is_flag=True, help="Optimize storage before export")
1400
+ @click.option("--verbose", is_flag=True, help="Verbose output")
1401
+ @click.option("--hf-dataset", help="HuggingFace Hub dataset name (for huggingface_hub format)")
1402
+ @click.option("--license", default="MIT", help="Dataset license (default: MIT)")
1403
+ @click.option("--private", is_flag=True, help="Make HuggingFace dataset private")
1404
+ @click.option("--nsfw", is_flag=True, help="Mark dataset as NSFW")
1405
+ @click.option("--tags", help="Comma-separated tags for HuggingFace dataset")
724
1406
  def export(
725
1407
  data_dir: str,
726
1408
  format: str,
@@ -729,219 +1411,56 @@ def export(
729
1411
  columns: Optional[str],
730
1412
  export_column: str,
731
1413
  filename_column: str,
1414
+ shard: Optional[str],
1415
+ shards: Optional[str],
732
1416
  include_empty: bool,
733
1417
  stats_only: bool,
734
1418
  optimize: bool,
735
1419
  verbose: bool,
736
1420
  hf_dataset: Optional[str],
737
- license: Optional[str],
1421
+ license: str,
738
1422
  private: bool,
739
1423
  nsfw: bool,
740
1424
  tags: Optional[str],
741
1425
  ):
742
- """Export caption data to various formats."""
743
- from .storage import StorageManager
744
- from .storage.exporter import StorageExporter, ExportError
1426
+ """Export caption data to various formats with per-shard support."""
1427
+ from .storage.exporter import ExportError
745
1428
 
746
- # Initialize storage manager
747
- storage_path = Path(data_dir)
748
- if not storage_path.exists():
749
- console.print(f"[red]Storage directory not found: {data_dir}[/red]")
750
- sys.exit(1)
751
-
752
- storage = StorageManager(storage_path)
753
-
754
- async def run_export():
755
- await storage.initialize()
756
-
757
- # Show statistics first
758
- stats = await storage.get_caption_stats()
759
- console.print("\n[bold cyan]Storage Statistics:[/bold cyan]")
760
- console.print(f"[green]Total rows:[/green] {stats['total_rows']:,}")
761
- console.print(f"[green]Total outputs:[/green] {stats['total_outputs']:,}")
762
- console.print(f"[green]Output fields:[/green] {', '.join(stats['output_fields'])}")
763
-
764
- if stats.get("field_stats"):
765
- console.print("\n[cyan]Field breakdown:[/cyan]")
766
- for field, field_stat in stats["field_stats"].items():
767
- console.print(
768
- f" • {field}: {field_stat['total_items']:,} items "
769
- f"in {field_stat['rows_with_data']:,} rows"
770
- )
771
-
772
- if stats_only:
773
- return
774
-
775
- # Optimize storage if requested
776
- if optimize:
777
- console.print("\n[yellow]Optimizing storage (removing empty columns)...[/yellow]")
778
- await storage.optimize_storage()
779
-
780
- # Prepare columns list
781
- column_list = None
782
- if columns:
783
- column_list = [col.strip() for col in columns.split(",")]
784
- console.print(f"\n[cyan]Exporting columns:[/cyan] {', '.join(column_list)}")
785
-
786
- # Get storage contents
787
- console.print("\n[yellow]Loading data...[/yellow]")
788
- try:
789
- contents = await storage.get_storage_contents(
790
- limit=limit, columns=column_list, include_metadata=True
791
- )
792
- except ValueError as e:
793
- console.print(f"[red]Error: {e}[/red]")
794
- sys.exit(1)
1429
+ storage = _validate_export_setup(data_dir)
795
1430
 
796
- if not contents.rows:
797
- console.print("[yellow]No data to export![/yellow]")
798
- return
799
-
800
- # Filter out empty rows if not including empty
801
- if not include_empty and format in ["txt", "json"]:
802
- original_count = len(contents.rows)
803
- contents.rows = [
804
- row
805
- for row in contents.rows
806
- if row.get(export_column)
807
- and (not isinstance(row[export_column], list) or len(row[export_column]) > 0)
808
- ]
809
- filtered_count = original_count - len(contents.rows)
810
- if filtered_count > 0:
811
- console.print(f"[dim]Filtered {filtered_count} empty rows[/dim]")
812
-
813
- # Create exporter
814
- exporter = StorageExporter(contents)
815
-
816
- # Determine output paths
817
- if format == "all":
818
- # Export to all formats
819
- base_name = output or "caption_export"
820
- base_path = Path(base_name)
821
-
822
- formats_exported = []
823
-
824
- # JSONL
825
- jsonl_path = base_path.with_suffix(".jsonl")
826
- console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {jsonl_path}")
827
- rows = exporter.to_jsonl(jsonl_path)
828
- formats_exported.append(f"JSONL: {rows:,} rows")
829
-
830
- # CSV
831
- csv_path = base_path.with_suffix(".csv")
832
- console.print(f"[cyan]Exporting to CSV:[/cyan] {csv_path}")
833
- try:
834
- rows = exporter.to_csv(csv_path)
835
- formats_exported.append(f"CSV: {rows:,} rows")
836
- except ExportError as e:
837
- console.print(f"[yellow]Skipping CSV: {e}[/yellow]")
838
-
839
- # JSON files
840
- json_dir = base_path.parent / f"{base_path.stem}_json"
841
- console.print(f"[cyan]Exporting to JSON files:[/cyan] {json_dir}/")
842
- try:
843
- files = exporter.to_json(json_dir, filename_column)
844
- formats_exported.append(f"JSON: {files:,} files")
845
- except ExportError as e:
846
- console.print(f"[yellow]Skipping JSON files: {e}[/yellow]")
847
-
848
- # Text files
849
- txt_dir = base_path.parent / f"{base_path.stem}_txt"
850
- console.print(f"[cyan]Exporting to text files:[/cyan] {txt_dir}/")
851
- try:
852
- files = exporter.to_txt(txt_dir, filename_column, export_column)
853
- formats_exported.append(f"Text: {files:,} files")
854
- except ExportError as e:
855
- console.print(f"[yellow]Skipping text files: {e}[/yellow]")
856
-
857
- console.print(f"\n[green]✓ Export complete![/green]")
858
- for fmt in formats_exported:
859
- console.print(f" • {fmt}")
860
-
861
- else:
862
- # Single format export
863
- try:
864
- if format == "jsonl":
865
- output_path = output or "captions.jsonl"
866
- console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {output_path}")
867
- rows = exporter.to_jsonl(output_path)
868
- console.print(f"[green]✓ Exported {rows:,} rows[/green]")
869
-
870
- elif format == "csv":
871
- output_path = output or "captions.csv"
872
- console.print(f"\n[cyan]Exporting to CSV:[/cyan] {output_path}")
873
- rows = exporter.to_csv(output_path)
874
- console.print(f"[green]✓ Exported {rows:,} rows[/green]")
875
-
876
- elif format == "json":
877
- output_dir = output or "./json_output"
878
- console.print(f"\n[cyan]Exporting to JSON files:[/cyan] {output_dir}/")
879
- files = exporter.to_json(output_dir, filename_column)
880
- console.print(f"[green]✓ Created {files:,} JSON files[/green]")
881
-
882
- elif format == "txt":
883
- output_dir = output or "./txt_output"
884
- console.print(f"\n[cyan]Exporting to text files:[/cyan] {output_dir}/")
885
- console.print(f"[dim]Export column: {export_column}[/dim]")
886
- files = exporter.to_txt(output_dir, filename_column, export_column)
887
- console.print(f"[green]✓ Created {files:,} text files[/green]")
888
-
889
- elif format == "huggingface_hub":
890
- # Validate required parameters
891
- if not hf_dataset:
892
- console.print(
893
- "[red]Error: --hf-dataset required for huggingface_hub format[/red]"
894
- )
895
- console.print(
896
- "[dim]Example: --hf-dataset username/my-caption-dataset[/dim]"
897
- )
898
- sys.exit(1)
899
-
900
- # Parse tags
901
- tag_list = None
902
- if tags:
903
- tag_list = [tag.strip() for tag in tags.split(",")]
904
-
905
- console.print(f"\n[cyan]Uploading to Hugging Face Hub:[/cyan] {hf_dataset}")
906
- if private:
907
- console.print("[dim]Privacy: Private dataset[/dim]")
908
- if nsfw:
909
- console.print("[dim]Content: Not for all audiences[/dim]")
910
- if tag_list:
911
- console.print(f"[dim]Tags: {', '.join(tag_list)}[/dim]")
912
-
913
- url = exporter.to_huggingface_hub(
914
- dataset_name=hf_dataset,
915
- license=license,
916
- private=private,
917
- nsfw=nsfw,
918
- tags=tag_list,
919
- )
920
- console.print(f"[green]✓ Dataset uploaded to: {url}[/green]")
921
-
922
- except ExportError as e:
923
- console.print(f"[red]Export error: {e}[/red]")
924
- sys.exit(1)
925
-
926
- # Show export metadata
927
- if verbose and contents.metadata:
928
- console.print("\n[dim]Export metadata:[/dim]")
929
- console.print(f" Timestamp: {contents.metadata.get('export_timestamp')}")
930
- console.print(f" Total available: {contents.metadata.get('total_available_rows'):,}")
931
- console.print(f" Rows exported: {contents.metadata.get('rows_exported'):,}")
932
-
933
- # Run the async export
934
1431
  try:
935
- asyncio.run(run_export())
1432
+ asyncio.run(
1433
+ _run_export_process(
1434
+ storage,
1435
+ format,
1436
+ output,
1437
+ shard,
1438
+ shards,
1439
+ columns,
1440
+ limit,
1441
+ filename_column,
1442
+ export_column,
1443
+ verbose,
1444
+ hf_dataset,
1445
+ license,
1446
+ private,
1447
+ nsfw,
1448
+ tags,
1449
+ stats_only,
1450
+ optimize,
1451
+ )
1452
+ )
1453
+ except ExportError as e:
1454
+ console.print(f"[red]Export error: {e}[/red]")
1455
+ sys.exit(1)
936
1456
  except KeyboardInterrupt:
937
1457
  console.print("\n[yellow]Export cancelled[/yellow]")
938
1458
  sys.exit(1)
939
1459
  except Exception as e:
940
1460
  console.print(f"[red]Unexpected error: {e}[/red]")
941
- if verbose:
942
- import traceback
1461
+ import traceback
943
1462
 
944
- traceback.print_exc()
1463
+ traceback.print_exc()
945
1464
  sys.exit(1)
946
1465
 
947
1466
 
@@ -963,7 +1482,7 @@ def generate_cert(
963
1482
  cert_path, key_path = cert_manager.generate_self_signed(Path(output_dir), cert_domain)
964
1483
  console.print(f"[green]✓[/green] Certificate: {cert_path}")
965
1484
  console.print(f"[green]✓[/green] Key: {key_path}")
966
- console.print(f"\n[cyan]Use these paths in your config or CLI:[/cyan]")
1485
+ console.print("\n[cyan]Use these paths in your config or CLI:[/cyan]")
967
1486
  console.print(f" --cert {cert_path}")
968
1487
  console.print(f" --key {key_path}")
969
1488
  elif domain and email:
@@ -980,7 +1499,7 @@ def generate_cert(
980
1499
  )
981
1500
  console.print(f"[green]✓[/green] Certificate: {cert_path}")
982
1501
  console.print(f"[green]✓[/green] Key: {key_path}")
983
- console.print(f"\n[cyan]Use these paths in your config or CLI:[/cyan]")
1502
+ console.print("\n[cyan]Use these paths in your config or CLI:[/cyan]")
984
1503
  console.print(f" --cert {cert_path}")
985
1504
  console.print(f" --key {key_path}")
986
1505
 
@@ -1024,13 +1543,13 @@ def inspect_cert(cert_path: str):
1024
1543
 
1025
1544
  from datetime import datetime
1026
1545
 
1027
- if info["not_after"] < datetime.utcnow():
1546
+ if info["not_after"] < datetime.now(_datetime.UTC):
1028
1547
  console.print("[red]✗ Certificate has expired![/red]")
1029
- elif (info["not_after"] - datetime.utcnow()).days < 30:
1030
- days_left = (info["not_after"] - datetime.utcnow()).days
1548
+ elif (info["not_after"] - datetime.now(_datetime.UTC)).days < 30:
1549
+ days_left = (info["not_after"] - datetime.now(_datetime.UTC)).days
1031
1550
  console.print(f"[yellow]⚠ Certificate expires in {days_left} days[/yellow]")
1032
1551
  else:
1033
- days_left = (info["not_after"] - datetime.utcnow()).days
1552
+ days_left = (info["not_after"] - datetime.now(_datetime.UTC)).days
1034
1553
  console.print(f"[green]✓ Certificate valid for {days_left} more days[/green]")
1035
1554
 
1036
1555
  except Exception as e: