cellfinder 1.2.0__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cellfinder might be problematic. Click here for more details.

Files changed (72) hide show
  1. {cellfinder-1.2.0 → cellfinder-1.3.0}/.github/workflows/test_and_deploy.yml +32 -21
  2. {cellfinder-1.2.0 → cellfinder-1.3.0}/.github/workflows/test_include_guard.yaml +10 -13
  3. {cellfinder-1.2.0 → cellfinder-1.3.0}/PKG-INFO +4 -3
  4. cellfinder-1.3.0/cellfinder/__init__.py +33 -0
  5. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/classify/classify.py +13 -6
  6. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/classify/cube_generator.py +27 -11
  7. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/classify/resnet.py +9 -6
  8. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/classify/tools.py +13 -11
  9. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/main.py +3 -28
  10. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/tools/prep.py +1 -10
  11. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/tools/system.py +10 -0
  12. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/train/train_yml.py +25 -21
  13. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/curation.py +1 -1
  14. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/detect/detect.py +8 -2
  15. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/detect/detect_containers.py +2 -0
  16. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/detect/thread_worker.py +2 -2
  17. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder.egg-info/PKG-INFO +4 -3
  18. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder.egg-info/SOURCES.txt +0 -1
  19. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder.egg-info/requires.txt +2 -6
  20. {cellfinder-1.2.0 → cellfinder-1.3.0}/pyproject.toml +9 -14
  21. cellfinder-1.2.0/cellfinder/__init__.py +0 -27
  22. cellfinder-1.2.0/cellfinder/core/tools/tf.py +0 -46
  23. {cellfinder-1.2.0 → cellfinder-1.3.0}/.gitignore +0 -0
  24. {cellfinder-1.2.0 → cellfinder-1.3.0}/.napari/config.yml +0 -0
  25. {cellfinder-1.2.0 → cellfinder-1.3.0}/CITATION.cff +0 -0
  26. {cellfinder-1.2.0 → cellfinder-1.3.0}/LICENSE +0 -0
  27. {cellfinder-1.2.0 → cellfinder-1.3.0}/MANIFEST.in +0 -0
  28. {cellfinder-1.2.0 → cellfinder-1.3.0}/README.md +0 -0
  29. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/cli_migration_warning.py +0 -0
  30. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/__init__.py +0 -0
  31. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/classify/__init__.py +0 -0
  32. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/classify/augment.py +0 -0
  33. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/config/__init__.py +0 -0
  34. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/config/cellfinder.conf +0 -0
  35. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/__init__.py +0 -0
  36. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/detect.py +0 -0
  37. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/__init__.py +0 -0
  38. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/plane/__init__.py +0 -0
  39. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/plane/classical_filter.py +0 -0
  40. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/plane/plane_filter.py +0 -0
  41. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/plane/tile_walker.py +0 -0
  42. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/setup_filters.py +0 -0
  43. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/volume/__init__.py +0 -0
  44. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/volume/ball_filter.py +0 -0
  45. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/volume/structure_detection.py +0 -0
  46. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/volume/structure_splitting.py +0 -0
  47. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/detect/filters/volume/volume_filter.py +0 -0
  48. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/download/__init__.py +0 -0
  49. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/download/cli.py +0 -0
  50. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/download/download.py +0 -0
  51. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/tools/__init__.py +0 -0
  52. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/tools/array_operations.py +0 -0
  53. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/tools/geometry.py +0 -0
  54. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/tools/image_processing.py +0 -0
  55. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/tools/source_files.py +0 -0
  56. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/tools/tiff.py +0 -0
  57. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/tools/tools.py +0 -0
  58. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/train/__init__.py +0 -0
  59. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/core/types.py +0 -0
  60. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/__init__.py +0 -0
  61. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/detect/__init__.py +0 -0
  62. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/input_container.py +0 -0
  63. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/napari.yaml +0 -0
  64. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/sample_data.py +0 -0
  65. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/train/__init__.py +0 -0
  66. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/train/train.py +0 -0
  67. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/train/train_containers.py +0 -0
  68. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder/napari/utils.py +0 -0
  69. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder.egg-info/dependency_links.txt +0 -0
  70. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder.egg-info/entry_points.txt +0 -0
  71. {cellfinder-1.2.0 → cellfinder-1.3.0}/cellfinder.egg-info/top_level.txt +0 -0
  72. {cellfinder-1.2.0 → cellfinder-1.3.0}/setup.cfg +0 -0
@@ -37,27 +37,31 @@ jobs:
37
37
  name: Run package tests
38
38
  timeout-minutes: 60
39
39
  runs-on: ${{ matrix.os }}
40
+ env:
41
+ KERAS_BACKEND: torch
42
+ CELLFINDER_TEST_DEVICE: cpu
40
43
  strategy:
41
44
  matrix:
42
45
  # Run all supported Python versions on linux
43
46
  os: [ubuntu-latest]
44
- python-version: ["3.9", "3.10"]
45
- # Include one windows, one macos run each for M1 (latest) and Intel (13)
47
+ python-version: ["3.9", "3.10", "3.11"]
48
+ # Include one windows and two macOS (intel based and arm based) runs
46
49
  include:
47
50
  - os: macos-13
48
- python-version: "3.10"
51
+ python-version: "3.11"
49
52
  - os: macos-latest
50
- python-version: "3.10"
53
+ python-version: "3.11"
51
54
  - os: windows-latest
52
- python-version: "3.10"
55
+ python-version: "3.11"
53
56
 
54
57
  steps:
55
- # Cache the tensorflow model so we don't have to remake it every time
56
- - name: Cache tensorflow model
58
+ - name: Cache brainglobe directory
57
59
  uses: actions/cache@v3
58
60
  with:
59
- path: "~/.cellfinder"
60
- key: models-${{ hashFiles('~/.brainglobe/**') }}
61
+ path: | # ensure we don't cache any interrupted atlas download and extraction, if e.g. we cancel the workflow manually
62
+ ~/.brainglobe
63
+ !~/.brainglobe/atlas.tar.gz
64
+ key: brainglobe
61
65
  # Setup pyqt libraries
62
66
  - name: Setup qtpy libraries
63
67
  uses: tlambert03/setup-qt-libs@v1
@@ -79,11 +83,13 @@ jobs:
79
83
  NUMBA_DISABLE_JIT: "1"
80
84
 
81
85
  steps:
82
- - name: Cache tensorflow model
86
+ - name: Cache brainglobe directory
83
87
  uses: actions/cache@v3
84
88
  with:
85
- path: "~/.cellfinder"
86
- key: models-${{ hashFiles('~/.brainglobe/**') }}
89
+ path: | # ensure we don't cache any interrupted atlas download and extraction, if e.g. we cancel the workflow manually
90
+ ~/.brainglobe
91
+ !~/.brainglobe/atlas.tar.gz
92
+ key: brainglobe
87
93
  # Setup pyqt libraries
88
94
  - name: Setup qtpy libraries
89
95
  uses: tlambert03/setup-qt-libs@v1
@@ -92,7 +98,7 @@ jobs:
92
98
  # Run test suite with numba disabled
93
99
  - uses: neuroinformatics-unit/actions/test@v2
94
100
  with:
95
- python-version: "3.10"
101
+ python-version: "3.11"
96
102
  secret-codecov-token: ${{ secrets.CODECOV_TOKEN }}
97
103
  codecov-flags: "numba"
98
104
 
@@ -103,28 +109,33 @@ jobs:
103
109
  name: Run brainmapper tests to check for breakages
104
110
  timeout-minutes: 60
105
111
  runs-on: ubuntu-latest
112
+ env:
113
+ KERAS_BACKEND: torch
114
+ CELLFINDER_TEST_DEVICE: cpu
106
115
  steps:
107
- - name: Cache tensorflow model
116
+ - name: Cache brainglobe directory
108
117
  uses: actions/cache@v3
109
118
  with:
110
- path: "~/.cellfinder"
111
- key: models-${{ hashFiles('~/.brainglobe/**') }}
112
-
119
+ path: | # ensure we don't cache any interrupted atlas download and extraction, if e.g. we cancel the workflow manually
120
+ ~/.brainglobe
121
+ !~/.brainglobe/atlas.tar.gz
122
+ key: brainglobe
113
123
  - name: Checkout brainglobe-workflows
114
124
  uses: actions/checkout@v3
115
125
  with:
116
126
  repository: 'brainglobe/brainglobe-workflows'
117
127
 
118
- - name: Set up Python 3.10
128
+ - name: Set up Python 3.11
119
129
  uses: actions/setup-python@v3
120
130
  with:
121
- python-version: "3.10"
131
+ python-version: "3.11"
122
132
 
123
133
  - name: Install test dependencies
124
134
  run: |
125
135
  python -m pip install --upgrade pip wheel
126
- # Install latest SHA on this brainglobe-workflows branch
127
- python -m pip install git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA
136
+ # Install cellfinder from the latest SHA on this branch
137
+ python -m pip install "cellfinder @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA"
138
+
128
139
  # Install checked out copy of brainglobe-workflows
129
140
  python -m pip install .[dev]
130
141
 
@@ -1,5 +1,5 @@
1
- name: Test Tensorflow include guards
2
- # These tests check that the include guards checking for tensorflow's availability
1
+ name: Test Keras include guards
2
+ # These tests check that the include guards checking for Keras availability
3
3
  # behave as expected on ubuntu and macOS.
4
4
 
5
5
  on:
@@ -9,7 +9,7 @@ on:
9
9
  - main
10
10
 
11
11
  jobs:
12
- tensorflow_guards:
12
+ keras_guards:
13
13
  name: Test include guards
14
14
  strategy:
15
15
  matrix:
@@ -22,26 +22,23 @@ jobs:
22
22
  - name: Setup Python
23
23
  uses: actions/setup-python@v4
24
24
  with:
25
- python-version: '3.10'
25
+ python-version: '3.11'
26
26
 
27
- - name: Install via pip
28
- run: python -m pip install -e .
27
+ - name: Install cellfinder via pip
28
+ run: python -m pip install -e "."
29
29
 
30
30
  - name: Test (working) import
31
31
  uses: jannekem/run-python-script-action@v1
32
+ env:
33
+ KERAS_BACKEND: torch
32
34
  with:
33
35
  fail-on-error: true
34
36
  script: |
35
37
  import cellfinder.core
36
38
  import cellfinder.napari
37
39
 
38
- - name: Uninstall tensorflow-macos on Mac M1
39
- if: matrix.os == 'macos-latest'
40
- run: python -m pip uninstall -y tensorflow-macos
41
-
42
- - name: Uninstall tensorflow on Ubuntu
43
- if: matrix.os == 'ubuntu-latest'
44
- run: python -m pip uninstall -y tensorflow
40
+ - name: Uninstall keras
41
+ run: python -m pip uninstall -y keras
45
42
 
46
43
  - name: Test (broken) import
47
44
  id: broken_import
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cellfinder
3
- Version: 1.2.0
3
+ Version: 1.3.0
4
4
  Summary: Automated 3D cell detection in large microscopy images
5
5
  Author-email: "Adam Tyson, Christian Niedworok, Charly Rousseau" <code@adamltyson.com>
6
6
  License: BSD-3-Clause
@@ -18,6 +18,7 @@ Classifier: Programming Language :: Python
18
18
  Classifier: Programming Language :: Python :: 3
19
19
  Classifier: Programming Language :: Python :: 3.9
20
20
  Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
21
22
  Classifier: Topic :: Scientific/Engineering :: Image Recognition
22
23
  Requires-Python: >=3.9
23
24
  Description-Content-Type: text/markdown
@@ -31,8 +32,8 @@ Requires-Dist: numba
31
32
  Requires-Dist: numpy
32
33
  Requires-Dist: scikit-image
33
34
  Requires-Dist: scikit-learn
34
- Requires-Dist: tensorflow-macos<2.12.0,>=2.5.0; platform_system == "Darwin" and platform_machine == "arm64"
35
- Requires-Dist: tensorflow<2.12.0,>=2.5.0; platform_system != "Darwin" or platform_machine != "arm64"
35
+ Requires-Dist: keras>=3.0.0
36
+ Requires-Dist: torch>=2.1.0
36
37
  Requires-Dist: tifffile
37
38
  Requires-Dist: tqdm
38
39
  Provides-Extra: dev
@@ -0,0 +1,33 @@
1
+ import os
2
+ from importlib.metadata import PackageNotFoundError, version
3
+ from pathlib import Path
4
+
5
+ # Check cellfinder is installed
6
+ try:
7
+ __version__ = version("cellfinder")
8
+ except PackageNotFoundError as e:
9
+ raise PackageNotFoundError("cellfinder package not installed") from e
10
+
11
+ # If Keras is not present, tools cannot be used.
12
+ # Throw an error in this case to prevent invocation of functions.
13
+ try:
14
+ KERAS_VERSION = version("keras")
15
+ except PackageNotFoundError as e:
16
+ raise PackageNotFoundError(
17
+ f"cellfinder tools cannot be invoked without Keras. "
18
+ f"Please install Keras with a backend into your environment "
19
+ f"to use cellfinder tools. "
20
+ f"For more information on Keras backends, please see "
21
+ f"https://keras.io/getting_started/#installing-keras-3."
22
+ f"For more information on brainglobe, please see "
23
+ f"https://github.com/brainglobe/brainglobe-meta#readme."
24
+ ) from e
25
+
26
+
27
+ # Set the Keras backend to torch
28
+ os.environ["KERAS_BACKEND"] = "torch"
29
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
30
+
31
+ __license__ = "BSD-3-Clause"
32
+
33
+ DEFAULT_CELLFINDER_DIRECTORY = Path.home() / ".brainglobe" / "cellfinder"
@@ -1,10 +1,11 @@
1
1
  import os
2
+ from datetime import datetime
2
3
  from typing import Any, Callable, Dict, List, Optional, Tuple
3
4
 
5
+ import keras
4
6
  import numpy as np
5
7
  from brainglobe_utils.cells.cells import Cell
6
8
  from brainglobe_utils.general.system import get_num_processes
7
- from tensorflow import keras
8
9
 
9
10
  from cellfinder.core import logger, types
10
11
  from cellfinder.core.classify.cube_generator import CubeGeneratorFromFile
@@ -48,9 +49,9 @@ def main(
48
49
  callbacks = None
49
50
 
50
51
  # Too many workers doesn't increase speed, and uses huge amounts of RAM
51
- workers = get_num_processes(
52
- min_free_cpu_cores=n_free_cpus, n_max_processes=max_workers
53
- )
52
+ workers = get_num_processes(min_free_cpu_cores=n_free_cpus)
53
+
54
+ start_time = datetime.now()
54
55
 
55
56
  logger.debug("Initialising cube generator")
56
57
  inference_generator = CubeGeneratorFromFile(
@@ -63,6 +64,8 @@ def main(
63
64
  cube_width=cube_width,
64
65
  cube_height=cube_height,
65
66
  cube_depth=cube_depth,
67
+ use_multiprocessing=False,
68
+ workers=workers,
66
69
  )
67
70
 
68
71
  model = get_model(
@@ -73,10 +76,9 @@ def main(
73
76
  )
74
77
 
75
78
  logger.info("Running inference")
79
+ # in Keras 3.0 multiprocessing params are specified in the generator
76
80
  predictions = model.predict(
77
81
  inference_generator,
78
- use_multiprocessing=True,
79
- workers=workers,
80
82
  verbose=True,
81
83
  callbacks=callbacks,
82
84
  )
@@ -91,6 +93,11 @@ def main(
91
93
  cell.type = predictions[idx] + 1
92
94
  points_list.append(cell)
93
95
 
96
+ time_elapsed = datetime.now() - start_time
97
+ print(
98
+ "Classfication complete - all points done in : {}".format(time_elapsed)
99
+ )
100
+
94
101
  return points_list
95
102
 
96
103
 
@@ -2,13 +2,13 @@ from pathlib import Path
2
2
  from random import shuffle
3
3
  from typing import Dict, List, Optional, Tuple, Union
4
4
 
5
+ import keras
5
6
  import numpy as np
6
- import tensorflow as tf
7
7
  from brainglobe_utils.cells.cells import Cell, group_cells_by_z
8
8
  from brainglobe_utils.general.numerical import is_even
9
+ from keras.utils import Sequence
9
10
  from scipy.ndimage import zoom
10
11
  from skimage.io import imread
11
- from tensorflow.keras.utils import Sequence
12
12
 
13
13
  from cellfinder.core import types
14
14
  from cellfinder.core.classify.augment import AugmentationParameters, augment
@@ -40,7 +40,7 @@ class CubeGeneratorFromFile(Sequence):
40
40
  background_array: types.array,
41
41
  voxel_sizes: Tuple[int, int, int],
42
42
  network_voxel_sizes: Tuple[int, int, int],
43
- batch_size: int = 16,
43
+ batch_size: int = 64,
44
44
  cube_width: int = 50,
45
45
  cube_height: int = 50,
46
46
  cube_depth: int = 20,
@@ -56,7 +56,14 @@ class CubeGeneratorFromFile(Sequence):
56
56
  translate: Tuple[float, float, float] = (0.05, 0.05, 0.05),
57
57
  shuffle: bool = False,
58
58
  interpolation_order: int = 2,
59
+ *args,
60
+ **kwargs,
59
61
  ):
62
+ # pass any additional arguments not specified in signature to the
63
+ # constructor of the superclass (e.g.: `use_multiprocessing` or
64
+ # `workers`)
65
+ super().__init__(*args, **kwargs)
66
+
60
67
  self.points = points
61
68
  self.signal_array = signal_array
62
69
  self.background_array = background_array
@@ -218,10 +225,10 @@ class CubeGeneratorFromFile(Sequence):
218
225
 
219
226
  if self.train:
220
227
  batch_labels = [cell.type - 1 for cell in cell_batch]
221
- batch_labels = tf.keras.utils.to_categorical(
228
+ batch_labels = keras.utils.to_categorical(
222
229
  batch_labels, num_classes=self.classes
223
230
  )
224
- return images, batch_labels
231
+ return images, batch_labels.astype(np.float32)
225
232
  elif self.extract:
226
233
  batch_info = self.__get_batch_dict(cell_batch)
227
234
  return images, batch_info
@@ -252,7 +259,8 @@ class CubeGeneratorFromFile(Sequence):
252
259
  (number_images,)
253
260
  + (self.cube_height, self.cube_width, self.cube_depth)
254
261
  + (self.channels,)
255
- )
262
+ ),
263
+ dtype=np.float32,
256
264
  )
257
265
 
258
266
  for idx, cell in enumerate(cell_batch):
@@ -337,7 +345,7 @@ class CubeGeneratorFromDisk(Sequence):
337
345
  signal_list: List[Union[str, Path]],
338
346
  background_list: List[Union[str, Path]],
339
347
  labels: Optional[List[int]] = None, # only if training or validating
340
- batch_size: int = 16,
348
+ batch_size: int = 64,
341
349
  shape: Tuple[int, int, int] = (50, 50, 20),
342
350
  channels: int = 2,
343
351
  classes: int = 2,
@@ -350,7 +358,14 @@ class CubeGeneratorFromDisk(Sequence):
350
358
  translate: Tuple[float, float, float] = (0.2, 0.2, 0.2),
351
359
  train: bool = False, # also return labels
352
360
  interpolation_order: int = 2,
361
+ *args,
362
+ **kwargs,
353
363
  ):
364
+ # pass any additional arguments not specified in signature to the
365
+ # constructor of the superclass (e.g.: `use_multiprocessing` or
366
+ # `workers`)
367
+ super().__init__(*args, **kwargs)
368
+
354
369
  self.im_shape = shape
355
370
  self.batch_size = batch_size
356
371
  self.labels = labels
@@ -410,10 +425,10 @@ class CubeGeneratorFromDisk(Sequence):
410
425
 
411
426
  if self.train and self.labels is not None:
412
427
  batch_labels = [self.labels[k] for k in indexes]
413
- batch_labels = tf.keras.utils.to_categorical(
428
+ batch_labels = keras.utils.to_categorical(
414
429
  batch_labels, num_classes=self.classes
415
430
  )
416
- return images, batch_labels
431
+ return images, batch_labels.astype(np.float32)
417
432
  else:
418
433
  return images
419
434
 
@@ -424,7 +439,8 @@ class CubeGeneratorFromDisk(Sequence):
424
439
  ) -> np.ndarray:
425
440
  number_images = len(list_signal_tmp)
426
441
  images = np.empty(
427
- ((number_images,) + self.im_shape + (self.channels,))
442
+ ((number_images,) + self.im_shape + (self.channels,)),
443
+ dtype=np.float32,
428
444
  )
429
445
 
430
446
  for idx, signal_im in enumerate(list_signal_tmp):
@@ -433,7 +449,7 @@ class CubeGeneratorFromDisk(Sequence):
433
449
  images, idx, signal_im, background_im
434
450
  )
435
451
 
436
- return images.astype(np.float16)
452
+ return images
437
453
 
438
454
  def __populate_array_with_cubes(
439
455
  self,
@@ -1,9 +1,11 @@
1
1
  from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
2
2
 
3
- from tensorflow import Tensor
4
- from tensorflow.keras import Model
5
- from tensorflow.keras.initializers import Initializer
6
- from tensorflow.keras.layers import (
3
+ from keras import (
4
+ KerasTensor as Tensor,
5
+ )
6
+ from keras import Model
7
+ from keras.initializers import Initializer
8
+ from keras.layers import (
7
9
  Activation,
8
10
  Add,
9
11
  BatchNormalization,
@@ -14,7 +16,7 @@ from tensorflow.keras.layers import (
14
16
  MaxPooling3D,
15
17
  ZeroPadding3D,
16
18
  )
17
- from tensorflow.keras.optimizers import Adam, Optimizer
19
+ from keras.optimizers import Adam, Optimizer
18
20
 
19
21
  #####################################################################
20
22
  # Define the types of ResNet
@@ -113,7 +115,7 @@ def non_residual_block(
113
115
  activation: str = "relu",
114
116
  use_bias: bool = False,
115
117
  bn_epsilon: float = 1e-5,
116
- pooling_padding: str = "same",
118
+ pooling_padding: str = "valid",
117
119
  axis: int = 3,
118
120
  ) -> Tensor:
119
121
  """
@@ -131,6 +133,7 @@ def non_residual_block(
131
133
  )(x)
132
134
  x = BatchNormalization(axis=axis, epsilon=bn_epsilon, name="conv1_bn")(x)
133
135
  x = Activation(activation, name="conv1_activation")(x)
136
+
134
137
  x = MaxPooling3D(
135
138
  max_pool_size,
136
139
  strides=strides,
@@ -1,9 +1,10 @@
1
1
  import os
2
- from typing import List, Optional, Sequence, Tuple, Union
2
+ from collections.abc import Sequence
3
+ from typing import List, Optional, Tuple, Union
3
4
 
5
+ import keras
4
6
  import numpy as np
5
- import tensorflow as tf
6
- from tensorflow.keras import Model
7
+ from keras import Model
7
8
 
8
9
  from cellfinder.core import logger
9
10
  from cellfinder.core.classify.resnet import build_model, layer_type
@@ -17,8 +18,7 @@ def get_model(
17
18
  inference: bool = False,
18
19
  continue_training: bool = False,
19
20
  ) -> Model:
20
- """
21
- Returns the correct model based on the arguments passed
21
+ """Returns the correct model based on the arguments passed
22
22
  :param existing_model: An existing, trained model. This is returned if it
23
23
  exists
24
24
  :param model_weights: This file is used to set the model weights if it
@@ -30,29 +30,31 @@ def get_model(
30
30
  by using the default one
31
31
  :param continue_training: If True, will ensure that a trained model
32
32
  exists. E.g. by using the default one
33
- :return: A tf.keras model
33
+ :return: A keras model
34
34
 
35
35
  """
36
36
  if existing_model is not None or network_depth is None:
37
37
  logger.debug(f"Loading model: {existing_model}")
38
- return tf.keras.models.load_model(existing_model)
38
+ return keras.models.load_model(existing_model)
39
39
  else:
40
40
  logger.debug(f"Creating a new instance of model: {network_depth}")
41
41
  model = build_model(
42
- network_depth=network_depth, learning_rate=learning_rate
42
+ network_depth=network_depth,
43
+ learning_rate=learning_rate,
43
44
  )
44
45
  if inference or continue_training:
45
46
  logger.debug(
46
- f"Setting model weights according to: {model_weights}"
47
+ f"Setting model weights according to: {model_weights}",
47
48
  )
48
49
  if model_weights is None:
49
- raise IOError("`model_weights` must be provided")
50
+ raise OSError("`model_weights` must be provided")
50
51
  model.load_weights(model_weights)
51
52
  return model
52
53
 
53
54
 
54
55
  def make_lists(
55
- tiff_files: Sequence, train: bool = True
56
+ tiff_files: Sequence,
57
+ train: bool = True,
56
58
  ) -> Union[Tuple[List, List], Tuple[List, List, np.ndarray]]:
57
59
  signal_list = []
58
60
  background_list = []
@@ -1,23 +1,13 @@
1
- """
2
- N.B imports are within functions to prevent tensorflow being imported before
3
- it's warnings are silenced
4
- """
5
-
6
1
  import os
7
2
  from typing import Callable, List, Optional, Tuple
8
3
 
9
4
  import numpy as np
10
5
  from brainglobe_utils.cells.cells import Cell
11
- from brainglobe_utils.general.logging import suppress_specific_logs
12
6
 
13
7
  from cellfinder.core import logger
14
8
  from cellfinder.core.download.download import model_type
15
9
  from cellfinder.core.train.train_yml import depth_type
16
10
 
17
- tf_suppress_log_messages = [
18
- "multiprocessing can interact badly with TensorFlow"
19
- ]
20
-
21
11
 
22
12
  def main(
23
13
  signal_array: np.ndarray,
@@ -28,7 +18,7 @@ def main(
28
18
  trained_model: Optional[os.PathLike] = None,
29
19
  model_weights: Optional[os.PathLike] = None,
30
20
  model: model_type = "resnet50_tv",
31
- batch_size: int = 32,
21
+ batch_size: int = 64,
32
22
  n_free_cpus: int = 2,
33
23
  network_voxel_sizes: Tuple[int, int, int] = (5, 1, 1),
34
24
  soma_diameter: int = 16,
@@ -58,13 +48,11 @@ def main(
58
48
  Called every time a plane has finished being processed during the
59
49
  detection stage. Called with the plane number that has finished.
60
50
  classify_callback : Callable[int], optional
61
- Called every time tensorflow has finished classifying a point.
51
+ Called every time a point has finished being classified.
62
52
  Called with the batch number that has just finished.
63
53
  detect_finished_callback : Callable[list], optional
64
54
  Called after detection is finished with the list of detected points.
65
55
  """
66
- suppress_tf_logging(tf_suppress_log_messages)
67
-
68
56
  from cellfinder.core.classify import classify
69
57
  from cellfinder.core.detect import detect
70
58
  from cellfinder.core.tools import prep
@@ -98,7 +86,7 @@ def main(
98
86
  if not skip_classification:
99
87
  install_path = None
100
88
  model_weights = prep.prep_model_weights(
101
- model_weights, install_path, model, n_free_cpus
89
+ model_weights, install_path, model
102
90
  )
103
91
  if len(points) > 0:
104
92
  logger.info("Running classification")
@@ -120,17 +108,4 @@ def main(
120
108
  )
121
109
  else:
122
110
  logger.info("No candidates, skipping classification")
123
-
124
111
  return points
125
-
126
-
127
- def suppress_tf_logging(tf_suppress_log_messages: List[str]) -> None:
128
- """
129
- Prevents many lines of logs such as:
130
- "2019-10-24 16:54:41.363978: I tensorflow/stream_executor/platform/default
131
- /dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1"
132
- """
133
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
134
-
135
- for message in tf_suppress_log_messages:
136
- suppress_specific_logs("tensorflow", message)
@@ -9,9 +9,7 @@ from pathlib import Path
9
9
  from typing import Optional
10
10
 
11
11
  from brainglobe_utils.general.config import get_config_obj
12
- from brainglobe_utils.general.system import get_num_processes
13
12
 
14
- import cellfinder.core.tools.tf as tf_tools
15
13
  from cellfinder.core import logger
16
14
  from cellfinder.core.download.download import (
17
15
  DEFAULT_DOWNLOAD_DIRECTORY,
@@ -26,20 +24,13 @@ def prep_model_weights(
26
24
  model_weights: Optional[os.PathLike],
27
25
  install_path: Optional[os.PathLike],
28
26
  model_name: model_type,
29
- n_free_cpus: int,
30
27
  ) -> Path:
31
- n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
32
- prep_tensorflow(n_processes)
28
+ # prepare models (get default weights or provided ones)
33
29
  model_weights = prep_models(model_weights, install_path, model_name)
34
30
 
35
31
  return model_weights
36
32
 
37
33
 
38
- def prep_tensorflow(max_threads: int) -> None:
39
- tf_tools.set_tf_threads(max_threads)
40
- tf_tools.allow_gpu_memory_growth()
41
-
42
-
43
34
  def prep_models(
44
35
  model_weights_path: Optional[os.PathLike],
45
36
  install_path: Optional[os.PathLike],
@@ -1,5 +1,6 @@
1
1
  from pathlib import Path
2
2
 
3
+ import keras
3
4
  from brainglobe_utils.general.exceptions import CommandLineInputError
4
5
 
5
6
 
@@ -80,3 +81,12 @@ def memory_in_bytes(memory_amount, unit):
80
81
  )
81
82
  else:
82
83
  return memory_amount * 10 ** supported_units[unit]
84
+
85
+
86
+ def force_cpu():
87
+ """
88
+ Forces the CPU to be used, even if a GPU is available
89
+ """
90
+ keras.src.backend.common.global_state.set_global_attribute(
91
+ "torch_device", "cpu"
92
+ )
@@ -22,7 +22,10 @@ from brainglobe_utils.general.numerical import (
22
22
  check_positive_float,
23
23
  check_positive_int,
24
24
  )
25
- from brainglobe_utils.general.system import ensure_directory_exists
25
+ from brainglobe_utils.general.system import (
26
+ ensure_directory_exists,
27
+ get_num_processes,
28
+ )
26
29
  from brainglobe_utils.IO.cells import find_relevant_tiffs
27
30
  from brainglobe_utils.IO.yaml import read_yaml_section
28
31
  from fancylog import fancylog
@@ -33,11 +36,6 @@ from cellfinder.core import logger
33
36
  from cellfinder.core.classify.resnet import layer_type
34
37
  from cellfinder.core.download.download import DEFAULT_DOWNLOAD_DIRECTORY
35
38
 
36
- tf_suppress_log_messages = [
37
- "sample_weight modes were coerced from",
38
- "multiprocessing can interact badly with TensorFlow",
39
- ]
40
-
41
39
  depth_type = Literal["18", "34", "50", "101", "152"]
42
40
 
43
41
  models: Dict[depth_type, layer_type] = {
@@ -318,11 +316,7 @@ def run(
318
316
  save_progress=False,
319
317
  epochs=100,
320
318
  ):
321
- from cellfinder.core.main import suppress_tf_logging
322
-
323
- suppress_tf_logging(tf_suppress_log_messages)
324
-
325
- from tensorflow.keras.callbacks import (
319
+ from keras.callbacks import (
326
320
  CSVLogger,
327
321
  ModelCheckpoint,
328
322
  TensorBoard,
@@ -339,7 +333,6 @@ def run(
339
333
  model_weights=model_weights,
340
334
  install_path=install_path,
341
335
  model_name=model,
342
- n_free_cpus=n_free_cpus,
343
336
  )
344
337
 
345
338
  yaml_contents = parse_yaml(yaml_file)
@@ -361,6 +354,7 @@ def run(
361
354
 
362
355
  signal_train, background_train, labels_train = make_lists(tiff_files)
363
356
 
357
+ n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
364
358
  if test_fraction > 0:
365
359
  logger.info("Splitting data into training and validation datasets")
366
360
  (
@@ -387,15 +381,17 @@ def run(
387
381
  labels=labels_test,
388
382
  batch_size=batch_size,
389
383
  train=True,
384
+ use_multiprocessing=False,
385
+ workers=n_processes,
390
386
  )
391
387
 
392
388
  # for saving checkpoints
393
- base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}.h5"
389
+ base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}"
394
390
 
395
391
  else:
396
392
  logger.info("No validation data selected.")
397
393
  validation_generator = None
398
- base_checkpoint_file_name = "-epoch.{epoch:02d}.h5"
394
+ base_checkpoint_file_name = "-epoch.{epoch:02d}"
399
395
 
400
396
  training_generator = CubeGeneratorFromDisk(
401
397
  signal_train,
@@ -405,6 +401,8 @@ def run(
405
401
  shuffle=True,
406
402
  train=True,
407
403
  augment=not no_augment,
404
+ use_multiprocessing=False,
405
+ workers=n_processes,
408
406
  )
409
407
  callbacks = []
410
408
 
@@ -421,9 +419,14 @@ def run(
421
419
 
422
420
  if not no_save_checkpoints:
423
421
  if save_weights:
424
- filepath = str(output_dir / ("weight" + base_checkpoint_file_name))
422
+ filepath = str(
423
+ output_dir
424
+ / ("weight" + base_checkpoint_file_name + ".weights.h5")
425
+ )
425
426
  else:
426
- filepath = str(output_dir / ("model" + base_checkpoint_file_name))
427
+ filepath = str(
428
+ output_dir / ("model" + base_checkpoint_file_name + ".keras")
429
+ )
427
430
 
428
431
  checkpoints = ModelCheckpoint(
429
432
  filepath,
@@ -432,25 +435,26 @@ def run(
432
435
  callbacks.append(checkpoints)
433
436
 
434
437
  if save_progress:
435
- filepath = str(output_dir / "training.csv")
436
- csv_logger = CSVLogger(filepath)
438
+ csv_filepath = str(output_dir / "training.csv")
439
+ csv_logger = CSVLogger(csv_filepath)
437
440
  callbacks.append(csv_logger)
438
441
 
439
442
  logger.info("Beginning training.")
443
+ # Keras 3.0: `use_multiprocessing` input is set in the
444
+ # `training_generator` (False by default)
440
445
  model.fit(
441
446
  training_generator,
442
447
  validation_data=validation_generator,
443
- use_multiprocessing=False,
444
448
  epochs=epochs,
445
449
  callbacks=callbacks,
446
450
  )
447
451
 
448
452
  if save_weights:
449
453
  logger.info("Saving model weights")
450
- model.save_weights(str(output_dir / "model_weights.h5"))
454
+ model.save_weights(output_dir / "model.weights.h5")
451
455
  else:
452
456
  logger.info("Saving model")
453
- model.save(output_dir / "model.h5")
457
+ model.save(output_dir / "model.keras")
454
458
 
455
459
  logger.info(
456
460
  "Finished training, " "Total time taken: %s",
@@ -54,7 +54,7 @@ class CurationWidget(QWidget):
54
54
  self.save_empty_cubes = save_empty_cubes
55
55
  self.max_ram = max_ram
56
56
  self.voxel_sizes = [5, 2, 2]
57
- self.batch_size = 32
57
+ self.batch_size = 64
58
58
  self.viewer = viewer
59
59
 
60
60
  self.signal_layer = None
@@ -253,8 +253,9 @@ def detect_widget() -> FunctionGui:
253
253
  max_cluster_size: int,
254
254
  classification_options,
255
255
  skip_classification: bool,
256
- trained_model: Optional[Path],
257
256
  use_pre_trained_weights: bool,
257
+ trained_model: Optional[Path],
258
+ batch_size: int,
258
259
  misc_options,
259
260
  start_plane: int,
260
261
  end_plane: int,
@@ -298,6 +299,8 @@ def detect_widget() -> FunctionGui:
298
299
  should be attempted
299
300
  use_pre_trained_weights : bool
300
301
  Select to use pre-trained model weights
302
+ batch_size : int
303
+ How many points to classify at one time
301
304
  skip_classification : bool
302
305
  If selected, the classification step is skipped and all cells from
303
306
  the detection stage are added
@@ -372,7 +375,10 @@ def detect_widget() -> FunctionGui:
372
375
  if use_pre_trained_weights:
373
376
  trained_model = None
374
377
  classification_inputs = ClassificationInputs(
375
- skip_classification, use_pre_trained_weights, trained_model
378
+ skip_classification,
379
+ use_pre_trained_weights,
380
+ trained_model,
381
+ batch_size,
376
382
  )
377
383
 
378
384
  if analyse_local:
@@ -114,6 +114,7 @@ class ClassificationInputs(InputContainer):
114
114
  skip_classification: bool = False
115
115
  use_pre_trained_weights: bool = True
116
116
  trained_model: Optional[Path] = Path.home()
117
+ batch_size: int = 64
117
118
 
118
119
  def as_core_arguments(self) -> dict:
119
120
  args = super().as_core_arguments()
@@ -131,6 +132,7 @@ class ClassificationInputs(InputContainer):
131
132
  skip_classification=dict(
132
133
  value=cls.defaults()["skip_classification"]
133
134
  ),
135
+ batch_size=dict(value=cls.defaults()["batch_size"]),
134
136
  )
135
137
 
136
138
 
@@ -72,10 +72,10 @@ class Worker(WorkerBase):
72
72
  def classify_callback(batch: int) -> None:
73
73
  self.update_progress_bar.emit(
74
74
  "Classifying cells",
75
- # Default cellfinder-core batch size is 32. This seems to give
75
+ # Default cellfinder-core batch size is 64. This seems to give
76
76
  # a slight underestimate of the number of batches though, so
77
77
  # allow for batch number to go over this
78
- max(self.npoints_detected // 32 + 1, batch + 1),
78
+ max(self.npoints_detected // 64 + 1, batch + 1),
79
79
  batch + 1,
80
80
  )
81
81
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cellfinder
3
- Version: 1.2.0
3
+ Version: 1.3.0
4
4
  Summary: Automated 3D cell detection in large microscopy images
5
5
  Author-email: "Adam Tyson, Christian Niedworok, Charly Rousseau" <code@adamltyson.com>
6
6
  License: BSD-3-Clause
@@ -18,6 +18,7 @@ Classifier: Programming Language :: Python
18
18
  Classifier: Programming Language :: Python :: 3
19
19
  Classifier: Programming Language :: Python :: 3.9
20
20
  Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
21
22
  Classifier: Topic :: Scientific/Engineering :: Image Recognition
22
23
  Requires-Python: >=3.9
23
24
  Description-Content-Type: text/markdown
@@ -31,8 +32,8 @@ Requires-Dist: numba
31
32
  Requires-Dist: numpy
32
33
  Requires-Dist: scikit-image
33
34
  Requires-Dist: scikit-learn
34
- Requires-Dist: tensorflow-macos<2.12.0,>=2.5.0; platform_system == "Darwin" and platform_machine == "arm64"
35
- Requires-Dist: tensorflow<2.12.0,>=2.5.0; platform_system != "Darwin" or platform_machine != "arm64"
35
+ Requires-Dist: keras>=3.0.0
36
+ Requires-Dist: torch>=2.1.0
36
37
  Requires-Dist: tifffile
37
38
  Requires-Dist: tqdm
38
39
  Provides-Extra: dev
@@ -49,7 +49,6 @@ cellfinder/core/tools/image_processing.py
49
49
  cellfinder/core/tools/prep.py
50
50
  cellfinder/core/tools/source_files.py
51
51
  cellfinder/core/tools/system.py
52
- cellfinder/core/tools/tf.py
53
52
  cellfinder/core/tools/tiff.py
54
53
  cellfinder/core/tools/tools.py
55
54
  cellfinder/core/train/__init__.py
@@ -7,15 +7,11 @@ numba
7
7
  numpy
8
8
  scikit-image
9
9
  scikit-learn
10
+ keras>=3.0.0
11
+ torch>=2.1.0
10
12
  tifffile
11
13
  tqdm
12
14
 
13
- [:platform_system != "Darwin" or platform_machine != "arm64"]
14
- tensorflow<2.12.0,>=2.5.0
15
-
16
- [:platform_system == "Darwin" and platform_machine == "arm64"]
17
- tensorflow-macos<2.12.0,>=2.5.0
18
-
19
15
  [dev]
20
16
  black
21
17
  pre-commit
@@ -16,6 +16,7 @@ classifiers = [
16
16
  "Programming Language :: Python :: 3",
17
17
  "Programming Language :: Python :: 3.9",
18
18
  "Programming Language :: Python :: 3.10",
19
+ "Programming Language :: Python :: 3.11",
19
20
  "Topic :: Scientific/Engineering :: Image Recognition",
20
21
  ]
21
22
  requires-python = ">=3.9"
@@ -29,9 +30,8 @@ dependencies = [
29
30
  "numpy",
30
31
  "scikit-image",
31
32
  "scikit-learn",
32
- # See https://github.com/brainglobe/cellfinder-core/issues/103 for < 2.12.0 pin
33
- "tensorflow-macos>=2.5.0,<2.12.0; platform_system=='Darwin' and platform_machine=='arm64'",
34
- "tensorflow>=2.5.0,<2.12.0; platform_system!='Darwin' or platform_machine!='arm64'",
33
+ "keras>=3.0.0",
34
+ "torch>=2.1.0",
35
35
  "tifffile",
36
36
  "tqdm",
37
37
  ]
@@ -79,7 +79,7 @@ requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"]
79
79
  build-backend = 'setuptools.build_meta'
80
80
 
81
81
  [tool.black]
82
- target-version = ['py39', 'py310']
82
+ target-version = ['py39', 'py310','py311']
83
83
  skip-string-normalization = false
84
84
  line-length = 79
85
85
 
@@ -111,27 +111,22 @@ markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
111
111
  legacy_tox_ini = """
112
112
  # For more information about tox, see https://tox.readthedocs.io/en/latest/
113
113
  [tox]
114
- envlist = py{39,310}
114
+ envlist = py{39,310,311}
115
115
  isolated_build = true
116
116
 
117
117
  [gh-actions]
118
118
  python =
119
119
  3.9: py39
120
120
  3.10: py310
121
+ 3.11: py311
121
122
 
122
123
  [testenv]
123
124
  commands = python -m pytest -v --color=yes --cov=cellfinder --cov-report=xml
124
- deps =
125
- pytest
126
- pytest-cov
127
- pytest-mock
128
- pytest-timeout
129
- # Even though napari is a requirement for cellfinder.napari, we have to
130
- # ensure it is installed with the default Qt backend here.
131
- napari[all]
132
- pytest-qt
133
125
  extras =
126
+ dev
134
127
  napari
128
+ setenv =
129
+ KERAS_BACKEND = torch
135
130
  passenv =
136
131
  NUMBA_DISABLE_JIT
137
132
  CI
@@ -1,27 +0,0 @@
1
- from importlib.metadata import PackageNotFoundError, version
2
- from pathlib import Path
3
-
4
- try:
5
- __version__ = version("cellfinder")
6
- except PackageNotFoundError as e:
7
- raise PackageNotFoundError("cellfinder package not installed") from e
8
-
9
- # If tensorflow is not present, tools cannot be used.
10
- # Throw an error in this case to prevent invocation of functions.
11
- try:
12
- TF_VERSION = version("tensorflow")
13
- except PackageNotFoundError as e:
14
- try:
15
- TF_VERSION = version("tensorflow-macos")
16
- except PackageNotFoundError as e:
17
- raise PackageNotFoundError(
18
- f"cellfinder tools cannot be invoked without tensorflow. "
19
- f"Please install tensorflow into your environment to use cellfinder tools. "
20
- f"For more information, please see "
21
- f"https://github.com/brainglobe/brainglobe-meta#readme."
22
- ) from e
23
-
24
- __author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau"
25
- __license__ = "BSD-3-Clause"
26
-
27
- DEFAULT_CELLFINDER_DIRECTORY = Path.home() / ".brainglobe" / "cellfinder"
@@ -1,46 +0,0 @@
1
- import tensorflow as tf
2
-
3
- from cellfinder.core import logger
4
-
5
-
6
- def allow_gpu_memory_growth():
7
- """
8
- If a gpu is present, prevent tensorflow from using all the memory straight
9
- away. Allows multiple processes to use the GPU (and avoid occasional
10
- errors on some systems) at the cost of a slight performance penalty.
11
- """
12
- gpus = tf.config.experimental.list_physical_devices("GPU")
13
- if gpus:
14
- logger.debug("Allowing GPU memory growth")
15
- try:
16
- # Currently, memory growth needs to be the same across GPUs
17
- for gpu in gpus:
18
- tf.config.experimental.set_memory_growth(gpu, True)
19
- logical_gpus = tf.config.experimental.list_logical_devices("GPU")
20
- logger.debug(
21
- f"{len(gpus)} physical GPUs, {len(logical_gpus)} logical GPUs"
22
- )
23
- except RuntimeError as e:
24
- # Memory growth must be set before GPUs have been initialized
25
- print(e)
26
- else:
27
- logger.debug("No GPUs found, using CPU.")
28
-
29
-
30
- def set_tf_threads(max_threads):
31
- """
32
- Limit the number of threads that tensorflow uses
33
- :param max_threads: Maximum number of threads to use
34
- :return:
35
- """
36
- logger.debug(
37
- f"Setting maximum number of threads for tensorflow "
38
- f"to: {max_threads}"
39
- )
40
-
41
- # If statements are for testing. If tf is initialised, then setting these
42
- # parameters throws an error
43
- if tf.config.threading.get_inter_op_parallelism_threads() != 0:
44
- tf.config.threading.set_inter_op_parallelism_threads(max_threads)
45
- if tf.config.threading.get_intra_op_parallelism_threads() != 0:
46
- tf.config.threading.set_intra_op_parallelism_threads(max_threads)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes