celldetective 1.5.0b2__py3-none-any.whl → 1.5.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,351 @@
1
+ import pytest
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import logging
6
+ from PyQt5 import QtCore
7
+ from PyQt5.QtWidgets import QApplication
8
+ from celldetective.gui.InitWindow import AppInitWindow
9
+ from celldetective.gui.measure_annotator import MeasureAnnotator
10
+ from celldetective import get_software_location
11
+ from unittest.mock import patch
12
+ import shutil
13
+ import json
14
+ import tifffile
15
+
16
+ software_location = get_software_location()
17
+
18
+
19
+ @pytest.fixture(autouse=True)
20
+ def disable_logging():
21
+ """Disable all logging to avoid Windows OSError with pytest capture."""
22
+ logger = logging.getLogger()
23
+ try:
24
+ logging.disable(logging.CRITICAL)
25
+ yield
26
+ finally:
27
+ logging.disable(logging.NOTSET)
28
+
29
+
30
+ @pytest.fixture
31
+ def app(qtbot):
32
+ test_app = AppInitWindow(software_location=software_location)
33
+ qtbot.addWidget(test_app)
34
+ return test_app
35
+
36
+
37
+ def create_dummy_movie(exp_dir, well="W1", pos="100", prefix="sample", frames=5):
38
+ movie_dir = os.path.join(exp_dir, well, pos, "movie")
39
+ os.makedirs(movie_dir, exist_ok=True)
40
+ # Use a single multi-page TIF as expected by locate_stack
41
+ movie_path = os.path.join(movie_dir, f"{prefix}.tif")
42
+ img = np.zeros((frames, 100, 100), dtype=np.uint16)
43
+ tifffile.imwrite(movie_path, img)
44
+
45
+
46
+ def test_measure_annotator_enhancements(app, qtbot, tmp_path):
47
+ """
48
+ Test that MeasureAnnotator correctly discovers group_* and status_* columns.
49
+ Uses patching to bypass QCheckableComboBox headless interaction issues.
50
+ """
51
+ exp_dir = str(tmp_path / "Experiment")
52
+ os.makedirs(os.path.join(exp_dir, "W1", "100", "output", "tables"), exist_ok=True)
53
+ os.makedirs(os.path.join(exp_dir, "configs"), exist_ok=True)
54
+
55
+ with open(os.path.join(exp_dir, "config.ini"), "w") as f:
56
+ f.write(
57
+ "[MovieSettings]\nmovie_prefix = sample\nlen_movie = 10\nshape_x = 100\nshape_y = 100\npxtoum = 1.0\nframetomin = 1.0\n"
58
+ )
59
+ f.write(
60
+ "[Labels]\nconcentrations = 0\ncell_types = dummy\nantibodies = none\npharmaceutical_agents = none\n[Channels]\nChannel1 = 0\n"
61
+ )
62
+
63
+ create_dummy_movie(exp_dir, well="W1", pos="100", prefix="sample", frames=10)
64
+
65
+ # Use numeric class/status to avoid color_from_state failure
66
+ df = pd.DataFrame(
67
+ {
68
+ "TRACK_ID": [1, 1, 2, 2],
69
+ "FRAME": [0, 1, 0, 1],
70
+ "group_experimental": ["A", "A", "B", "B"],
71
+ "class_firstdetection": [0, 0, 1, 1],
72
+ "area": [100.0, 110.0, 105.0, 115.0], # Needed for MinMaxScaler
73
+ "POSITION_X": [10, 12, 10, 12],
74
+ "POSITION_Y": [10, 12, 10, 12],
75
+ }
76
+ )
77
+ traj_path = os.path.join(
78
+ exp_dir, "W1", "100", "output", "tables", "trajectories_effectors.csv"
79
+ )
80
+ df.to_csv(traj_path, index=False)
81
+
82
+ app.experiment_path_selection.setText(exp_dir)
83
+ qtbot.mouseClick(app.validate_button, QtCore.Qt.LeftButton)
84
+ qtbot.waitUntil(lambda: hasattr(app, "control_panel"), timeout=30000)
85
+
86
+ cp = app.control_panel
87
+ p0 = cp.ProcessPopulations[0]
88
+
89
+ qtbot.waitUntil(lambda: cp.well_list.count() > 0, timeout=30000)
90
+
91
+ with patch.object(cp.well_list, "getSelectedIndices", return_value=[0]):
92
+ with patch.object(cp.position_list, "getSelectedIndices", return_value=[0]):
93
+
94
+ cp.update_position_options()
95
+ qtbot.wait(500)
96
+
97
+ qtbot.mouseClick(p0.check_measurements_btn, QtCore.Qt.LeftButton)
98
+
99
+ try:
100
+ qtbot.waitUntil(lambda: hasattr(p0, "measure_annotator"), timeout=15000)
101
+ except Exception:
102
+ print("DEBUG: measure_annotator not found on p0.")
103
+ raise
104
+
105
+ annotator = p0.measure_annotator
106
+ qtbot.wait(1000)
107
+
108
+ assert annotator is not None
109
+ try:
110
+ cols = annotator.class_cols
111
+ except RuntimeError:
112
+ pytest.fail("MeasureAnnotator closed unexpectedly (RuntimeError).")
113
+
114
+ assert "group_experimental" in cols
115
+ assert "status_firstdetection" in cols
116
+
117
+ annotator.close()
118
+
119
+
120
+ def test_event_mapping_anticipation(app, qtbot, tmp_path):
121
+ """
122
+ Test that SignalModelParamsWidget anticipates signals based on metadata.
123
+ """
124
+ exp_dir = str(tmp_path / "ExperimentAnticipation")
125
+ os.makedirs(os.path.join(exp_dir, "W1", "100"), exist_ok=True)
126
+ os.makedirs(os.path.join(exp_dir, "configs"), exist_ok=True)
127
+
128
+ with open(os.path.join(exp_dir, "config.ini"), "w") as f:
129
+ f.write(
130
+ "[MovieSettings]\nmovie_prefix = sample\nlen_movie = 10\nshape_x = 100\nshape_y = 100\npxtoum = 1.0\nframetomin = 1.0\n"
131
+ )
132
+ f.write(
133
+ "[Labels]\nconcentrations = 0\ncell_types = dummy\nantibodies = none\npharmaceutical_agents = none\n"
134
+ )
135
+ f.write("[Channels]\nDAPI = 0\nGFP = 1\n")
136
+
137
+ # Create measurement instructions for anticipation
138
+ instructions = {
139
+ "features": [],
140
+ "intensity_measurement_radii": [5],
141
+ "isotropic_operations": ["mean"],
142
+ "border_distances": [10],
143
+ }
144
+ # Create for both targets and effectors to be safe
145
+ for pop in ["targets", "effectors"]:
146
+ with open(
147
+ os.path.join(exp_dir, "configs", f"measurement_instructions_{pop}.json"),
148
+ "w",
149
+ ) as f:
150
+ json.dump(instructions, f)
151
+
152
+ create_dummy_movie(exp_dir, well="W1", pos="100", prefix="sample", frames=10)
153
+
154
+ dummy_model_name = "DummyModelTest"
155
+ models_dir = os.path.join(
156
+ software_location, "celldetective", "models", "signal_detection"
157
+ )
158
+ model_path = os.path.join(models_dir, dummy_model_name)
159
+ os.makedirs(model_path, exist_ok=True)
160
+
161
+ with open(os.path.join(model_path, "config_input.json"), "w") as f:
162
+ json.dump({"channels": ["DAPI", "GFP"]}, f)
163
+
164
+ try:
165
+ app.experiment_path_selection.setText(exp_dir)
166
+ qtbot.mouseClick(app.validate_button, QtCore.Qt.LeftButton)
167
+ qtbot.waitUntil(lambda: hasattr(app, "control_panel"), timeout=30000)
168
+
169
+ p0 = app.control_panel.ProcessPopulations[0]
170
+
171
+ qtbot.waitUntil(lambda: app.control_panel.well_list.count() > 0, timeout=30000)
172
+
173
+ with patch.object(
174
+ app.control_panel.well_list, "getSelectedIndices", return_value=[0]
175
+ ):
176
+ with patch.object(
177
+ app.control_panel.position_list, "getSelectedIndices", return_value=[0]
178
+ ):
179
+ app.control_panel.update_position_options()
180
+ qtbot.wait(500)
181
+
182
+ p0.signal_analysis_action.setChecked(True)
183
+ p0.refresh_signal_models()
184
+ idx = p0.signal_models_list.findText(dummy_model_name)
185
+ if idx >= 0:
186
+ p0.signal_models_list.setCurrentIndex(idx)
187
+
188
+ qtbot.mouseClick(p0.submit_btn, QtCore.Qt.LeftButton)
189
+
190
+ qtbot.waitUntil(
191
+ lambda: hasattr(p0, "signalChannelWidget"), timeout=30000
192
+ )
193
+ widget = p0.signalChannelWidget
194
+ assert widget is not None
195
+
196
+ qtbot.wait(1000)
197
+ items = [
198
+ widget.channel_cbs[0].itemText(i)
199
+ for i in range(widget.channel_cbs[0].count())
200
+ ]
201
+
202
+ assert "area" in items
203
+ assert "dapi_mean" in items or "DAPI_mean" in items
204
+ assert "gfp_mean" in items or "GFP_mean" in items
205
+
206
+ # Check for anticipated measurements from instructions
207
+ # DAPI_circle_5_mean, DAPI_mean_edge_10px
208
+ anticipated = [
209
+ "DAPI_circle_5_mean",
210
+ "GFP_circle_5_mean",
211
+ "DAPI_mean_edge_10px",
212
+ "GFP_mean_edge_10px",
213
+ ]
214
+ for ant in anticipated:
215
+ found = any(ant.lower() == item.lower() for item in items)
216
+ assert found, f"Anticipated measurement {ant} not found in {items}"
217
+
218
+ widget.close()
219
+
220
+ finally:
221
+ if os.path.exists(model_path):
222
+ try:
223
+ shutil.rmtree(model_path)
224
+ except:
225
+ pass
226
+
227
+
228
+ def test_fix_missing_labels(tmp_path):
229
+ """
230
+ Test that fix_missing_labels creates empty label files when they don't exist.
231
+ """
232
+ from celldetective.utils.image_loaders import fix_missing_labels
233
+
234
+ exp_dir = str(tmp_path / "ExperimentFixLabels")
235
+ # Don't create labels_effectors here, let fix_missing_labels do it or ensure it handles missing dirs
236
+ # Actually fix_missing_labels expects the directory to exist if it writes to it?
237
+ # Let's check logic: path = position + os.sep + f"labels_{population}" -> save_tiff...
238
+ # It does NOT appear to create the directory involved in 'path'.
239
+ # But usually creating a dummy movie creates 'movie' dir.
240
+ # We should create the parent 'labels_effectors' dir to be safe, or see if it fails.
241
+ # My previous fix for test_measure_annotator_enhancements added `os.makedirs(..., "labels_effectors")`
242
+ # Check fix_missing_labels logic again?
243
+ # It constructs path... save_tiff_imagej_compatible(os.sep.join([path, file]))
244
+ # save_tiff_imagej_compatible calls imsave -> tifffile.imwrite.
245
+ # If dir doesn't exist, it might fail.
246
+ # But let's follow the standard pattern: "labels" usually pre-exist if segmentation started.
247
+ # Here we simulate "missing files" inside that directory.
248
+
249
+ well, pos = "W1", "100"
250
+ os.makedirs(os.path.join(exp_dir, well, pos, "labels_effectors"), exist_ok=True)
251
+ create_dummy_movie(exp_dir, well=well, pos=pos, prefix="sample", frames=5)
252
+
253
+ # Verify no labels yet
254
+ lbl_dir = os.path.join(exp_dir, well, pos, "labels_effectors")
255
+ movie_dir = os.path.join(exp_dir, well, pos, "movie")
256
+ assert len(os.listdir(lbl_dir)) == 0
257
+
258
+ from celldetective.utils.image_loaders import locate_stack
259
+
260
+ stack = locate_stack(os.path.join(exp_dir, well, pos), prefix="sample")
261
+
262
+ # Call fix
263
+ fix_missing_labels(
264
+ os.path.join(exp_dir, well, pos), population="effectors", prefix="sample"
265
+ )
266
+
267
+ # Verify 5 label files created
268
+ files = os.listdir(lbl_dir)
269
+ assert len(files) == 5
270
+ assert "0000.tif" in files
271
+ assert "0004.tif" in files
272
+
273
+ # Verify content is empty (zeros)
274
+ img = tifffile.imread(os.path.join(lbl_dir, "0000.tif"))
275
+ assert np.all(img == 0)
276
+ assert img.shape == (100, 100)
277
+
278
+
279
+ def test_table_exploration_logic(app, qtbot, tmp_path):
280
+ """
281
+ Test the logic for single vs multi-position exploration.
282
+ """
283
+ exp_dir = str(tmp_path / "ExperimentMulti")
284
+ for pos in ["100", "101"]:
285
+ os.makedirs(os.path.join(exp_dir, "W1", pos, "output", "tables"), exist_ok=True)
286
+ create_dummy_movie(exp_dir, well="W1", pos=pos, prefix="sample", frames=10)
287
+
288
+ df = pd.DataFrame(
289
+ {
290
+ "TRACK_ID": [1],
291
+ "FRAME": [0],
292
+ "area": [100.0], # Needed for MinMaxScaler
293
+ "POSITION_X": [10],
294
+ "POSITION_Y": [10],
295
+ }
296
+ )
297
+ traj_path = os.path.join(
298
+ exp_dir, "W1", pos, "output", "tables", "trajectories_effectors.csv"
299
+ )
300
+ df.to_csv(traj_path, index=False)
301
+
302
+ os.makedirs(os.path.join(exp_dir, "configs"), exist_ok=True)
303
+ with open(os.path.join(exp_dir, "config.ini"), "w") as f:
304
+ f.write(
305
+ "[MovieSettings]\nmovie_prefix = sample\nlen_movie = 10\nshape_x = 100\nshape_y = 100\npxtoum = 1.0\nframetomin = 1.0\n"
306
+ )
307
+ f.write(
308
+ "[Labels]\nconcentrations = 0\ncell_types = dummy\nantibodies = none\npharmaceutical_agents = none\n[Channels]\nChannel1 = 0\n"
309
+ )
310
+
311
+ app.experiment_path_selection.setText(exp_dir)
312
+ qtbot.mouseClick(app.validate_button, QtCore.Qt.LeftButton)
313
+ qtbot.waitUntil(lambda: hasattr(app, "control_panel"), timeout=30000)
314
+
315
+ cp = app.control_panel
316
+ p0 = cp.ProcessPopulations[0]
317
+
318
+ qtbot.waitUntil(lambda: cp.well_list.count() > 0, timeout=30000)
319
+
320
+ # 2. Test Single Selection behavior (default)
321
+ with patch.object(cp.well_list, "getSelectedIndices", return_value=[0]):
322
+ with patch.object(cp.position_list, "getSelectedIndices", return_value=[0]):
323
+ cp.update_position_options()
324
+ qtbot.wait(500)
325
+
326
+ with patch.object(MeasureAnnotator, "show") as mock_show:
327
+ qtbot.mouseClick(p0.check_measurements_btn, QtCore.Qt.LeftButton)
328
+ try:
329
+ qtbot.waitUntil(
330
+ lambda: hasattr(p0, "measure_annotator"), timeout=15000
331
+ )
332
+ except:
333
+ print("DEBUG: Single selection check failed.")
334
+ raise
335
+ assert p0.measure_annotator is not None
336
+ p0.measure_annotator.close()
337
+ del p0.measure_annotator
338
+
339
+ # 3. Test Multi Selection behavior
340
+ with patch.object(cp.well_list, "getSelectedIndices", return_value=[0]):
341
+ with patch.object(cp.position_list, "getSelectedIndices", return_value=[0, 1]):
342
+ cp.update_position_options()
343
+ qtbot.wait(500)
344
+
345
+ assert p0.view_tab_btn.isEnabled()
346
+ assert p0.check_measurements_btn.isEnabled()
347
+
348
+ with patch.object(p0, "view_table_ui") as mock_view_tab:
349
+ qtbot.mouseClick(p0.check_measurements_btn, QtCore.Qt.LeftButton)
350
+ qtbot.wait(200)
351
+ mock_view_tab.assert_called_once()
@@ -0,0 +1,130 @@
1
+ import pytest
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import logging
6
+ from PyQt5 import QtCore
7
+ from celldetective.gui.InitWindow import AppInitWindow
8
+ from celldetective.gui.measure_annotator import MeasureAnnotator
9
+ from celldetective import get_software_location
10
+ from unittest.mock import patch
11
+ import shutil
12
+ import tifffile
13
+
14
+ software_location = get_software_location()
15
+
16
+
17
+ @pytest.fixture(autouse=True)
18
+ def disable_logging():
19
+ """Disable all logging to avoid Windows OSError with pytest capture."""
20
+ logger = logging.getLogger()
21
+ try:
22
+ logging.disable(logging.CRITICAL)
23
+ yield
24
+ finally:
25
+ logging.disable(logging.NOTSET)
26
+
27
+
28
+ @pytest.fixture
29
+ def app(qtbot):
30
+ test_app = AppInitWindow(software_location=software_location)
31
+ qtbot.addWidget(test_app)
32
+ return test_app
33
+
34
+
35
+ def create_dummy_movie(exp_dir, well="W1", pos="100", prefix="sample", frames=5):
36
+ movie_dir = os.path.join(exp_dir, well, pos, "movie")
37
+ os.makedirs(movie_dir, exist_ok=True)
38
+ # Use a single multi-page TIF as expected by locate_stack
39
+ movie_path = os.path.join(movie_dir, f"{prefix}.tif")
40
+ img = np.zeros((frames, 100, 100), dtype=np.uint16)
41
+ tifffile.imwrite(movie_path, img)
42
+
43
+
44
+ def test_measure_annotator_colors_writable(app, qtbot, tmp_path):
45
+ """
46
+ Test that self.colors in MeasureAnnotator contains writable arrays.
47
+ This verifies the fix for 'ValueError: assignment destination is read-only'.
48
+ """
49
+ exp_dir = str(tmp_path / "ExperimentColors")
50
+ os.makedirs(os.path.join(exp_dir, "W1", "100", "output", "tables"), exist_ok=True)
51
+ os.makedirs(os.path.join(exp_dir, "configs"), exist_ok=True)
52
+
53
+ with open(os.path.join(exp_dir, "config.ini"), "w") as f:
54
+ f.write(
55
+ "[MovieSettings]\nmovie_prefix = sample\nlen_movie = 10\nshape_x = 100\nshape_y = 100\npxtoum = 1.0\nframetomin = 1.0\n"
56
+ )
57
+ f.write(
58
+ "[Labels]\nconcentrations = 0\ncell_types = dummy\nantibodies = none\npharmaceutical_agents = none\n[Channels]\nChannel1 = 0\n"
59
+ )
60
+
61
+ create_dummy_movie(exp_dir, well="W1", pos="100", prefix="sample", frames=10)
62
+
63
+ # DataFrame with tracks
64
+ df = pd.DataFrame(
65
+ {
66
+ "TRACK_ID": [1, 1],
67
+ "FRAME": [0, 1],
68
+ "group_experimental": ["A", "A"],
69
+ "area": [100.0, 110.0],
70
+ "POSITION_X": [10, 12],
71
+ "POSITION_Y": [10, 12],
72
+ "status": [0, 0], # Ensure status column exists
73
+ }
74
+ )
75
+ # The 'group_color' column is usually generated inside MeasureAnnotator,
76
+ # but let's see if we need to let it generate it.
77
+ # MeasureAnnotator calls 'color_from_state', then assigns 'group_color'.
78
+
79
+ traj_path = os.path.join(
80
+ exp_dir, "W1", "100", "output", "tables", "trajectories_effectors.csv"
81
+ )
82
+ df.to_csv(traj_path, index=False)
83
+
84
+ app.experiment_path_selection.setText(exp_dir)
85
+ qtbot.mouseClick(app.validate_button, QtCore.Qt.LeftButton)
86
+ qtbot.waitUntil(lambda: hasattr(app, "control_panel"), timeout=30000)
87
+
88
+ cp = app.control_panel
89
+ p0 = cp.ProcessPopulations[0]
90
+
91
+ qtbot.waitUntil(lambda: cp.well_list.count() > 0, timeout=30000)
92
+
93
+ with patch.object(cp.well_list, "getSelectedIndices", return_value=[0]):
94
+ with patch.object(cp.position_list, "getSelectedIndices", return_value=[0]):
95
+
96
+ cp.update_position_options()
97
+ qtbot.wait(500)
98
+
99
+ qtbot.mouseClick(p0.check_measurements_btn, QtCore.Qt.LeftButton)
100
+
101
+ try:
102
+ qtbot.waitUntil(lambda: hasattr(p0, "measure_annotator"), timeout=15000)
103
+ except Exception:
104
+ print("DEBUG: measure_annotator not found on p0.")
105
+ raise
106
+
107
+ annotator = p0.measure_annotator
108
+ qtbot.wait(1000)
109
+ assert annotator is not None
110
+
111
+ # Verify that self.colors arrays are writable
112
+ # extract_scatter_from_trajectories should have been called during init
113
+ assert hasattr(annotator, "colors")
114
+ assert len(annotator.colors) > 0
115
+
116
+ # Check the first frame's colors
117
+ colors_frame_0 = annotator.colors[0]
118
+
119
+ # Check flags
120
+ assert colors_frame_0.flags[
121
+ "WRITEABLE"
122
+ ], "self.colors arrays must be writable"
123
+
124
+ # Try to modify (should not raise ValueError)
125
+ try:
126
+ colors_frame_0[0] = "lime"
127
+ except ValueError as e:
128
+ pytest.fail(f"Could not modify colors array: {e}")
129
+
130
+ annotator.close()
@@ -0,0 +1,101 @@
1
+ import unittest
2
+ from unittest.mock import MagicMock, patch
3
+ import sys
4
+
5
+ # Do not import torch here to avoid WinError 1114 if environment is broken.
6
+ # We will mock it in setUp.
7
+
8
+
9
+ class TestCellposeFallback(unittest.TestCase):
10
+
11
+ def setUp(self):
12
+ # Create a mock for torch
13
+ self.mock_torch = MagicMock()
14
+ self.mock_torch.device = MagicMock(return_value="cpu")
15
+ self.mock_torch.cuda = MagicMock()
16
+ self.mock_torch.cuda.is_available.return_value = (
17
+ False # Default to CPU environment simulation
18
+ )
19
+
20
+ # Patch modules so that 'import torch' and 'import cellpose' work with our mocks
21
+ # We need to patch 'torch' in sys.modules BEFORE importing code that uses it.
22
+ self.modules_patcher = patch.dict(
23
+ sys.modules,
24
+ {
25
+ "torch": self.mock_torch,
26
+ "cellpose": MagicMock(),
27
+ "cellpose.models": MagicMock(),
28
+ },
29
+ )
30
+ self.modules_patcher.start()
31
+
32
+ # Define a mock CellposeModel that we can control
33
+ self.MockCellposeModel = MagicMock()
34
+ sys.modules["cellpose.models"].CellposeModel = self.MockCellposeModel
35
+
36
+ def tearDown(self):
37
+ self.modules_patcher.stop()
38
+
39
+ def test_gpu_fallback_on_assertion_error(self):
40
+ """
41
+ Test that _prep_cellpose_model falls back to CPU if GPU init fails with AssertionError.
42
+ """
43
+ # Lazy import inside the test method/patch context
44
+ from celldetective.utils.cellpose_utils import _prep_cellpose_model
45
+
46
+ # Side effect for CellposeModel constructor
47
+ def side_effect(gpu=False, **kwargs):
48
+ if gpu:
49
+ raise AssertionError("Torch not compiled with CUDA enabled")
50
+
51
+ # Return a mock model object
52
+ model = MagicMock()
53
+ model.diam_mean = 30.0
54
+ model.diam_labels = 30.0
55
+ return model
56
+
57
+ self.MockCellposeModel.side_effect = side_effect
58
+
59
+ # Call the function with use_gpu=True
60
+ # We expect it to try with gpu=True, fail, print warning, and retry with gpu=False
61
+ model, scale = _prep_cellpose_model(
62
+ model_name="fake_model", path="fake_path/", use_gpu=True, n_channels=2
63
+ )
64
+
65
+ # Check call history
66
+ self.assertEqual(self.MockCellposeModel.call_count, 2)
67
+
68
+ args1, kwargs1 = self.MockCellposeModel.call_args_list[0]
69
+ self.assertTrue(kwargs1.get("gpu"), "First call should try gpu=True")
70
+
71
+ args2, kwargs2 = self.MockCellposeModel.call_args_list[1]
72
+ self.assertFalse(kwargs2.get("gpu"), "Second call should retry with gpu=False")
73
+
74
+ self.assertIsNotNone(model)
75
+
76
+ def test_gpu_success(self):
77
+ """
78
+ Test that _prep_cellpose_model works normally if GPU init succeeds.
79
+ """
80
+ from celldetective.utils.cellpose_utils import _prep_cellpose_model
81
+
82
+ # Side effect for success
83
+ def side_effect(gpu=False, **kwargs):
84
+ model = MagicMock()
85
+ model.diam_mean = 30.0
86
+ model.diam_labels = 30.0
87
+ return model
88
+
89
+ self.MockCellposeModel.side_effect = side_effect
90
+
91
+ model, scale = _prep_cellpose_model(
92
+ model_name="fake_model", path="fake_path/", use_gpu=True, n_channels=2
93
+ )
94
+
95
+ self.assertEqual(self.MockCellposeModel.call_count, 1)
96
+ args, kwargs = self.MockCellposeModel.call_args
97
+ self.assertTrue(kwargs.get("gpu"))
98
+
99
+
100
+ if __name__ == "__main__":
101
+ unittest.main()
tests/test_notebooks.py CHANGED
@@ -2,7 +2,8 @@
2
2
  import nbformat
3
3
  from nbclient import NotebookClient
4
4
 
5
+
5
6
  def test_notebook_runs():
6
7
  nb = nbformat.read("../demos/ADCC_analysis_demo.ipynb", as_version=4)
7
8
  client = NotebookClient(nb, timeout=600, kernel_name="python3")
8
- client.execute() # raises exception if any cell fails
9
+ client.execute() # raises exception if any cell fails