cellfinder 1.3.2__py3-none-any.whl → 1.4.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,151 +1,461 @@
1
- import math
2
- import multiprocessing.pool
1
+ import multiprocessing as mp
3
2
  import os
4
3
  from functools import partial
5
- from queue import Queue
6
- from threading import Lock
7
- from typing import Any, Callable, List, Optional, Tuple
4
+ from typing import Callable, List, Optional, Tuple
8
5
 
9
6
  import numpy as np
7
+ import torch
10
8
  from brainglobe_utils.cells.cells import Cell
11
9
  from tifffile import tifffile
12
10
  from tqdm import tqdm
13
11
 
14
- from cellfinder.core import logger
15
- from cellfinder.core.detect.filters.setup_filters import (
16
- get_ball_filter,
17
- get_cell_detector,
18
- )
12
+ from cellfinder.core import logger, types
13
+ from cellfinder.core.detect.filters.plane import TileProcessor
14
+ from cellfinder.core.detect.filters.setup_filters import DetectionSettings
15
+ from cellfinder.core.detect.filters.volume.ball_filter import BallFilter
19
16
  from cellfinder.core.detect.filters.volume.structure_detection import (
17
+ CellDetector,
20
18
  get_structure_centre,
21
19
  )
22
20
  from cellfinder.core.detect.filters.volume.structure_splitting import (
23
21
  StructureSplitException,
24
22
  split_cells,
25
23
  )
24
+ from cellfinder.core.tools.threading import (
25
+ EOFSignal,
26
+ ProcessWithException,
27
+ ThreadWithException,
28
+ )
29
+ from cellfinder.core.tools.tools import inference_wrapper
30
+
31
+
32
+ @inference_wrapper
33
+ def _plane_filter(
34
+ process: ProcessWithException,
35
+ tile_processor: TileProcessor,
36
+ n_threads: int,
37
+ buffers: List[Tuple[torch.Tensor, torch.Tensor]],
38
+ ):
39
+ """
40
+ When running on cpu, we spin up a process for each plane in the batch.
41
+ This function runs in the process.
42
+
43
+ For every new batch, main process sends a buffer token and plane index
44
+ to this function. We process that plane and let the main process know
45
+ we are done.
46
+ """
47
+ # more than about 4 threads seems to slow down computation
48
+ torch.set_num_threads(min(n_threads, 4))
49
+
50
+ while True:
51
+ msg = process.get_msg_from_mainthread()
52
+ if msg == EOFSignal:
53
+ return
54
+ # with torch multiprocessing, tensors are shared in memory - so
55
+ # just update in place
56
+ token, i = msg
57
+ tensor, masks = buffers[token]
58
+
59
+ plane, mask = tile_processor.get_tile_mask(tensor[i : i + 1, :, :])
60
+ tensor[i : i + 1, :, :] = plane
61
+ masks[i : i + 1, :, :] = mask
62
+
63
+ # tell the main thread we processed all the planes for this tensor
64
+ process.send_msg_to_mainthread(None)
65
+
66
+
67
+ class VolumeFilter:
68
+ """
69
+ Filters and detects cells in the input data.
70
+
71
+ This will take a 3d data array, filter each plane first with 2d filters
72
+ finding bright spots. Then it filters the stack with a ball filter to
73
+ find voxels that are potential cells. Then it runs cell detection on it
74
+ to actually identify the cells.
75
+
76
+ Parameters
77
+ ----------
78
+ settings : DetectionSettings
79
+ Settings object that contains all the configuration data.
80
+ """
81
+
82
+ def __init__(self, settings: DetectionSettings):
83
+ self.settings = settings
84
+
85
+ self.ball_filter = BallFilter(
86
+ plane_height=settings.plane_height,
87
+ plane_width=settings.plane_width,
88
+ ball_xy_size=settings.ball_xy_size,
89
+ ball_z_size=settings.ball_z_size,
90
+ overlap_fraction=settings.ball_overlap_fraction,
91
+ threshold_value=settings.threshold_value,
92
+ soma_centre_value=settings.soma_centre_value,
93
+ tile_height=settings.tile_height,
94
+ tile_width=settings.tile_width,
95
+ dtype=settings.filtering_dtype.__name__,
96
+ batch_size=settings.batch_size,
97
+ torch_device=settings.torch_device,
98
+ use_mask=True,
99
+ )
26
100
 
101
+ self.z = settings.start_plane + self.ball_filter.first_valid_plane
27
102
 
28
- class VolumeFilter(object):
29
- def __init__(
30
- self,
31
- *,
32
- soma_diameter: float,
33
- soma_size_spread_factor: float = 1.4,
34
- setup_params: Tuple[np.ndarray, Any, int, int, float, Any],
35
- n_planes: int,
36
- n_locks_release: int,
37
- save_planes: bool = False,
38
- plane_directory: Optional[str] = None,
39
- start_plane: int = 0,
40
- max_cluster_size: int = 5000,
41
- outlier_keep: bool = False,
42
- artifact_keep: bool = True,
43
- ):
44
- self.soma_diameter = soma_diameter
45
- self.soma_size_spread_factor = soma_size_spread_factor
46
- self.n_planes = n_planes
47
- self.z = start_plane
48
- self.save_planes = save_planes
49
- self.plane_directory = plane_directory
50
- self.max_cluster_size = max_cluster_size
51
- self.outlier_keep = outlier_keep
52
- self.n_locks_release = n_locks_release
53
-
54
- self.artifact_keep = artifact_keep
55
-
56
- self.clipping_val = None
57
- self.threshold_value = None
58
- self.setup_params = setup_params
59
-
60
- self.previous_plane: Optional[np.ndarray] = None
61
-
62
- self.ball_filter = get_ball_filter(
63
- plane=self.setup_params[0],
64
- soma_diameter=self.setup_params[1],
65
- ball_xy_size=self.setup_params[2],
66
- ball_z_size=self.setup_params[3],
67
- ball_overlap_fraction=self.setup_params[4],
103
+ self.cell_detector = CellDetector(
104
+ settings.plane_height,
105
+ settings.plane_width,
106
+ start_z=self.z,
107
+ soma_centre_value=settings.detection_soma_centre_value,
68
108
  )
69
-
70
- self.cell_detector = get_cell_detector(
71
- plane_shape=self.setup_params[0].shape, # type: ignore
72
- ball_z_size=self.setup_params[3],
73
- z_offset=self.setup_params[5],
109
+ # make sure we load enough data to filter. Otherwise, we won't be ready
110
+ # to filter and the data loading thread will wait for data to be
111
+ # processed before loading more data, but that will never happen
112
+ self.n_queue_buffer = max(
113
+ self.settings.num_prefetch_batches,
114
+ self.ball_filter.num_batches_before_ready,
74
115
  )
75
116
 
117
+ def _get_filter_buffers(
118
+ self, cpu: bool, tile_processor: TileProcessor
119
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
120
+ """
121
+ Generates buffers to use for data loading and filtering.
122
+
123
+ It creates pinned tensors ahead of time for faster copying to gpu.
124
+ Pinned tensors are kept in RAM and are faster to copy to GPU because
125
+ they can't be paged. So loaded data is copied to the tensor and then
126
+ sent to the device.
127
+
128
+ For CPU even though we don't pin, it's useful to create the buffers
129
+ ahead of time and reuse it so we can filter in sub-processes
130
+ (see `_plane_filter`).
131
+ For tile masks, we only create buffers for CPU. On CUDA, they are
132
+ generated every time new on the device.
133
+ """
134
+ batch_size = self.settings.batch_size
135
+ torch_dtype = getattr(torch, self.settings.filtering_dtype.__name__)
136
+
137
+ buffers = []
138
+ for _ in range(self.n_queue_buffer):
139
+ # the tensor used for data loading
140
+ tensor = torch.empty(
141
+ (batch_size, *self.settings.plane_shape),
142
+ dtype=torch_dtype,
143
+ pin_memory=not cpu,
144
+ device="cpu",
145
+ )
146
+
147
+ # tile mask buffer - only for cpu
148
+ masks = None
149
+ if cpu:
150
+ masks = tile_processor.get_tiled_buffer(
151
+ batch_size, self.settings.torch_device
152
+ )
153
+
154
+ buffers.append((tensor, masks))
155
+
156
+ return buffers
157
+
158
+ @inference_wrapper
159
+ def _feed_signal_batches(
160
+ self,
161
+ thread: ThreadWithException,
162
+ data: types.array,
163
+ processors: List[ProcessWithException],
164
+ buffers: List[Tuple[torch.Tensor, torch.Tensor]],
165
+ ) -> None:
166
+ """
167
+ Runs in its own thread. It loads the input data planes, converts them
168
+ to torch tensors of the right data-type, and sends them to cuda or to
169
+ subprocesses for cpu to be filtered etc.
170
+ """
171
+ batch_size = self.settings.batch_size
172
+ device = self.settings.torch_device
173
+ start_plane = self.settings.start_plane
174
+ end_plane = start_plane + self.settings.n_planes
175
+ data_converter = self.settings.filter_data_converter_func
176
+ cpu = self.settings.torch_device == "cpu"
177
+ # should only have 2d filter processors on the cpu
178
+ assert bool(processors) == cpu
179
+
180
+ # seed the queue with tokens for the buffers
181
+ for token in range(len(buffers)):
182
+ thread.send_msg_to_thread(token)
183
+
184
+ for z in range(start_plane, end_plane, batch_size):
185
+ # convert the data to the right type
186
+ np_data = data_converter(data[z : z + batch_size, :, :])
187
+ # if we ran out of batches, we are done!
188
+ n = np_data.shape[0]
189
+ assert n
190
+
191
+ # thread/underlying queues get first crack at msg. Unless we get
192
+ # eof, this will block until a buffer is returned from the main
193
+ # thread for reuse
194
+ token = thread.get_msg_from_mainthread()
195
+ if token is EOFSignal:
196
+ return
197
+
198
+ # buffer is free, get it from token
199
+ tensor, masks = buffers[token]
200
+
201
+ # for last batch, it can be smaller than normal so only set up to n
202
+ tensor[:n, :, :] = torch.from_numpy(np_data)
203
+ tensor = tensor[:n, :, :]
204
+ if not cpu:
205
+ # send to device - it won't block here because we pinned memory
206
+ tensor = tensor.to(device=device, non_blocking=True)
207
+
208
+ # if used, send each plane in batch to processor
209
+ used_processors = []
210
+ if cpu:
211
+ used_processors = processors[:n]
212
+ for i, process in enumerate(used_processors):
213
+ process.send_msg_to_thread((token, i))
214
+
215
+ # tell the main thread to wait for processors (if used)
216
+ msg = token, tensor, masks, used_processors, n
217
+
218
+ if n < batch_size:
219
+ # on last batch, we are also done after this
220
+ thread.send_msg_to_mainthread(msg)
221
+ return
222
+ # send the data to the main thread
223
+ thread.send_msg_to_mainthread(msg)
224
+
76
225
  def process(
77
226
  self,
78
- async_result_queue: Queue,
79
- locks: List[Lock],
227
+ tile_processor: TileProcessor,
228
+ signal_array,
80
229
  *,
81
- callback: Callable[[int], None],
230
+ callback: Optional[Callable[[int], None]],
82
231
  ) -> None:
83
- progress_bar = tqdm(total=self.n_planes, desc="Processing planes")
84
- for z in range(self.n_planes):
85
- # Get result from the queue.
86
- #
87
- # It is important to remove the result from the queue here
88
- # to free up memory once this plane has been processed by
89
- # the 3D filter here
90
- logger.debug(f"🏐 Waiting for plane {z}")
91
- result = async_result_queue.get()
92
- # .get() blocks until the result is available
93
- plane, mask = result.get()
94
- logger.debug(f"🏐 Got plane {z}")
95
-
96
- self.ball_filter.append(plane, mask)
97
-
98
- if self.ball_filter.ready:
99
- # Let the next 2D filter run
100
- z_release = z + self.n_locks_release + 1
101
- if z_release < len(locks):
102
- logger.debug(f"🔓 Releasing lock for plane {z_release}")
103
- locks[z_release].release()
104
-
105
- self._run_filter()
232
+ """
233
+ Takes the processor and the data and passes them through the filtering
234
+ and cell detection stages.
235
+
236
+ If the callback is provided, we call it after every plane with the
237
+ current z index to update the status. It may be called from secondary
238
+ threads.
239
+ """
240
+ progress_bar = tqdm(
241
+ total=self.settings.n_planes, desc="Processing planes"
242
+ )
243
+ cpu = self.settings.torch_device == "cpu"
244
+ n_threads = self.settings.n_torch_comp_threads
245
+
246
+ # we re-use these tensors for data loading, so we have a fixed number
247
+ # of planes in memory. The feeder thread will wait to load more data
248
+ # until a tensor is free to be reused.
249
+ # We have to keep the tensors in memory in main process while it's
250
+ # in used elsewhere
251
+ buffers = self._get_filter_buffers(cpu, tile_processor)
252
+
253
+ # on cpu these processes will 2d filter each plane in the batch
254
+ plane_processes = []
255
+ if cpu:
256
+ for _ in range(self.settings.batch_size):
257
+ process = ProcessWithException(
258
+ target=_plane_filter,
259
+ args=(tile_processor, n_threads, buffers),
260
+ pass_self=True,
261
+ )
262
+ process.start()
263
+ plane_processes.append(process)
264
+
265
+ # thread that reads and sends us data
266
+ feed_thread = ThreadWithException(
267
+ target=self._feed_signal_batches,
268
+ args=(signal_array, plane_processes, buffers),
269
+ pass_self=True,
270
+ )
271
+ feed_thread.start()
106
272
 
107
- callback(self.z)
108
- self.z += 1
109
- progress_bar.update()
273
+ # thread that takes the 3d filtered data and does cell detection
274
+ cells_thread = ThreadWithException(
275
+ target=self._run_filter_thread,
276
+ args=(callback, progress_bar),
277
+ pass_self=True,
278
+ )
279
+ cells_thread.start()
280
+
281
+ try:
282
+ self._process(feed_thread, cells_thread, tile_processor, cpu)
283
+ finally:
284
+ # if we end, make sure to tell the threads to stop
285
+ feed_thread.notify_to_end_thread()
286
+ cells_thread.notify_to_end_thread()
287
+ for process in plane_processes:
288
+ process.notify_to_end_thread()
289
+
290
+ # the notification above ensures this won't block forever
291
+ feed_thread.join()
292
+ cells_thread.join()
293
+ for process in plane_processes:
294
+ process.join()
295
+
296
+ # in case these threads sent us an exception but we didn't yet read
297
+ # it, make sure to process them
298
+ feed_thread.clear_remaining()
299
+ cells_thread.clear_remaining()
300
+ for process in plane_processes:
301
+ process.clear_remaining()
110
302
 
111
303
  progress_bar.close()
112
304
  logger.debug("3D filter done")
113
305
 
114
- def _run_filter(self) -> None:
115
- logger.debug(f"🏐 Ball filtering plane {self.z}")
116
- # filtering original images, the images should be large enough in x/y
117
- # to benefit from parallelization. Note: don't pass arg as keyword arg
118
- # because numba gets stuck (probably b/c class jit is new)
119
- self.ball_filter.walk(True)
120
-
121
- middle_plane = self.ball_filter.get_middle_plane()
122
- if self.save_planes:
123
- self.save_plane(middle_plane)
124
-
125
- logger.debug(f"🏫 Detecting structures for plane {self.z}")
126
- self.previous_plane = self.cell_detector.process(
127
- middle_plane, self.previous_plane
128
- )
306
+ def _process(
307
+ self,
308
+ feed_thread: ThreadWithException,
309
+ cells_thread: ThreadWithException,
310
+ tile_processor: TileProcessor,
311
+ cpu: bool,
312
+ ) -> None:
313
+ """
314
+ Processes the loaded data from feeder thread. If on cpu it is already
315
+ 2d filtered so just 3d filter. On cuda we need to do both 2d and 3d
316
+ filtering. Then, it sends the filtered data off to the detection thread
317
+ for cell detection.
318
+ """
319
+ processing_tokens = []
320
+
321
+ while True:
322
+ # thread/underlying queues get first crack at msg. Unless we get
323
+ # eof, this will block until we get more loaded data until no more
324
+ # data or exception
325
+ msg = feed_thread.get_msg_from_thread()
326
+ # feeder thread exits at the end, causing a eof to be sent
327
+ if msg is EOFSignal:
328
+ break
329
+ token, tensor, masks, used_processors, n = msg
330
+ # this token is in use until we return it
331
+ processing_tokens.append(token)
332
+
333
+ if cpu:
334
+ # we did 2d filtering in different process. Make sure all the
335
+ # planes are done filtering. Each msg from feeder thread has
336
+ # corresponding msg for each used processor (unless exception)
337
+ for process in used_processors:
338
+ process.get_msg_from_thread()
339
+ # batch size can change at the end so resize buffer
340
+ planes = tensor[:n, :, :]
341
+ masks = masks[:n, :, :]
342
+ else:
343
+ # we're not doing 2d filtering in different process
344
+ planes, masks = tile_processor.get_tile_mask(tensor)
129
345
 
130
- logger.debug(f"🏫 Structures done for plane {self.z}")
346
+ self.ball_filter.append(planes, masks)
347
+ if self.ball_filter.ready:
348
+ self.ball_filter.walk()
349
+ middle_planes = self.ball_filter.get_processed_planes()
350
+
351
+ # at this point we know input tensor can be reused - return
352
+ # it so feeder thread can load more data into it
353
+ for token in processing_tokens:
354
+ feed_thread.send_msg_to_thread(token)
355
+ processing_tokens.clear()
356
+
357
+ # thread/underlying queues get first crack at msg. Unless
358
+ # we get eof, this will block until we get a token,
359
+ # indicating we can send more data. The cells thread has a
360
+ # fixed token supply, ensuring we don't send it too much
361
+ # data, in case detection takes longer than filtering
362
+ # Also, error messages incoming are at most # tokens behind
363
+ token = cells_thread.get_msg_from_thread()
364
+ if token is EOFSignal:
365
+ break
366
+ # send it more data and return the token
367
+ cells_thread.send_msg_to_thread((middle_planes, token))
368
+
369
+ @inference_wrapper
370
+ def _run_filter_thread(
371
+ self, thread: ThreadWithException, callback, progress_bar
372
+ ) -> None:
373
+ """
374
+ Runs in its own thread and takes the filtered planes and passes them
375
+ through the cell detection system. Also saves the planes as needed.
376
+ """
377
+ detector = self.cell_detector
378
+ original_dtype = self.settings.plane_original_np_dtype
379
+ detection_converter = self.settings.detection_data_converter_func
380
+ save_planes = self.settings.save_planes
381
+ previous_plane = None
382
+ bf = self.ball_filter
383
+
384
+ # these many planes are not processed at start because 3d filter uses
385
+ # it as padding at the start of filter
386
+ progress_bar.update(bf.first_valid_plane)
387
+
388
+ # main thread needs a token to send us planes - populate with some
389
+ for _ in range(self.n_queue_buffer):
390
+ thread.send_msg_to_mainthread(object())
391
+
392
+ while True:
393
+ # thread/underlying queues get first crack at msg. Unless we get
394
+ # eof, this will block until we get more data
395
+ msg = thread.get_msg_from_mainthread()
396
+ # requested that we return. This can mean the main thread finished
397
+ # sending data and it appended eof - so we get eof after all planes
398
+ if msg is EOFSignal:
399
+ # these many planes are not processed at the end because 3d
400
+ # filter uses it as padding at the end of the filter
401
+ progress_bar.update(bf.remaining_planes)
402
+ return
403
+
404
+ # convert plane to the type needed by detection system
405
+ # we should not need scaling because throughout
406
+ # filtering we make sure result fits in this data type
407
+ middle_planes, token = msg
408
+ detection_middle_planes = detection_converter(middle_planes)
409
+
410
+ logger.debug(f"🏫 Detecting structures for planes {self.z}+")
411
+ for plane, detection_plane in zip(
412
+ middle_planes, detection_middle_planes
413
+ ):
414
+ if save_planes:
415
+ self.save_plane(plane.astype(original_dtype))
416
+
417
+ previous_plane = detector.process(
418
+ detection_plane, previous_plane
419
+ )
420
+
421
+ if callback is not None:
422
+ callback(self.z)
423
+ self.z += 1
424
+ progress_bar.update()
425
+
426
+ # we must return the token, otherwise the main thread will run out
427
+ # and won't send more data to us
428
+ thread.send_msg_to_mainthread(token)
429
+ logger.debug(f"🏫 Structures done for planes {self.z}+")
131
430
 
132
431
  def save_plane(self, plane: np.ndarray) -> None:
133
- if self.plane_directory is None:
432
+ """
433
+ Saves the plane as an image according to the settings.
434
+ """
435
+ if self.settings.plane_directory is None:
134
436
  raise ValueError(
135
437
  "plane_directory must be set to save planes to file"
136
438
  )
137
- plane_name = f"plane_{str(self.z).zfill(4)}.tif"
138
- f_path = os.path.join(self.plane_directory, plane_name)
139
- tifffile.imsave(f_path, plane.T)
140
-
141
- def get_results(self, worker_pool: multiprocessing.Pool) -> List[Cell]:
439
+ # self.z is zero based, we should save names as 1-based.
440
+ plane_name = self.settings.plane_prefix.format(n=self.z + 1) + ".tif"
441
+ f_path = os.path.join(self.settings.plane_directory, plane_name)
442
+ tifffile.imwrite(f_path, plane)
443
+
444
+ def get_results(self, settings: DetectionSettings) -> List[Cell]:
445
+ """
446
+ Returns the detected cells.
447
+
448
+ After filtering, this parses the resulting cells and splits large
449
+ bright regions into individual cells.
450
+ """
142
451
  logger.info("Splitting cell clusters and writing results")
143
452
 
144
- max_cell_volume = sphere_volume(
145
- self.soma_size_spread_factor * self.soma_diameter / 2
146
- )
453
+ root_settings = self.settings
454
+ max_cell_volume = settings.max_cell_volume
147
455
 
456
+ # valid cells
148
457
  cells = []
458
+ # regions that must be split into cells
149
459
  needs_split = []
150
460
  structures = self.cell_detector.get_structures().items()
151
461
  logger.debug(f"Processing {len(structures)} found cells")
@@ -158,7 +468,7 @@ class VolumeFilter(object):
158
468
  cell_centre = get_structure_centre(cell_points)
159
469
  cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN))
160
470
  else:
161
- if cell_volume < self.max_cluster_size:
471
+ if cell_volume < settings.max_cluster_size:
162
472
  needs_split.append((cell_id, cell_points))
163
473
  else:
164
474
  cell_centre = get_structure_centre(cell_points)
@@ -174,13 +484,22 @@ class VolumeFilter(object):
174
484
  total=len(needs_split), desc="Splitting cell clusters"
175
485
  )
176
486
 
177
- # we are not returning Cell instances from func because it'd be pickled
178
- # by multiprocess which slows it down
179
- func = partial(_split_cells, outlier_keep=self.outlier_keep)
180
- for cell_centres in worker_pool.imap_unordered(func, needs_split):
181
- for cell_centre in cell_centres:
182
- cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN))
183
- progress_bar.update()
487
+ # the settings is pickled and re-created for each process, which is
488
+ # important because splitting can modify the settings, so we don't want
489
+ # parallel modifications for same object
490
+ f = partial(_split_cells, settings=settings)
491
+ ctx = mp.get_context("spawn")
492
+ # we can't use the context manager because of coverage issues:
493
+ # https://pytest-cov.readthedocs.io/en/latest/subprocess-support.html
494
+ pool = ctx.Pool(processes=root_settings.n_processes)
495
+ try:
496
+ for cell_centres in pool.imap_unordered(f, needs_split):
497
+ for cell_centre in cell_centres:
498
+ cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN))
499
+ progress_bar.update()
500
+ finally:
501
+ pool.close()
502
+ pool.join()
184
503
 
185
504
  progress_bar.close()
186
505
  logger.debug(
@@ -190,13 +509,15 @@ class VolumeFilter(object):
190
509
  return cells
191
510
 
192
511
 
193
- def _split_cells(arg, outlier_keep):
512
+ @inference_wrapper
513
+ def _split_cells(arg, settings: DetectionSettings):
514
+ # runs in its own process for a bright region to be split.
515
+ # For splitting cells, we only run with one thread. Because the volume is
516
+ # likely small and using multiple threads would cost more in overhead than
517
+ # is worth. num threads can be set only at processes level.
518
+ torch.set_num_threads(1)
194
519
  cell_id, cell_points = arg
195
520
  try:
196
- return split_cells(cell_points, outlier_keep=outlier_keep)
521
+ return split_cells(cell_points, settings=settings)
197
522
  except (ValueError, AssertionError) as err:
198
523
  raise StructureSplitException(f"Cell {cell_id}, error; {err}")
199
-
200
-
201
- def sphere_volume(radius: float) -> float:
202
- return (4 / 3) * math.pi * radius**3
cellfinder/core/main.py CHANGED
@@ -26,7 +26,7 @@ def main(
26
26
  ball_z_size: int = 15,
27
27
  ball_overlap_fraction: float = 0.6,
28
28
  log_sigma_size: float = 0.2,
29
- n_sds_above_mean_thresh: int = 10,
29
+ n_sds_above_mean_thresh: float = 10,
30
30
  soma_spread_factor: float = 1.4,
31
31
  max_cluster_size: int = 100000,
32
32
  cube_width: int = 50,
@@ -36,11 +36,13 @@ def main(
36
36
  skip_detection: bool = False,
37
37
  skip_classification: bool = False,
38
38
  detected_cells: List[Cell] = None,
39
+ classification_batch_size: Optional[int] = None,
40
+ classification_torch_device: str = "cpu",
39
41
  *,
40
42
  detect_callback: Optional[Callable[[int], None]] = None,
41
43
  classify_callback: Optional[Callable[[int], None]] = None,
42
44
  detect_finished_callback: Optional[Callable[[list], None]] = None,
43
- ) -> List:
45
+ ) -> List[Cell]:
44
46
  """
45
47
  Parameters
46
48
  ----------
@@ -74,6 +76,8 @@ def main(
74
76
  n_free_cpus,
75
77
  log_sigma_size,
76
78
  n_sds_above_mean_thresh,
79
+ batch_size=classification_batch_size,
80
+ torch_device=classification_torch_device,
77
81
  callback=detect_callback,
78
82
  )
79
83