imagebaker 0.0.41__py3-none-any.whl → 0.0.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. imagebaker/__init__.py +1 -1
  2. imagebaker/core/__init__.py +0 -0
  3. imagebaker/core/configs/__init__.py +1 -0
  4. imagebaker/core/configs/configs.py +156 -0
  5. imagebaker/core/defs/__init__.py +1 -0
  6. imagebaker/core/defs/defs.py +258 -0
  7. imagebaker/core/plugins/__init__.py +0 -0
  8. imagebaker/core/plugins/base_plugin.py +39 -0
  9. imagebaker/core/plugins/cosine_plugin.py +39 -0
  10. imagebaker/layers/__init__.py +3 -0
  11. imagebaker/layers/annotable_layer.py +847 -0
  12. imagebaker/layers/base_layer.py +724 -0
  13. imagebaker/layers/canvas_layer.py +1007 -0
  14. imagebaker/list_views/__init__.py +3 -0
  15. imagebaker/list_views/annotation_list.py +203 -0
  16. imagebaker/list_views/canvas_list.py +185 -0
  17. imagebaker/list_views/image_list.py +138 -0
  18. imagebaker/list_views/layer_list.py +390 -0
  19. imagebaker/list_views/layer_settings.py +219 -0
  20. imagebaker/models/__init__.py +0 -0
  21. imagebaker/models/base_model.py +150 -0
  22. imagebaker/tabs/__init__.py +2 -0
  23. imagebaker/tabs/baker_tab.py +496 -0
  24. imagebaker/tabs/layerify_tab.py +837 -0
  25. imagebaker/utils/__init__.py +0 -0
  26. imagebaker/utils/image.py +105 -0
  27. imagebaker/utils/state_utils.py +92 -0
  28. imagebaker/utils/transform_mask.py +107 -0
  29. imagebaker/window/__init__.py +1 -0
  30. imagebaker/window/app.py +136 -0
  31. imagebaker/window/main_window.py +181 -0
  32. imagebaker/workers/__init__.py +3 -0
  33. imagebaker/workers/baker_worker.py +247 -0
  34. imagebaker/workers/layerify_worker.py +91 -0
  35. imagebaker/workers/model_worker.py +54 -0
  36. {imagebaker-0.0.41.dist-info → imagebaker-0.0.48.dist-info}/METADATA +6 -6
  37. imagebaker-0.0.48.dist-info/RECORD +41 -0
  38. {imagebaker-0.0.41.dist-info → imagebaker-0.0.48.dist-info}/WHEEL +1 -1
  39. imagebaker-0.0.41.dist-info/RECORD +0 -7
  40. {imagebaker-0.0.41.dist-info/licenses → imagebaker-0.0.48.dist-info}/LICENSE +0 -0
  41. {imagebaker-0.0.41.dist-info → imagebaker-0.0.48.dist-info}/entry_points.txt +0 -0
  42. {imagebaker-0.0.41.dist-info → imagebaker-0.0.48.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
1
+ from abc import ABC, abstractmethod
2
+ from loguru import logger
3
+ import numpy as np
4
+ import time
5
+ import cv2
6
+
7
+ from imagebaker.core.defs.defs import ModelType, PredictionResult
8
+ from imagebaker.core.configs import DefaultModelConfig
9
+
10
+
11
+ class BaseModel(ABC):
12
+ def __init__(self, config: DefaultModelConfig):
13
+ """
14
+ A base class for all models.
15
+
16
+ Args:
17
+ config (DefaultModelConfig): Model configuration.
18
+ """
19
+ self.config = config
20
+ self.model = None
21
+ self.image_shape: tuple = None
22
+
23
+ self.setup()
24
+
25
+ @property
26
+ def name(self):
27
+ # class name
28
+ return self.__class__.__name__
29
+
30
+ def __repr__(self):
31
+ return f"{self.config.model_name} v{self.config.model_version}"
32
+
33
+ # @abstractmethod
34
+ def setup(self):
35
+ pass
36
+
37
+ # @abstractmethod
38
+ def preprocess(self, image: np.ndarray):
39
+ return image
40
+
41
+ # @abstractmethod
42
+ def postprocess(self, output) -> PredictionResult:
43
+ return output
44
+
45
+ def predict(
46
+ self,
47
+ image: np.ndarray,
48
+ points: list[int] | None = None,
49
+ rectangles: list[list[int]] | None = None,
50
+ polygons: list[list[int]] | None = None,
51
+ label_hints: list[int] | None = None,
52
+ ) -> list[PredictionResult]:
53
+ t0 = time.time()
54
+ self.image_shape = image.shape[:2]
55
+ if image.shape[2] == 4:
56
+ logger.info("Converting image from RGBA to RGB")
57
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
58
+ preprocessed_image = self.preprocess(image)
59
+ t1 = time.time()
60
+ logger.info(f"Preprocessing time: {t1-t0:.4f} seconds")
61
+
62
+ if self.config.model_type == ModelType.DETECTION:
63
+ output = self.predict_boxes(preprocessed_image)
64
+
65
+ elif self.config.model_type == ModelType.SEGMENTATION:
66
+ output = self.predict_mask(preprocessed_image)
67
+
68
+ elif self.config.model_type == ModelType.CLASSIFICATION:
69
+ output = self.predict_class(preprocessed_image)
70
+
71
+ elif self.config.model_type == ModelType.PROMPT:
72
+
73
+ output = self.predict_prompt(
74
+ preprocessed_image, points, rectangles, label_hints
75
+ )
76
+
77
+ t2 = time.time()
78
+ logger.info(f"Prediction time: {t2-t1:.4f} seconds")
79
+
80
+ result = self.postprocess(output)
81
+ t3 = time.time()
82
+ logger.info(f"Postprocessing time: {t3-t2:.4f} seconds")
83
+
84
+ return result
85
+
86
+
87
+ class BaseSegmentationModel(BaseModel):
88
+ def __init__(self, config: DefaultModelConfig):
89
+ super().__init__(config)
90
+
91
+ @abstractmethod
92
+ def predict_mask(self, image):
93
+ pass
94
+
95
+
96
+ class BaseDetectionModel(BaseModel):
97
+ def __init__(self, config: DefaultModelConfig):
98
+ super().__init__(config)
99
+
100
+ @abstractmethod
101
+ def predict_boxes(self, image):
102
+ pass
103
+
104
+
105
+ class BaseClassificationModel(BaseModel):
106
+ def __init__(self, config: DefaultModelConfig):
107
+ super().__init__(config)
108
+
109
+ @abstractmethod
110
+ def predict_class(self, image):
111
+ pass
112
+
113
+
114
+ class BasePromptModel(BaseModel):
115
+ def __init__(self, config: DefaultModelConfig):
116
+ super().__init__(config)
117
+
118
+ @abstractmethod
119
+ def predict_prompt(self, image, points, rectangles, polygons):
120
+ pass
121
+
122
+
123
+ def get_dummy_prediction_result(result_type: ModelType) -> PredictionResult:
124
+ if result_type == ModelType.DETECTION:
125
+ # retrun random rectangle
126
+
127
+ x1, y1 = np.random.randint(0, 1000, 2)
128
+ x2, y2 = np.random.randint(200, 500, 2)
129
+ return PredictionResult(
130
+ class_name="dummy", class_id=0, score=0.99, rectangle=[x1, y1, x2, y2]
131
+ )
132
+
133
+ elif result_type == ModelType.SEGMENTATION:
134
+ return PredictionResult(
135
+ class_name="dummy",
136
+ class_id=0,
137
+ score=0.99,
138
+ mask=[[0, 0], [0, 100], [100, 100], [100, 0]],
139
+ )
140
+
141
+ elif result_type == ModelType.CLASSIFICATION:
142
+ return PredictionResult(class_name="dummy", class_id=0, score=0.99)
143
+
144
+ elif result_type == ModelType.PROMPT:
145
+ return PredictionResult(
146
+ prompt="dummy",
147
+ class_id=0,
148
+ score=0.99,
149
+ mask=[[0, 0], [0, 100], [100, 100], [100, 0]],
150
+ )
@@ -0,0 +1,2 @@
1
+ from .baker_tab import BakerTab # noqa
2
+ from .layerify_tab import LayerifyTab # noqa
@@ -0,0 +1,496 @@
1
+ from imagebaker.list_views import LayerList, LayerSettings
2
+ from imagebaker.list_views.canvas_list import CanvasList
3
+ from imagebaker.layers.canvas_layer import CanvasLayer
4
+ from imagebaker.core.defs import BakingResult, MouseMode, Annotation
5
+ from imagebaker.core.configs import CanvasConfig
6
+ from imagebaker import logger
7
+
8
+ from PySide6.QtCore import Qt, Signal
9
+ from PySide6.QtGui import QPixmap
10
+ from PySide6.QtWidgets import (
11
+ QColorDialog,
12
+ QWidget,
13
+ QVBoxLayout,
14
+ QHBoxLayout,
15
+ QPushButton,
16
+ QSizePolicy,
17
+ QDockWidget,
18
+ QSlider,
19
+ QLabel,
20
+ QSpinBox,
21
+ QComboBox,
22
+ )
23
+ from collections import deque
24
+
25
+
26
+ class BakerTab(QWidget):
27
+ """Baker Tab implementation"""
28
+
29
+ messageSignal = Signal(str)
30
+ bakingResult = Signal(BakingResult)
31
+
32
+ def __init__(self, main_window, config: CanvasConfig):
33
+ """Initialize the Baker Tab."""
34
+ super().__init__(main_window)
35
+ self.main_window = main_window
36
+ self.config = config
37
+ self.toolbar = None
38
+ self.main_layout = QVBoxLayout(self)
39
+
40
+ # Deque to store multiple CanvasLayer objects with a fixed size
41
+ self.canvases = deque(maxlen=self.config.deque_maxlen)
42
+
43
+ # Currently selected canvas
44
+ self.current_canvas = None
45
+
46
+ self.init_ui()
47
+
48
+ def init_ui(self):
49
+ """Initialize the UI components."""
50
+ # Create toolbar
51
+ self.create_toolbar()
52
+
53
+ # Create a single canvas for now
54
+ self.current_canvas = CanvasLayer(parent=self.main_window, config=self.config)
55
+ self.current_canvas.setVisible(True) # Initially hide all canvases
56
+ self.canvases.append(self.current_canvas)
57
+ self.main_layout.addWidget(self.current_canvas)
58
+
59
+ # Create and add CanvasList
60
+ self.canvas_list = CanvasList(self.canvases, parent=self.main_window)
61
+ self.main_window.addDockWidget(Qt.LeftDockWidgetArea, self.canvas_list)
62
+
63
+ # Create and add LayerList
64
+ self.layer_settings = LayerSettings(
65
+ parent=self.main_window,
66
+ max_xpos=self.config.max_xpos,
67
+ max_ypos=self.config.max_ypos,
68
+ max_scale=self.config.max_scale,
69
+ )
70
+ self.layer_list = LayerList(
71
+ canvas=self.current_canvas,
72
+ parent=self.main_window,
73
+ layer_settings=self.layer_settings,
74
+ )
75
+ self.layer_settings.setVisible(False)
76
+ self.main_window.addDockWidget(Qt.RightDockWidgetArea, self.layer_list)
77
+ self.main_window.addDockWidget(Qt.RightDockWidgetArea, self.layer_settings)
78
+
79
+ # Create a dock widget for the toolbar
80
+ self.toolbar_dock = QDockWidget("Tools", self)
81
+ self.toolbar_dock.setWidget(self.toolbar)
82
+ self.toolbar_dock.setFeatures(
83
+ QDockWidget.DockWidgetMovable | QDockWidget.DockWidgetFloatable
84
+ )
85
+ self.main_window.addDockWidget(Qt.BottomDockWidgetArea, self.toolbar_dock)
86
+
87
+ # Connections
88
+ self.layer_settings.messageSignal.connect(self.messageSignal.emit)
89
+ self.current_canvas.bakingResult.connect(self.bakingResult.emit)
90
+ self.current_canvas.layersChanged.connect(self.update_list)
91
+ self.current_canvas.layerRemoved.connect(self.update_list)
92
+
93
+ self.canvas_list.canvasSelected.connect(self.on_canvas_selected)
94
+ self.canvas_list.canvasAdded.connect(self.on_canvas_added)
95
+ self.canvas_list.canvasDeleted.connect(self.on_canvas_deleted)
96
+ # self.current_canvas.thumbnailsAvailable.connect(self.generate_state_previews)
97
+
98
+ def update_slider_range(self, steps):
99
+ """Update the slider range based on the number of steps."""
100
+ self.timeline_slider.setMaximum(steps - 1)
101
+ self.messageSignal.emit(f"Updated steps to {steps}")
102
+ self.timeline_slider.setEnabled(False) # Disable the slider
103
+ self.timeline_slider.update()
104
+
105
+ def generate_state_previews(self):
106
+ """Generate previews for each state."""
107
+ # Clear existing previews
108
+ for i in reversed(range(self.preview_layout.count())):
109
+ widget = self.preview_layout.itemAt(i).widget()
110
+ if widget:
111
+ widget.deleteLater()
112
+
113
+ # Generate a preview for each state
114
+ for step, states in sorted(self.current_canvas.states.items()):
115
+ # Create a container widget for the preview
116
+ preview_widget = QWidget()
117
+ preview_layout = QVBoxLayout(preview_widget)
118
+ preview_layout.setContentsMargins(0, 0, 0, 0)
119
+ preview_layout.setSpacing(2)
120
+
121
+ # Placeholder thumbnail
122
+ placeholder = QPixmap(50, 50)
123
+ placeholder.fill(Qt.gray) # Gray placeholder
124
+ thumbnail_label = QLabel()
125
+ thumbnail_label.setPixmap(placeholder)
126
+ thumbnail_label.setFixedSize(50, 50) # Set a fixed size for the thumbnail
127
+ thumbnail_label.setScaledContents(True)
128
+
129
+ # Add the step number on top of the thumbnail
130
+ step_label = QLabel(f"Step {step}")
131
+ step_label.setAlignment(Qt.AlignCenter)
132
+ step_label.setStyleSheet("font-weight: bold; font-size: 10px;")
133
+
134
+ # Add a button to make the preview clickable
135
+ preview_button = QPushButton()
136
+ preview_button.setFixedSize(
137
+ 50, 70
138
+ ) # Match the size of the thumbnail + step label
139
+ preview_button.setStyleSheet("background: transparent; border: none;")
140
+ preview_button.clicked.connect(lambda _, s=step: self.seek_state(s))
141
+
142
+ # Add the thumbnail and step label to the layout
143
+ preview_layout.addWidget(thumbnail_label)
144
+ preview_layout.addWidget(step_label)
145
+
146
+ # Add the preview widget to the button
147
+ preview_button.setLayout(preview_layout)
148
+
149
+ # Add the button to the preview panel
150
+ self.preview_layout.addWidget(preview_button)
151
+
152
+ # Update the thumbnail dynamically when it becomes available
153
+ self.current_canvas.thumbnailsAvailable.connect(
154
+ lambda step=step, label=thumbnail_label: self.update_thumbnail(
155
+ step, label
156
+ )
157
+ )
158
+
159
+ # Refresh the preview panel
160
+ self.preview_panel.update()
161
+
162
+ def update_thumbnail(self, step, thumbnail_label):
163
+ """Update the thumbnail for a specific step."""
164
+ if step in self.current_canvas.state_thumbnail:
165
+ thumbnail = self.current_canvas.state_thumbnail[step]
166
+ thumbnail_label.setPixmap(thumbnail)
167
+ thumbnail_label.update()
168
+
169
+ def update_list(self, layer=None):
170
+ """Update the layer list and layer settings."""
171
+ if layer:
172
+ self.layer_list.layers = self.current_canvas.layers
173
+ self.layer_list.update_list()
174
+ self.layer_settings.update_sliders()
175
+ self.update()
176
+
177
+ def on_canvas_deleted(self, canvas: CanvasLayer):
178
+ """Handle the deletion of a canvas."""
179
+ # Ensure only the currently selected canvas is visible
180
+ if self.canvases:
181
+ self.layer_list.canvas = self.canvases[-1]
182
+ self.layer_list.layers = self.canvases[-1].layers
183
+ self.current_canvas = self.canvases[-1] # Select the last canvas
184
+ self.current_canvas.setVisible(True) # Show the last canvas
185
+ else:
186
+ self.current_canvas = None # No canvases left
187
+ self.messageSignal.emit("No canvases available.") # Notify the user
188
+ self.layer_list.canvas = None
189
+ self.layer_list.layers = []
190
+ self.layer_settings.selected_layer = None
191
+ self.layer_settings.update_sliders()
192
+ self.canvas_list.update_canvas_list() # Update the canvas list
193
+ self.layer_list.update_list()
194
+ self.update()
195
+
196
+ def on_canvas_selected(self, canvas: CanvasLayer):
197
+ """Handle canvas selection from the CanvasList."""
198
+ # Hide all canvases and show only the selected one
199
+ for layer in self.canvases:
200
+ layer.setVisible(layer == canvas)
201
+
202
+ # Update the current canvas
203
+ self.current_canvas = canvas
204
+ self.layer_list.canvas = canvas
205
+ self.layer_list.layers = canvas.layers
206
+ self.layer_settings.selected_layer = canvas.selected_layer
207
+ self.layer_list.layer_settings = self.layer_settings
208
+
209
+ self.layer_list.update_list()
210
+ self.layer_settings.update_sliders()
211
+
212
+ logger.info(f"Selected canvas: {canvas.layer_name}")
213
+ self.update()
214
+
215
+ def on_canvas_added(self, new_canvas: CanvasLayer):
216
+ """Handle the addition of a new canvas."""
217
+ logger.info(f"New canvas added: {new_canvas.layer_name}")
218
+ self.main_layout.addWidget(new_canvas) # Add the new canvas to the layout
219
+ if self.current_canvas is not None:
220
+ self.current_canvas.setVisible(False) # Hide the current canvas
221
+
222
+ # self.canvases.append(new_canvas) # Add the new canvas to the deque
223
+ # connect it to the layer list
224
+ self.layer_list.canvas = new_canvas
225
+ self.current_canvas = new_canvas # Update the current canvas
226
+ self.canvas_list.update_canvas_list() # Update the canvas list
227
+ new_canvas.setVisible(True) # Hide the new canvas initially
228
+ # already added to the list
229
+ # self.canvases.append(new_canvas) # Add to the deque
230
+
231
+ self.current_canvas.bakingResult.connect(self.bakingResult.emit)
232
+ self.current_canvas.layersChanged.connect(self.update_list)
233
+ self.current_canvas.layerRemoved.connect(self.update_list)
234
+
235
+ self.current_canvas.update()
236
+ self.layer_list.layers = new_canvas.layers
237
+ self.layer_list.update_list()
238
+ self.layer_settings.selected_layer = None
239
+ self.layer_settings.update_sliders()
240
+
241
+ def create_toolbar(self):
242
+ """Create Baker-specific toolbar"""
243
+ self.toolbar = QWidget()
244
+ baker_toolbar_layout = QHBoxLayout(self.toolbar)
245
+ baker_toolbar_layout.setContentsMargins(5, 5, 5, 5)
246
+ baker_toolbar_layout.setSpacing(10)
247
+
248
+ # Add a label for "Steps"
249
+ steps_label = QLabel("Steps:")
250
+ steps_label.setStyleSheet("font-weight: bold;")
251
+ baker_toolbar_layout.addWidget(steps_label)
252
+
253
+ # Add a spin box for entering the number of steps
254
+ self.steps_spinbox = QSpinBox()
255
+ self.steps_spinbox.setMinimum(1)
256
+ self.steps_spinbox.setMaximum(1000) # Arbitrary maximum value
257
+ self.steps_spinbox.setValue(1) # Default value
258
+ self.steps_spinbox.valueChanged.connect(self.update_slider_range)
259
+ baker_toolbar_layout.addWidget(self.steps_spinbox)
260
+
261
+ # Add buttons for Baker modes with emojis
262
+ baker_modes = [
263
+ ("📤 Export Current State", self.export_current_state),
264
+ ("💾 Save State", self.save_current_state),
265
+ ("🔮 Predict State", self.predict_state),
266
+ ("▶️ Play States", self.play_saved_states),
267
+ ("🗑️ Clear States", self.clear_states), # New button
268
+ ("📤 Annotate States", self.export_for_annotation),
269
+ ("📤 Export States", self.export_locally),
270
+ ]
271
+
272
+ for text, callback in baker_modes:
273
+ btn = QPushButton(text)
274
+ btn.clicked.connect(callback)
275
+ baker_toolbar_layout.addWidget(btn)
276
+
277
+ # If the button is "Play States", add the slider beside it
278
+ if text == "▶️ Play States":
279
+ self.timeline_slider = QSlider(Qt.Horizontal) # Create the slider
280
+ self.timeline_slider.setMinimum(0)
281
+ self.timeline_slider.setMaximum(0) # Will be updated dynamically
282
+ self.timeline_slider.setValue(0)
283
+ self.timeline_slider.setSingleStep(
284
+ 1
285
+ ) # Set the granularity of the slider
286
+ self.timeline_slider.setPageStep(1) # Allow smoother jumps
287
+ self.timeline_slider.setEnabled(False) # Initially disabled
288
+ self.timeline_slider.valueChanged.connect(self.seek_state)
289
+ baker_toolbar_layout.addWidget(self.timeline_slider)
290
+
291
+ # Add a drawing button
292
+ draw_button = QPushButton("✏️ Draw")
293
+ draw_button.setCheckable(True) # Make it toggleable
294
+ draw_button.clicked.connect(self.toggle_drawing_mode)
295
+ baker_toolbar_layout.addWidget(draw_button)
296
+
297
+ # Add an erase button
298
+ erase_button = QPushButton("🧹 Erase")
299
+ erase_button.setCheckable(True) # Make it toggleable
300
+ erase_button.clicked.connect(self.toggle_erase_mode)
301
+ baker_toolbar_layout.addWidget(erase_button)
302
+
303
+ # Add a color picker button
304
+ color_picker_button = QPushButton("🎨")
305
+ color_picker_button.clicked.connect(self.open_color_picker)
306
+ baker_toolbar_layout.addWidget(color_picker_button)
307
+
308
+ # Add a spacer to push the rest of the elements to the right
309
+ spacer = QWidget()
310
+ spacer.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred)
311
+ baker_toolbar_layout.addWidget(spacer)
312
+
313
+ # Add the toolbar to the main layout
314
+ self.main_layout.addWidget(self.toolbar)
315
+
316
+ def toggle_drawing_mode(self):
317
+ """Toggle drawing mode on the current canvas."""
318
+ if self.current_canvas:
319
+ self.current_canvas.mouse_mode = (
320
+ MouseMode.DRAW
321
+ if self.current_canvas.mouse_mode != MouseMode.DRAW
322
+ else MouseMode.IDLE
323
+ )
324
+ mode = self.current_canvas.mouse_mode.name.lower()
325
+ self.messageSignal.emit(f"Drawing mode {mode}.")
326
+
327
+ def toggle_erase_mode(self):
328
+ """Toggle drawing mode on the current canvas."""
329
+ if self.current_canvas:
330
+ self.current_canvas.mouse_mode = (
331
+ MouseMode.ERASE
332
+ if self.current_canvas.mouse_mode != MouseMode.ERASE
333
+ else MouseMode.IDLE
334
+ )
335
+ mode = self.current_canvas.mouse_mode.name.lower()
336
+ self.messageSignal.emit(f"Erasing mode {mode}.")
337
+
338
+ def open_color_picker(self):
339
+ """Open a color picker dialog to select a custom color."""
340
+ color = QColorDialog.getColor()
341
+ if color.isValid():
342
+ self.current_canvas.drawing_color = color
343
+ self.messageSignal.emit(f"Selected custom color: {color.name()}")
344
+
345
+ def export_for_annotation(self):
346
+ """Export the baked states for annotation."""
347
+ self.messageSignal.emit("Exporting states for prediction...")
348
+ self.current_canvas.export_baked_states(export_to_annotation_tab=True)
349
+
350
+ def export_locally(self):
351
+ """Export the baked states locally."""
352
+ self.messageSignal.emit("Exporting baked states...")
353
+ self.current_canvas.export_baked_states()
354
+
355
+ def play_saved_states(self):
356
+ """Play the saved states in sequence."""
357
+ self.messageSignal.emit("Playing saved state...")
358
+
359
+ # Enable the timeline slider
360
+
361
+ # Update the slider range based on the number of states
362
+ if self.current_canvas.states:
363
+ num_states = len(self.current_canvas.states)
364
+ self.timeline_slider.setMaximum(num_states - 1)
365
+ self.steps_spinbox.setValue(
366
+ num_states
367
+ ) # Sync the spinbox with the number of states
368
+ self.timeline_slider.setEnabled(True)
369
+ else:
370
+ self.timeline_slider.setMaximum(0)
371
+ self.steps_spinbox.setValue(1)
372
+ self.messageSignal.emit("No saved states available.")
373
+ self.timeline_slider.setEnabled(False)
374
+
375
+ self.timeline_slider.update()
376
+ # Start playing the states
377
+ self.current_canvas.play_states()
378
+
379
+ def save_current_state(self):
380
+ """Save the current state of the canvas."""
381
+ self.messageSignal.emit("Saving current state...")
382
+ logger.info(f"Saving current state for {self.steps_spinbox.value()}...")
383
+
384
+ self.current_canvas.save_current_state(steps=self.steps_spinbox.value())
385
+ self.messageSignal.emit(
386
+ "Current state saved. Total states: {}".format(
387
+ len(self.current_canvas.states)
388
+ )
389
+ )
390
+
391
+ self.steps_spinbox.setValue(1) # Reset the spinbox value
392
+ self.steps_spinbox.update()
393
+ # Disable the timeline slider
394
+ self.timeline_slider.setEnabled(False)
395
+ self.timeline_slider.update()
396
+
397
+ def clear_states(self):
398
+ """Clear all saved states and disable the timeline slider."""
399
+ self.messageSignal.emit("Clearing all saved states...")
400
+ if self.current_canvas:
401
+ self.current_canvas.previous_state = None
402
+ self.current_canvas.current_step = 0
403
+ self.current_canvas.states.clear() # Clear all saved states
404
+ self.timeline_slider.setEnabled(False) # Disable the slider
405
+ self.timeline_slider.setMaximum(0) # Reset the slider range
406
+ self.timeline_slider.setValue(0) # Reset the slider position
407
+ self.messageSignal.emit("All states cleared.")
408
+ self.steps_spinbox.setValue(1) # Reset the spinbox value
409
+
410
+ self.steps_spinbox.update()
411
+ self.timeline_slider.update()
412
+ self.current_canvas.update()
413
+
414
+ def seek_state(self, step):
415
+ """Seek to a specific state using the timeline slider."""
416
+ self.messageSignal.emit(f"Seeking to step {step}")
417
+ logger.info(f"Seeking to step {step}")
418
+
419
+ # Get the states for the selected step
420
+ if step in self.current_canvas.states:
421
+ states = self.current_canvas.states[step]
422
+ for state in states:
423
+ layer = self.current_canvas.get_layer(state.layer_id)
424
+ if layer:
425
+ layer.layer_state = state
426
+ layer.update()
427
+
428
+ # Update the canvas
429
+ self.current_canvas.update()
430
+
431
+ def export_current_state(self):
432
+ """Export the current state as an image."""
433
+ self.messageSignal.emit("Exporting current state...")
434
+ self.current_canvas.export_current_state()
435
+
436
+ def predict_state(self):
437
+ """Pass the current state to predict."""
438
+ self.messageSignal.emit("Predicting state...")
439
+
440
+ self.current_canvas.predict_state()
441
+
442
+ def add_layer(self, layer: CanvasLayer):
443
+ """Add a new layer to the canvas."""
444
+ self.layer_list.add_layer(layer)
445
+ self.layer_settings.selected_layer = self.current_canvas.selected_layer
446
+ self.layer_settings.update_sliders()
447
+
448
+ def keyPressEvent(self, event):
449
+ """Handle key press events."""
450
+ # Ctrl + S: Save the current state
451
+ curr_mode = self.current_canvas.mouse_mode
452
+ if event.key() == Qt.Key_S and event.modifiers() == Qt.ControlModifier:
453
+ self.save_current_state()
454
+ self.current_canvas.mouse_mode = curr_mode
455
+ self.current_canvas.update()
456
+ # if ctrl + D: Toggle drawing mode
457
+ if event.key() == Qt.Key_D and event.modifiers() == Qt.ControlModifier:
458
+ self.toggle_drawing_mode()
459
+ self.current_canvas.update()
460
+ # if ctrl + E: Toggle erase mode
461
+ if event.key() == Qt.Key_E and event.modifiers() == Qt.ControlModifier:
462
+ self.toggle_erase_mode()
463
+ self.current_canvas.update()
464
+
465
+ # Delete: Delete the selected layer
466
+ if event.key() == Qt.Key_Delete:
467
+ if (
468
+ self.current_canvas.selected_layer
469
+ and self.current_canvas.selected_layer in self.current_canvas.layers
470
+ ):
471
+ self.current_canvas.layers.remove(self.current_canvas.selected_layer)
472
+ self.current_canvas.selected_layer = None
473
+ self.layer_settings.selected_layer = None
474
+
475
+ self.current_canvas.update()
476
+ self.layer_list.update_list()
477
+ self.layer_settings.update_sliders()
478
+ self.messageSignal.emit("Deleted selected layer")
479
+
480
+ # Ctrl + N: Add a new layer to the current canvas
481
+ if event.key() == Qt.Key_N and event.modifiers() == Qt.ControlModifier:
482
+ new_layer = CanvasLayer(parent=self.current_canvas)
483
+ new_layer.layer_name = f"Layer {len(self.current_canvas.layers) + 1}"
484
+ new_layer.annotations = [
485
+ Annotation(annotation_id=0, label="New Annotation"),
486
+ ]
487
+ balnk_qimage = QPixmap(self.current_canvas.size())
488
+ balnk_qimage.fill(Qt.transparent)
489
+ new_layer.set_image(balnk_qimage)
490
+ self.current_canvas.layers.append(new_layer)
491
+ self.current_canvas.update()
492
+ self.layer_list.update_list()
493
+ self.messageSignal.emit(f"Added new layer: {new_layer.layer_name}")
494
+
495
+ self.update()
496
+ return super().keyPressEvent(event)