caption-flow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +9 -0
- caption_flow/cli.py +709 -0
- caption_flow/models.py +82 -0
- caption_flow/monitor.py +211 -0
- caption_flow/orchestrator.py +1301 -0
- caption_flow/storage.py +694 -0
- caption_flow/utils/__init__.py +4 -0
- caption_flow/utils/auth.py +67 -0
- caption_flow/utils/caption_utils.py +172 -0
- caption_flow/utils/certificates.py +140 -0
- caption_flow/utils/chunk_tracker.py +365 -0
- caption_flow/utils/dataset_loader.py +186 -0
- caption_flow/utils/image_processor.py +51 -0
- caption_flow/utils/job_queue.py +41 -0
- caption_flow/utils/json_utils.py +201 -0
- caption_flow/utils/vllm_config.py +164 -0
- caption_flow/worker.py +300 -0
- caption_flow/worker_data.py +482 -0
- caption_flow/worker_vllm.py +1028 -0
- caption_flow-0.1.0.dist-info/METADATA +427 -0
- caption_flow-0.1.0.dist-info/RECORD +25 -0
- caption_flow-0.1.0.dist-info/WHEEL +5 -0
- caption_flow-0.1.0.dist-info/entry_points.txt +2 -0
- caption_flow-0.1.0.dist-info/licenses/LICENSE +661 -0
- caption_flow-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,201 @@
|
|
1
|
+
"""JSON serialization utilities for handling special types like datetime."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
from datetime import datetime, date
|
5
|
+
from decimal import Decimal
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Dict, List, Union
|
8
|
+
from dataclasses import asdict, is_dataclass
|
9
|
+
from enum import Enum
|
10
|
+
|
11
|
+
|
12
|
+
def safe_json_dumps(obj: Any, **kwargs) -> str:
|
13
|
+
"""
|
14
|
+
Safely serialize objects to JSON, handling special types.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
obj: Object to serialize
|
18
|
+
**kwargs: Additional arguments to pass to json.dumps
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
JSON string representation
|
22
|
+
"""
|
23
|
+
return json.dumps(obj, default=json_serializer, **kwargs)
|
24
|
+
|
25
|
+
|
26
|
+
def safe_dict(obj: Any) -> Dict[str, Any]:
|
27
|
+
"""
|
28
|
+
Convert an object to a dictionary, handling special types.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
obj: Object to convert (dataclass, dict, etc.)
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
Dictionary with JSON-serializable values
|
35
|
+
"""
|
36
|
+
if is_dataclass(obj):
|
37
|
+
data = asdict(obj)
|
38
|
+
elif hasattr(obj, "__dict__"):
|
39
|
+
data = obj.__dict__.copy()
|
40
|
+
elif isinstance(obj, dict):
|
41
|
+
data = obj.copy()
|
42
|
+
else:
|
43
|
+
return obj
|
44
|
+
|
45
|
+
return sanitize_dict(data)
|
46
|
+
|
47
|
+
|
48
|
+
def sanitize_dict(data: Dict[str, Any]) -> Dict[str, Any]:
|
49
|
+
"""
|
50
|
+
Recursively sanitize a dictionary to ensure all values are JSON-serializable.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
data: Dictionary to sanitize
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
Sanitized dictionary
|
57
|
+
"""
|
58
|
+
result = {}
|
59
|
+
|
60
|
+
for key, value in data.items():
|
61
|
+
if value is None:
|
62
|
+
result[key] = None
|
63
|
+
elif isinstance(value, (datetime, date)):
|
64
|
+
result[key] = value.isoformat()
|
65
|
+
elif isinstance(value, Decimal):
|
66
|
+
result[key] = float(value)
|
67
|
+
elif isinstance(value, Path):
|
68
|
+
result[key] = str(value)
|
69
|
+
elif isinstance(value, Enum):
|
70
|
+
result[key] = value.value
|
71
|
+
elif isinstance(value, (list, tuple)):
|
72
|
+
result[key] = [sanitize_value(item) for item in value]
|
73
|
+
elif isinstance(value, dict):
|
74
|
+
result[key] = sanitize_dict(value)
|
75
|
+
elif is_dataclass(value):
|
76
|
+
result[key] = sanitize_dict(asdict(value))
|
77
|
+
elif hasattr(value, "__dict__"):
|
78
|
+
result[key] = sanitize_dict(value.__dict__)
|
79
|
+
else:
|
80
|
+
result[key] = value
|
81
|
+
|
82
|
+
return result
|
83
|
+
|
84
|
+
|
85
|
+
def sanitize_value(value: Any) -> Any:
|
86
|
+
"""
|
87
|
+
Sanitize a single value for JSON serialization.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
value: Value to sanitize
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
JSON-serializable value
|
94
|
+
"""
|
95
|
+
if value is None:
|
96
|
+
return None
|
97
|
+
elif isinstance(value, (datetime, date)):
|
98
|
+
return value.isoformat()
|
99
|
+
elif isinstance(value, Decimal):
|
100
|
+
return float(value)
|
101
|
+
elif isinstance(value, Path):
|
102
|
+
return str(value)
|
103
|
+
elif isinstance(value, Enum):
|
104
|
+
return value.value
|
105
|
+
elif isinstance(value, dict):
|
106
|
+
return sanitize_dict(value)
|
107
|
+
elif isinstance(value, (list, tuple)):
|
108
|
+
return [sanitize_value(item) for item in value]
|
109
|
+
elif is_dataclass(value):
|
110
|
+
return sanitize_dict(asdict(value))
|
111
|
+
elif hasattr(value, "__dict__"):
|
112
|
+
return sanitize_dict(value.__dict__)
|
113
|
+
else:
|
114
|
+
return value
|
115
|
+
|
116
|
+
|
117
|
+
def json_serializer(obj: Any) -> Any:
|
118
|
+
"""
|
119
|
+
Default JSON serializer for special types.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
obj: Object to serialize
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
JSON-serializable representation
|
126
|
+
|
127
|
+
Raises:
|
128
|
+
TypeError: If object type is not supported
|
129
|
+
"""
|
130
|
+
if isinstance(obj, (datetime, date)):
|
131
|
+
return obj.isoformat()
|
132
|
+
elif isinstance(obj, Decimal):
|
133
|
+
return float(obj)
|
134
|
+
elif isinstance(obj, Path):
|
135
|
+
return str(obj)
|
136
|
+
elif isinstance(obj, Enum):
|
137
|
+
return obj.value
|
138
|
+
elif type(obj).__name__ == "int64":
|
139
|
+
return int(obj)
|
140
|
+
elif is_dataclass(obj):
|
141
|
+
return sanitize_dict(asdict(obj))
|
142
|
+
elif hasattr(obj, "__dict__"):
|
143
|
+
return sanitize_dict(obj.__dict__)
|
144
|
+
else:
|
145
|
+
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
|
146
|
+
|
147
|
+
|
148
|
+
def parse_datetime(dt_string: Union[str, datetime, None]) -> Union[datetime, None]:
|
149
|
+
"""
|
150
|
+
Parse a datetime string or return existing datetime.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
dt_string: ISO format datetime string, datetime object, or None
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
datetime object or None
|
157
|
+
"""
|
158
|
+
if dt_string is None:
|
159
|
+
return None
|
160
|
+
elif isinstance(dt_string, datetime):
|
161
|
+
return dt_string
|
162
|
+
elif isinstance(dt_string, str):
|
163
|
+
try:
|
164
|
+
return datetime.fromisoformat(dt_string.replace("Z", "+00:00"))
|
165
|
+
except ValueError:
|
166
|
+
# Try parsing without timezone
|
167
|
+
return datetime.fromisoformat(dt_string)
|
168
|
+
else:
|
169
|
+
raise ValueError(f"Cannot parse datetime from {type(dt_string).__name__}")
|
170
|
+
|
171
|
+
|
172
|
+
# Convenience functions for common use cases
|
173
|
+
def to_json_dict(obj: Any) -> Dict[str, Any]:
|
174
|
+
"""
|
175
|
+
Convert any object to a JSON-serializable dictionary.
|
176
|
+
|
177
|
+
This is a convenience wrapper around safe_dict.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
obj: Object to convert
|
181
|
+
|
182
|
+
Returns:
|
183
|
+
JSON-serializable dictionary
|
184
|
+
"""
|
185
|
+
return safe_dict(obj)
|
186
|
+
|
187
|
+
|
188
|
+
def to_json_string(obj: Any, indent: int = None) -> str:
|
189
|
+
"""
|
190
|
+
Convert any object to a JSON string.
|
191
|
+
|
192
|
+
This is a convenience wrapper around safe_json_dumps.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
obj: Object to convert
|
196
|
+
indent: Number of spaces for indentation (None for compact)
|
197
|
+
|
198
|
+
Returns:
|
199
|
+
JSON string
|
200
|
+
"""
|
201
|
+
return safe_json_dumps(obj, indent=indent)
|
@@ -0,0 +1,164 @@
|
|
1
|
+
"""vLLM configuration management utilities."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Dict, Any, Optional, Tuple, List
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class VLLMConfigChange:
|
12
|
+
"""Represents changes between vLLM configurations."""
|
13
|
+
|
14
|
+
requires_reload: bool = False
|
15
|
+
model_changed: bool = False
|
16
|
+
sampling_changed: bool = False
|
17
|
+
prompts_changed: bool = False
|
18
|
+
changed_fields: List[str] = field(default_factory=list)
|
19
|
+
|
20
|
+
|
21
|
+
class VLLMConfigManager:
|
22
|
+
"""Manages vLLM configuration changes and reloading."""
|
23
|
+
|
24
|
+
# Fields that require full vLLM reload
|
25
|
+
RELOAD_REQUIRED_FIELDS = {
|
26
|
+
"model",
|
27
|
+
"tensor_parallel_size",
|
28
|
+
"max_model_len",
|
29
|
+
"dtype",
|
30
|
+
"gpu_memory_utilization",
|
31
|
+
"enforce_eager",
|
32
|
+
"limit_mm_per_prompt",
|
33
|
+
"disable_mm_preprocessor_cache",
|
34
|
+
}
|
35
|
+
|
36
|
+
# Fields that can be updated without reload
|
37
|
+
RUNTIME_UPDATEABLE_FIELDS = {
|
38
|
+
"batch_size",
|
39
|
+
"sampling",
|
40
|
+
"inference_prompts",
|
41
|
+
}
|
42
|
+
|
43
|
+
def __init__(self):
|
44
|
+
self.current_config: Optional[Dict[str, Any]] = None
|
45
|
+
self.current_sampling_params = None
|
46
|
+
|
47
|
+
def analyze_config_change(
|
48
|
+
self, old_config: Optional[Dict[str, Any]], new_config: Dict[str, Any]
|
49
|
+
) -> VLLMConfigChange:
|
50
|
+
"""Analyze differences between configs to determine required actions."""
|
51
|
+
change = VLLMConfigChange()
|
52
|
+
|
53
|
+
if not old_config:
|
54
|
+
# First time setup
|
55
|
+
change.requires_reload = True
|
56
|
+
change.model_changed = True
|
57
|
+
logger.info("Initial vLLM configuration - full load required")
|
58
|
+
return change
|
59
|
+
|
60
|
+
# Check each field for changes
|
61
|
+
all_keys = set(old_config.keys()) | set(new_config.keys())
|
62
|
+
|
63
|
+
for key in all_keys:
|
64
|
+
old_value = old_config.get(key)
|
65
|
+
new_value = new_config.get(key)
|
66
|
+
|
67
|
+
if old_value != new_value:
|
68
|
+
change.changed_fields.append(key)
|
69
|
+
|
70
|
+
if key in self.RELOAD_REQUIRED_FIELDS:
|
71
|
+
change.requires_reload = True
|
72
|
+
if key == "model":
|
73
|
+
change.model_changed = True
|
74
|
+
logger.info(f"Model changed from {old_value} to {new_value}")
|
75
|
+
elif key == "sampling":
|
76
|
+
change.sampling_changed = True
|
77
|
+
elif key == "inference_prompts":
|
78
|
+
change.prompts_changed = True
|
79
|
+
|
80
|
+
if change.changed_fields:
|
81
|
+
logger.info(f"vLLM config changes detected: {change.changed_fields}")
|
82
|
+
if change.requires_reload:
|
83
|
+
logger.info("Changes require vLLM reload")
|
84
|
+
else:
|
85
|
+
logger.info("Changes can be applied without reload")
|
86
|
+
else:
|
87
|
+
logger.debug("No vLLM config changes detected")
|
88
|
+
|
89
|
+
return change
|
90
|
+
|
91
|
+
def create_sampling_params(self, vllm_config: Dict[str, Any]):
|
92
|
+
"""Create SamplingParams from config."""
|
93
|
+
from vllm import SamplingParams
|
94
|
+
|
95
|
+
sampling_config = vllm_config.get("sampling", {})
|
96
|
+
|
97
|
+
params = SamplingParams(
|
98
|
+
temperature=sampling_config.get("temperature", 0.7),
|
99
|
+
top_p=sampling_config.get("top_p", 0.95),
|
100
|
+
max_tokens=sampling_config.get("max_tokens", 256),
|
101
|
+
stop=sampling_config.get("stop", ["<|end|>", "<|endoftext|>", "<|im_end|>"]),
|
102
|
+
repetition_penalty=sampling_config.get("repetition_penalty", 1.05),
|
103
|
+
skip_special_tokens=sampling_config.get("skip_special_tokens", True),
|
104
|
+
)
|
105
|
+
|
106
|
+
self.current_sampling_params = params
|
107
|
+
return params
|
108
|
+
|
109
|
+
def should_reload_vllm(
|
110
|
+
self, old_config: Optional[Dict[str, Any]], new_config: Dict[str, Any]
|
111
|
+
) -> bool:
|
112
|
+
"""Quick check if vLLM needs to be reloaded."""
|
113
|
+
change = self.analyze_config_change(old_config, new_config)
|
114
|
+
return change.requires_reload
|
115
|
+
|
116
|
+
def get_vllm_init_params(self, vllm_config: Dict[str, Any]) -> Dict[str, Any]:
|
117
|
+
"""Extract vLLM initialization parameters from config."""
|
118
|
+
return {
|
119
|
+
"model": vllm_config["model"],
|
120
|
+
"trust_remote_code": True,
|
121
|
+
"tensor_parallel_size": vllm_config.get("tensor_parallel_size", 1),
|
122
|
+
"max_model_len": vllm_config.get("max_model_len", 16384),
|
123
|
+
"enforce_eager": vllm_config.get("enforce_eager", True),
|
124
|
+
"gpu_memory_utilization": vllm_config.get("gpu_memory_utilization", 0.92),
|
125
|
+
"dtype": vllm_config.get("dtype", "float16"),
|
126
|
+
"limit_mm_per_prompt": vllm_config.get("limit_mm_per_prompt", {"image": 1}),
|
127
|
+
"disable_mm_preprocessor_cache": vllm_config.get("disable_mm_preprocessor_cache", True),
|
128
|
+
}
|
129
|
+
|
130
|
+
def requires_tokenizer_reload(
|
131
|
+
self, old_config: Optional[Dict[str, Any]], new_config: Dict[str, Any]
|
132
|
+
) -> bool:
|
133
|
+
"""Check if tokenizer/processor need to be reloaded."""
|
134
|
+
if not old_config:
|
135
|
+
return True
|
136
|
+
|
137
|
+
# Tokenizer/processor depend on the model
|
138
|
+
return old_config.get("model") != new_config.get("model")
|
139
|
+
|
140
|
+
def update_runtime_config(
|
141
|
+
self, vllm_instance, old_config: Dict[str, Any], new_config: Dict[str, Any]
|
142
|
+
) -> Tuple[bool, Optional[Any]]:
|
143
|
+
"""
|
144
|
+
Update vLLM configuration at runtime without reload.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
Tuple of (success, new_sampling_params)
|
148
|
+
"""
|
149
|
+
change = self.analyze_config_change(old_config, new_config)
|
150
|
+
|
151
|
+
if change.requires_reload:
|
152
|
+
logger.warning("Config changes require reload, cannot update at runtime")
|
153
|
+
return False, None
|
154
|
+
|
155
|
+
# Update sampling params if changed
|
156
|
+
new_sampling_params = None
|
157
|
+
if change.sampling_changed:
|
158
|
+
new_sampling_params = self.create_sampling_params(new_config)
|
159
|
+
logger.info("Updated sampling parameters")
|
160
|
+
|
161
|
+
# Note: batch_size and prompts are handled by the worker directly
|
162
|
+
# as they don't affect the vLLM instance itself
|
163
|
+
|
164
|
+
return True, new_sampling_params
|
caption_flow/worker.py
ADDED
@@ -0,0 +1,300 @@
|
|
1
|
+
"""Worker node for distributed captioning."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import ssl
|
7
|
+
from typing import Dict, Any, Optional
|
8
|
+
from pathlib import Path
|
9
|
+
|
10
|
+
import websockets
|
11
|
+
import websockets.exceptions
|
12
|
+
from websockets.client import WebSocketClientProtocol
|
13
|
+
|
14
|
+
from .models import Job, JobStatus
|
15
|
+
from .utils.image_processor import ImageProcessor
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class Worker:
|
21
|
+
"""Worker node that processes captioning jobs."""
|
22
|
+
|
23
|
+
def __init__(self, config: Dict[str, Any]):
|
24
|
+
self.config = config
|
25
|
+
self.server_url = config["server"]
|
26
|
+
self.token = config["token"]
|
27
|
+
self.name = config.get("name", "worker")
|
28
|
+
self.batch_size = config.get("batch_size", 32)
|
29
|
+
|
30
|
+
# Dataset configuration will be received from orchestrator
|
31
|
+
self.dataset_config = None
|
32
|
+
self.dataset_type = None
|
33
|
+
self.dataset_path = None
|
34
|
+
|
35
|
+
# SSL configuration
|
36
|
+
self.ssl_context = self._setup_ssl()
|
37
|
+
|
38
|
+
# Components
|
39
|
+
self.image_processor = ImageProcessor()
|
40
|
+
|
41
|
+
# State
|
42
|
+
self.worker_id: Optional[str] = None
|
43
|
+
self.websocket: Optional[WebSocketClientProtocol] = None
|
44
|
+
self.running = False
|
45
|
+
self.current_job: Optional[Job] = None
|
46
|
+
|
47
|
+
# Metrics
|
48
|
+
self.processed_count = 0
|
49
|
+
self.error_count = 0
|
50
|
+
|
51
|
+
def _setup_ssl(self) -> Optional[ssl.SSLContext]:
|
52
|
+
"""Configure SSL context."""
|
53
|
+
# Check if URL is WSS (requires SSL)
|
54
|
+
if self.server_url.startswith("ws://"):
|
55
|
+
logger.warning(
|
56
|
+
"Using insecure WebSocket connection (ws://). Consider using wss:// for production."
|
57
|
+
)
|
58
|
+
return None # No SSL for ws://
|
59
|
+
|
60
|
+
if not self.config.get("verify_ssl", True):
|
61
|
+
# Disable SSL verification for development
|
62
|
+
context = ssl.create_default_context()
|
63
|
+
context.check_hostname = False
|
64
|
+
context.verify_mode = ssl.CERT_NONE
|
65
|
+
return context
|
66
|
+
|
67
|
+
return ssl.create_default_context()
|
68
|
+
|
69
|
+
async def start(self):
|
70
|
+
"""Start the worker and connect to orchestrator."""
|
71
|
+
self.running = True
|
72
|
+
|
73
|
+
while self.running:
|
74
|
+
try:
|
75
|
+
await self._connect_and_run()
|
76
|
+
except websockets.exceptions.ConnectionClosed as e:
|
77
|
+
logger.warning(f"Connection closed: {e}")
|
78
|
+
if self.running:
|
79
|
+
logger.info("Reconnecting in 5 seconds...")
|
80
|
+
await asyncio.sleep(5)
|
81
|
+
except Exception as e:
|
82
|
+
logger.error(f"Connection error: {e}")
|
83
|
+
if self.running:
|
84
|
+
logger.info("Reconnecting in 5 seconds...")
|
85
|
+
await asyncio.sleep(5)
|
86
|
+
|
87
|
+
async def _connect_and_run(self):
|
88
|
+
"""Connect to orchestrator and process jobs."""
|
89
|
+
logger.info(f"Connecting to {self.server_url}")
|
90
|
+
|
91
|
+
try:
|
92
|
+
async with websockets.connect(self.server_url, ssl=self.ssl_context) as websocket:
|
93
|
+
self.websocket = websocket
|
94
|
+
|
95
|
+
# Authenticate
|
96
|
+
await websocket.send(json.dumps({"token": self.token, "name": self.name}))
|
97
|
+
|
98
|
+
# Wait for welcome message with dataset configuration
|
99
|
+
welcome = await websocket.recv()
|
100
|
+
welcome_data = json.loads(welcome)
|
101
|
+
|
102
|
+
if "error" in welcome_data:
|
103
|
+
logger.error(f"Authentication failed: {welcome_data['error']}")
|
104
|
+
self.running = False
|
105
|
+
return
|
106
|
+
|
107
|
+
self.worker_id = welcome_data.get("worker_id")
|
108
|
+
|
109
|
+
# Extract and store dataset configuration from orchestrator
|
110
|
+
if "dataset_config" in welcome_data:
|
111
|
+
self.dataset_config = welcome_data["dataset_config"]
|
112
|
+
self.dataset_type = self.dataset_config.get("dataset_type")
|
113
|
+
self.dataset_path = self.dataset_config.get("dataset_path")
|
114
|
+
logger.info(
|
115
|
+
f"Received dataset configuration from orchestrator: "
|
116
|
+
f"type={self.dataset_type}, path={self.dataset_path}"
|
117
|
+
)
|
118
|
+
else:
|
119
|
+
logger.warning("No dataset configuration received from orchestrator")
|
120
|
+
|
121
|
+
logger.info(f"Connected as {self.worker_id}")
|
122
|
+
|
123
|
+
# Create tasks for concurrent operations
|
124
|
+
tasks = [
|
125
|
+
asyncio.create_task(self._heartbeat_loop()),
|
126
|
+
asyncio.create_task(self._job_processing_loop()),
|
127
|
+
asyncio.create_task(self._message_handler()),
|
128
|
+
]
|
129
|
+
|
130
|
+
try:
|
131
|
+
# Wait for any task to complete (usually due to connection close)
|
132
|
+
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
133
|
+
|
134
|
+
# Cancel remaining tasks
|
135
|
+
for task in pending:
|
136
|
+
task.cancel()
|
137
|
+
try:
|
138
|
+
await task
|
139
|
+
except asyncio.CancelledError:
|
140
|
+
pass
|
141
|
+
|
142
|
+
# Check if we had an error in completed tasks
|
143
|
+
for task in done:
|
144
|
+
try:
|
145
|
+
task.result()
|
146
|
+
except websockets.exceptions.ConnectionClosed:
|
147
|
+
logger.info("WebSocket connection closed")
|
148
|
+
except Exception as e:
|
149
|
+
logger.error(f"Task error: {e}")
|
150
|
+
|
151
|
+
except websockets.exceptions.ConnectionClosed:
|
152
|
+
logger.info("Connection closed by orchestrator")
|
153
|
+
|
154
|
+
except websockets.exceptions.ConnectionClosed as e:
|
155
|
+
logger.info(f"Failed to connect: {e}")
|
156
|
+
raise
|
157
|
+
except Exception as e:
|
158
|
+
logger.error(f"Unexpected error in connection: {e}")
|
159
|
+
raise
|
160
|
+
finally:
|
161
|
+
self.websocket = None
|
162
|
+
self.current_job = None
|
163
|
+
|
164
|
+
async def _job_processing_loop(self):
|
165
|
+
"""Main loop for requesting and processing jobs."""
|
166
|
+
while self.running and self.websocket:
|
167
|
+
try:
|
168
|
+
# Request a job
|
169
|
+
await self.websocket.send(json.dumps({"type": "request_job"}))
|
170
|
+
|
171
|
+
# Wait a bit for response
|
172
|
+
await asyncio.sleep(1)
|
173
|
+
|
174
|
+
if self.current_job:
|
175
|
+
await self._process_job(self.current_job)
|
176
|
+
self.current_job = None
|
177
|
+
else:
|
178
|
+
# No job available, wait before requesting again
|
179
|
+
await asyncio.sleep(5)
|
180
|
+
|
181
|
+
except websockets.exceptions.ConnectionClosed:
|
182
|
+
logger.info("Connection closed during job processing")
|
183
|
+
break
|
184
|
+
except Exception as e:
|
185
|
+
logger.error(f"Job processing error: {e}")
|
186
|
+
self.error_count += 1
|
187
|
+
await asyncio.sleep(1)
|
188
|
+
|
189
|
+
async def _message_handler(self):
|
190
|
+
"""Handle incoming messages from orchestrator."""
|
191
|
+
try:
|
192
|
+
async for message in self.websocket:
|
193
|
+
try:
|
194
|
+
data = json.loads(message)
|
195
|
+
msg_type = data.get("type")
|
196
|
+
|
197
|
+
if msg_type == "job":
|
198
|
+
job_data = data["job"]
|
199
|
+
self.current_job = Job(**job_data)
|
200
|
+
logger.info(f"Received job {self.current_job.job_id}")
|
201
|
+
|
202
|
+
elif msg_type == "no_jobs":
|
203
|
+
logger.debug("No jobs available")
|
204
|
+
|
205
|
+
elif msg_type == "ack":
|
206
|
+
logger.debug(f"Job {data['job_id']} acknowledged")
|
207
|
+
self.processed_count += 1
|
208
|
+
|
209
|
+
except json.JSONDecodeError as e:
|
210
|
+
logger.error(f"Invalid message format: {e}")
|
211
|
+
except Exception as e:
|
212
|
+
logger.error(f"Error handling message: {e}")
|
213
|
+
|
214
|
+
except websockets.exceptions.ConnectionClosed:
|
215
|
+
logger.info("Connection closed while waiting for messages")
|
216
|
+
except Exception as e:
|
217
|
+
logger.error(f"Message handler error: {e}")
|
218
|
+
|
219
|
+
async def _process_job(self, job: Job):
|
220
|
+
"""Process a single captioning job."""
|
221
|
+
if not self.websocket:
|
222
|
+
logger.warning(f"No websocket connection, skipping job {job.job_id}")
|
223
|
+
return
|
224
|
+
|
225
|
+
logger.info(f"Processing job {job.job_id}")
|
226
|
+
|
227
|
+
try:
|
228
|
+
# Load and preprocess images
|
229
|
+
images = await self._load_images(job)
|
230
|
+
|
231
|
+
# TODO: Here you would integrate your captioning model
|
232
|
+
# For now, using placeholder
|
233
|
+
caption = f"[Generated caption for {job.item_key}]"
|
234
|
+
|
235
|
+
# Submit result
|
236
|
+
await self.websocket.send(
|
237
|
+
json.dumps(
|
238
|
+
{
|
239
|
+
"type": "submit_caption",
|
240
|
+
"job_id": job.job_id,
|
241
|
+
"dataset": job.dataset,
|
242
|
+
"shard": job.shard,
|
243
|
+
"item_key": job.item_key,
|
244
|
+
"caption": caption,
|
245
|
+
}
|
246
|
+
)
|
247
|
+
)
|
248
|
+
|
249
|
+
logger.info(f"Completed job {job.job_id}")
|
250
|
+
|
251
|
+
except websockets.exceptions.ConnectionClosed:
|
252
|
+
logger.warning(f"Connection lost while processing job {job.job_id}")
|
253
|
+
raise # Re-raise to trigger reconnection
|
254
|
+
except Exception as e:
|
255
|
+
logger.error(f"Failed to process job {job.job_id}: {e}")
|
256
|
+
|
257
|
+
# Report failure if still connected
|
258
|
+
if self.websocket:
|
259
|
+
try:
|
260
|
+
await self.websocket.send(
|
261
|
+
json.dumps({"type": "job_failed", "job_id": job.job_id, "error": str(e)})
|
262
|
+
)
|
263
|
+
except:
|
264
|
+
pass # Connection might be closed
|
265
|
+
|
266
|
+
async def _load_images(self, job: Job):
|
267
|
+
"""Load and preprocess images for a job."""
|
268
|
+
# This would load actual images from the dataset
|
269
|
+
# Now can use self.dataset_type and self.dataset_path received from orchestrator
|
270
|
+
# For now, returning placeholder
|
271
|
+
return []
|
272
|
+
|
273
|
+
async def _heartbeat_loop(self):
|
274
|
+
"""Send periodic heartbeats to orchestrator."""
|
275
|
+
while self.running and self.websocket:
|
276
|
+
try:
|
277
|
+
await self.websocket.send(
|
278
|
+
json.dumps(
|
279
|
+
{
|
280
|
+
"type": "heartbeat",
|
281
|
+
"processed": self.processed_count,
|
282
|
+
"errors": self.error_count,
|
283
|
+
}
|
284
|
+
)
|
285
|
+
)
|
286
|
+
await asyncio.sleep(30)
|
287
|
+
except websockets.exceptions.ConnectionClosed:
|
288
|
+
logger.info("Connection closed during heartbeat")
|
289
|
+
break
|
290
|
+
except Exception as e:
|
291
|
+
logger.error(f"Heartbeat error: {e}")
|
292
|
+
break
|
293
|
+
|
294
|
+
async def shutdown(self):
|
295
|
+
"""Graceful shutdown."""
|
296
|
+
logger.info("Shutting down worker...")
|
297
|
+
self.running = False
|
298
|
+
|
299
|
+
if self.websocket:
|
300
|
+
await self.websocket.close()
|