celldetective 1.5.0b7__py3-none-any.whl → 1.5.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. celldetective/_version.py +1 -1
  2. celldetective/event_detection_models.py +2463 -0
  3. celldetective/gui/base/channel_norm_generator.py +19 -3
  4. celldetective/gui/base/figure_canvas.py +1 -1
  5. celldetective/gui/base/list_widget.py +1 -1
  6. celldetective/gui/base_annotator.py +2 -5
  7. celldetective/gui/event_annotator.py +248 -138
  8. celldetective/gui/generic_signal_plot.py +14 -14
  9. celldetective/gui/gui_utils.py +27 -6
  10. celldetective/gui/pair_event_annotator.py +146 -20
  11. celldetective/gui/plot_signals_ui.py +32 -15
  12. celldetective/gui/process_block.py +2 -2
  13. celldetective/gui/seg_model_loader.py +4 -4
  14. celldetective/gui/settings/_settings_event_model_training.py +32 -14
  15. celldetective/gui/settings/_settings_segmentation_model_training.py +5 -5
  16. celldetective/gui/settings/_settings_signal_annotator.py +0 -19
  17. celldetective/gui/survival_ui.py +39 -11
  18. celldetective/gui/tableUI.py +69 -148
  19. celldetective/gui/thresholds_gui.py +45 -5
  20. celldetective/gui/viewers/base_viewer.py +17 -20
  21. celldetective/gui/viewers/spot_detection_viewer.py +136 -27
  22. celldetective/processes/train_signal_model.py +1 -1
  23. celldetective/scripts/train_signal_model.py +1 -1
  24. celldetective/signals.py +4 -2426
  25. celldetective/utils/event_detection/__init__.py +1 -1
  26. {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b9.dist-info}/METADATA +1 -1
  27. {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b9.dist-info}/RECORD +33 -31
  28. tests/gui/test_spot_detection_viewer.py +187 -0
  29. tests/test_signals.py +135 -116
  30. {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b9.dist-info}/WHEEL +0 -0
  31. {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b9.dist-info}/entry_points.txt +0 -0
  32. {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b9.dist-info}/licenses/LICENSE +0 -0
  33. {celldetective-1.5.0b7.dist-info → celldetective-1.5.0b9.dist-info}/top_level.txt +0 -0
@@ -74,6 +74,7 @@ class ThresholdConfigWizard(CelldetectiveMainWindow):
74
74
 
75
75
  super().__init__()
76
76
  self.parent_window = parent_window
77
+ # Navigate explicit parent chain: SegModelLoader -> ControlPanel -> ProcessPanel -> MainWindow
77
78
  self.screen_height = (
78
79
  self.parent_window.parent_window.parent_window.parent_window.screen_height
79
80
  )
@@ -124,6 +125,17 @@ class ThresholdConfigWizard(CelldetectiveMainWindow):
124
125
  self.bg_loader = BackgroundLoader()
125
126
  self.bg_loader.start()
126
127
 
128
+ def closeEvent(self, event):
129
+ """Clean up resources on close."""
130
+ if hasattr(self, "bg_loader") and self.bg_loader.isRunning():
131
+ self.bg_loader.quit()
132
+ self.bg_loader.wait()
133
+ # Clear large arrays
134
+ for attr in ["img", "labels", "edt_map", "props", "coords"]:
135
+ if hasattr(self, attr):
136
+ delattr(self, attr)
137
+ super().closeEvent(event)
138
+
127
139
  def _create_menu_bar(self):
128
140
  menu_bar = self.menuBar()
129
141
  # Creating menus using a QMenu object
@@ -810,7 +822,8 @@ class ThresholdConfigWizard(CelldetectiveMainWindow):
810
822
  for i in range(2):
811
823
  try:
812
824
  self.features_cb[i].disconnect()
813
- except Exception as _:
825
+ except TypeError:
826
+ # No connections to disconnect
814
827
  pass
815
828
  self.features_cb[i].clear()
816
829
 
@@ -879,12 +892,39 @@ class ThresholdConfigWizard(CelldetectiveMainWindow):
879
892
  self.exp_dir + f"configs/threshold_config_{self.mode}.json",
880
893
  "JSON (*.json)",
881
894
  )[0]
882
- with open(self.previous_instruction_file, "r") as f:
883
- threshold_instructions = json.load(f)
895
+
896
+ if not self.previous_instruction_file:
897
+ return # User cancelled
898
+
899
+ try:
900
+ with open(self.previous_instruction_file, "r") as f:
901
+ threshold_instructions = json.load(f)
902
+ except (FileNotFoundError, json.JSONDecodeError) as e:
903
+ generic_message(f"Could not load config: {e}")
904
+ return
905
+
906
+ # Validate required keys
907
+ required_keys = [
908
+ "target_channel",
909
+ "filters",
910
+ "thresholds",
911
+ "marker_footprint_size",
912
+ "marker_min_distance",
913
+ "feature_queries",
914
+ ]
915
+ missing_keys = [k for k in required_keys if k not in threshold_instructions]
916
+ if missing_keys:
917
+ generic_message(f"Config file is missing required keys: {missing_keys}")
918
+ return
884
919
 
885
920
  target_channel = threshold_instructions["target_channel"]
886
- index = self.viewer.channels_cb.findText(target_channel)
887
- self.viewer.channels_cb.setCurrentIndex(index)
921
+ index = self.viewer.channel_cb.findText(target_channel)
922
+ if index >= 0:
923
+ self.viewer.channel_cb.setCurrentIndex(index)
924
+ else:
925
+ logger.warning(
926
+ f"Channel '{target_channel}' not found in available channels"
927
+ )
888
928
 
889
929
  filters = threshold_instructions["filters"]
890
930
  items_to_add = [f[0] + "_filter" for f in filters]
@@ -6,6 +6,7 @@ from PyQt5.QtWidgets import QHBoxLayout, QAction, QLabel, QComboBox
6
6
  from fonticon_mdi6 import MDI6
7
7
  from superqt import QLabeledDoubleRangeSlider, QLabeledSlider
8
8
  from superqt.fonticon import icon
9
+ import matplotlib.gridspec as gridspec
9
10
 
10
11
  from celldetective.gui.base.components import CelldetectiveWidget
11
12
  from celldetective.gui.base.utils import center_window
@@ -102,8 +103,7 @@ class StackLoader(QThread):
102
103
  self.mutex.unlock()
103
104
 
104
105
  except Exception as e:
105
- pass
106
- # logger.error(f"Error loading frame {frame_to_load}: {e}")
106
+ logger.debug(f"Error loading frame {frame_to_load}: {e}")
107
107
  # Prepare to wait to avoid spin loop on error
108
108
  self.msleep(100)
109
109
 
@@ -169,10 +169,14 @@ class StackVisualizer(CelldetectiveWidget):
169
169
  window_title="View",
170
170
  PxToUm=None,
171
171
  background_color="transparent",
172
- imshow_kwargs={},
172
+ imshow_kwargs=None,
173
173
  ):
174
174
  super().__init__()
175
175
 
176
+ # Default mutable argument handling
177
+ if imshow_kwargs is None:
178
+ imshow_kwargs = {}
179
+
176
180
  # self.setWindowTitle(window_title)
177
181
  self.window_title = window_title
178
182
 
@@ -265,7 +269,6 @@ class StackVisualizer(CelldetectiveWidget):
265
269
  self.canvas.toolbar.insertAction(insert_before, self.lock_y_action)
266
270
  else:
267
271
  if len(actions) > 5:
268
- self.canvas.toolbar.insertAction(actions[5], self.line_action)
269
272
  self.canvas.toolbar.insertAction(actions[5], self.line_action)
270
273
  self.canvas.toolbar.insertAction(actions[5], self.lock_y_action)
271
274
  else:
@@ -313,8 +316,6 @@ class StackVisualizer(CelldetectiveWidget):
313
316
  # Use GridSpec for robust layout
314
317
  # 2 rows: Main Image (top, ~75%), Profile (bottom, ~25%)
315
318
  # Add margins to ensure axis labels and text are visible
316
- import matplotlib.gridspec as gridspec
317
-
318
319
  gs = gridspec.GridSpec(
319
320
  2,
320
321
  1,
@@ -406,10 +407,6 @@ class StackVisualizer(CelldetectiveWidget):
406
407
  self.ax_profile = None
407
408
 
408
409
  # Restore original layout
409
- # if hasattr(self, "ax_original_pos"):
410
- # standard 1x1 GridSpec or manual restore
411
- import matplotlib.gridspec as gridspec
412
-
413
410
  gs = gridspec.GridSpec(1, 1)
414
411
  self.ax.set_subplotspec(gs[0])
415
412
  # self.ax.set_position(gs[0].get_position(self.fig))
@@ -505,9 +502,8 @@ class StackVisualizer(CelldetectiveWidget):
505
502
  profile = np.zeros_like(profile)
506
503
  profile[:] = np.nan
507
504
 
508
- # Distance in microns if available
505
+ # Distance in pixels
509
506
  dist_axis = np.arange(num_points)
510
- x_label = "Distance (px)"
511
507
 
512
508
  # Only show pixel length, rounded to integer
513
509
  title_str = f"{round(length_px,2)} [px]"
@@ -525,11 +521,8 @@ class StackVisualizer(CelldetectiveWidget):
525
521
  if hasattr(self, "profile_line") and self.profile_line:
526
522
  try:
527
523
  self.profile_line.remove()
528
- except:
529
- pass
530
-
531
- # Distance in microns if available
532
- dist_axis = np.arange(num_points)
524
+ except ValueError:
525
+ pass # Already removed
533
526
 
534
527
  (self.profile_line,) = self.ax_profile.plot(
535
528
  dist_axis, profile, color="black", linestyle="-"
@@ -850,12 +843,16 @@ class StackVisualizer(CelldetectiveWidget):
850
843
 
851
844
  if curr_min < self._min:
852
845
  self._min = curr_min
853
- rescale_constrast = True
846
+ rescale_contrast = True
854
847
  if curr_max > self._max:
855
848
  self._max = curr_max
856
849
  rescale_contrast = True
857
850
 
858
- if rescale_contrast:
851
+ if (
852
+ rescale_contrast
853
+ and self.create_contrast_slider
854
+ and hasattr(self, "contrast_slider")
855
+ ):
859
856
  self.contrast_slider.setRange(self._min, self._max)
860
857
  self.canvas.canvas.draw_idle()
861
858
  self.update_profile()
@@ -888,5 +885,5 @@ class StackVisualizer(CelldetectiveWidget):
888
885
  try:
889
886
  if hasattr(self, "loader_thread") and self.loader_thread:
890
887
  self.loader_thread.stop()
891
- except:
888
+ except Exception:
892
889
  pass
@@ -5,18 +5,34 @@ from pathlib import Path
5
5
  import numpy as np
6
6
  from PyQt5.QtCore import QSize
7
7
  from PyQt5.QtGui import QDoubleValidator
8
- from PyQt5.QtWidgets import QMessageBox, QHBoxLayout, QLabel, QComboBox, QLineEdit, QPushButton
8
+ from PyQt5.QtWidgets import (
9
+ QMessageBox,
10
+ QHBoxLayout,
11
+ QLabel,
12
+ QComboBox,
13
+ QLineEdit,
14
+ QPushButton,
15
+ QCheckBox,
16
+ QSizePolicy,
17
+ QWidget,
18
+ )
19
+ from PyQt5.QtCore import Qt
20
+ from celldetective.gui.base.utils import center_window
9
21
  from fonticon_mdi6 import MDI6
10
22
  from natsort import natsorted
11
23
  from superqt.fonticon import icon
12
24
 
13
- from celldetective.gui.gui_utils import PreprocessingLayout2
14
25
  from celldetective.gui.viewers.base_viewer import StackVisualizer
26
+ from celldetective.gui.gui_utils import PreprocessingLayout2
15
27
  from celldetective.utils.image_loaders import load_frames
28
+ from celldetective.measure import extract_blobs_in_image
29
+ from celldetective.filters import filter_image
16
30
  from celldetective import get_logger
31
+ from tifffile import imread
17
32
 
18
33
  logger = get_logger(__name__)
19
34
 
35
+
20
36
  class SpotDetectionVisualizer(StackVisualizer):
21
37
 
22
38
  def __init__(
@@ -37,6 +53,7 @@ class SpotDetectionVisualizer(StackVisualizer):
37
53
  self.labels = labels
38
54
  self.detection_channel = self.target_channel
39
55
  self.switch_from_channel = False
56
+ self.preview_preprocessing = False
40
57
 
41
58
  self.parent_channel_cb = parent_channel_cb
42
59
  self.parent_diameter_le = parent_diameter_le
@@ -47,6 +64,35 @@ class SpotDetectionVisualizer(StackVisualizer):
47
64
  self.floatValidator = QDoubleValidator()
48
65
  self.init_scatter()
49
66
 
67
+ self.setWindowTitle(self.window_title)
68
+ self.resize(1200, 800)
69
+
70
+ # Main Layout (Horizontal split)
71
+ self.main_layout = QHBoxLayout(self)
72
+ self.main_layout.setContentsMargins(10, 10, 10, 10)
73
+
74
+ # Left Panel (Settings) - Scrollable
75
+ from PyQt5.QtWidgets import QScrollArea, QWidget, QVBoxLayout
76
+
77
+ self.scroll_area = QScrollArea()
78
+ self.scroll_area.setWidgetResizable(True)
79
+ self.settings_widget = QWidget()
80
+ self.settings_layout = QVBoxLayout(self.settings_widget)
81
+ self.settings_layout.setContentsMargins(10, 10, 10, 10)
82
+ self.settings_layout.setSpacing(15)
83
+ self.settings_layout.setAlignment(Qt.AlignTop)
84
+ self.scroll_area.setWidget(self.settings_widget)
85
+ self.scroll_area.setFixedWidth(350) # Set a reasonable width for settings
86
+
87
+ # Add Left Panel
88
+ self.main_layout.addWidget(self.scroll_area)
89
+
90
+ # Right Panel (Image Canvas)
91
+ # self.canvas is created by super().__init__
92
+ # We allow it to expand
93
+ self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
94
+ self.main_layout.addWidget(self.canvas)
95
+
50
96
  self.generate_detection_channel()
51
97
  self.detection_channel = self.detection_channel_cb.currentIndex()
52
98
 
@@ -57,12 +103,29 @@ class SpotDetectionVisualizer(StackVisualizer):
57
103
 
58
104
  self.ax.callbacks.connect("xlim_changed", self.update_marker_sizes)
59
105
  self.ax.callbacks.connect("ylim_changed", self.update_marker_sizes)
106
+ self._axis_callbacks_connected = True
60
107
 
61
108
  self.apply_diam_btn.clicked.connect(self.detect_and_display_spots)
62
109
  self.apply_thresh_btn.clicked.connect(self.detect_and_display_spots)
63
110
 
64
- self.channel_cb.setCurrentIndex(self.target_channel)
65
- self.detection_channel_cb.setCurrentIndex(self.target_channel)
111
+ self.channel_cb.setCurrentIndex(min(self.target_channel, self.n_channels - 1))
112
+ self.detection_channel_cb.setCurrentIndex(
113
+ min(self.target_channel, self.n_channels - 1)
114
+ )
115
+
116
+ def closeEvent(self, event):
117
+ """Clean up resources on close."""
118
+ # Clear large arrays
119
+ self.target_img = None
120
+ self.init_label = None
121
+ self.spot_sizes = []
122
+
123
+ # Remove scatter
124
+ if hasattr(self, "spot_scat") and self.spot_scat is not None:
125
+ self.spot_scat.remove()
126
+ self.spot_scat = None
127
+
128
+ super().closeEvent(event)
66
129
 
67
130
  def update_marker_sizes(self, event=None):
68
131
 
@@ -105,9 +168,7 @@ class SpotDetectionVisualizer(StackVisualizer):
105
168
  if not self.switch_from_channel:
106
169
  self.reset_detection()
107
170
 
108
- if self.mode == "virtual":
109
- from tifffile import imread
110
-
171
+ if self.mode == "virtual" and hasattr(self, "mask_paths"):
111
172
  self.init_label = imread(self.mask_paths[value])
112
173
  self.target_img = load_frames(
113
174
  self.img_num_per_channel[self.detection_channel, value],
@@ -122,15 +183,11 @@ class SpotDetectionVisualizer(StackVisualizer):
122
183
 
123
184
  self.reset_detection()
124
185
  self.control_valid_parameters() # set current diam and threshold
125
- # self.change_frame(self.frame_slider.value())
126
- # self.set_detection_channel_index(self.detection_channel_cb.currentIndex())
127
186
 
128
187
  image_preprocessing = self.preprocessing.list.items
129
188
  if image_preprocessing == []:
130
189
  image_preprocessing = None
131
190
 
132
- from celldetective.measure import extract_blobs_in_image
133
-
134
191
  blobs_filtered = extract_blobs_in_image(
135
192
  self.target_img,
136
193
  self.init_label,
@@ -157,8 +214,7 @@ class SpotDetectionVisualizer(StackVisualizer):
157
214
  self.canvas.canvas.draw()
158
215
 
159
216
  def reset_detection(self):
160
-
161
- self.ax.scatter([], []).get_offsets()
217
+ """Clear spot detection display."""
162
218
  empty_offset = np.ma.masked_array([0, 0], mask=True)
163
219
  self.spot_scat.set_offsets(empty_offset)
164
220
  self.canvas.canvas.draw()
@@ -205,8 +261,6 @@ class SpotDetectionVisualizer(StackVisualizer):
205
261
  returnValue = msgBox.exec()
206
262
  self.close()
207
263
 
208
- from tifffile import imread
209
-
210
264
  self.init_label = imread(self.mask_paths[self.frame_slider.value()])
211
265
 
212
266
  def generate_detection_channel(self):
@@ -215,7 +269,7 @@ class SpotDetectionVisualizer(StackVisualizer):
215
269
  assert len(self.channel_names) == self.n_channels
216
270
 
217
271
  channel_layout = QHBoxLayout()
218
- channel_layout.setContentsMargins(15, 0, 15, 0)
272
+ channel_layout.setContentsMargins(0, 0, 0, 0)
219
273
  channel_layout.addWidget(QLabel("Detection\nchannel: "), 25)
220
274
 
221
275
  self.detection_channel_cb = QComboBox()
@@ -231,11 +285,22 @@ class SpotDetectionVisualizer(StackVisualizer):
231
285
  # self.invert_check.toggled.connect(self.set_invert)
232
286
  # channel_layout.addWidget(self.invert_check, 10)
233
287
 
234
- self.canvas.layout.addLayout(channel_layout)
288
+ self.settings_layout.addLayout(channel_layout)
235
289
 
236
- self.preprocessing = PreprocessingLayout2(fraction=25, parent_window=self)
237
- self.preprocessing.setContentsMargins(15, 0, 15, 0)
238
- self.canvas.layout.addLayout(self.preprocessing)
290
+ self.preview_cb = QCheckBox("Preview")
291
+ self.preview_cb.toggled.connect(self.toggle_preprocessing_preview)
292
+
293
+ self.preprocessing = PreprocessingLayout2(
294
+ fraction=25, parent_window=self, extra_widget=self.preview_cb
295
+ )
296
+ self.preprocessing.setContentsMargins(0, 10, 0, 10)
297
+ self.preprocessing.list.list_widget.model().rowsInserted.connect(
298
+ self.update_preview_if_active
299
+ )
300
+ self.preprocessing.list.list_widget.model().rowsRemoved.connect(
301
+ self.update_preview_if_active
302
+ )
303
+ self.settings_layout.addLayout(self.preprocessing)
239
304
 
240
305
  # def set_invert(self):
241
306
  # if self.invert_check.isChecked():
@@ -272,19 +337,22 @@ class SpotDetectionVisualizer(StackVisualizer):
272
337
  self.spot_diam_le.textChanged.connect(self.control_valid_parameters)
273
338
  self.spot_thresh_le.textChanged.connect(self.control_valid_parameters)
274
339
 
340
+ self.apply_diam_btn.clicked.connect(self.detect_and_display_spots)
341
+ self.apply_thresh_btn.clicked.connect(self.detect_and_display_spots)
342
+
275
343
  spot_diam_layout = QHBoxLayout()
276
- spot_diam_layout.setContentsMargins(15, 0, 15, 0)
344
+ spot_diam_layout.setContentsMargins(0, 0, 0, 0)
277
345
  spot_diam_layout.addWidget(QLabel("Spot diameter: "), 25)
278
346
  spot_diam_layout.addWidget(self.spot_diam_le, 65)
279
347
  spot_diam_layout.addWidget(self.apply_diam_btn, 10)
280
- self.canvas.layout.addLayout(spot_diam_layout)
348
+ self.settings_layout.addLayout(spot_diam_layout)
281
349
 
282
350
  spot_thresh_layout = QHBoxLayout()
283
- spot_thresh_layout.setContentsMargins(15, 0, 15, 0)
351
+ spot_thresh_layout.setContentsMargins(0, 0, 0, 0)
284
352
  spot_thresh_layout.addWidget(QLabel("Detection\nthreshold: "), 25)
285
353
  spot_thresh_layout.addWidget(self.spot_thresh_le, 65)
286
354
  spot_thresh_layout.addWidget(self.apply_thresh_btn, 10)
287
- self.canvas.layout.addLayout(spot_thresh_layout)
355
+ self.settings_layout.addLayout(spot_thresh_layout)
288
356
 
289
357
  def generate_add_measurement_btn(self):
290
358
 
@@ -297,7 +365,48 @@ class SpotDetectionVisualizer(StackVisualizer):
297
365
  add_hbox.addWidget(QLabel(""), 33)
298
366
  add_hbox.addWidget(self.add_measurement_btn, 33)
299
367
  add_hbox.addWidget(QLabel(""), 33)
300
- self.canvas.layout.addLayout(add_hbox)
368
+ self.settings_layout.addLayout(add_hbox)
369
+
370
+ def show(self):
371
+ QWidget.show(self)
372
+ center_window(self)
373
+
374
+ def update_preview_if_active(self):
375
+ if self.preview_cb.isChecked():
376
+ self.toggle_preprocessing_preview()
377
+
378
+ def toggle_preprocessing_preview(self):
379
+
380
+ image_preprocessing = self.preprocessing.list.items
381
+ if image_preprocessing == []:
382
+ image_preprocessing = None
383
+
384
+ if self.preview_cb.isChecked() and image_preprocessing is not None:
385
+ # Apply preprocessing
386
+ try:
387
+ preprocessed_img = filter_image(
388
+ self.target_img.copy(), filters=image_preprocessing
389
+ )
390
+ self.im.set_data(preprocessed_img)
391
+
392
+ # Update contrast to match new range
393
+ p01 = np.nanpercentile(preprocessed_img, 0.1)
394
+ p99 = np.nanpercentile(preprocessed_img, 99.9)
395
+ self.im.set_clim(vmin=p01, vmax=p99)
396
+ if hasattr(self, "contrast_slider"):
397
+ self.contrast_slider.setValue((p01, p99))
398
+ self.canvas.draw()
399
+ except Exception as e:
400
+ logger.error(f"Preprocessing preview failed: {e}")
401
+ else:
402
+ # Restore original
403
+ self.im.set_data(self.target_img)
404
+ p01 = np.nanpercentile(self.target_img, 0.1)
405
+ p99 = np.nanpercentile(self.target_img, 99.9)
406
+ self.im.set_clim(vmin=p01, vmax=p99)
407
+ if hasattr(self, "contrast_slider"):
408
+ self.contrast_slider.setValue((p01, p99))
409
+ self.canvas.draw()
301
410
 
302
411
  def control_valid_parameters(self):
303
412
 
@@ -305,14 +414,14 @@ class SpotDetectionVisualizer(StackVisualizer):
305
414
  try:
306
415
  self.diameter = float(self.spot_diam_le.text().replace(",", "."))
307
416
  valid_diam = True
308
- except:
417
+ except ValueError:
309
418
  valid_diam = False
310
419
 
311
420
  valid_thresh = False
312
421
  try:
313
422
  self.thresh = float(self.spot_thresh_le.text().replace(",", "."))
314
423
  valid_thresh = True
315
- except:
424
+ except ValueError:
316
425
  valid_thresh = False
317
426
 
318
427
  if valid_diam and valid_thresh:
@@ -7,7 +7,7 @@ import numpy as np
7
7
  from art import tprint
8
8
  from tensorflow.python.keras.callbacks import Callback
9
9
 
10
- from celldetective.signals import SignalDetectionModel
10
+ from celldetective.event_detection_models import SignalDetectionModel
11
11
  from celldetective.log_manager import get_logger
12
12
  from celldetective.utils.model_loaders import locate_signal_model
13
13
 
@@ -8,7 +8,7 @@ import json
8
8
  from glob import glob
9
9
  import numpy as np
10
10
  from art import tprint
11
- from celldetective.signals import SignalDetectionModel
11
+ from celldetective.event_detection_models import SignalDetectionModel
12
12
  from celldetective.utils.model_loaders import locate_signal_model
13
13
 
14
14
  tprint("Train")