caption-flow 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,6 @@
1
1
  """Utility modules for CaptionFlow."""
2
2
 
3
- from .dataset_loader import DatasetLoader, ShardTracker
3
+ from .dataset_loader import DatasetLoader
4
+ from .shard_tracker import ShardTracker
5
+ from .chunk_tracker import ChunkTracker
4
6
  from .caption_utils import CaptionUtils
@@ -17,7 +17,30 @@ class AuthManager:
17
17
  """Manages authentication tokens."""
18
18
 
19
19
  def __init__(self, config: Dict[str, Any]):
20
- self.reload_config(config=config)
20
+ self.worker_tokens = {}
21
+ self.admin_tokens = {}
22
+ self.monitor_tokens = {}
23
+ if "orchestrator" in config:
24
+ # compatibility with nested config as well.
25
+ config = config.get("orchestrator").get("auth")
26
+
27
+ # Load worker tokens
28
+ for worker in config.get("worker_tokens", []):
29
+ worker_name = worker.get("name", None)
30
+ assert worker_name is not None, "Worker token must have a name"
31
+ self.worker_tokens[worker["token"]] = worker_name
32
+
33
+ # Load admin tokens
34
+ for admin in config.get("admin_tokens", []):
35
+ admin_name = admin.get("name", None)
36
+ assert admin_name is not None, "Admin token must have a name"
37
+ self.admin_tokens[admin["token"]] = admin_name
38
+
39
+ # Load monitor tokens
40
+ for monitor in config.get("monitor_tokens", []):
41
+ monitor_name = monitor.get("name", None)
42
+ assert monitor_name is not None, "Monitor token must have a name"
43
+ self.monitor_tokens[monitor["token"]] = monitor_name
21
44
 
22
45
  def authenticate(self, token: str) -> Optional[str]:
23
46
  """Authenticate token and return role."""
@@ -41,27 +64,3 @@ class AuthManager:
41
64
  role=role, name=self.worker_tokens.get(token, f"Anonymous {role}"), token=token
42
65
  )
43
66
  return worker_auth_details
44
-
45
- def reload_config(self, config: dict) -> None:
46
- """Reload configuration from file."""
47
- self.worker_tokens = {}
48
- self.admin_tokens = {}
49
- self.monitor_tokens = {}
50
-
51
- # Load worker tokens
52
- for worker in config.get("worker_tokens", []):
53
- worker_name = worker.get("name", None)
54
- assert worker_name is not None, "Worker token must have a name"
55
- self.worker_tokens[worker["token"]] = worker_name
56
-
57
- # Load admin tokens
58
- for admin in config.get("admin_tokens", []):
59
- admin_name = admin.get("name", None)
60
- assert admin_name is not None, "Admin token must have a name"
61
- self.admin_tokens[admin["token"]] = admin_name
62
-
63
- # Load monitor tokens
64
- for monitor in config.get("monitor_tokens", []):
65
- monitor_name = monitor.get("name", None)
66
- assert monitor_name is not None, "Monitor token must have a name"
67
- self.monitor_tokens[monitor["token"]] = monitor_name
@@ -0,0 +1,92 @@
1
+ """Base class for checkpoint tracking with persistent state."""
2
+
3
+ import json
4
+ import logging
5
+ from abc import ABC, abstractmethod
6
+ from pathlib import Path
7
+ from typing import Dict, Any, Optional
8
+ from datetime import datetime
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class CheckpointTracker(ABC):
14
+ """Abstract base class for trackers that persist state to JSON checkpoints."""
15
+
16
+ def __init__(self, checkpoint_path: Path):
17
+ """Initialize tracker with checkpoint file path."""
18
+ self.checkpoint_path = checkpoint_path
19
+ self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
20
+ self.load()
21
+
22
+ @abstractmethod
23
+ def _get_default_state(self) -> Dict[str, Any]:
24
+ """Return default state structure for new checkpoints."""
25
+ pass
26
+
27
+ @abstractmethod
28
+ def _deserialize_state(self, data: Dict[str, Any]) -> None:
29
+ """Deserialize loaded data into instance state."""
30
+ pass
31
+
32
+ @abstractmethod
33
+ def _serialize_state(self) -> Dict[str, Any]:
34
+ """Serialize instance state for saving."""
35
+ pass
36
+
37
+ def load(self) -> None:
38
+ """Load checkpoint from disk."""
39
+ if self.checkpoint_path.exists():
40
+ try:
41
+ with open(self.checkpoint_path, "r") as f:
42
+ data = json.load(f)
43
+ self._deserialize_state(data)
44
+ logger.info(f"Loaded checkpoint from {self.checkpoint_path}")
45
+ except Exception as e:
46
+ logger.error(f"Failed to load checkpoint: {e}")
47
+ # Initialize with defaults on load failure
48
+ self._deserialize_state(self._get_default_state())
49
+ else:
50
+ # Initialize with defaults
51
+ self._deserialize_state(self._get_default_state())
52
+
53
+ def save(self) -> None:
54
+ """Save checkpoint to disk atomically."""
55
+ try:
56
+ # Prepare data with metadata
57
+ data = self._serialize_state()
58
+ data["updated_at"] = datetime.utcnow().isoformat()
59
+
60
+ # Write atomically using temp file
61
+ tmp_file = self.checkpoint_path.with_suffix(".tmp")
62
+
63
+ with open(tmp_file, "w") as f:
64
+ json.dump(data, f, indent=2)
65
+
66
+ # Ensure temp file was created
67
+ if not tmp_file.exists():
68
+ raise IOError(f"Failed to create temporary file: {tmp_file}")
69
+
70
+ # Move atomically
71
+ tmp_file.replace(self.checkpoint_path)
72
+
73
+ logger.debug(f"Saved checkpoint to {self.checkpoint_path}")
74
+
75
+ except Exception as e:
76
+ logger.error(f"Error saving checkpoint: {e}", exc_info=True)
77
+ # Try direct write as fallback
78
+ try:
79
+ with open(self.checkpoint_path, "w") as f:
80
+ json.dump(data, f, indent=2)
81
+ logger.info("Saved checkpoint using fallback direct write")
82
+ except Exception as fallback_error:
83
+ logger.error(f"Fallback save also failed: {fallback_error}")
84
+
85
+ def get_stats(self) -> Dict[str, Any]:
86
+ """Get statistics about tracked items. Override for custom stats."""
87
+ return {
88
+ "checkpoint_path": str(self.checkpoint_path),
89
+ "last_modified": (
90
+ self.checkpoint_path.stat().st_mtime if self.checkpoint_path.exists() else None
91
+ ),
92
+ }