cellfinder 1.1.3__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cellfinder might be problematic. Click here for more details.

Files changed (34) hide show
  1. cellfinder/__init__.py +21 -12
  2. cellfinder/core/classify/classify.py +13 -6
  3. cellfinder/core/classify/cube_generator.py +27 -11
  4. cellfinder/core/classify/resnet.py +9 -6
  5. cellfinder/core/classify/tools.py +13 -11
  6. cellfinder/core/detect/detect.py +12 -1
  7. cellfinder/core/detect/filters/volume/ball_filter.py +198 -113
  8. cellfinder/core/detect/filters/volume/structure_detection.py +105 -41
  9. cellfinder/core/detect/filters/volume/structure_splitting.py +1 -1
  10. cellfinder/core/detect/filters/volume/volume_filter.py +48 -49
  11. cellfinder/core/download/cli.py +39 -32
  12. cellfinder/core/download/download.py +44 -56
  13. cellfinder/core/main.py +53 -68
  14. cellfinder/core/tools/prep.py +12 -20
  15. cellfinder/core/tools/source_files.py +5 -3
  16. cellfinder/core/tools/system.py +10 -0
  17. cellfinder/core/train/train_yml.py +29 -27
  18. cellfinder/napari/curation.py +1 -1
  19. cellfinder/napari/detect/detect.py +259 -58
  20. cellfinder/napari/detect/detect_containers.py +11 -1
  21. cellfinder/napari/detect/thread_worker.py +16 -2
  22. cellfinder/napari/train/train.py +2 -9
  23. cellfinder/napari/train/train_containers.py +3 -3
  24. cellfinder/napari/utils.py +88 -47
  25. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/METADATA +12 -11
  26. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/RECORD +30 -34
  27. cellfinder/core/download/models.py +0 -49
  28. cellfinder/core/tools/IO.py +0 -48
  29. cellfinder/core/tools/tf.py +0 -46
  30. cellfinder/napari/images/brainglobe.png +0 -0
  31. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/LICENSE +0 -0
  32. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/WHEEL +0 -0
  33. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/entry_points.txt +0 -0
  34. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/top_level.txt +0 -0
cellfinder/core/main.py CHANGED
@@ -1,22 +1,13 @@
1
- """
2
- N.B imports are within functions to prevent tensorflow being imported before
3
- it's warnings are silenced
4
- """
5
-
6
1
  import os
7
2
  from typing import Callable, List, Optional, Tuple
8
3
 
9
4
  import numpy as np
10
- from brainglobe_utils.general.logging import suppress_specific_logs
5
+ from brainglobe_utils.cells.cells import Cell
11
6
 
12
7
  from cellfinder.core import logger
13
- from cellfinder.core.download.models import model_type
8
+ from cellfinder.core.download.download import model_type
14
9
  from cellfinder.core.train.train_yml import depth_type
15
10
 
16
- tf_suppress_log_messages = [
17
- "multiprocessing can interact badly with TensorFlow"
18
- ]
19
-
20
11
 
21
12
  def main(
22
13
  signal_array: np.ndarray,
@@ -27,7 +18,7 @@ def main(
27
18
  trained_model: Optional[os.PathLike] = None,
28
19
  model_weights: Optional[os.PathLike] = None,
29
20
  model: model_type = "resnet50_tv",
30
- batch_size: int = 32,
21
+ batch_size: int = 64,
31
22
  n_free_cpus: int = 2,
32
23
  network_voxel_sizes: Tuple[int, int, int] = (5, 1, 1),
33
24
  soma_diameter: int = 16,
@@ -42,6 +33,9 @@ def main(
42
33
  cube_height: int = 50,
43
34
  cube_depth: int = 20,
44
35
  network_depth: depth_type = "50",
36
+ skip_detection: bool = False,
37
+ skip_classification: bool = False,
38
+ detected_cells: List[Cell] = None,
45
39
  *,
46
40
  detect_callback: Optional[Callable[[int], None]] = None,
47
41
  classify_callback: Optional[Callable[[int], None]] = None,
@@ -54,73 +48,64 @@ def main(
54
48
  Called every time a plane has finished being processed during the
55
49
  detection stage. Called with the plane number that has finished.
56
50
  classify_callback : Callable[int], optional
57
- Called every time tensorflow has finished classifying a point.
51
+ Called every time a point has finished being classified.
58
52
  Called with the batch number that has just finished.
59
53
  detect_finished_callback : Callable[list], optional
60
54
  Called after detection is finished with the list of detected points.
61
55
  """
62
- suppress_tf_logging(tf_suppress_log_messages)
63
-
64
56
  from cellfinder.core.classify import classify
65
57
  from cellfinder.core.detect import detect
66
58
  from cellfinder.core.tools import prep
67
59
 
68
- logger.info("Detecting cell candidates")
60
+ if not skip_detection:
61
+ logger.info("Detecting cell candidates")
69
62
 
70
- points = detect.main(
71
- signal_array,
72
- start_plane,
73
- end_plane,
74
- voxel_sizes,
75
- soma_diameter,
76
- max_cluster_size,
77
- ball_xy_size,
78
- ball_z_size,
79
- ball_overlap_fraction,
80
- soma_spread_factor,
81
- n_free_cpus,
82
- log_sigma_size,
83
- n_sds_above_mean_thresh,
84
- callback=detect_callback,
85
- )
86
-
87
- if detect_finished_callback is not None:
88
- detect_finished_callback(points)
89
-
90
- install_path = None
91
- model_weights = prep.prep_model_weights(
92
- model_weights, install_path, model, n_free_cpus
93
- )
94
- if len(points) > 0:
95
- logger.info("Running classification")
96
- points = classify.main(
97
- points,
63
+ points = detect.main(
98
64
  signal_array,
99
- background_array,
100
- n_free_cpus,
65
+ start_plane,
66
+ end_plane,
101
67
  voxel_sizes,
102
- network_voxel_sizes,
103
- batch_size,
104
- cube_height,
105
- cube_width,
106
- cube_depth,
107
- trained_model,
108
- model_weights,
109
- network_depth,
110
- callback=classify_callback,
68
+ soma_diameter,
69
+ max_cluster_size,
70
+ ball_xy_size,
71
+ ball_z_size,
72
+ ball_overlap_fraction,
73
+ soma_spread_factor,
74
+ n_free_cpus,
75
+ log_sigma_size,
76
+ n_sds_above_mean_thresh,
77
+ callback=detect_callback,
111
78
  )
112
- else:
113
- logger.info("No candidates, skipping classification")
114
- return points
115
-
116
79
 
117
- def suppress_tf_logging(tf_suppress_log_messages: List[str]) -> None:
118
- """
119
- Prevents many lines of logs such as:
120
- "2019-10-24 16:54:41.363978: I tensorflow/stream_executor/platform/default
121
- /dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1"
122
- """
123
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
80
+ if detect_finished_callback is not None:
81
+ detect_finished_callback(points)
82
+ else:
83
+ points = detected_cells or [] # if None
84
+ detect_finished_callback(points)
124
85
 
125
- for message in tf_suppress_log_messages:
126
- suppress_specific_logs("tensorflow", message)
86
+ if not skip_classification:
87
+ install_path = None
88
+ model_weights = prep.prep_model_weights(
89
+ model_weights, install_path, model
90
+ )
91
+ if len(points) > 0:
92
+ logger.info("Running classification")
93
+ points = classify.main(
94
+ points,
95
+ signal_array,
96
+ background_array,
97
+ n_free_cpus,
98
+ voxel_sizes,
99
+ network_voxel_sizes,
100
+ batch_size,
101
+ cube_height,
102
+ cube_width,
103
+ cube_depth,
104
+ trained_model,
105
+ model_weights,
106
+ network_depth,
107
+ callback=classify_callback,
108
+ )
109
+ else:
110
+ logger.info("No candidates, skipping classification")
111
+ return points
@@ -9,42 +9,34 @@ from pathlib import Path
9
9
  from typing import Optional
10
10
 
11
11
  from brainglobe_utils.general.config import get_config_obj
12
- from brainglobe_utils.general.system import get_num_processes
13
12
 
14
- import cellfinder.core.tools.tf as tf_tools
15
13
  from cellfinder.core import logger
16
- from cellfinder.core.download import models as model_download
17
- from cellfinder.core.download.download import amend_user_configuration
14
+ from cellfinder.core.download.download import (
15
+ DEFAULT_DOWNLOAD_DIRECTORY,
16
+ amend_user_configuration,
17
+ download_models,
18
+ model_type,
19
+ )
18
20
  from cellfinder.core.tools.source_files import user_specific_configuration_path
19
21
 
20
- home = Path.home()
21
- DEFAULT_INSTALL_PATH = home / ".cellfinder"
22
-
23
22
 
24
23
  def prep_model_weights(
25
24
  model_weights: Optional[os.PathLike],
26
25
  install_path: Optional[os.PathLike],
27
- model_name: model_download.model_type,
28
- n_free_cpus: int,
26
+ model_name: model_type,
29
27
  ) -> Path:
30
- n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
31
- prep_tensorflow(n_processes)
28
+ # prepare models (get default weights or provided ones)
32
29
  model_weights = prep_models(model_weights, install_path, model_name)
33
30
 
34
31
  return model_weights
35
32
 
36
33
 
37
- def prep_tensorflow(max_threads: int) -> None:
38
- tf_tools.set_tf_threads(max_threads)
39
- tf_tools.allow_gpu_memory_growth()
40
-
41
-
42
34
  def prep_models(
43
35
  model_weights_path: Optional[os.PathLike],
44
36
  install_path: Optional[os.PathLike],
45
- model_name: model_download.model_type,
37
+ model_name: model_type,
46
38
  ) -> Path:
47
- install_path = install_path or DEFAULT_INSTALL_PATH
39
+ install_path = install_path or DEFAULT_DOWNLOAD_DIRECTORY
48
40
  # if no model or weights, set default weights
49
41
  if model_weights_path is None:
50
42
  logger.debug("No model supplied, so using the default")
@@ -53,13 +45,13 @@ def prep_models(
53
45
 
54
46
  if not Path(config_file).exists():
55
47
  logger.debug("Custom config does not exist, downloading models")
56
- model_path = model_download.main(model_name, install_path)
48
+ model_path = download_models(model_name, install_path)
57
49
  amend_user_configuration(new_model_path=model_path)
58
50
 
59
51
  model_weights = get_model_weights(config_file)
60
52
  if not model_weights.exists():
61
53
  logger.debug("Model weights do not exist, downloading")
62
- model_path = model_download.main(model_name, install_path)
54
+ model_path = download_models(model_name, install_path)
63
55
  amend_user_configuration(new_model_path=model_path)
64
56
  model_weights = get_model_weights(config_file)
65
57
  else:
@@ -1,5 +1,7 @@
1
1
  from pathlib import Path
2
2
 
3
+ from cellfinder import DEFAULT_CELLFINDER_DIRECTORY
4
+
3
5
 
4
6
  def default_configuration_path():
5
7
  """
@@ -17,11 +19,11 @@ def user_specific_configuration_path():
17
19
 
18
20
  This function returns the path to the user-specific configuration file
19
21
  for cellfinder. The user-specific configuration file is located in the
20
- user's home directory under the ".cellfinder" folder and is named
21
- "cellfinder.conf.custom".
22
+ user's home directory under the ".brainglobe/cellfinder" folder
23
+ and is named "cellfinder.conf.custom".
22
24
 
23
25
  Returns:
24
26
  Path: The path to the custom configuration file.
25
27
 
26
28
  """
27
- return Path.home() / ".cellfinder" / "cellfinder.conf.custom"
29
+ return DEFAULT_CELLFINDER_DIRECTORY / "cellfinder.conf.custom"
@@ -1,5 +1,6 @@
1
1
  from pathlib import Path
2
2
 
3
+ import keras
3
4
  from brainglobe_utils.general.exceptions import CommandLineInputError
4
5
 
5
6
 
@@ -80,3 +81,12 @@ def memory_in_bytes(memory_amount, unit):
80
81
  )
81
82
  else:
82
83
  return memory_amount * 10 ** supported_units[unit]
84
+
85
+
86
+ def force_cpu():
87
+ """
88
+ Forces the CPU to be used, even if a GPU is available
89
+ """
90
+ keras.src.backend.common.global_state.set_global_attribute(
91
+ "torch_device", "cpu"
92
+ )
@@ -22,7 +22,10 @@ from brainglobe_utils.general.numerical import (
22
22
  check_positive_float,
23
23
  check_positive_int,
24
24
  )
25
- from brainglobe_utils.general.system import ensure_directory_exists
25
+ from brainglobe_utils.general.system import (
26
+ ensure_directory_exists,
27
+ get_num_processes,
28
+ )
26
29
  from brainglobe_utils.IO.cells import find_relevant_tiffs
27
30
  from brainglobe_utils.IO.yaml import read_yaml_section
28
31
  from fancylog import fancylog
@@ -31,12 +34,7 @@ from sklearn.model_selection import train_test_split
31
34
  import cellfinder.core as program_for_log
32
35
  from cellfinder.core import logger
33
36
  from cellfinder.core.classify.resnet import layer_type
34
- from cellfinder.core.tools.prep import DEFAULT_INSTALL_PATH
35
-
36
- tf_suppress_log_messages = [
37
- "sample_weight modes were coerced from",
38
- "multiprocessing can interact badly with TensorFlow",
39
- ]
37
+ from cellfinder.core.download.download import DEFAULT_DOWNLOAD_DIRECTORY
40
38
 
41
39
  depth_type = Literal["18", "34", "50", "101", "152"]
42
40
 
@@ -112,8 +110,7 @@ def misc_parse(parser):
112
110
 
113
111
  def training_parse():
114
112
  from cellfinder.core.download.cli import (
115
- download_directory_parser,
116
- model_parser,
113
+ download_parser,
117
114
  )
118
115
 
119
116
  training_parser = ArgumentParser(
@@ -223,8 +220,7 @@ def training_parse():
223
220
  )
224
221
 
225
222
  training_parser = misc_parse(training_parser)
226
- training_parser = model_parser(training_parser)
227
- training_parser = download_directory_parser(training_parser)
223
+ training_parser = download_parser(training_parser)
228
224
  args = training_parser.parse_args()
229
225
 
230
226
  return args
@@ -306,7 +302,7 @@ def run(
306
302
  n_free_cpus=2,
307
303
  trained_model=None,
308
304
  model_weights=None,
309
- install_path=DEFAULT_INSTALL_PATH,
305
+ install_path=DEFAULT_DOWNLOAD_DIRECTORY,
310
306
  model="resnet50_tv",
311
307
  network_depth="50",
312
308
  learning_rate=0.0001,
@@ -320,11 +316,7 @@ def run(
320
316
  save_progress=False,
321
317
  epochs=100,
322
318
  ):
323
- from cellfinder.core.main import suppress_tf_logging
324
-
325
- suppress_tf_logging(tf_suppress_log_messages)
326
-
327
- from tensorflow.keras.callbacks import (
319
+ from keras.callbacks import (
328
320
  CSVLogger,
329
321
  ModelCheckpoint,
330
322
  TensorBoard,
@@ -341,7 +333,6 @@ def run(
341
333
  model_weights=model_weights,
342
334
  install_path=install_path,
343
335
  model_name=model,
344
- n_free_cpus=n_free_cpus,
345
336
  )
346
337
 
347
338
  yaml_contents = parse_yaml(yaml_file)
@@ -363,6 +354,7 @@ def run(
363
354
 
364
355
  signal_train, background_train, labels_train = make_lists(tiff_files)
365
356
 
357
+ n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
366
358
  if test_fraction > 0:
367
359
  logger.info("Splitting data into training and validation datasets")
368
360
  (
@@ -389,15 +381,17 @@ def run(
389
381
  labels=labels_test,
390
382
  batch_size=batch_size,
391
383
  train=True,
384
+ use_multiprocessing=False,
385
+ workers=n_processes,
392
386
  )
393
387
 
394
388
  # for saving checkpoints
395
- base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}.h5"
389
+ base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}"
396
390
 
397
391
  else:
398
392
  logger.info("No validation data selected.")
399
393
  validation_generator = None
400
- base_checkpoint_file_name = "-epoch.{epoch:02d}.h5"
394
+ base_checkpoint_file_name = "-epoch.{epoch:02d}"
401
395
 
402
396
  training_generator = CubeGeneratorFromDisk(
403
397
  signal_train,
@@ -407,6 +401,8 @@ def run(
407
401
  shuffle=True,
408
402
  train=True,
409
403
  augment=not no_augment,
404
+ use_multiprocessing=False,
405
+ workers=n_processes,
410
406
  )
411
407
  callbacks = []
412
408
 
@@ -423,9 +419,14 @@ def run(
423
419
 
424
420
  if not no_save_checkpoints:
425
421
  if save_weights:
426
- filepath = str(output_dir / ("weight" + base_checkpoint_file_name))
422
+ filepath = str(
423
+ output_dir
424
+ / ("weight" + base_checkpoint_file_name + ".weights.h5")
425
+ )
427
426
  else:
428
- filepath = str(output_dir / ("model" + base_checkpoint_file_name))
427
+ filepath = str(
428
+ output_dir / ("model" + base_checkpoint_file_name + ".keras")
429
+ )
429
430
 
430
431
  checkpoints = ModelCheckpoint(
431
432
  filepath,
@@ -434,25 +435,26 @@ def run(
434
435
  callbacks.append(checkpoints)
435
436
 
436
437
  if save_progress:
437
- filepath = str(output_dir / "training.csv")
438
- csv_logger = CSVLogger(filepath)
438
+ csv_filepath = str(output_dir / "training.csv")
439
+ csv_logger = CSVLogger(csv_filepath)
439
440
  callbacks.append(csv_logger)
440
441
 
441
442
  logger.info("Beginning training.")
443
+ # Keras 3.0: `use_multiprocessing` input is set in the
444
+ # `training_generator` (False by default)
442
445
  model.fit(
443
446
  training_generator,
444
447
  validation_data=validation_generator,
445
- use_multiprocessing=False,
446
448
  epochs=epochs,
447
449
  callbacks=callbacks,
448
450
  )
449
451
 
450
452
  if save_weights:
451
453
  logger.info("Saving model weights")
452
- model.save_weights(str(output_dir / "model_weights.h5"))
454
+ model.save_weights(output_dir / "model.weights.h5")
453
455
  else:
454
456
  logger.info("Saving model")
455
- model.save(output_dir / "model.h5")
457
+ model.save(output_dir / "model.keras")
456
458
 
457
459
  logger.info(
458
460
  "Finished training, " "Total time taken: %s",
@@ -54,7 +54,7 @@ class CurationWidget(QWidget):
54
54
  self.save_empty_cubes = save_empty_cubes
55
55
  self.max_ram = max_ram
56
56
  self.voxel_sizes = [5, 2, 2]
57
- self.batch_size = 32
57
+ self.batch_size = 64
58
58
  self.viewer = viewer
59
59
 
60
60
  self.signal_layer = None