solarviewer 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- solar_radio_image_viewer/__init__.py +12 -0
- solar_radio_image_viewer/assets/add_tab_default.png +0 -0
- solar_radio_image_viewer/assets/add_tab_default_light.png +0 -0
- solar_radio_image_viewer/assets/add_tab_hover.png +0 -0
- solar_radio_image_viewer/assets/add_tab_hover_light.png +0 -0
- solar_radio_image_viewer/assets/browse.png +0 -0
- solar_radio_image_viewer/assets/browse_light.png +0 -0
- solar_radio_image_viewer/assets/close_tab_default.png +0 -0
- solar_radio_image_viewer/assets/close_tab_default_light.png +0 -0
- solar_radio_image_viewer/assets/close_tab_hover.png +0 -0
- solar_radio_image_viewer/assets/close_tab_hover_light.png +0 -0
- solar_radio_image_viewer/assets/ellipse_selection.png +0 -0
- solar_radio_image_viewer/assets/ellipse_selection_light.png +0 -0
- solar_radio_image_viewer/assets/icons8-ellipse-90.png +0 -0
- solar_radio_image_viewer/assets/icons8-ellipse-90_light.png +0 -0
- solar_radio_image_viewer/assets/icons8-info-90.png +0 -0
- solar_radio_image_viewer/assets/icons8-info-90_light.png +0 -0
- solar_radio_image_viewer/assets/profile.png +0 -0
- solar_radio_image_viewer/assets/profile_light.png +0 -0
- solar_radio_image_viewer/assets/rectangle_selection.png +0 -0
- solar_radio_image_viewer/assets/rectangle_selection_light.png +0 -0
- solar_radio_image_viewer/assets/reset.png +0 -0
- solar_radio_image_viewer/assets/reset_light.png +0 -0
- solar_radio_image_viewer/assets/ruler.png +0 -0
- solar_radio_image_viewer/assets/ruler_light.png +0 -0
- solar_radio_image_viewer/assets/search.png +0 -0
- solar_radio_image_viewer/assets/search_light.png +0 -0
- solar_radio_image_viewer/assets/settings.png +0 -0
- solar_radio_image_viewer/assets/settings_light.png +0 -0
- solar_radio_image_viewer/assets/splash.fits +0 -0
- solar_radio_image_viewer/assets/zoom_60arcmin.png +0 -0
- solar_radio_image_viewer/assets/zoom_60arcmin_light.png +0 -0
- solar_radio_image_viewer/assets/zoom_in.png +0 -0
- solar_radio_image_viewer/assets/zoom_in_light.png +0 -0
- solar_radio_image_viewer/assets/zoom_out.png +0 -0
- solar_radio_image_viewer/assets/zoom_out_light.png +0 -0
- solar_radio_image_viewer/create_video.py +1345 -0
- solar_radio_image_viewer/dialogs.py +2665 -0
- solar_radio_image_viewer/from_simpl/__init__.py +184 -0
- solar_radio_image_viewer/from_simpl/caltable_visualizer.py +1001 -0
- solar_radio_image_viewer/from_simpl/dynamic_spectra_dialog.py +332 -0
- solar_radio_image_viewer/from_simpl/make_dynamic_spectra.py +351 -0
- solar_radio_image_viewer/from_simpl/pipeline_logger_gui.py +1232 -0
- solar_radio_image_viewer/from_simpl/simpl_theme.py +352 -0
- solar_radio_image_viewer/from_simpl/utils.py +984 -0
- solar_radio_image_viewer/from_simpl/view_dynamic_spectra_GUI.py +1975 -0
- solar_radio_image_viewer/helioprojective.py +1916 -0
- solar_radio_image_viewer/helioprojective_viewer.py +817 -0
- solar_radio_image_viewer/helioviewer_browser.py +1514 -0
- solar_radio_image_viewer/main.py +148 -0
- solar_radio_image_viewer/move_phasecenter.py +1269 -0
- solar_radio_image_viewer/napari_viewer.py +368 -0
- solar_radio_image_viewer/noaa_events/__init__.py +32 -0
- solar_radio_image_viewer/noaa_events/noaa_events.py +430 -0
- solar_radio_image_viewer/noaa_events/noaa_events_gui.py +1922 -0
- solar_radio_image_viewer/norms.py +293 -0
- solar_radio_image_viewer/radio_data_downloader/__init__.py +25 -0
- solar_radio_image_viewer/radio_data_downloader/radio_data_downloader.py +756 -0
- solar_radio_image_viewer/radio_data_downloader/radio_data_downloader_gui.py +528 -0
- solar_radio_image_viewer/searchable_combobox.py +220 -0
- solar_radio_image_viewer/solar_context/__init__.py +41 -0
- solar_radio_image_viewer/solar_context/active_regions.py +371 -0
- solar_radio_image_viewer/solar_context/cme_alerts.py +234 -0
- solar_radio_image_viewer/solar_context/context_images.py +297 -0
- solar_radio_image_viewer/solar_context/realtime_data.py +528 -0
- solar_radio_image_viewer/solar_data_downloader/__init__.py +35 -0
- solar_radio_image_viewer/solar_data_downloader/solar_data_downloader.py +1667 -0
- solar_radio_image_viewer/solar_data_downloader/solar_data_downloader_cli.py +901 -0
- solar_radio_image_viewer/solar_data_downloader/solar_data_downloader_gui.py +1210 -0
- solar_radio_image_viewer/styles.py +643 -0
- solar_radio_image_viewer/utils/__init__.py +32 -0
- solar_radio_image_viewer/utils/rate_limiter.py +255 -0
- solar_radio_image_viewer/utils.py +952 -0
- solar_radio_image_viewer/video_dialog.py +2629 -0
- solar_radio_image_viewer/video_utils.py +656 -0
- solar_radio_image_viewer/viewer.py +11174 -0
- solarviewer-1.0.2.dist-info/METADATA +343 -0
- solarviewer-1.0.2.dist-info/RECORD +82 -0
- solarviewer-1.0.2.dist-info/WHEEL +5 -0
- solarviewer-1.0.2.dist-info/entry_points.txt +8 -0
- solarviewer-1.0.2.dist-info/licenses/LICENSE +21 -0
- solarviewer-1.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,984 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for the LOFAR Solar Imaging Pipeline.
|
|
3
|
+
|
|
4
|
+
This module provides various utility functions for the pipeline, including
|
|
5
|
+
a centralized logging system that can be used across all pipeline components.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import logging
|
|
11
|
+
import json
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from logging.handlers import RotatingFileHandler
|
|
14
|
+
import threading
|
|
15
|
+
from queue import Queue
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LogRecord:
|
|
19
|
+
"""A simple container for log records to be passed to GUI listeners."""
|
|
20
|
+
def __init__(self, level, name, message, timestamp=None):
|
|
21
|
+
self.level = level
|
|
22
|
+
self.name = name
|
|
23
|
+
self.message = message
|
|
24
|
+
self.timestamp = timestamp or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
25
|
+
|
|
26
|
+
def to_dict(self):
|
|
27
|
+
"""Convert to dictionary for JSON serialization."""
|
|
28
|
+
return {
|
|
29
|
+
"level": self.level,
|
|
30
|
+
"name": self.name,
|
|
31
|
+
"message": self.message,
|
|
32
|
+
"timestamp": self.timestamp
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PipelineLoggerManager:
|
|
37
|
+
"""
|
|
38
|
+
Centralized logger manager for the LOFAR pipeline.
|
|
39
|
+
|
|
40
|
+
This class manages logger instances, file handlers, and GUI listeners,
|
|
41
|
+
ensuring consistent logging behavior across the pipeline.
|
|
42
|
+
"""
|
|
43
|
+
_instance = None
|
|
44
|
+
_lock = threading.Lock()
|
|
45
|
+
|
|
46
|
+
def __new__(cls, *args, **kwargs):
|
|
47
|
+
with cls._lock:
|
|
48
|
+
if cls._instance is None:
|
|
49
|
+
cls._instance = super(PipelineLoggerManager, cls).__new__(cls)
|
|
50
|
+
cls._instance._initialized = False
|
|
51
|
+
return cls._instance
|
|
52
|
+
|
|
53
|
+
def __init__(self, log_dir=None, log_level=logging.INFO):
|
|
54
|
+
if self._initialized and log_dir is None:
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
# If we're already initialized but a new log_dir is provided, update it
|
|
58
|
+
if self._initialized and log_dir is not None and log_dir != self.log_dir:
|
|
59
|
+
self.update_log_dir(log_dir)
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
# Set log directory
|
|
63
|
+
self.log_dir = log_dir or os.path.join(os.getcwd(), "logs")
|
|
64
|
+
os.makedirs(self.log_dir, exist_ok=True)
|
|
65
|
+
|
|
66
|
+
# Default log level
|
|
67
|
+
self.log_level = log_level
|
|
68
|
+
|
|
69
|
+
# Generate log file name with timestamp
|
|
70
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
71
|
+
self.log_file = os.path.join(self.log_dir, f"pipeline_{timestamp}.log")
|
|
72
|
+
|
|
73
|
+
# Set up the root logger
|
|
74
|
+
self.root_logger = logging.getLogger()
|
|
75
|
+
self.root_logger.setLevel(self.log_level)
|
|
76
|
+
|
|
77
|
+
# Create handlers
|
|
78
|
+
self._create_handlers()
|
|
79
|
+
|
|
80
|
+
# Store registered loggers
|
|
81
|
+
self.loggers = {}
|
|
82
|
+
|
|
83
|
+
# Queue and listeners for GUI updates
|
|
84
|
+
self.log_queue = Queue()
|
|
85
|
+
self.gui_listeners = []
|
|
86
|
+
|
|
87
|
+
self._initialized = True
|
|
88
|
+
|
|
89
|
+
def update_log_dir(self, new_log_dir):
|
|
90
|
+
"""Update the log directory and recreate file handlers."""
|
|
91
|
+
# Store old handlers for removal
|
|
92
|
+
old_file_handler = self.file_handler if hasattr(self, 'file_handler') else None
|
|
93
|
+
|
|
94
|
+
# Update log directory
|
|
95
|
+
self.log_dir = new_log_dir
|
|
96
|
+
os.makedirs(self.log_dir, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
# Generate a new log file name
|
|
99
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
100
|
+
self.log_file = os.path.join(self.log_dir, f"pipeline_{timestamp}.log")
|
|
101
|
+
|
|
102
|
+
# Create a new file handler
|
|
103
|
+
file_handler = RotatingFileHandler(
|
|
104
|
+
self.log_file,
|
|
105
|
+
maxBytes=10*1024*1024, # 10MB
|
|
106
|
+
backupCount=5
|
|
107
|
+
)
|
|
108
|
+
file_handler.setLevel(self.log_level)
|
|
109
|
+
file_format = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
|
110
|
+
file_handler.setFormatter(file_format)
|
|
111
|
+
|
|
112
|
+
# Remove old file handler and add new one
|
|
113
|
+
if old_file_handler:
|
|
114
|
+
self.root_logger.removeHandler(old_file_handler)
|
|
115
|
+
self.root_logger.addHandler(file_handler)
|
|
116
|
+
|
|
117
|
+
# Update reference to the file handler
|
|
118
|
+
self.file_handler = file_handler
|
|
119
|
+
|
|
120
|
+
# Log the change to both handlers
|
|
121
|
+
self.root_logger.info(f"Log directory updated to: {self.log_dir}")
|
|
122
|
+
self.root_logger.info(f"New log file: {self.log_file}")
|
|
123
|
+
|
|
124
|
+
def _create_handlers(self):
|
|
125
|
+
"""Create and configure log handlers."""
|
|
126
|
+
# Clear existing handlers
|
|
127
|
+
for handler in self.root_logger.handlers[:]:
|
|
128
|
+
self.root_logger.removeHandler(handler)
|
|
129
|
+
|
|
130
|
+
# Console handler
|
|
131
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
|
132
|
+
console_handler.setLevel(self.log_level)
|
|
133
|
+
console_format = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
|
134
|
+
console_handler.setFormatter(console_format)
|
|
135
|
+
|
|
136
|
+
# File handler (rotating)
|
|
137
|
+
file_handler = RotatingFileHandler(
|
|
138
|
+
self.log_file,
|
|
139
|
+
maxBytes=10*1024*1024, # 10MB
|
|
140
|
+
backupCount=5
|
|
141
|
+
)
|
|
142
|
+
file_handler.setLevel(self.log_level)
|
|
143
|
+
file_format = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
|
144
|
+
file_handler.setFormatter(file_format)
|
|
145
|
+
|
|
146
|
+
# Add handlers to root logger
|
|
147
|
+
self.root_logger.addHandler(console_handler)
|
|
148
|
+
self.root_logger.addHandler(file_handler)
|
|
149
|
+
|
|
150
|
+
# Store handlers for later access
|
|
151
|
+
self.console_handler = console_handler
|
|
152
|
+
self.file_handler = file_handler
|
|
153
|
+
|
|
154
|
+
def set_log_level(self, level):
|
|
155
|
+
"""Set the logging level for all handlers."""
|
|
156
|
+
self.log_level = level
|
|
157
|
+
self.root_logger.setLevel(level)
|
|
158
|
+
self.console_handler.setLevel(level)
|
|
159
|
+
self.file_handler.setLevel(level)
|
|
160
|
+
|
|
161
|
+
def get_logger(self, name):
|
|
162
|
+
"""
|
|
163
|
+
Get or create a logger with the specified name.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
name : str
|
|
168
|
+
Name of the logger.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
PipelineLogger
|
|
173
|
+
Logger instance with the specified name.
|
|
174
|
+
"""
|
|
175
|
+
if name in self.loggers:
|
|
176
|
+
return self.loggers[name]
|
|
177
|
+
|
|
178
|
+
# Create a new logger
|
|
179
|
+
logger = logging.getLogger(name)
|
|
180
|
+
logger.setLevel(self.log_level)
|
|
181
|
+
|
|
182
|
+
# Wrap with our custom logger
|
|
183
|
+
pipeline_logger = PipelineLogger(logger, self)
|
|
184
|
+
self.loggers[name] = pipeline_logger
|
|
185
|
+
|
|
186
|
+
return pipeline_logger
|
|
187
|
+
|
|
188
|
+
def register_gui_listener(self, callback):
|
|
189
|
+
"""
|
|
190
|
+
Register a GUI callback to receive log updates.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
callback : callable
|
|
195
|
+
Function to call with new log records.
|
|
196
|
+
"""
|
|
197
|
+
self.gui_listeners.append(callback)
|
|
198
|
+
|
|
199
|
+
def unregister_gui_listener(self, callback):
|
|
200
|
+
"""
|
|
201
|
+
Unregister a GUI callback.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
callback : callable
|
|
206
|
+
The callback to unregister.
|
|
207
|
+
"""
|
|
208
|
+
if callback in self.gui_listeners:
|
|
209
|
+
self.gui_listeners.remove(callback)
|
|
210
|
+
|
|
211
|
+
def notify_listeners(self, log_record):
|
|
212
|
+
"""
|
|
213
|
+
Notify all GUI listeners of a new log record.
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
log_record : LogRecord
|
|
218
|
+
The log record to notify listeners about.
|
|
219
|
+
"""
|
|
220
|
+
# Add to queue
|
|
221
|
+
self.log_queue.put(log_record)
|
|
222
|
+
|
|
223
|
+
# Notify all listeners
|
|
224
|
+
for listener in self.gui_listeners:
|
|
225
|
+
try:
|
|
226
|
+
listener(log_record)
|
|
227
|
+
except Exception as e:
|
|
228
|
+
# Don't let errors in GUI callbacks break the logging
|
|
229
|
+
sys.stderr.write(f"Error in log listener: {e}\n")
|
|
230
|
+
|
|
231
|
+
def get_log_queue(self):
|
|
232
|
+
"""Get the queue containing log records for GUI consumption."""
|
|
233
|
+
return self.log_queue
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class PipelineLogger:
|
|
237
|
+
"""
|
|
238
|
+
Wrapper around a standard logger that notifies GUI listeners.
|
|
239
|
+
"""
|
|
240
|
+
def __init__(self, logger, manager):
|
|
241
|
+
self._logger = logger
|
|
242
|
+
self._manager = manager
|
|
243
|
+
|
|
244
|
+
def _log_and_notify(self, level, message, *args, **kwargs):
|
|
245
|
+
"""Log a message and notify GUI listeners."""
|
|
246
|
+
# Format the message with args and kwargs
|
|
247
|
+
if args or kwargs:
|
|
248
|
+
formatted_message = message % args if args else message
|
|
249
|
+
else:
|
|
250
|
+
formatted_message = message
|
|
251
|
+
|
|
252
|
+
# Create a log record
|
|
253
|
+
log_record = LogRecord(
|
|
254
|
+
level=logging.getLevelName(level),
|
|
255
|
+
name=self._logger.name,
|
|
256
|
+
message=formatted_message
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Notify listeners
|
|
260
|
+
self._manager.notify_listeners(log_record)
|
|
261
|
+
|
|
262
|
+
# Log the message
|
|
263
|
+
self._logger.log(level, message, *args, **kwargs)
|
|
264
|
+
|
|
265
|
+
def debug(self, message, *args, **kwargs):
|
|
266
|
+
"""Log a debug message."""
|
|
267
|
+
self._log_and_notify(logging.DEBUG, message, *args, **kwargs)
|
|
268
|
+
|
|
269
|
+
def info(self, message, *args, **kwargs):
|
|
270
|
+
"""Log an info message."""
|
|
271
|
+
self._log_and_notify(logging.INFO, message, *args, **kwargs)
|
|
272
|
+
|
|
273
|
+
def warning(self, message, *args, **kwargs):
|
|
274
|
+
"""Log a warning message."""
|
|
275
|
+
self._log_and_notify(logging.WARNING, message, *args, **kwargs)
|
|
276
|
+
|
|
277
|
+
def error(self, message, *args, **kwargs):
|
|
278
|
+
"""Log an error message."""
|
|
279
|
+
self._log_and_notify(logging.ERROR, message, *args, **kwargs)
|
|
280
|
+
|
|
281
|
+
def critical(self, message, *args, **kwargs):
|
|
282
|
+
"""Log a critical message."""
|
|
283
|
+
self._log_and_notify(logging.CRITICAL, message, *args, **kwargs)
|
|
284
|
+
|
|
285
|
+
def exception(self, message, *args, **kwargs):
|
|
286
|
+
"""Log an exception message."""
|
|
287
|
+
self._log_and_notify(logging.ERROR, message, *args, **kwargs)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
# Global instance for easy access
|
|
291
|
+
def get_logger_manager(log_dir=None, log_level=None):
|
|
292
|
+
"""
|
|
293
|
+
Get the singleton logger manager instance.
|
|
294
|
+
|
|
295
|
+
Parameters
|
|
296
|
+
----------
|
|
297
|
+
log_dir : str, optional
|
|
298
|
+
Directory to store log files. Only used on first call.
|
|
299
|
+
log_level : int, optional
|
|
300
|
+
Logging level. Only used on first call.
|
|
301
|
+
|
|
302
|
+
Returns
|
|
303
|
+
-------
|
|
304
|
+
PipelineLoggerManager
|
|
305
|
+
Singleton logger manager instance.
|
|
306
|
+
"""
|
|
307
|
+
if log_level is None:
|
|
308
|
+
log_level = logging.INFO
|
|
309
|
+
|
|
310
|
+
manager = PipelineLoggerManager(log_dir, log_level)
|
|
311
|
+
return manager
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def get_logger(name, log_dir=None, log_level=None):
|
|
315
|
+
"""
|
|
316
|
+
Get a logger with the specified name.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
name : str
|
|
321
|
+
Name of the logger.
|
|
322
|
+
log_dir : str, optional
|
|
323
|
+
Directory to store log files. Only used on first call.
|
|
324
|
+
log_level : int, optional
|
|
325
|
+
Logging level. Only used on first call.
|
|
326
|
+
|
|
327
|
+
Returns
|
|
328
|
+
-------
|
|
329
|
+
PipelineLogger
|
|
330
|
+
Logger instance with the specified name.
|
|
331
|
+
"""
|
|
332
|
+
manager = get_logger_manager(log_dir, log_level)
|
|
333
|
+
return manager.get_logger(name)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
"""
|
|
340
|
+
LOFAR Solar Imaging Pipeline Checkpoint Manager
|
|
341
|
+
|
|
342
|
+
This module manages checkpoints to allow the pipeline to resume from where it left off
|
|
343
|
+
if interrupted. It provides functions to save, load, and check checkpoint status.
|
|
344
|
+
|
|
345
|
+
Checkpoints are stored as JSON files in a dedicated directory structure.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
import os
|
|
349
|
+
import json
|
|
350
|
+
import shutil
|
|
351
|
+
import logging
|
|
352
|
+
import glob
|
|
353
|
+
import time
|
|
354
|
+
from typing import Dict, List, Any, Union, Optional, Tuple
|
|
355
|
+
import sys
|
|
356
|
+
|
|
357
|
+
class CheckpointManager:
|
|
358
|
+
"""
|
|
359
|
+
Manages checkpoints for the LOFAR Solar Imaging Pipeline to enable resuming
|
|
360
|
+
from interruptions.
|
|
361
|
+
|
|
362
|
+
Attributes:
|
|
363
|
+
working_dir (str): Pipeline working directory
|
|
364
|
+
checkpoint_dir (str): Directory to store checkpoint files
|
|
365
|
+
logger (logging.Logger): Logger for checkpoint operations
|
|
366
|
+
checkpoint_data (dict): Current checkpoint data
|
|
367
|
+
resume_mode (bool): Whether pipeline is running in resume mode
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def __init__(self, working_dir: str, resume: bool = False, logger: Optional[logging.Logger] = None):
|
|
371
|
+
"""
|
|
372
|
+
Initialize the CheckpointManager.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
working_dir (str): Pipeline working directory
|
|
376
|
+
resume (bool): Whether to resume from checkpoints
|
|
377
|
+
logger (logging.Logger, optional): Logger instance
|
|
378
|
+
"""
|
|
379
|
+
self.working_dir = working_dir
|
|
380
|
+
self.checkpoint_dir = os.path.join(working_dir, ".checkpoints")
|
|
381
|
+
self.resume_mode = resume
|
|
382
|
+
self.logger = logger or logging.getLogger("checkpoint_manager")
|
|
383
|
+
|
|
384
|
+
# Create checkpoint directory if it doesn't exist
|
|
385
|
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
386
|
+
|
|
387
|
+
# Main checkpoint file location
|
|
388
|
+
self.checkpoint_file = os.path.join(self.checkpoint_dir, "pipeline_checkpoints.json")
|
|
389
|
+
|
|
390
|
+
# Initialize or load checkpoint data
|
|
391
|
+
if resume and os.path.exists(self.checkpoint_file):
|
|
392
|
+
self.checkpoint_data = self._load_checkpoints()
|
|
393
|
+
self.logger.info(f"Resumed checkpoint data from {self.checkpoint_file}")
|
|
394
|
+
else:
|
|
395
|
+
# Initialize fresh checkpoint data
|
|
396
|
+
self.checkpoint_data = {
|
|
397
|
+
"pipeline_info": {
|
|
398
|
+
"start_time": datetime.now().isoformat(),
|
|
399
|
+
"last_update": datetime.now().isoformat(),
|
|
400
|
+
"working_dir": working_dir,
|
|
401
|
+
"resume_count": 0
|
|
402
|
+
},
|
|
403
|
+
"stages": {
|
|
404
|
+
"dynamic_spectra": {"status": "pending", "timestamp": None},
|
|
405
|
+
"calibrator_processing": {"status": "pending", "timestamp": None},
|
|
406
|
+
"calibrator_calibration": {"status": "pending", "timestamp": None},
|
|
407
|
+
"solar_preparation": {"status": "pending", "timestamp": None},
|
|
408
|
+
"selfcal": {"status": "pending", "timestamp": None, "chunks": {}},
|
|
409
|
+
"imaging": {"status": "pending", "timestamp": None, "chunks": {}}
|
|
410
|
+
},
|
|
411
|
+
"results": {
|
|
412
|
+
"calib_table": None,
|
|
413
|
+
"for_imaging_ms": [],
|
|
414
|
+
"selfcal_chunks": {},
|
|
415
|
+
"selfcal_results": [],
|
|
416
|
+
"imaging_results": [],
|
|
417
|
+
"imaging_chunks": {},
|
|
418
|
+
"dynamic_spectra_files": {}
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
if resume:
|
|
423
|
+
self.logger.warning(f"Resume mode requested but no checkpoint file found at {self.checkpoint_file}. Starting fresh.")
|
|
424
|
+
# Increment resume counter to indicate a fresh start in resume mode
|
|
425
|
+
self.checkpoint_data["pipeline_info"]["resume_count"] = 1
|
|
426
|
+
else:
|
|
427
|
+
self.logger.info("Starting fresh pipeline run with new checkpoints.")
|
|
428
|
+
|
|
429
|
+
# Save initial checkpoint data
|
|
430
|
+
self._save_checkpoints()
|
|
431
|
+
|
|
432
|
+
def _save_checkpoints(self) -> None:
|
|
433
|
+
"""Save current checkpoint data to disk."""
|
|
434
|
+
# Update timestamp
|
|
435
|
+
self.checkpoint_data["pipeline_info"]["last_update"] = datetime.now().isoformat()
|
|
436
|
+
|
|
437
|
+
# Write to temporary file first, then rename to avoid corruption if interrupted
|
|
438
|
+
temp_file = f"{self.checkpoint_file}.tmp"
|
|
439
|
+
try:
|
|
440
|
+
with open(temp_file, 'w') as f:
|
|
441
|
+
json.dump(self.checkpoint_data, f, indent=2, cls=CustomJSONEncoder)
|
|
442
|
+
|
|
443
|
+
# Atomic replacement
|
|
444
|
+
shutil.move(temp_file, self.checkpoint_file)
|
|
445
|
+
|
|
446
|
+
# Log at debug level to avoid excessive messages for frequent saves
|
|
447
|
+
self.logger.debug(f"Checkpoint data saved to {self.checkpoint_file}")
|
|
448
|
+
except Exception as e:
|
|
449
|
+
self.logger.error(f"Error saving checkpoint data: {str(e)}")
|
|
450
|
+
if os.path.exists(temp_file):
|
|
451
|
+
try:
|
|
452
|
+
os.remove(temp_file)
|
|
453
|
+
except:
|
|
454
|
+
pass
|
|
455
|
+
# Don't re-raise this error as it would disrupt the pipeline
|
|
456
|
+
|
|
457
|
+
def _load_checkpoints(self) -> Dict[str, Any]:
|
|
458
|
+
"""Load checkpoint data from disk."""
|
|
459
|
+
try:
|
|
460
|
+
with open(self.checkpoint_file, 'r') as f:
|
|
461
|
+
data = json.load(f)
|
|
462
|
+
|
|
463
|
+
# Increment resume counter
|
|
464
|
+
if "pipeline_info" in data and "resume_count" in data["pipeline_info"]:
|
|
465
|
+
data["pipeline_info"]["resume_count"] += 1
|
|
466
|
+
else:
|
|
467
|
+
# Initialize if missing
|
|
468
|
+
data.setdefault("pipeline_info", {})
|
|
469
|
+
data["pipeline_info"]["resume_count"] = 1
|
|
470
|
+
|
|
471
|
+
# Updated for this resume
|
|
472
|
+
data["pipeline_info"]["last_update"] = datetime.now().isoformat()
|
|
473
|
+
|
|
474
|
+
# Auto-repair any list-format chunk data
|
|
475
|
+
#self._fix_chunk_data_format(data)
|
|
476
|
+
|
|
477
|
+
return data
|
|
478
|
+
except (json.JSONDecodeError, IOError, ValueError) as e:
|
|
479
|
+
self.logger.error(f"Error loading checkpoint file: {str(e)}")
|
|
480
|
+
self.logger.warning("Starting with fresh checkpoints due to loading error.")
|
|
481
|
+
return {
|
|
482
|
+
"pipeline_info": {
|
|
483
|
+
"start_time": datetime.now().isoformat(),
|
|
484
|
+
"last_update": datetime.now().isoformat(),
|
|
485
|
+
"working_dir": self.working_dir,
|
|
486
|
+
"resume_count": 1
|
|
487
|
+
},
|
|
488
|
+
"stages": {},
|
|
489
|
+
"results": {}
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
def _fix_chunk_data_format(self, data: Dict[str, Any]) -> None:
|
|
493
|
+
"""
|
|
494
|
+
Convert any list-format chunk data to dictionary format.
|
|
495
|
+
This ensures backward compatibility with older checkpoint formats.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
data (Dict[str, Any]): Checkpoint data to fix
|
|
499
|
+
"""
|
|
500
|
+
try:
|
|
501
|
+
# Fix results section
|
|
502
|
+
if "results" in data:
|
|
503
|
+
for stage_name in ["selfcal", "imaging"]:
|
|
504
|
+
chunk_key = f"{stage_name}_chunks"
|
|
505
|
+
if chunk_key in data["results"] and isinstance(data["results"][chunk_key], list):
|
|
506
|
+
self.logger.info(f"Converting {chunk_key} from list to dictionary format for compatibility")
|
|
507
|
+
|
|
508
|
+
# Create new dictionary from list items
|
|
509
|
+
new_chunks = {}
|
|
510
|
+
for idx, item in enumerate(data["results"][chunk_key]):
|
|
511
|
+
if item is not None:
|
|
512
|
+
new_chunks[str(idx)] = item
|
|
513
|
+
|
|
514
|
+
# Replace list with dictionary
|
|
515
|
+
data["results"][chunk_key] = new_chunks
|
|
516
|
+
|
|
517
|
+
# If stage data exists, ensure chunks are marked as completed
|
|
518
|
+
if ("stages" in data and stage_name in data["stages"] and
|
|
519
|
+
isinstance(data["stages"][stage_name], dict)):
|
|
520
|
+
if "chunks" not in data["stages"][stage_name]:
|
|
521
|
+
data["stages"][stage_name]["chunks"] = {}
|
|
522
|
+
|
|
523
|
+
for chunk_idx in new_chunks.keys():
|
|
524
|
+
if chunk_idx not in data["stages"][stage_name]["chunks"]:
|
|
525
|
+
data["stages"][stage_name]["chunks"][chunk_idx] = {
|
|
526
|
+
"status": "completed",
|
|
527
|
+
"timestamp": datetime.now().isoformat()
|
|
528
|
+
}
|
|
529
|
+
except Exception as e:
|
|
530
|
+
# Log but don't raise - we don't want to prevent loading checkpoint just because repair failed
|
|
531
|
+
self.logger.warning(f"Error fixing chunk data format: {str(e)}. Some features may not work correctly.")
|
|
532
|
+
|
|
533
|
+
def mark_stage_complete(self, stage_name: str, result: Any = None) -> None:
|
|
534
|
+
"""
|
|
535
|
+
Mark a pipeline stage as complete and save the result if provided.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
stage_name (str): Name of the completed stage
|
|
539
|
+
result (Any, optional): Result data to store for this stage
|
|
540
|
+
"""
|
|
541
|
+
try:
|
|
542
|
+
# Update stage status
|
|
543
|
+
self.checkpoint_data["stages"].setdefault(stage_name, {})
|
|
544
|
+
self.checkpoint_data["stages"][stage_name]["status"] = "completed"
|
|
545
|
+
self.checkpoint_data["stages"][stage_name]["timestamp"] = datetime.now().isoformat()
|
|
546
|
+
|
|
547
|
+
# Store result if provided
|
|
548
|
+
if result is not None:
|
|
549
|
+
self.checkpoint_data["results"][stage_name] = result
|
|
550
|
+
|
|
551
|
+
self.logger.info(f"Stage '{stage_name}' marked as complete.")
|
|
552
|
+
self._save_checkpoints()
|
|
553
|
+
except Exception as e:
|
|
554
|
+
self.logger.error(f"Error marking stage {stage_name} as complete: {str(e)}")
|
|
555
|
+
# Re-raise the error to allow caller to handle it
|
|
556
|
+
raise
|
|
557
|
+
|
|
558
|
+
def mark_stage_in_progress(self, stage_name: str) -> None:
|
|
559
|
+
"""
|
|
560
|
+
Mark a pipeline stage as in progress.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
stage_name (str): Name of the stage
|
|
564
|
+
"""
|
|
565
|
+
try:
|
|
566
|
+
self.checkpoint_data["stages"].setdefault(stage_name, {})
|
|
567
|
+
self.checkpoint_data["stages"][stage_name]["status"] = "in_progress"
|
|
568
|
+
self.checkpoint_data["stages"][stage_name]["timestamp"] = datetime.now().isoformat()
|
|
569
|
+
|
|
570
|
+
self.logger.debug(f"Stage '{stage_name}' marked as in-progress.")
|
|
571
|
+
self._save_checkpoints()
|
|
572
|
+
except Exception as e:
|
|
573
|
+
self.logger.error(f"Error marking stage {stage_name} as in progress: {str(e)}")
|
|
574
|
+
# Re-raise the error to allow caller to handle it
|
|
575
|
+
raise
|
|
576
|
+
|
|
577
|
+
def mark_chunk_complete(self, stage_name: str, chunk_index: int, chunk_result: Any = None) -> None:
|
|
578
|
+
"""
|
|
579
|
+
Mark a chunk within a stage (like self-calibration or imaging) as complete.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
stage_name (str): Name of the parent stage
|
|
583
|
+
chunk_index (int): Index of the chunk
|
|
584
|
+
chunk_result (Any, optional): Result data for this chunk
|
|
585
|
+
"""
|
|
586
|
+
try:
|
|
587
|
+
# Initialize if needed
|
|
588
|
+
self.checkpoint_data["stages"].setdefault(stage_name, {})
|
|
589
|
+
self.checkpoint_data["stages"][stage_name].setdefault("chunks", {})
|
|
590
|
+
|
|
591
|
+
# Make sure to use string key for storing in JSON
|
|
592
|
+
chunk_key = str(chunk_index)
|
|
593
|
+
self.checkpoint_data["stages"][stage_name]["chunks"][chunk_key] = {
|
|
594
|
+
"status": "completed",
|
|
595
|
+
"timestamp": datetime.now().isoformat()
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
# Store result
|
|
599
|
+
if chunk_result is not None:
|
|
600
|
+
# Make sure results section is initialized
|
|
601
|
+
self.checkpoint_data["results"].setdefault(f"{stage_name}_chunks", {})
|
|
602
|
+
|
|
603
|
+
# Handle different types of chunk_result
|
|
604
|
+
if isinstance(chunk_result, dict):
|
|
605
|
+
# Dictionary is the expected format - use as is
|
|
606
|
+
self.checkpoint_data["results"][f"{stage_name}_chunks"][chunk_key] = chunk_result
|
|
607
|
+
elif isinstance(chunk_result, str):
|
|
608
|
+
# String is often a path to the output file - wrap in a dictionary
|
|
609
|
+
self.logger.warning(f"Converting string chunk result to dictionary for chunk {chunk_index}")
|
|
610
|
+
self.checkpoint_data["results"][f"{stage_name}_chunks"][chunk_key] = {
|
|
611
|
+
"output_ms": chunk_result,
|
|
612
|
+
"original_string_value": chunk_result
|
|
613
|
+
}
|
|
614
|
+
elif isinstance(chunk_result, list):
|
|
615
|
+
# Convert list to dictionary
|
|
616
|
+
self.logger.warning(f"Converting list chunk result to dictionary for chunk {chunk_index}")
|
|
617
|
+
result_dict = {}
|
|
618
|
+
for i, item in enumerate(chunk_result):
|
|
619
|
+
result_dict[f"item_{i}"] = item
|
|
620
|
+
self.checkpoint_data["results"][f"{stage_name}_chunks"][chunk_key] = result_dict
|
|
621
|
+
else:
|
|
622
|
+
# For any other type, store in a wrapper dictionary
|
|
623
|
+
self.logger.warning(f"Wrapping chunk result of type {type(chunk_result)} in dictionary for chunk {chunk_index}")
|
|
624
|
+
self.checkpoint_data["results"][f"{stage_name}_chunks"][chunk_key] = {
|
|
625
|
+
"value": str(chunk_result),
|
|
626
|
+
"original_type": str(type(chunk_result))
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
self.logger.info(f"Chunk {chunk_index} of stage '{stage_name}' marked as complete.")
|
|
630
|
+
self._save_checkpoints()
|
|
631
|
+
except Exception as e:
|
|
632
|
+
self.logger.error(f"Error marking chunk {chunk_index} of stage {stage_name} as complete: {str(e)}")
|
|
633
|
+
# Re-raise the error to allow caller to handle it
|
|
634
|
+
raise
|
|
635
|
+
|
|
636
|
+
def mark_chunk_in_progress(self, stage_name: str, chunk_index: int) -> None:
|
|
637
|
+
"""
|
|
638
|
+
Mark a chunk within a stage as in progress.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
stage_name (str): Name of the parent stage
|
|
642
|
+
chunk_index (int): Index of the chunk
|
|
643
|
+
"""
|
|
644
|
+
try:
|
|
645
|
+
# Initialize if needed
|
|
646
|
+
self.checkpoint_data["stages"].setdefault(stage_name, {})
|
|
647
|
+
self.checkpoint_data["stages"][stage_name].setdefault("chunks", {})
|
|
648
|
+
|
|
649
|
+
# Make sure to use string key for storing in JSON
|
|
650
|
+
chunk_key = str(chunk_index)
|
|
651
|
+
self.checkpoint_data["stages"][stage_name]["chunks"][chunk_key] = {
|
|
652
|
+
"status": "in_progress",
|
|
653
|
+
"timestamp": datetime.now().isoformat()
|
|
654
|
+
}
|
|
655
|
+
|
|
656
|
+
self.logger.debug(f"Chunk {chunk_index} of stage '{stage_name}' marked as in-progress.")
|
|
657
|
+
self._save_checkpoints()
|
|
658
|
+
except Exception as e:
|
|
659
|
+
self.logger.error(f"Error marking chunk {chunk_index} of stage {stage_name} as in progress: {str(e)}")
|
|
660
|
+
# Re-raise the error to allow caller to handle it
|
|
661
|
+
raise
|
|
662
|
+
|
|
663
|
+
def is_stage_completed(self, stage_name: str) -> bool:
|
|
664
|
+
"""
|
|
665
|
+
Check if a pipeline stage has been completed.
|
|
666
|
+
|
|
667
|
+
Args:
|
|
668
|
+
stage_name (str): Name of the stage to check
|
|
669
|
+
|
|
670
|
+
Returns:
|
|
671
|
+
bool: True if stage is completed, False otherwise
|
|
672
|
+
"""
|
|
673
|
+
if not self.resume_mode:
|
|
674
|
+
return False
|
|
675
|
+
|
|
676
|
+
try:
|
|
677
|
+
return (stage_name in self.checkpoint_data["stages"] and
|
|
678
|
+
self.checkpoint_data["stages"][stage_name]["status"] == "completed")
|
|
679
|
+
except (KeyError, TypeError):
|
|
680
|
+
return False
|
|
681
|
+
|
|
682
|
+
def is_chunk_completed(self, stage_name: str, chunk_index: int) -> bool:
|
|
683
|
+
"""
|
|
684
|
+
Check if a chunk within a stage has been completed.
|
|
685
|
+
|
|
686
|
+
Args:
|
|
687
|
+
stage_name (str): Name of the parent stage
|
|
688
|
+
chunk_index (int): Index of the chunk
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
bool: True if chunk is completed, False otherwise
|
|
692
|
+
"""
|
|
693
|
+
if not self.resume_mode:
|
|
694
|
+
return False
|
|
695
|
+
|
|
696
|
+
try:
|
|
697
|
+
# Make sure we're using a string key for lookup
|
|
698
|
+
chunk_key = str(chunk_index)
|
|
699
|
+
return (stage_name in self.checkpoint_data["stages"] and
|
|
700
|
+
"chunks" in self.checkpoint_data["stages"][stage_name] and
|
|
701
|
+
chunk_key in self.checkpoint_data["stages"][stage_name]["chunks"] and
|
|
702
|
+
self.checkpoint_data["stages"][stage_name]["chunks"][chunk_key]["status"] == "completed")
|
|
703
|
+
except (KeyError, TypeError, ValueError) as e:
|
|
704
|
+
self.logger.error(f"Error checking if chunk {chunk_index} is completed for stage {stage_name}: {str(e)}")
|
|
705
|
+
return False
|
|
706
|
+
|
|
707
|
+
def get_completed_chunks(self, stage_name: str) -> List[int]:
|
|
708
|
+
"""
|
|
709
|
+
Get list of completed chunk indices for a stage.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
stage_name (str): Name of the stage
|
|
713
|
+
|
|
714
|
+
Returns:
|
|
715
|
+
List[int]: List of completed chunk indices
|
|
716
|
+
"""
|
|
717
|
+
if not self.resume_mode:
|
|
718
|
+
return []
|
|
719
|
+
|
|
720
|
+
try:
|
|
721
|
+
completed = []
|
|
722
|
+
if stage_name in self.checkpoint_data["stages"]:
|
|
723
|
+
stage_data = self.checkpoint_data["stages"][stage_name]
|
|
724
|
+
|
|
725
|
+
# Handle the standard case where chunks are stored in a dictionary
|
|
726
|
+
if isinstance(stage_data, dict) and "chunks" in stage_data and isinstance(stage_data["chunks"], dict):
|
|
727
|
+
chunks = stage_data["chunks"]
|
|
728
|
+
# Ensure we're returning integers, not strings
|
|
729
|
+
completed = [
|
|
730
|
+
int(k) for k, v in chunks.items()
|
|
731
|
+
if isinstance(v, dict) and v.get("status") == "completed"
|
|
732
|
+
]
|
|
733
|
+
|
|
734
|
+
# Also check for chunks in the results section for backward compatibility
|
|
735
|
+
result_chunks = self.checkpoint_data["results"].get(f"{stage_name}_chunks")
|
|
736
|
+
if isinstance(result_chunks, dict):
|
|
737
|
+
# Add any indices from results not already in completed list
|
|
738
|
+
for k in result_chunks.keys():
|
|
739
|
+
try:
|
|
740
|
+
chunk_idx = int(k)
|
|
741
|
+
if chunk_idx not in completed:
|
|
742
|
+
# Verify it's actually completed by checking the stage data
|
|
743
|
+
if self.is_chunk_completed(stage_name, chunk_idx):
|
|
744
|
+
completed.append(chunk_idx)
|
|
745
|
+
except (ValueError, TypeError):
|
|
746
|
+
pass
|
|
747
|
+
elif isinstance(result_chunks, list):
|
|
748
|
+
# If stored as a list, indices with data are considered completed
|
|
749
|
+
for i, item in enumerate(result_chunks):
|
|
750
|
+
if item is not None and i not in completed:
|
|
751
|
+
# Verify it's actually completed by checking the stage data
|
|
752
|
+
if self.is_chunk_completed(stage_name, i):
|
|
753
|
+
completed.append(i)
|
|
754
|
+
|
|
755
|
+
return sorted(completed)
|
|
756
|
+
except (KeyError, TypeError, ValueError) as e:
|
|
757
|
+
self.logger.error(f"Error getting completed chunks for stage {stage_name}: {str(e)}")
|
|
758
|
+
return []
|
|
759
|
+
|
|
760
|
+
def get_stage_result(self, stage_name: str) -> Any:
|
|
761
|
+
"""
|
|
762
|
+
Get the stored result for a completed stage.
|
|
763
|
+
|
|
764
|
+
Args:
|
|
765
|
+
stage_name (str): Name of the stage
|
|
766
|
+
|
|
767
|
+
Returns:
|
|
768
|
+
Any: The stored result, or None if not found
|
|
769
|
+
"""
|
|
770
|
+
try:
|
|
771
|
+
return self.checkpoint_data["results"].get(stage_name)
|
|
772
|
+
except (KeyError, TypeError):
|
|
773
|
+
return None
|
|
774
|
+
|
|
775
|
+
def get_chunk_result(self, stage_name: str, chunk_index: int) -> Any:
|
|
776
|
+
"""
|
|
777
|
+
Get the stored result for a completed chunk.
|
|
778
|
+
|
|
779
|
+
Args:
|
|
780
|
+
stage_name (str): Name of the parent stage
|
|
781
|
+
chunk_index (int): Index of the chunk
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
Any: The stored result, or None if not found
|
|
785
|
+
"""
|
|
786
|
+
try:
|
|
787
|
+
chunk_key = str(chunk_index)
|
|
788
|
+
chunk_results = self.checkpoint_data["results"].get(f"{stage_name}_chunks")
|
|
789
|
+
|
|
790
|
+
# Check if chunk_results is a dictionary (expected format)
|
|
791
|
+
if isinstance(chunk_results, dict):
|
|
792
|
+
result = chunk_results.get(chunk_key)
|
|
793
|
+
|
|
794
|
+
# if result is a dictionary, return it
|
|
795
|
+
if isinstance(result, dict):
|
|
796
|
+
print(f"Chunk result is a dictionary: {result}")
|
|
797
|
+
return result
|
|
798
|
+
|
|
799
|
+
# If result is a string, convert to dictionary for compatibility with pipeline
|
|
800
|
+
elif isinstance(result, str):
|
|
801
|
+
self.logger.debug(f"Converting string chunk result to dictionary for chunk {chunk_index}")
|
|
802
|
+
return {
|
|
803
|
+
"output_ms": result,
|
|
804
|
+
"original_string_value": result
|
|
805
|
+
}
|
|
806
|
+
return result
|
|
807
|
+
|
|
808
|
+
# Check if it's a list (unexpected but possible in some cases)
|
|
809
|
+
elif isinstance(chunk_results, list) and 0 <= chunk_index < len(chunk_results):
|
|
810
|
+
result = chunk_results[chunk_index]
|
|
811
|
+
|
|
812
|
+
# If result is a string, convert to dictionary for compatibility with pipeline
|
|
813
|
+
if isinstance(result, str):
|
|
814
|
+
self.logger.debug(f"Converting string chunk result to dictionary for chunk {chunk_index}")
|
|
815
|
+
return {
|
|
816
|
+
"output_ms": result,
|
|
817
|
+
"original_string_value": result
|
|
818
|
+
}
|
|
819
|
+
return result
|
|
820
|
+
else:
|
|
821
|
+
return None
|
|
822
|
+
except (KeyError, TypeError, IndexError) as e:
|
|
823
|
+
self.logger.error(f"Error getting chunk result for stage {stage_name}, chunk {chunk_index}: {str(e)}")
|
|
824
|
+
return None
|
|
825
|
+
|
|
826
|
+
def store_data(self, key: str, data: Any) -> None:
|
|
827
|
+
"""
|
|
828
|
+
Store arbitrary data in the checkpoint.
|
|
829
|
+
|
|
830
|
+
Args:
|
|
831
|
+
key (str): Key to store the data under
|
|
832
|
+
data (Any): Data to store
|
|
833
|
+
"""
|
|
834
|
+
try:
|
|
835
|
+
# Special handling for chunk-related data to ensure consistency
|
|
836
|
+
if key.endswith("_chunks") and not isinstance(data, dict):
|
|
837
|
+
# If trying to store a list for a chunks key, store it as a dictionary instead
|
|
838
|
+
if isinstance(data, list):
|
|
839
|
+
self.logger.debug(f"Converting {key} from list to dictionary for storage")
|
|
840
|
+
new_data = {}
|
|
841
|
+
for idx, item in enumerate(data):
|
|
842
|
+
if item is not None:
|
|
843
|
+
new_data[str(idx)] = item
|
|
844
|
+
data = new_data
|
|
845
|
+
# If it's neither a list nor dictionary, wrap it in a dictionary with key "0"
|
|
846
|
+
elif data is not None:
|
|
847
|
+
self.logger.debug(f"Wrapping {key} data in a dictionary for storage")
|
|
848
|
+
data = {"0": data}
|
|
849
|
+
|
|
850
|
+
self.checkpoint_data["results"][key] = data
|
|
851
|
+
self._save_checkpoints()
|
|
852
|
+
except Exception as e:
|
|
853
|
+
self.logger.error(f"Error storing data under key '{key}': {str(e)}")
|
|
854
|
+
# Re-raise to allow caller to handle
|
|
855
|
+
raise
|
|
856
|
+
|
|
857
|
+
def get_data(self, key: str, default: Any = None) -> Any:
|
|
858
|
+
"""
|
|
859
|
+
Retrieve stored data from the checkpoint.
|
|
860
|
+
|
|
861
|
+
Args:
|
|
862
|
+
key (str): Key to retrieve
|
|
863
|
+
default (Any, optional): Default value if key not found
|
|
864
|
+
|
|
865
|
+
Returns:
|
|
866
|
+
Any: The stored data, or default if not found
|
|
867
|
+
"""
|
|
868
|
+
try:
|
|
869
|
+
return self.checkpoint_data["results"].get(key, default)
|
|
870
|
+
except Exception as e:
|
|
871
|
+
self.logger.error(f"Error retrieving data for key '{key}': {str(e)}")
|
|
872
|
+
return default
|
|
873
|
+
|
|
874
|
+
def verify_file_exists(self, filepath: str) -> bool:
|
|
875
|
+
"""
|
|
876
|
+
Verify that a file exists as part of checkpoint validation.
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
filepath (str): Path to the file to check
|
|
880
|
+
|
|
881
|
+
Returns:
|
|
882
|
+
bool: True if file exists, False otherwise
|
|
883
|
+
"""
|
|
884
|
+
return os.path.exists(filepath)
|
|
885
|
+
|
|
886
|
+
def get_resume_info(self) -> Dict[str, Any]:
|
|
887
|
+
"""
|
|
888
|
+
Get information about the current pipeline resumption.
|
|
889
|
+
|
|
890
|
+
Returns:
|
|
891
|
+
Dict[str, Any]: Resume information dictionary
|
|
892
|
+
"""
|
|
893
|
+
return {
|
|
894
|
+
"resume_mode": self.resume_mode,
|
|
895
|
+
"resume_count": self.checkpoint_data["pipeline_info"].get("resume_count", 0),
|
|
896
|
+
"original_start_time": self.checkpoint_data["pipeline_info"].get("start_time"),
|
|
897
|
+
"last_update": self.checkpoint_data["pipeline_info"].get("last_update")
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
def clear_checkpoints(self) -> None:
|
|
901
|
+
"""Clear all checkpoints and reset to initial state."""
|
|
902
|
+
if os.path.exists(self.checkpoint_file):
|
|
903
|
+
# Backup the checkpoint file before deleting
|
|
904
|
+
backup_file = f"{self.checkpoint_file}.bak.{int(time.time())}"
|
|
905
|
+
shutil.copy2(self.checkpoint_file, backup_file)
|
|
906
|
+
self.logger.info(f"Backed up checkpoint file to {backup_file}")
|
|
907
|
+
|
|
908
|
+
# Delete the checkpoint file
|
|
909
|
+
os.remove(self.checkpoint_file)
|
|
910
|
+
|
|
911
|
+
# Reset the checkpoint data
|
|
912
|
+
self.checkpoint_data = {
|
|
913
|
+
"pipeline_info": {
|
|
914
|
+
"start_time": datetime.now().isoformat(),
|
|
915
|
+
"last_update": datetime.now().isoformat(),
|
|
916
|
+
"working_dir": self.working_dir,
|
|
917
|
+
"resume_count": 0
|
|
918
|
+
},
|
|
919
|
+
"stages": {},
|
|
920
|
+
"results": {}
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
self.logger.info("All checkpoints cleared.")
|
|
924
|
+
self._save_checkpoints()
|
|
925
|
+
|
|
926
|
+
class CustomJSONEncoder(json.JSONEncoder):
|
|
927
|
+
"""
|
|
928
|
+
Custom JSON encoder that can handle numpy arrays and other special types.
|
|
929
|
+
|
|
930
|
+
This makes the checkpoint system more robust when saving complex results.
|
|
931
|
+
"""
|
|
932
|
+
def default(self, obj):
|
|
933
|
+
try:
|
|
934
|
+
# Handle NumPy arrays
|
|
935
|
+
if 'numpy' in sys.modules and hasattr(sys.modules['numpy'], 'ndarray'):
|
|
936
|
+
if isinstance(obj, sys.modules['numpy'].ndarray):
|
|
937
|
+
return obj.tolist()
|
|
938
|
+
|
|
939
|
+
# Handle NumPy data types
|
|
940
|
+
if 'numpy' in sys.modules and hasattr(sys.modules['numpy'], 'integer'):
|
|
941
|
+
if isinstance(obj, sys.modules['numpy'].integer):
|
|
942
|
+
return int(obj)
|
|
943
|
+
|
|
944
|
+
if 'numpy' in sys.modules and hasattr(sys.modules['numpy'], 'floating'):
|
|
945
|
+
if isinstance(obj, sys.modules['numpy'].floating):
|
|
946
|
+
return float(obj)
|
|
947
|
+
|
|
948
|
+
# Handle sets
|
|
949
|
+
if isinstance(obj, set):
|
|
950
|
+
return list(obj)
|
|
951
|
+
|
|
952
|
+
# Handle datetimes
|
|
953
|
+
if isinstance(obj, datetime.datetime):
|
|
954
|
+
return obj.isoformat()
|
|
955
|
+
|
|
956
|
+
# Handle custom objects with __dict__
|
|
957
|
+
if hasattr(obj, '__dict__'):
|
|
958
|
+
return {
|
|
959
|
+
'__class__': obj.__class__.__name__,
|
|
960
|
+
'__module__': obj.__module__,
|
|
961
|
+
'attributes': obj.__dict__
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
# For any other types, use string representation
|
|
965
|
+
return str(obj)
|
|
966
|
+
except:
|
|
967
|
+
return str(obj)
|
|
968
|
+
|
|
969
|
+
# Let the base class handle it otherwise
|
|
970
|
+
return json.JSONEncoder.default(self, obj)
|
|
971
|
+
|
|
972
|
+
def create_checkpoint_manager(working_dir: str, resume: bool = False, logger: Optional[logging.Logger] = None) -> CheckpointManager:
|
|
973
|
+
"""
|
|
974
|
+
Factory function to create a CheckpointManager instance.
|
|
975
|
+
|
|
976
|
+
Args:
|
|
977
|
+
working_dir (str): Pipeline working directory
|
|
978
|
+
resume (bool): Whether to resume from checkpoints
|
|
979
|
+
logger (logging.Logger, optional): Logger instance
|
|
980
|
+
|
|
981
|
+
Returns:
|
|
982
|
+
CheckpointManager: Initialized checkpoint manager
|
|
983
|
+
"""
|
|
984
|
+
return CheckpointManager(working_dir, resume, logger)
|