caption-flow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,201 @@
1
+ """JSON serialization utilities for handling special types like datetime."""
2
+
3
+ import json
4
+ from datetime import datetime, date
5
+ from decimal import Decimal
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Union
8
+ from dataclasses import asdict, is_dataclass
9
+ from enum import Enum
10
+
11
+
12
+ def safe_json_dumps(obj: Any, **kwargs) -> str:
13
+ """
14
+ Safely serialize objects to JSON, handling special types.
15
+
16
+ Args:
17
+ obj: Object to serialize
18
+ **kwargs: Additional arguments to pass to json.dumps
19
+
20
+ Returns:
21
+ JSON string representation
22
+ """
23
+ return json.dumps(obj, default=json_serializer, **kwargs)
24
+
25
+
26
+ def safe_dict(obj: Any) -> Dict[str, Any]:
27
+ """
28
+ Convert an object to a dictionary, handling special types.
29
+
30
+ Args:
31
+ obj: Object to convert (dataclass, dict, etc.)
32
+
33
+ Returns:
34
+ Dictionary with JSON-serializable values
35
+ """
36
+ if is_dataclass(obj):
37
+ data = asdict(obj)
38
+ elif hasattr(obj, "__dict__"):
39
+ data = obj.__dict__.copy()
40
+ elif isinstance(obj, dict):
41
+ data = obj.copy()
42
+ else:
43
+ return obj
44
+
45
+ return sanitize_dict(data)
46
+
47
+
48
+ def sanitize_dict(data: Dict[str, Any]) -> Dict[str, Any]:
49
+ """
50
+ Recursively sanitize a dictionary to ensure all values are JSON-serializable.
51
+
52
+ Args:
53
+ data: Dictionary to sanitize
54
+
55
+ Returns:
56
+ Sanitized dictionary
57
+ """
58
+ result = {}
59
+
60
+ for key, value in data.items():
61
+ if value is None:
62
+ result[key] = None
63
+ elif isinstance(value, (datetime, date)):
64
+ result[key] = value.isoformat()
65
+ elif isinstance(value, Decimal):
66
+ result[key] = float(value)
67
+ elif isinstance(value, Path):
68
+ result[key] = str(value)
69
+ elif isinstance(value, Enum):
70
+ result[key] = value.value
71
+ elif isinstance(value, (list, tuple)):
72
+ result[key] = [sanitize_value(item) for item in value]
73
+ elif isinstance(value, dict):
74
+ result[key] = sanitize_dict(value)
75
+ elif is_dataclass(value):
76
+ result[key] = sanitize_dict(asdict(value))
77
+ elif hasattr(value, "__dict__"):
78
+ result[key] = sanitize_dict(value.__dict__)
79
+ else:
80
+ result[key] = value
81
+
82
+ return result
83
+
84
+
85
+ def sanitize_value(value: Any) -> Any:
86
+ """
87
+ Sanitize a single value for JSON serialization.
88
+
89
+ Args:
90
+ value: Value to sanitize
91
+
92
+ Returns:
93
+ JSON-serializable value
94
+ """
95
+ if value is None:
96
+ return None
97
+ elif isinstance(value, (datetime, date)):
98
+ return value.isoformat()
99
+ elif isinstance(value, Decimal):
100
+ return float(value)
101
+ elif isinstance(value, Path):
102
+ return str(value)
103
+ elif isinstance(value, Enum):
104
+ return value.value
105
+ elif isinstance(value, dict):
106
+ return sanitize_dict(value)
107
+ elif isinstance(value, (list, tuple)):
108
+ return [sanitize_value(item) for item in value]
109
+ elif is_dataclass(value):
110
+ return sanitize_dict(asdict(value))
111
+ elif hasattr(value, "__dict__"):
112
+ return sanitize_dict(value.__dict__)
113
+ else:
114
+ return value
115
+
116
+
117
+ def json_serializer(obj: Any) -> Any:
118
+ """
119
+ Default JSON serializer for special types.
120
+
121
+ Args:
122
+ obj: Object to serialize
123
+
124
+ Returns:
125
+ JSON-serializable representation
126
+
127
+ Raises:
128
+ TypeError: If object type is not supported
129
+ """
130
+ if isinstance(obj, (datetime, date)):
131
+ return obj.isoformat()
132
+ elif isinstance(obj, Decimal):
133
+ return float(obj)
134
+ elif isinstance(obj, Path):
135
+ return str(obj)
136
+ elif isinstance(obj, Enum):
137
+ return obj.value
138
+ elif type(obj).__name__ == "int64":
139
+ return int(obj)
140
+ elif is_dataclass(obj):
141
+ return sanitize_dict(asdict(obj))
142
+ elif hasattr(obj, "__dict__"):
143
+ return sanitize_dict(obj.__dict__)
144
+ else:
145
+ raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
146
+
147
+
148
+ def parse_datetime(dt_string: Union[str, datetime, None]) -> Union[datetime, None]:
149
+ """
150
+ Parse a datetime string or return existing datetime.
151
+
152
+ Args:
153
+ dt_string: ISO format datetime string, datetime object, or None
154
+
155
+ Returns:
156
+ datetime object or None
157
+ """
158
+ if dt_string is None:
159
+ return None
160
+ elif isinstance(dt_string, datetime):
161
+ return dt_string
162
+ elif isinstance(dt_string, str):
163
+ try:
164
+ return datetime.fromisoformat(dt_string.replace("Z", "+00:00"))
165
+ except ValueError:
166
+ # Try parsing without timezone
167
+ return datetime.fromisoformat(dt_string)
168
+ else:
169
+ raise ValueError(f"Cannot parse datetime from {type(dt_string).__name__}")
170
+
171
+
172
+ # Convenience functions for common use cases
173
+ def to_json_dict(obj: Any) -> Dict[str, Any]:
174
+ """
175
+ Convert any object to a JSON-serializable dictionary.
176
+
177
+ This is a convenience wrapper around safe_dict.
178
+
179
+ Args:
180
+ obj: Object to convert
181
+
182
+ Returns:
183
+ JSON-serializable dictionary
184
+ """
185
+ return safe_dict(obj)
186
+
187
+
188
+ def to_json_string(obj: Any, indent: int = None) -> str:
189
+ """
190
+ Convert any object to a JSON string.
191
+
192
+ This is a convenience wrapper around safe_json_dumps.
193
+
194
+ Args:
195
+ obj: Object to convert
196
+ indent: Number of spaces for indentation (None for compact)
197
+
198
+ Returns:
199
+ JSON string
200
+ """
201
+ return safe_json_dumps(obj, indent=indent)
@@ -0,0 +1,164 @@
1
+ """vLLM configuration management utilities."""
2
+
3
+ import logging
4
+ from typing import Dict, Any, Optional, Tuple, List
5
+ from dataclasses import dataclass, field
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ @dataclass
11
+ class VLLMConfigChange:
12
+ """Represents changes between vLLM configurations."""
13
+
14
+ requires_reload: bool = False
15
+ model_changed: bool = False
16
+ sampling_changed: bool = False
17
+ prompts_changed: bool = False
18
+ changed_fields: List[str] = field(default_factory=list)
19
+
20
+
21
+ class VLLMConfigManager:
22
+ """Manages vLLM configuration changes and reloading."""
23
+
24
+ # Fields that require full vLLM reload
25
+ RELOAD_REQUIRED_FIELDS = {
26
+ "model",
27
+ "tensor_parallel_size",
28
+ "max_model_len",
29
+ "dtype",
30
+ "gpu_memory_utilization",
31
+ "enforce_eager",
32
+ "limit_mm_per_prompt",
33
+ "disable_mm_preprocessor_cache",
34
+ }
35
+
36
+ # Fields that can be updated without reload
37
+ RUNTIME_UPDATEABLE_FIELDS = {
38
+ "batch_size",
39
+ "sampling",
40
+ "inference_prompts",
41
+ }
42
+
43
+ def __init__(self):
44
+ self.current_config: Optional[Dict[str, Any]] = None
45
+ self.current_sampling_params = None
46
+
47
+ def analyze_config_change(
48
+ self, old_config: Optional[Dict[str, Any]], new_config: Dict[str, Any]
49
+ ) -> VLLMConfigChange:
50
+ """Analyze differences between configs to determine required actions."""
51
+ change = VLLMConfigChange()
52
+
53
+ if not old_config:
54
+ # First time setup
55
+ change.requires_reload = True
56
+ change.model_changed = True
57
+ logger.info("Initial vLLM configuration - full load required")
58
+ return change
59
+
60
+ # Check each field for changes
61
+ all_keys = set(old_config.keys()) | set(new_config.keys())
62
+
63
+ for key in all_keys:
64
+ old_value = old_config.get(key)
65
+ new_value = new_config.get(key)
66
+
67
+ if old_value != new_value:
68
+ change.changed_fields.append(key)
69
+
70
+ if key in self.RELOAD_REQUIRED_FIELDS:
71
+ change.requires_reload = True
72
+ if key == "model":
73
+ change.model_changed = True
74
+ logger.info(f"Model changed from {old_value} to {new_value}")
75
+ elif key == "sampling":
76
+ change.sampling_changed = True
77
+ elif key == "inference_prompts":
78
+ change.prompts_changed = True
79
+
80
+ if change.changed_fields:
81
+ logger.info(f"vLLM config changes detected: {change.changed_fields}")
82
+ if change.requires_reload:
83
+ logger.info("Changes require vLLM reload")
84
+ else:
85
+ logger.info("Changes can be applied without reload")
86
+ else:
87
+ logger.debug("No vLLM config changes detected")
88
+
89
+ return change
90
+
91
+ def create_sampling_params(self, vllm_config: Dict[str, Any]):
92
+ """Create SamplingParams from config."""
93
+ from vllm import SamplingParams
94
+
95
+ sampling_config = vllm_config.get("sampling", {})
96
+
97
+ params = SamplingParams(
98
+ temperature=sampling_config.get("temperature", 0.7),
99
+ top_p=sampling_config.get("top_p", 0.95),
100
+ max_tokens=sampling_config.get("max_tokens", 256),
101
+ stop=sampling_config.get("stop", ["<|end|>", "<|endoftext|>", "<|im_end|>"]),
102
+ repetition_penalty=sampling_config.get("repetition_penalty", 1.05),
103
+ skip_special_tokens=sampling_config.get("skip_special_tokens", True),
104
+ )
105
+
106
+ self.current_sampling_params = params
107
+ return params
108
+
109
+ def should_reload_vllm(
110
+ self, old_config: Optional[Dict[str, Any]], new_config: Dict[str, Any]
111
+ ) -> bool:
112
+ """Quick check if vLLM needs to be reloaded."""
113
+ change = self.analyze_config_change(old_config, new_config)
114
+ return change.requires_reload
115
+
116
+ def get_vllm_init_params(self, vllm_config: Dict[str, Any]) -> Dict[str, Any]:
117
+ """Extract vLLM initialization parameters from config."""
118
+ return {
119
+ "model": vllm_config["model"],
120
+ "trust_remote_code": True,
121
+ "tensor_parallel_size": vllm_config.get("tensor_parallel_size", 1),
122
+ "max_model_len": vllm_config.get("max_model_len", 16384),
123
+ "enforce_eager": vllm_config.get("enforce_eager", True),
124
+ "gpu_memory_utilization": vllm_config.get("gpu_memory_utilization", 0.92),
125
+ "dtype": vllm_config.get("dtype", "float16"),
126
+ "limit_mm_per_prompt": vllm_config.get("limit_mm_per_prompt", {"image": 1}),
127
+ "disable_mm_preprocessor_cache": vllm_config.get("disable_mm_preprocessor_cache", True),
128
+ }
129
+
130
+ def requires_tokenizer_reload(
131
+ self, old_config: Optional[Dict[str, Any]], new_config: Dict[str, Any]
132
+ ) -> bool:
133
+ """Check if tokenizer/processor need to be reloaded."""
134
+ if not old_config:
135
+ return True
136
+
137
+ # Tokenizer/processor depend on the model
138
+ return old_config.get("model") != new_config.get("model")
139
+
140
+ def update_runtime_config(
141
+ self, vllm_instance, old_config: Dict[str, Any], new_config: Dict[str, Any]
142
+ ) -> Tuple[bool, Optional[Any]]:
143
+ """
144
+ Update vLLM configuration at runtime without reload.
145
+
146
+ Returns:
147
+ Tuple of (success, new_sampling_params)
148
+ """
149
+ change = self.analyze_config_change(old_config, new_config)
150
+
151
+ if change.requires_reload:
152
+ logger.warning("Config changes require reload, cannot update at runtime")
153
+ return False, None
154
+
155
+ # Update sampling params if changed
156
+ new_sampling_params = None
157
+ if change.sampling_changed:
158
+ new_sampling_params = self.create_sampling_params(new_config)
159
+ logger.info("Updated sampling parameters")
160
+
161
+ # Note: batch_size and prompts are handled by the worker directly
162
+ # as they don't affect the vLLM instance itself
163
+
164
+ return True, new_sampling_params
caption_flow/worker.py ADDED
@@ -0,0 +1,300 @@
1
+ """Worker node for distributed captioning."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import ssl
7
+ from typing import Dict, Any, Optional
8
+ from pathlib import Path
9
+
10
+ import websockets
11
+ import websockets.exceptions
12
+ from websockets.client import WebSocketClientProtocol
13
+
14
+ from .models import Job, JobStatus
15
+ from .utils.image_processor import ImageProcessor
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class Worker:
21
+ """Worker node that processes captioning jobs."""
22
+
23
+ def __init__(self, config: Dict[str, Any]):
24
+ self.config = config
25
+ self.server_url = config["server"]
26
+ self.token = config["token"]
27
+ self.name = config.get("name", "worker")
28
+ self.batch_size = config.get("batch_size", 32)
29
+
30
+ # Dataset configuration will be received from orchestrator
31
+ self.dataset_config = None
32
+ self.dataset_type = None
33
+ self.dataset_path = None
34
+
35
+ # SSL configuration
36
+ self.ssl_context = self._setup_ssl()
37
+
38
+ # Components
39
+ self.image_processor = ImageProcessor()
40
+
41
+ # State
42
+ self.worker_id: Optional[str] = None
43
+ self.websocket: Optional[WebSocketClientProtocol] = None
44
+ self.running = False
45
+ self.current_job: Optional[Job] = None
46
+
47
+ # Metrics
48
+ self.processed_count = 0
49
+ self.error_count = 0
50
+
51
+ def _setup_ssl(self) -> Optional[ssl.SSLContext]:
52
+ """Configure SSL context."""
53
+ # Check if URL is WSS (requires SSL)
54
+ if self.server_url.startswith("ws://"):
55
+ logger.warning(
56
+ "Using insecure WebSocket connection (ws://). Consider using wss:// for production."
57
+ )
58
+ return None # No SSL for ws://
59
+
60
+ if not self.config.get("verify_ssl", True):
61
+ # Disable SSL verification for development
62
+ context = ssl.create_default_context()
63
+ context.check_hostname = False
64
+ context.verify_mode = ssl.CERT_NONE
65
+ return context
66
+
67
+ return ssl.create_default_context()
68
+
69
+ async def start(self):
70
+ """Start the worker and connect to orchestrator."""
71
+ self.running = True
72
+
73
+ while self.running:
74
+ try:
75
+ await self._connect_and_run()
76
+ except websockets.exceptions.ConnectionClosed as e:
77
+ logger.warning(f"Connection closed: {e}")
78
+ if self.running:
79
+ logger.info("Reconnecting in 5 seconds...")
80
+ await asyncio.sleep(5)
81
+ except Exception as e:
82
+ logger.error(f"Connection error: {e}")
83
+ if self.running:
84
+ logger.info("Reconnecting in 5 seconds...")
85
+ await asyncio.sleep(5)
86
+
87
+ async def _connect_and_run(self):
88
+ """Connect to orchestrator and process jobs."""
89
+ logger.info(f"Connecting to {self.server_url}")
90
+
91
+ try:
92
+ async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
93
+ self.websocket = websocket
94
+
95
+ # Authenticate
96
+ await websocket.send(json.dumps({"token": self.token, "name": self.name}))
97
+
98
+ # Wait for welcome message with dataset configuration
99
+ welcome = await websocket.recv()
100
+ welcome_data = json.loads(welcome)
101
+
102
+ if "error" in welcome_data:
103
+ logger.error(f"Authentication failed: {welcome_data['error']}")
104
+ self.running = False
105
+ return
106
+
107
+ self.worker_id = welcome_data.get("worker_id")
108
+
109
+ # Extract and store dataset configuration from orchestrator
110
+ if "dataset_config" in welcome_data:
111
+ self.dataset_config = welcome_data["dataset_config"]
112
+ self.dataset_type = self.dataset_config.get("dataset_type")
113
+ self.dataset_path = self.dataset_config.get("dataset_path")
114
+ logger.info(
115
+ f"Received dataset configuration from orchestrator: "
116
+ f"type={self.dataset_type}, path={self.dataset_path}"
117
+ )
118
+ else:
119
+ logger.warning("No dataset configuration received from orchestrator")
120
+
121
+ logger.info(f"Connected as {self.worker_id}")
122
+
123
+ # Create tasks for concurrent operations
124
+ tasks = [
125
+ asyncio.create_task(self._heartbeat_loop()),
126
+ asyncio.create_task(self._job_processing_loop()),
127
+ asyncio.create_task(self._message_handler()),
128
+ ]
129
+
130
+ try:
131
+ # Wait for any task to complete (usually due to connection close)
132
+ done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
133
+
134
+ # Cancel remaining tasks
135
+ for task in pending:
136
+ task.cancel()
137
+ try:
138
+ await task
139
+ except asyncio.CancelledError:
140
+ pass
141
+
142
+ # Check if we had an error in completed tasks
143
+ for task in done:
144
+ try:
145
+ task.result()
146
+ except websockets.exceptions.ConnectionClosed:
147
+ logger.info("WebSocket connection closed")
148
+ except Exception as e:
149
+ logger.error(f"Task error: {e}")
150
+
151
+ except websockets.exceptions.ConnectionClosed:
152
+ logger.info("Connection closed by orchestrator")
153
+
154
+ except websockets.exceptions.ConnectionClosed as e:
155
+ logger.info(f"Failed to connect: {e}")
156
+ raise
157
+ except Exception as e:
158
+ logger.error(f"Unexpected error in connection: {e}")
159
+ raise
160
+ finally:
161
+ self.websocket = None
162
+ self.current_job = None
163
+
164
+ async def _job_processing_loop(self):
165
+ """Main loop for requesting and processing jobs."""
166
+ while self.running and self.websocket:
167
+ try:
168
+ # Request a job
169
+ await self.websocket.send(json.dumps({"type": "request_job"}))
170
+
171
+ # Wait a bit for response
172
+ await asyncio.sleep(1)
173
+
174
+ if self.current_job:
175
+ await self._process_job(self.current_job)
176
+ self.current_job = None
177
+ else:
178
+ # No job available, wait before requesting again
179
+ await asyncio.sleep(5)
180
+
181
+ except websockets.exceptions.ConnectionClosed:
182
+ logger.info("Connection closed during job processing")
183
+ break
184
+ except Exception as e:
185
+ logger.error(f"Job processing error: {e}")
186
+ self.error_count += 1
187
+ await asyncio.sleep(1)
188
+
189
+ async def _message_handler(self):
190
+ """Handle incoming messages from orchestrator."""
191
+ try:
192
+ async for message in self.websocket:
193
+ try:
194
+ data = json.loads(message)
195
+ msg_type = data.get("type")
196
+
197
+ if msg_type == "job":
198
+ job_data = data["job"]
199
+ self.current_job = Job(**job_data)
200
+ logger.info(f"Received job {self.current_job.job_id}")
201
+
202
+ elif msg_type == "no_jobs":
203
+ logger.debug("No jobs available")
204
+
205
+ elif msg_type == "ack":
206
+ logger.debug(f"Job {data['job_id']} acknowledged")
207
+ self.processed_count += 1
208
+
209
+ except json.JSONDecodeError as e:
210
+ logger.error(f"Invalid message format: {e}")
211
+ except Exception as e:
212
+ logger.error(f"Error handling message: {e}")
213
+
214
+ except websockets.exceptions.ConnectionClosed:
215
+ logger.info("Connection closed while waiting for messages")
216
+ except Exception as e:
217
+ logger.error(f"Message handler error: {e}")
218
+
219
+ async def _process_job(self, job: Job):
220
+ """Process a single captioning job."""
221
+ if not self.websocket:
222
+ logger.warning(f"No websocket connection, skipping job {job.job_id}")
223
+ return
224
+
225
+ logger.info(f"Processing job {job.job_id}")
226
+
227
+ try:
228
+ # Load and preprocess images
229
+ images = await self._load_images(job)
230
+
231
+ # TODO: Here you would integrate your captioning model
232
+ # For now, using placeholder
233
+ caption = f"[Generated caption for {job.item_key}]"
234
+
235
+ # Submit result
236
+ await self.websocket.send(
237
+ json.dumps(
238
+ {
239
+ "type": "submit_caption",
240
+ "job_id": job.job_id,
241
+ "dataset": job.dataset,
242
+ "shard": job.shard,
243
+ "item_key": job.item_key,
244
+ "caption": caption,
245
+ }
246
+ )
247
+ )
248
+
249
+ logger.info(f"Completed job {job.job_id}")
250
+
251
+ except websockets.exceptions.ConnectionClosed:
252
+ logger.warning(f"Connection lost while processing job {job.job_id}")
253
+ raise # Re-raise to trigger reconnection
254
+ except Exception as e:
255
+ logger.error(f"Failed to process job {job.job_id}: {e}")
256
+
257
+ # Report failure if still connected
258
+ if self.websocket:
259
+ try:
260
+ await self.websocket.send(
261
+ json.dumps({"type": "job_failed", "job_id": job.job_id, "error": str(e)})
262
+ )
263
+ except:
264
+ pass # Connection might be closed
265
+
266
+ async def _load_images(self, job: Job):
267
+ """Load and preprocess images for a job."""
268
+ # This would load actual images from the dataset
269
+ # Now can use self.dataset_type and self.dataset_path received from orchestrator
270
+ # For now, returning placeholder
271
+ return []
272
+
273
+ async def _heartbeat_loop(self):
274
+ """Send periodic heartbeats to orchestrator."""
275
+ while self.running and self.websocket:
276
+ try:
277
+ await self.websocket.send(
278
+ json.dumps(
279
+ {
280
+ "type": "heartbeat",
281
+ "processed": self.processed_count,
282
+ "errors": self.error_count,
283
+ }
284
+ )
285
+ )
286
+ await asyncio.sleep(30)
287
+ except websockets.exceptions.ConnectionClosed:
288
+ logger.info("Connection closed during heartbeat")
289
+ break
290
+ except Exception as e:
291
+ logger.error(f"Heartbeat error: {e}")
292
+ break
293
+
294
+ async def shutdown(self):
295
+ """Graceful shutdown."""
296
+ logger.info("Shutting down worker...")
297
+ self.running = False
298
+
299
+ if self.websocket:
300
+ await self.websocket.close()