cellfinder 1.1.3__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cellfinder/core/main.py CHANGED
@@ -7,10 +7,11 @@ import os
7
7
  from typing import Callable, List, Optional, Tuple
8
8
 
9
9
  import numpy as np
10
+ from brainglobe_utils.cells.cells import Cell
10
11
  from brainglobe_utils.general.logging import suppress_specific_logs
11
12
 
12
13
  from cellfinder.core import logger
13
- from cellfinder.core.download.models import model_type
14
+ from cellfinder.core.download.download import model_type
14
15
  from cellfinder.core.train.train_yml import depth_type
15
16
 
16
17
  tf_suppress_log_messages = [
@@ -42,6 +43,9 @@ def main(
42
43
  cube_height: int = 50,
43
44
  cube_depth: int = 20,
44
45
  network_depth: depth_type = "50",
46
+ skip_detection: bool = False,
47
+ skip_classification: bool = False,
48
+ detected_cells: List[Cell] = None,
45
49
  *,
46
50
  detect_callback: Optional[Callable[[int], None]] = None,
47
51
  classify_callback: Optional[Callable[[int], None]] = None,
@@ -65,52 +69,58 @@ def main(
65
69
  from cellfinder.core.detect import detect
66
70
  from cellfinder.core.tools import prep
67
71
 
68
- logger.info("Detecting cell candidates")
72
+ if not skip_detection:
73
+ logger.info("Detecting cell candidates")
69
74
 
70
- points = detect.main(
71
- signal_array,
72
- start_plane,
73
- end_plane,
74
- voxel_sizes,
75
- soma_diameter,
76
- max_cluster_size,
77
- ball_xy_size,
78
- ball_z_size,
79
- ball_overlap_fraction,
80
- soma_spread_factor,
81
- n_free_cpus,
82
- log_sigma_size,
83
- n_sds_above_mean_thresh,
84
- callback=detect_callback,
85
- )
86
-
87
- if detect_finished_callback is not None:
88
- detect_finished_callback(points)
89
-
90
- install_path = None
91
- model_weights = prep.prep_model_weights(
92
- model_weights, install_path, model, n_free_cpus
93
- )
94
- if len(points) > 0:
95
- logger.info("Running classification")
96
- points = classify.main(
97
- points,
75
+ points = detect.main(
98
76
  signal_array,
99
- background_array,
100
- n_free_cpus,
77
+ start_plane,
78
+ end_plane,
101
79
  voxel_sizes,
102
- network_voxel_sizes,
103
- batch_size,
104
- cube_height,
105
- cube_width,
106
- cube_depth,
107
- trained_model,
108
- model_weights,
109
- network_depth,
110
- callback=classify_callback,
80
+ soma_diameter,
81
+ max_cluster_size,
82
+ ball_xy_size,
83
+ ball_z_size,
84
+ ball_overlap_fraction,
85
+ soma_spread_factor,
86
+ n_free_cpus,
87
+ log_sigma_size,
88
+ n_sds_above_mean_thresh,
89
+ callback=detect_callback,
111
90
  )
91
+
92
+ if detect_finished_callback is not None:
93
+ detect_finished_callback(points)
112
94
  else:
113
- logger.info("No candidates, skipping classification")
95
+ points = detected_cells or [] # if None
96
+ detect_finished_callback(points)
97
+
98
+ if not skip_classification:
99
+ install_path = None
100
+ model_weights = prep.prep_model_weights(
101
+ model_weights, install_path, model, n_free_cpus
102
+ )
103
+ if len(points) > 0:
104
+ logger.info("Running classification")
105
+ points = classify.main(
106
+ points,
107
+ signal_array,
108
+ background_array,
109
+ n_free_cpus,
110
+ voxel_sizes,
111
+ network_voxel_sizes,
112
+ batch_size,
113
+ cube_height,
114
+ cube_width,
115
+ cube_depth,
116
+ trained_model,
117
+ model_weights,
118
+ network_depth,
119
+ callback=classify_callback,
120
+ )
121
+ else:
122
+ logger.info("No candidates, skipping classification")
123
+
114
124
  return points
115
125
 
116
126
 
@@ -13,18 +13,19 @@ from brainglobe_utils.general.system import get_num_processes
13
13
 
14
14
  import cellfinder.core.tools.tf as tf_tools
15
15
  from cellfinder.core import logger
16
- from cellfinder.core.download import models as model_download
17
- from cellfinder.core.download.download import amend_user_configuration
16
+ from cellfinder.core.download.download import (
17
+ DEFAULT_DOWNLOAD_DIRECTORY,
18
+ amend_user_configuration,
19
+ download_models,
20
+ model_type,
21
+ )
18
22
  from cellfinder.core.tools.source_files import user_specific_configuration_path
19
23
 
20
- home = Path.home()
21
- DEFAULT_INSTALL_PATH = home / ".cellfinder"
22
-
23
24
 
24
25
  def prep_model_weights(
25
26
  model_weights: Optional[os.PathLike],
26
27
  install_path: Optional[os.PathLike],
27
- model_name: model_download.model_type,
28
+ model_name: model_type,
28
29
  n_free_cpus: int,
29
30
  ) -> Path:
30
31
  n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
@@ -42,9 +43,9 @@ def prep_tensorflow(max_threads: int) -> None:
42
43
  def prep_models(
43
44
  model_weights_path: Optional[os.PathLike],
44
45
  install_path: Optional[os.PathLike],
45
- model_name: model_download.model_type,
46
+ model_name: model_type,
46
47
  ) -> Path:
47
- install_path = install_path or DEFAULT_INSTALL_PATH
48
+ install_path = install_path or DEFAULT_DOWNLOAD_DIRECTORY
48
49
  # if no model or weights, set default weights
49
50
  if model_weights_path is None:
50
51
  logger.debug("No model supplied, so using the default")
@@ -53,13 +54,13 @@ def prep_models(
53
54
 
54
55
  if not Path(config_file).exists():
55
56
  logger.debug("Custom config does not exist, downloading models")
56
- model_path = model_download.main(model_name, install_path)
57
+ model_path = download_models(model_name, install_path)
57
58
  amend_user_configuration(new_model_path=model_path)
58
59
 
59
60
  model_weights = get_model_weights(config_file)
60
61
  if not model_weights.exists():
61
62
  logger.debug("Model weights do not exist, downloading")
62
- model_path = model_download.main(model_name, install_path)
63
+ model_path = download_models(model_name, install_path)
63
64
  amend_user_configuration(new_model_path=model_path)
64
65
  model_weights = get_model_weights(config_file)
65
66
  else:
@@ -1,5 +1,7 @@
1
1
  from pathlib import Path
2
2
 
3
+ from cellfinder import DEFAULT_CELLFINDER_DIRECTORY
4
+
3
5
 
4
6
  def default_configuration_path():
5
7
  """
@@ -17,11 +19,11 @@ def user_specific_configuration_path():
17
19
 
18
20
  This function returns the path to the user-specific configuration file
19
21
  for cellfinder. The user-specific configuration file is located in the
20
- user's home directory under the ".cellfinder" folder and is named
21
- "cellfinder.conf.custom".
22
+ user's home directory under the ".brainglobe/cellfinder" folder
23
+ and is named "cellfinder.conf.custom".
22
24
 
23
25
  Returns:
24
26
  Path: The path to the custom configuration file.
25
27
 
26
28
  """
27
- return Path.home() / ".cellfinder" / "cellfinder.conf.custom"
29
+ return DEFAULT_CELLFINDER_DIRECTORY / "cellfinder.conf.custom"
@@ -31,7 +31,7 @@ from sklearn.model_selection import train_test_split
31
31
  import cellfinder.core as program_for_log
32
32
  from cellfinder.core import logger
33
33
  from cellfinder.core.classify.resnet import layer_type
34
- from cellfinder.core.tools.prep import DEFAULT_INSTALL_PATH
34
+ from cellfinder.core.download.download import DEFAULT_DOWNLOAD_DIRECTORY
35
35
 
36
36
  tf_suppress_log_messages = [
37
37
  "sample_weight modes were coerced from",
@@ -112,8 +112,7 @@ def misc_parse(parser):
112
112
 
113
113
  def training_parse():
114
114
  from cellfinder.core.download.cli import (
115
- download_directory_parser,
116
- model_parser,
115
+ download_parser,
117
116
  )
118
117
 
119
118
  training_parser = ArgumentParser(
@@ -223,8 +222,7 @@ def training_parse():
223
222
  )
224
223
 
225
224
  training_parser = misc_parse(training_parser)
226
- training_parser = model_parser(training_parser)
227
- training_parser = download_directory_parser(training_parser)
225
+ training_parser = download_parser(training_parser)
228
226
  args = training_parser.parse_args()
229
227
 
230
228
  return args
@@ -306,7 +304,7 @@ def run(
306
304
  n_free_cpus=2,
307
305
  trained_model=None,
308
306
  model_weights=None,
309
- install_path=DEFAULT_INSTALL_PATH,
307
+ install_path=DEFAULT_DOWNLOAD_DIRECTORY,
310
308
  model="resnet50_tv",
311
309
  network_depth="50",
312
310
  learning_rate=0.0001,
@@ -1,8 +1,11 @@
1
+ from functools import partial
1
2
  from math import ceil
2
3
  from pathlib import Path
3
- from typing import Optional
4
+ from typing import Any, Callable, Dict, Optional, Tuple
4
5
 
5
6
  import napari
7
+ import napari.layers
8
+ from brainglobe_utils.cells.cells import Cell
6
9
  from magicgui import magicgui
7
10
  from magicgui.widgets import FunctionGui, ProgressBar
8
11
  from napari.utils.notifications import show_info
@@ -10,10 +13,11 @@ from qtpy.QtWidgets import QScrollArea
10
13
 
11
14
  from cellfinder.core.classify.cube_generator import get_cube_depth_min_max
12
15
  from cellfinder.napari.utils import (
13
- add_layers,
14
- header_label_widget,
16
+ add_classified_layers,
17
+ add_single_layer,
18
+ cellfinder_header,
15
19
  html_label_widget,
16
- widget_header,
20
+ napari_array_to_cells,
17
21
  )
18
22
 
19
23
  from .detect_containers import (
@@ -33,14 +37,194 @@ CUBE_DEPTH = 20
33
37
  MIN_PLANES_ANALYSE = 0
34
38
 
35
39
 
40
+ def get_heavy_widgets(
41
+ options: Dict[str, Any]
42
+ ) -> Tuple[Callable, Callable, Callable]:
43
+ # signal and other input are separated out from the main magicgui
44
+ # parameter selections and are inserted as widget children in their own
45
+ # sub-containers of the root. Because if these image parameters are
46
+ # included in the root widget, every time *any* parameter updates, the gui
47
+ # freezes for a bit likely because magicgui is processing something for
48
+ # all the parameters when any parameter changes. And this processing takes
49
+ # particularly long for image parameters. Placing them as sub-containers
50
+ # alleviates this
51
+ @magicgui(
52
+ call_button=False,
53
+ persist=False,
54
+ scrollable=False,
55
+ labels=False,
56
+ auto_call=True,
57
+ )
58
+ def signal_image_opt(
59
+ viewer: napari.Viewer,
60
+ signal_image: napari.layers.Image,
61
+ ):
62
+ """
63
+ magicgui widget for setting the signal_image parameter.
64
+
65
+ Parameters
66
+ ----------
67
+ signal_image : napari.layers.Image
68
+ Image layer containing the labelled cells
69
+ """
70
+ options["signal_image"] = signal_image
71
+ options["viewer"] = viewer
72
+
73
+ @magicgui(
74
+ call_button=False,
75
+ persist=False,
76
+ scrollable=False,
77
+ labels=False,
78
+ auto_call=True,
79
+ )
80
+ def background_image_opt(
81
+ background_image: napari.layers.Image,
82
+ ):
83
+ """
84
+ magicgui widget for setting the background image parameter.
85
+
86
+ Parameters
87
+ ----------
88
+ background_image : napari.layers.Image
89
+ Image layer without labelled cells
90
+ """
91
+ options["background_image"] = background_image
92
+
93
+ @magicgui(
94
+ call_button=False,
95
+ persist=False,
96
+ scrollable=False,
97
+ labels=False,
98
+ auto_call=True,
99
+ )
100
+ def cell_layer_opt(
101
+ cell_layer: napari.layers.Points,
102
+ ):
103
+ """
104
+ magicgui widget for setting the cell layer input when detection is
105
+ skipped.
106
+
107
+ Parameters
108
+ ----------
109
+ cell_layer : napari.layers.Points
110
+ If detection is skipped, select the cell layer containing the
111
+ detected cells to use for classification
112
+ """
113
+ options["cell_layer"] = cell_layer
114
+
115
+ return signal_image_opt, background_image_opt, cell_layer_opt
116
+
117
+
118
+ def add_heavy_widgets(
119
+ root: FunctionGui,
120
+ widgets: Tuple[FunctionGui, ...],
121
+ new_names: Tuple[str, ...],
122
+ insertions: Tuple[str, ...],
123
+ ) -> None:
124
+ for widget, new_name, insertion in zip(widgets, new_names, insertions):
125
+ # make it look as if it's directly in the root container
126
+ widget.margins = 0, 0, 0, 0
127
+ # the parameters of these widgets are updated using `auto_call` only.
128
+ # If False, magicgui passes these as args to root() when the root's
129
+ # function runs. But that doesn't list them as args of its function
130
+ widget.gui_only = True
131
+ root.insert(root.index(insertion) + 1, widget)
132
+ getattr(root, widget.name).label = new_name
133
+
134
+
135
+ def restore_options_defaults(widget: FunctionGui) -> None:
136
+ """
137
+ Restore default widget values.
138
+ """
139
+ defaults = {
140
+ **DataInputs.defaults(),
141
+ **DetectionInputs.defaults(),
142
+ **ClassificationInputs.defaults(),
143
+ **MiscInputs.defaults(),
144
+ }
145
+ for name, value in defaults.items():
146
+ if value is not None: # ignore fields with no default
147
+ getattr(widget, name).value = value
148
+
149
+
150
+ def get_results_callback(
151
+ skip_classification: bool, viewer: napari.Viewer
152
+ ) -> Callable:
153
+ """
154
+ Returns the callback that is connected to output of the pipeline.
155
+ It returns the detected points that we have to visualize.
156
+ """
157
+ if skip_classification:
158
+ # after detection w/o classification, everything is unknown
159
+ def done_func(points):
160
+ add_single_layer(
161
+ points,
162
+ viewer=viewer,
163
+ name="Cell candidates",
164
+ cell_type=Cell.UNKNOWN,
165
+ )
166
+
167
+ else:
168
+ # after classification we have either cell or unknown
169
+ def done_func(points):
170
+ add_classified_layers(
171
+ points,
172
+ viewer=viewer,
173
+ unknown_name="Rejected",
174
+ cell_name="Detected",
175
+ )
176
+
177
+ return done_func
178
+
179
+
180
+ def find_local_planes(
181
+ viewer: napari.Viewer,
182
+ voxel_size_z: float,
183
+ signal_image: napari.layers.Image,
184
+ ) -> Tuple[int, int]:
185
+ """
186
+ When detecting only locally, it returns the start and end planes to use.
187
+ """
188
+ current_plane = viewer.dims.current_step[0]
189
+
190
+ # so a reasonable number of cells in the plane are detected
191
+ planes_needed = MIN_PLANES_ANALYSE + int(
192
+ ceil((CUBE_DEPTH * NETWORK_VOXEL_SIZES[0]) / voxel_size_z)
193
+ )
194
+
195
+ start_plane, end_plane = get_cube_depth_min_max(
196
+ current_plane, planes_needed
197
+ )
198
+ start_plane = max(0, start_plane)
199
+ end_plane = min(len(signal_image.data), end_plane)
200
+
201
+ return start_plane, end_plane
202
+
203
+
204
+ def reraise(e: Exception) -> None:
205
+ """Re-raises the exception."""
206
+ raise Exception from e
207
+
208
+
36
209
  def detect_widget() -> FunctionGui:
37
210
  """
38
211
  Create a detection plugin GUI.
39
212
  """
40
213
  progress_bar = ProgressBar()
41
214
 
215
+ # options that is filled in from the gui
216
+ options = {
217
+ "signal_image": None,
218
+ "background_image": None,
219
+ "viewer": None,
220
+ "cell_layer": None,
221
+ }
222
+
223
+ signal_image_opt, background_image_opt, cell_layer_opt = get_heavy_widgets(
224
+ options
225
+ )
226
+
42
227
  @magicgui(
43
- header=header_label_widget,
44
228
  detection_label=html_label_widget("Cell detection", tag="h3"),
45
229
  **DataInputs.widget_representation(),
46
230
  **DetectionInputs.widget_representation(),
@@ -52,16 +236,13 @@ def detect_widget() -> FunctionGui:
52
236
  scrollable=True,
53
237
  )
54
238
  def widget(
55
- header,
56
239
  detection_label,
57
240
  data_options,
58
- viewer: napari.Viewer,
59
- signal_image: napari.layers.Image,
60
- background_image: napari.layers.Image,
61
241
  voxel_size_z: float,
62
242
  voxel_size_y: float,
63
243
  voxel_size_x: float,
64
244
  detection_options,
245
+ skip_detection: bool,
65
246
  soma_diameter: float,
66
247
  ball_xy_size: float,
67
248
  ball_z_size: float,
@@ -71,6 +252,7 @@ def detect_widget() -> FunctionGui:
71
252
  soma_spread_factor: float,
72
253
  max_cluster_size: int,
73
254
  classification_options,
255
+ skip_classification: bool,
74
256
  trained_model: Optional[Path],
75
257
  use_pre_trained_weights: bool,
76
258
  misc_options,
@@ -86,16 +268,16 @@ def detect_widget() -> FunctionGui:
86
268
 
87
269
  Parameters
88
270
  ----------
89
- signal_image : napari.layers.Image
90
- Image layer containing the labelled cells
91
- background_image : napari.layers.Image
92
- Image layer without labelled cells
93
271
  voxel_size_z : float
94
272
  Size of your voxels in the axial dimension
95
273
  voxel_size_y : float
96
274
  Size of your voxels in the y direction (top to bottom)
97
275
  voxel_size_x : float
98
276
  Size of your voxels in the x direction (left to right)
277
+ skip_detection : bool
278
+ If selected, the detection step is skipped and instead we get the
279
+ detected cells from the cell layer below (from a previous
280
+ detection run or import)
99
281
  soma_diameter : float
100
282
  The expected in-plane soma diameter (microns)
101
283
  ball_xy_size : float
@@ -116,6 +298,9 @@ def detect_widget() -> FunctionGui:
116
298
  should be attempted
117
299
  use_pre_trained_weights : bool
118
300
  Select to use pre-trained model weights
301
+ skip_classification : bool
302
+ If selected, the classification step is skipped and all cells from
303
+ the detection stage are added
119
304
  trained_model : Optional[Path]
120
305
  Trained model file path (home directory (default) -> pretrained
121
306
  weights)
@@ -132,18 +317,48 @@ def detect_widget() -> FunctionGui:
132
317
  reset_button :
133
318
  Reset parameters to default
134
319
  """
135
- if signal_image is None or background_image is None:
320
+ # we must manually call so that the parameters of these functions are
321
+ # initialized and updated. Because, if the images are open in napari
322
+ # before we open cellfinder, then these functions may never be called,
323
+ # even though the image filenames are shown properly in the parameters
324
+ # in the gui. Likely auto_call doesn't make magicgui call the functions
325
+ # in this circumstance, only if the parameters are updated once
326
+ # cellfinder plugin is fully open and initialized
327
+ signal_image_opt()
328
+ background_image_opt()
329
+ cell_layer_opt()
330
+
331
+ signal_image = options["signal_image"]
332
+
333
+ if signal_image is None or options["background_image"] is None:
136
334
  show_info("Both signal and background images must be specified.")
137
335
  return
336
+
337
+ detected_cells = []
338
+ if skip_detection:
339
+ if options["cell_layer"] is None:
340
+ show_info(
341
+ "Skip detection selected, but no existing cell layer "
342
+ "is selected."
343
+ )
344
+ return
345
+
346
+ # set cells as unknown so that classification will process them
347
+ detected_cells = napari_array_to_cells(
348
+ options["cell_layer"], Cell.UNKNOWN
349
+ )
350
+
138
351
  data_inputs = DataInputs(
139
352
  signal_image.data,
140
- background_image.data,
353
+ options["background_image"].data,
141
354
  voxel_size_z,
142
355
  voxel_size_y,
143
356
  voxel_size_x,
144
357
  )
145
358
 
146
359
  detection_inputs = DetectionInputs(
360
+ skip_detection,
361
+ detected_cells,
147
362
  soma_diameter,
148
363
  ball_xy_size,
149
364
  ball_z_size,
@@ -157,24 +372,15 @@ def detect_widget() -> FunctionGui:
157
372
  if use_pre_trained_weights:
158
373
  trained_model = None
159
374
  classification_inputs = ClassificationInputs(
160
- use_pre_trained_weights, trained_model
375
+ skip_classification, use_pre_trained_weights, trained_model
161
376
  )
162
377
 
163
- end_plane = len(signal_image.data) if end_plane == 0 else end_plane
164
-
165
378
  if analyse_local:
166
- current_plane = viewer.dims.current_step[0]
167
-
168
- # so a reasonable number of cells in the plane are detected
169
- planes_needed = MIN_PLANES_ANALYSE + int(
170
- ceil((CUBE_DEPTH * NETWORK_VOXEL_SIZES[0]) / voxel_size_z)
379
+ start_plane, end_plane = find_local_planes(
380
+ options["viewer"], voxel_size_z, signal_image
171
381
  )
172
-
173
- start_plane, end_plane = get_cube_depth_min_max(
174
- current_plane, planes_needed
175
- )
176
- start_plane = max(0, start_plane)
177
- end_plane = min(len(signal_image.data), end_plane)
382
+ elif not end_plane:
383
+ end_plane = len(signal_image.data)
178
384
 
179
385
  misc_inputs = MiscInputs(
180
386
  start_plane, end_plane, n_free_cpus, analyse_local, debug
@@ -186,45 +392,34 @@ def detect_widget() -> FunctionGui:
186
392
  classification_inputs,
187
393
  misc_inputs,
188
394
  )
395
+
189
396
  worker.returned.connect(
190
- lambda points: add_layers(points, viewer=viewer)
397
+ get_results_callback(skip_classification, options["viewer"])
191
398
  )
192
-
193
399
  # Make sure if the worker emits an error, it is propagated to this
194
400
  # thread
195
- def reraise(e):
196
- raise Exception from e
197
-
198
401
  worker.errored.connect(reraise)
402
+ worker.connect_progress_bar_callback(progress_bar)
199
403
 
200
- def update_progress_bar(label: str, max: int, value: int):
201
- progress_bar.label = label
202
- progress_bar.max = max
203
- progress_bar.value = value
204
-
205
- worker.update_progress_bar.connect(update_progress_bar)
206
404
  worker.start()
207
405
 
208
- widget.header.value = widget_header
209
- widget.header.native.setOpenExternalLinks(True)
406
+ widget.native.layout().insertWidget(0, cellfinder_header())
210
407
 
211
- @widget.reset_button.changed.connect
212
- def restore_defaults():
213
- """
214
- Restore default widget values.
215
- """
216
- defaults = {
217
- **DataInputs.defaults(),
218
- **DetectionInputs.defaults(),
219
- **ClassificationInputs.defaults(),
220
- **MiscInputs.defaults(),
221
- }
222
- for name, value in defaults.items():
223
- if value is not None: # ignore fields with no default
224
- getattr(widget, name).value = value
408
+ # reset restores defaults
409
+ widget.reset_button.changed.connect(
410
+ partial(restore_options_defaults, widget)
411
+ )
225
412
 
226
413
  # Insert progress bar before the run and reset buttons
227
- widget.insert(-3, progress_bar)
414
+ widget.insert(widget.index("debug") + 1, progress_bar)
415
+
416
+ # add the signal and background image etc.
417
+ add_heavy_widgets(
418
+ widget,
419
+ (background_image_opt, signal_image_opt, cell_layer_opt),
420
+ ("Background image", "Signal image", "Candidate cell layer"),
421
+ ("voxel_size_z", "voxel_size_z", "soma_diameter"),
422
+ )
228
423
 
229
424
  scroll = QScrollArea()
230
425
  scroll.setWidget(widget._widget._qwidget)