cellfinder 1.3.2__py3-none-any.whl → 1.4.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,7 +30,7 @@ def main(
30
30
  max_workers: int = 3,
31
31
  *,
32
32
  callback: Optional[Callable[[int], None]] = None,
33
- ) -> List:
33
+ ) -> List[Cell]:
34
34
  """
35
35
  Parameters
36
36
  ----------
@@ -68,6 +68,14 @@ def main(
68
68
  workers=workers,
69
69
  )
70
70
 
71
+ if trained_model and trained_model.suffix == ".h5":
72
+ print(
73
+ "Weights provided in place of the model, "
74
+ "loading weights into default model."
75
+ )
76
+ model_weights = trained_model
77
+ trained_model = None
78
+
71
79
  model = get_model(
72
80
  existing_model=trained_model,
73
81
  model_weights=model_weights,
@@ -13,76 +13,47 @@ bright points, the input data is clipped to [0, (max_val - 2)]
13
13
  - (max_val) is used to mark bright points during 3D filtering
14
14
  """
15
15
 
16
- import multiprocessing
16
+ import dataclasses
17
17
  from datetime import datetime
18
- from queue import Queue
19
- from threading import Lock
20
- from typing import Callable, List, Optional, Sequence, Tuple, TypeVar
18
+ from typing import Callable, List, Optional, Tuple
21
19
 
22
20
  import numpy as np
21
+ import torch
23
22
  from brainglobe_utils.cells.cells import Cell
24
- from brainglobe_utils.general.system import get_num_processes
25
- from numba import set_num_threads
26
23
 
27
24
  from cellfinder.core import logger, types
28
25
  from cellfinder.core.detect.filters.plane import TileProcessor
29
- from cellfinder.core.detect.filters.setup_filters import setup_tile_filtering
26
+ from cellfinder.core.detect.filters.setup_filters import DetectionSettings
30
27
  from cellfinder.core.detect.filters.volume.volume_filter import VolumeFilter
28
+ from cellfinder.core.tools.tools import inference_wrapper
31
29
 
32
30
 
33
- def calculate_parameters_in_pixels(
34
- voxel_sizes: Tuple[float, float, float],
35
- soma_diameter_um: float,
36
- max_cluster_size_um3: float,
37
- ball_xy_size_um: float,
38
- ball_z_size_um: float,
39
- ) -> Tuple[int, int, int, int]:
40
- """
41
- Convert the command-line arguments from real (um) units to pixels
42
- """
43
-
44
- mean_in_plane_pixel_size = 0.5 * (
45
- float(voxel_sizes[2]) + float(voxel_sizes[1])
46
- )
47
- voxel_volume = (
48
- float(voxel_sizes[2]) * float(voxel_sizes[1]) * float(voxel_sizes[0])
49
- )
50
- soma_diameter = int(round(soma_diameter_um / mean_in_plane_pixel_size))
51
- max_cluster_size = int(round(max_cluster_size_um3 / voxel_volume))
52
- ball_xy_size = int(round(ball_xy_size_um / mean_in_plane_pixel_size))
53
- ball_z_size = int(round(ball_z_size_um / float(voxel_sizes[0])))
54
-
55
- if ball_z_size == 0:
56
- raise ValueError(
57
- "Ball z size has been calculated to be 0 voxels."
58
- " This may be due to large axial spacing of your data or the "
59
- "ball_z_size_um parameter being too small. "
60
- "Please check input parameters are correct. "
61
- "Note that cellfinder requires high resolution data in all "
62
- "dimensions, so that cells can be detected in multiple "
63
- "image planes."
64
- )
65
- return soma_diameter, max_cluster_size, ball_xy_size, ball_z_size
66
-
67
-
31
+ @inference_wrapper
68
32
  def main(
69
33
  signal_array: types.array,
70
- start_plane: int,
71
- end_plane: int,
72
- voxel_sizes: Tuple[float, float, float],
73
- soma_diameter: float,
74
- max_cluster_size: float,
75
- ball_xy_size: float,
76
- ball_z_size: float,
77
- ball_overlap_fraction: float,
78
- soma_spread_factor: float,
79
- n_free_cpus: int,
80
- log_sigma_size: float,
81
- n_sds_above_mean_thresh: float,
34
+ start_plane: int = 0,
35
+ end_plane: int = -1,
36
+ voxel_sizes: Tuple[float, float, float] = (5, 2, 2),
37
+ soma_diameter: float = 16,
38
+ max_cluster_size: float = 100_000,
39
+ ball_xy_size: float = 6,
40
+ ball_z_size: float = 15,
41
+ ball_overlap_fraction: float = 0.6,
42
+ soma_spread_factor: float = 1.4,
43
+ n_free_cpus: int = 2,
44
+ log_sigma_size: float = 0.2,
45
+ n_sds_above_mean_thresh: float = 10,
82
46
  outlier_keep: bool = False,
83
47
  artifact_keep: bool = False,
84
48
  save_planes: bool = False,
85
49
  plane_directory: Optional[str] = None,
50
+ batch_size: Optional[int] = None,
51
+ torch_device: str = "cpu",
52
+ use_scipy: bool = True,
53
+ split_ball_xy_size: int = 3,
54
+ split_ball_z_size: int = 3,
55
+ split_ball_overlap_fraction: float = 0.8,
56
+ split_soma_diameter: int = 7,
86
57
  *,
87
58
  callback: Optional[Callable[[int], None]] = None,
88
59
  ) -> List[Cell]:
@@ -101,7 +72,7 @@ def main(
101
72
  Index of the ending plane for detection.
102
73
 
103
74
  voxel_sizes : Tuple[float, float, float]
104
- Tuple of voxel sizes in each dimension (x, y, z).
75
+ Tuple of voxel sizes in each dimension (z, y, x).
105
76
 
106
77
  soma_diameter : float
107
78
  Diameter of the soma in physical units.
@@ -142,6 +113,18 @@ def main(
142
113
  plane_directory : str, optional
143
114
  Directory path to save the planes. Defaults to None.
144
115
 
116
+ batch_size : int, optional
117
+ The number of planes to process in each batch. Defaults to 1.
118
+ For CPU, there's no benefit for a larger batch size. Only a memory
119
+ usage increase. For CUDA, the larger the batch size the better the
120
+ performance. Until it fills up the GPU memory - after which it
121
+ becomes slower.
122
+
123
+ torch_device : str, optional
124
+ The device on which to run the computation. By default, it's "cpu".
125
+ To run on a gpu, specify the PyTorch device name, such as "cuda" to
126
+ run on the first GPU.
127
+
145
128
  callback : Callable[int], optional
146
129
  A callback function that is called every time a plane has finished
147
130
  being processed. Called with the plane number that has finished.
@@ -151,151 +134,103 @@ def main(
151
134
  List[Cell]
152
135
  List of detected cells.
153
136
  """
154
- if not np.issubdtype(signal_array.dtype, np.integer):
155
- raise ValueError(
156
- "signal_array must be integer datatype, but has datatype "
137
+ start_time = datetime.now()
138
+ if batch_size is None:
139
+ if torch_device == "cpu":
140
+ batch_size = 4
141
+ else:
142
+ batch_size = 1
143
+
144
+ if not np.issubdtype(signal_array.dtype, np.number):
145
+ raise TypeError(
146
+ "signal_array must be a numpy datatype, but has datatype "
157
147
  f"{signal_array.dtype}"
158
148
  )
159
- n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
160
- n_ball_procs = max(n_processes - 1, 1)
161
-
162
- # we parallelize 2d filtering, which typically lags behind the 3d
163
- # processing so for n_ball_procs 2d filtering threads, ball_z_size will
164
- # typically be in use while the others stall waiting for 3d processing
165
- # so we can use those for other things, such as numba threading
166
- set_num_threads(max(n_ball_procs - int(ball_z_size), 1))
167
-
168
- start_time = datetime.now()
169
-
170
- (
171
- soma_diameter,
172
- max_cluster_size,
173
- ball_xy_size,
174
- ball_z_size,
175
- ) = calculate_parameters_in_pixels(
176
- voxel_sizes,
177
- soma_diameter,
178
- max_cluster_size,
179
- ball_xy_size,
180
- ball_z_size,
181
- )
182
-
183
- if end_plane == -1:
184
- end_plane = len(signal_array)
185
- signal_array = signal_array[start_plane:end_plane]
186
- signal_array = signal_array.astype(np.uint32)
187
-
188
- callback = callback or (lambda *args, **kwargs: None)
189
149
 
190
150
  if signal_array.ndim != 3:
191
151
  raise ValueError("Input data must be 3D")
192
152
 
193
- setup_params = (
194
- signal_array[0, :, :],
195
- soma_diameter,
196
- ball_xy_size,
197
- ball_z_size,
198
- ball_overlap_fraction,
199
- start_plane,
200
- )
201
-
202
- # Create 3D analysis filter
203
- mp_3d_filter = VolumeFilter(
204
- soma_diameter=soma_diameter,
205
- setup_params=setup_params,
206
- soma_size_spread_factor=soma_spread_factor,
207
- n_planes=len(signal_array),
208
- n_locks_release=n_ball_procs,
209
- save_planes=save_planes,
210
- plane_directory=plane_directory,
153
+ if end_plane < 0:
154
+ end_plane = len(signal_array)
155
+ end_plane = min(len(signal_array), end_plane)
156
+
157
+ torch_device = torch_device.lower()
158
+ batch_size = max(batch_size, 1)
159
+ # brainmapper can pass them in as str
160
+ voxel_sizes = list(map(float, voxel_sizes))
161
+
162
+ settings = DetectionSettings(
163
+ plane_shape=signal_array.shape[1:],
164
+ plane_original_np_dtype=signal_array.dtype,
165
+ voxel_sizes=voxel_sizes,
166
+ soma_spread_factor=soma_spread_factor,
167
+ soma_diameter_um=soma_diameter,
168
+ max_cluster_size_um3=max_cluster_size,
169
+ ball_xy_size_um=ball_xy_size,
170
+ ball_z_size_um=ball_z_size,
211
171
  start_plane=start_plane,
212
- max_cluster_size=max_cluster_size,
172
+ end_plane=end_plane,
173
+ n_free_cpus=n_free_cpus,
174
+ ball_overlap_fraction=ball_overlap_fraction,
175
+ log_sigma_size=log_sigma_size,
176
+ n_sds_above_mean_thresh=n_sds_above_mean_thresh,
213
177
  outlier_keep=outlier_keep,
214
178
  artifact_keep=artifact_keep,
179
+ save_planes=save_planes,
180
+ plane_directory=plane_directory,
181
+ batch_size=batch_size,
182
+ torch_device=torch_device,
183
+ )
184
+
185
+ # replicate the settings specific to splitting, before we access anything
186
+ # of the original settings, causing cached properties
187
+ kwargs = dataclasses.asdict(settings)
188
+ kwargs["ball_z_size_um"] = split_ball_z_size * settings.z_pixel_size
189
+ kwargs["ball_xy_size_um"] = (
190
+ split_ball_xy_size * settings.in_plane_pixel_size
215
191
  )
192
+ kwargs["ball_overlap_fraction"] = split_ball_overlap_fraction
193
+ kwargs["soma_diameter_um"] = (
194
+ split_soma_diameter * settings.in_plane_pixel_size
195
+ )
196
+ # always run on cpu because copying to gpu overhead is likely slower than
197
+ # any benefit for detection on smallish volumes
198
+ kwargs["torch_device"] = "cpu"
199
+ # for splitting, we only do 3d filtering. Its input is a zero volume
200
+ # with cell voxels marked with threshold_value. So just use float32
201
+ # for input because the filters will also use float(32). So there will
202
+ # not be need to convert the input a different dtype before passing to
203
+ # the filters.
204
+ kwargs["plane_original_np_dtype"] = np.float32
205
+ splitting_settings = DetectionSettings(**kwargs)
206
+
207
+ # Create 3D analysis filter
208
+ mp_3d_filter = VolumeFilter(settings=settings)
216
209
 
217
- clipping_val, threshold_value = setup_tile_filtering(signal_array[0, :, :])
218
210
  # Create 2D analysis filter
219
211
  mp_tile_processor = TileProcessor(
220
- clipping_val,
221
- threshold_value,
222
- soma_diameter,
223
- log_sigma_size,
224
- n_sds_above_mean_thresh,
212
+ plane_shape=settings.plane_shape,
213
+ clipping_value=settings.clipping_value,
214
+ threshold_value=settings.threshold_value,
215
+ n_sds_above_mean_thresh=n_sds_above_mean_thresh,
216
+ log_sigma_size=log_sigma_size,
217
+ soma_diameter=settings.soma_diameter,
218
+ torch_device=torch_device,
219
+ dtype=settings.filtering_dtype.__name__,
220
+ use_scipy=use_scipy,
225
221
  )
226
222
 
227
- # Force spawn context
228
- mp_ctx = multiprocessing.get_context("spawn")
229
- with mp_ctx.Pool(n_ball_procs) as worker_pool:
230
- async_results, locks = _map_with_locks(
231
- mp_tile_processor.get_tile_mask,
232
- signal_array, # type: ignore
233
- worker_pool,
234
- )
235
-
236
- # Release the first set of locks for the 2D filtering
237
- for i in range(min(n_ball_procs + ball_z_size, len(locks))):
238
- logger.debug(f"🔓 Releasing lock for plane {i}")
239
- locks[i].release()
223
+ orig_n_threads = torch.get_num_threads()
224
+ torch.set_num_threads(settings.n_torch_comp_threads)
240
225
 
241
- # Start 3D filter
242
- #
243
- # This runs in the main thread, and blocks until the all the 2D and
244
- # then 3D filtering has finished. As batches of planes are filtered
245
- # by the 3D filter, it releases the locks of subsequent 2D filter
246
- # processes.
247
- mp_3d_filter.process(async_results, locks, callback=callback)
226
+ # process the data
227
+ mp_3d_filter.process(mp_tile_processor, signal_array, callback=callback)
228
+ cells = mp_3d_filter.get_results(splitting_settings)
248
229
 
249
- # it's now done filtering, get results with pool
250
- cells = mp_3d_filter.get_results(worker_pool)
230
+ torch.set_num_threads(orig_n_threads)
251
231
 
252
232
  time_elapsed = datetime.now() - start_time
253
- logger.debug(
254
- f"All Planes done. Found {len(cells)} cells in {format(time_elapsed)}"
255
- )
256
- print("Detection complete - all planes done in : {}".format(time_elapsed))
233
+ s = f"Detection complete. Found {len(cells)} cells in {time_elapsed}"
234
+ logger.debug(s)
235
+ print(s)
257
236
  return cells
258
-
259
-
260
- Tin = TypeVar("Tin")
261
- Tout = TypeVar("Tout")
262
-
263
-
264
- def _run_func_with_lock(
265
- func: Callable[[Tin], Tout], arg: Tin, lock: Lock
266
- ) -> Tout:
267
- """
268
- Run a function after acquiring a lock.
269
- """
270
- lock.acquire(blocking=True)
271
- return func(arg)
272
-
273
-
274
- def _map_with_locks(
275
- func: Callable[[Tin], Tout],
276
- iterable: Sequence[Tin],
277
- worker_pool: multiprocessing.pool.Pool,
278
- ) -> Tuple[Queue, List[Lock]]:
279
- """
280
- Map a function to arguments, blocking execution.
281
-
282
- Maps *func* to args in *iterable*, but blocks all execution and
283
- return a queue of asyncronous results and locks for each of the
284
- results. Execution can be enabled by releasing the returned
285
- locks in order.
286
- """
287
- # Setup a manager to handle the locks
288
- m = multiprocessing.Manager()
289
- # Setup one lock per argument to be mapped
290
- locks = [m.Lock() for _ in range(len(iterable))]
291
- [lock.acquire(blocking=False) for lock in locks]
292
-
293
- async_results: Queue = Queue()
294
-
295
- for arg, lock in zip(iterable, locks):
296
- async_result = worker_pool.apply_async(
297
- _run_func_with_lock, args=(func, arg, lock)
298
- )
299
- async_results.put(async_result)
300
-
301
- return async_results, locks