cellfinder 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cellfinder might be problematic. Click here for more details.

@@ -4,18 +4,20 @@ from typing import Dict, Optional, Tuple, TypeVar, Union
4
4
  import numba.typed
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
- from numba import njit, typed
7
+ from numba import njit, objmode, typed
8
8
  from numba.core import types
9
9
  from numba.experimental import jitclass
10
10
  from numba.types import DictType
11
11
 
12
+ from cellfinder.core.tools.tools import get_max_possible_int_value
13
+
12
14
  T = TypeVar("T")
13
15
  # type used for the domain of the volume - the size of the vol
14
16
  vol_np_type = np.int64
15
17
  vol_numba_type = types.int64
16
18
  # type used for the structure id
17
- sid_np_type = np.int64
18
- sid_numba_type = types.int64
19
+ sid_np_type = np.uint64
20
+ sid_numba_type = types.uint64
19
21
 
20
22
 
21
23
  @dataclass
@@ -26,14 +28,17 @@ class Point:
26
28
 
27
29
 
28
30
  @njit
29
- def get_non_zero_dtype_min(values: np.ndarray) -> int:
31
+ def get_non_zero_dtype_min(values: np.ndarray) -> sid_numba_type:
30
32
  """
31
33
  Get the minimum of non-zero entries in *values*.
32
34
 
33
35
  If all entries are zero, returns maximum storeable number
34
36
  in the values array.
35
37
  """
36
- min_val = np.iinfo(values.dtype).max
38
+ # we don't know how big the int is, so make it as large as possible (64)
39
+ with objmode(min_val="u8"):
40
+ min_val = get_max_possible_int_value(values.dtype)
41
+
37
42
  for v in values:
38
43
  if v != 0 and v < min_val:
39
44
  min_val = v
@@ -97,6 +102,7 @@ list_of_points_type = types.ListType(tuple_point_type)
97
102
  spec = [
98
103
  ("z", vol_numba_type),
99
104
  ("next_structure_id", sid_numba_type),
105
+ ("soma_centre_value", sid_numba_type), # as large as possible
100
106
  ("shape", types.UniTuple(vol_numba_type, 2)),
101
107
  ("obsolete_ids", DictType(sid_numba_type, sid_numba_type)),
102
108
  ("coords_maps", DictType(sid_numba_type, list_of_points_type)),
@@ -133,18 +139,25 @@ class CellDetector:
133
139
  points.
134
140
  """
135
141
 
136
- def __init__(self, width: int, height: int, start_z: int):
142
+ def __init__(
143
+ self,
144
+ height: int,
145
+ width: int,
146
+ start_z: int,
147
+ soma_centre_value: sid_numba_type,
148
+ ):
137
149
  """
138
150
  Parameters
139
151
  ----------
140
- width, height
152
+ height, width:
141
153
  Shape of the planes input to self.process()
142
154
  start_z:
143
155
  The z-coordinate of the first processed plane.
144
156
  """
145
- self.shape = width, height
157
+ self.shape = height, width
146
158
  self.z = start_z
147
159
  self.next_structure_id = 1
160
+ self.soma_centre_value = soma_centre_value
148
161
 
149
162
  # Mapping from obsolete IDs to the IDs that they have been
150
163
  # made obsolete by
@@ -156,11 +169,18 @@ class CellDetector:
156
169
  key_type=sid_numba_type, value_type=list_of_points_type
157
170
  )
158
171
 
172
+ def _set_soma(self, soma_centre_value: sid_numba_type):
173
+ # Due to https://github.com/numba/numba/issues/9576. For testing we try
174
+ # different data types. Because of that issue we cannot pass a uint64
175
+ # soma_centre_value to constructor after we pass a uint32. This is the
176
+ # only way for now until numba fixes the issue
177
+ self.soma_centre_value = soma_centre_value
178
+
159
179
  def process(
160
180
  self, plane: np.ndarray, previous_plane: Optional[np.ndarray]
161
181
  ) -> np.ndarray:
162
182
  """
163
- Process a new plane.
183
+ Process a new plane (should be in Y, X axis order).
164
184
  """
165
185
  if plane.shape[:2] != self.shape:
166
186
  raise ValueError("plane does not have correct shape")
@@ -185,21 +205,21 @@ class CellDetector:
185
205
  -------
186
206
  plane :
187
207
  Plane with pixels either set to zero (no structure) or labelled
188
- with their structure ID.
208
+ with their structure ID. Plane is in Y, X axis order.
189
209
  """
190
- SOMA_CENTRE_VALUE = np.iinfo(plane.dtype).max
191
- for y in range(plane.shape[1]):
192
- for x in range(plane.shape[0]):
193
- if plane[x, y] == SOMA_CENTRE_VALUE:
210
+ soma_centre_value = self.soma_centre_value
211
+ for y in range(plane.shape[0]):
212
+ for x in range(plane.shape[1]):
213
+ if plane[y, x] == soma_centre_value:
194
214
  # Labels of structures below, left and behind
195
215
  neighbour_ids = np.zeros(3, dtype=sid_np_type)
196
216
  # If in bounds look at neighbours
197
- if x > 0:
198
- neighbour_ids[0] = plane[x - 1, y]
199
217
  if y > 0:
200
- neighbour_ids[1] = plane[x, y - 1]
218
+ neighbour_ids[0] = plane[y - 1, x]
219
+ if x > 0:
220
+ neighbour_ids[1] = plane[y, x - 1]
201
221
  if previous_plane is not None:
202
- neighbour_ids[2] = previous_plane[x, y]
222
+ neighbour_ids[2] = previous_plane[y, x]
203
223
 
204
224
  if is_new_structure(neighbour_ids):
205
225
  neighbour_ids[0] = self.next_structure_id
@@ -210,17 +230,20 @@ class CellDetector:
210
230
  # structure in next iterations
211
231
  struct_id = 0
212
232
 
213
- plane[x, y] = struct_id
233
+ plane[y, x] = struct_id
214
234
 
215
235
  return plane
216
236
 
217
237
  def get_cell_centres(self) -> np.ndarray:
238
+ """
239
+ Returns the 2D array of cell centers. It's (N, 3) with X, Y, Z columns.
240
+ """
218
241
  return self.structures_to_cells()
219
242
 
220
243
  def get_structures(self) -> Dict[int, np.ndarray]:
221
244
  """
222
245
  Gets the structures as a dict of structure IDs mapped to the 2D array
223
- of structure points.
246
+ of structure points (points vs x, y, z columns).
224
247
  """
225
248
  d = {}
226
249
  for sid, points in self.coords_maps.items():
@@ -228,7 +251,10 @@ class CellDetector:
228
251
  # `item = np.array(points, dtype=vol_np_type)` so we need to create
229
252
  # array and then fill in the point
230
253
  item = np.empty((len(points), 3), dtype=vol_np_type)
231
- d[sid] = item
254
+ # need to cast to int64, otherwise when dict is used we can get
255
+ # warnings as in numba issue #8829 b/c it assumes it's uint64.
256
+ # Python uses int(64) as the type
257
+ d[types.int64(sid)] = item
232
258
 
233
259
  for i, point in enumerate(points):
234
260
  item[i, :] = point
@@ -239,33 +265,41 @@ class CellDetector:
239
265
  self, sid: int, point: Union[tuple, list, np.ndarray]
240
266
  ) -> None:
241
267
  """
242
- Add single 3d *point* to the structure with the given *sid*.
268
+ Add single 3d (x, y, z) *point* to the structure with the given *sid*.
243
269
  """
244
- if sid not in self.coords_maps:
245
- self.coords_maps[sid] = typed.List.empty_list(tuple_point_type)
270
+ # cast in case user passes in int64 (default type for int in python)
271
+ # and numba complains
272
+ key = sid_numba_type(sid)
273
+ if key not in self.coords_maps:
274
+ self.coords_maps[key] = typed.List.empty_list(tuple_point_type)
246
275
 
247
- self._add_point(sid, (int(point[0]), int(point[1]), int(point[2])))
276
+ self._add_point(key, (int(point[0]), int(point[1]), int(point[2])))
248
277
 
249
278
  def add_points(self, sid: int, points: np.ndarray):
250
279
  """
251
280
  Adds ndarray of *points* to the structure with the given *sid*.
252
- Each row is a 3d point.
281
+ Each row is a 3-column (x, y, z) point.
253
282
  """
254
- if sid not in self.coords_maps:
255
- self.coords_maps[sid] = typed.List.empty_list(tuple_point_type)
283
+ # cast in case user passes in int64 (default type for int in python)
284
+ # and numba complains
285
+ key = sid_numba_type(sid)
286
+ if key not in self.coords_maps:
287
+ self.coords_maps[key] = typed.List.empty_list(tuple_point_type)
256
288
 
257
- append = self.coords_maps[sid].append
289
+ append = self.coords_maps[key].append
258
290
  pts = np.round(points).astype(vol_np_type)
259
291
  for point in pts:
260
292
  append((point[0], point[1], point[2]))
261
293
 
262
- def _add_point(self, sid: int, point: Tuple[int, int, int]) -> None:
294
+ def _add_point(
295
+ self, sid: sid_numba_type, point: Tuple[int, int, int]
296
+ ) -> None:
263
297
  # sid must exist
264
298
  self.coords_maps[sid].append(point)
265
299
 
266
300
  def add(
267
301
  self, x: int, y: int, z: int, neighbour_ids: npt.NDArray[sid_np_type]
268
- ) -> int:
302
+ ) -> sid_numba_type:
269
303
  """
270
304
  For the current coordinates takes all the neighbours and find the
271
305
  minimum structure including obsolete structures mapping to any of
@@ -287,7 +321,9 @@ class CellDetector:
287
321
  self._add_point(updated_id, (int(x), int(y), int(z)))
288
322
  return updated_id
289
323
 
290
- def sanitise_ids(self, neighbour_ids: npt.NDArray[sid_np_type]) -> int:
324
+ def sanitise_ids(
325
+ self, neighbour_ids: npt.NDArray[sid_np_type]
326
+ ) -> sid_numba_type:
291
327
  """
292
328
  Get the smallest ID of all the structures that are connected to IDs
293
329
  in `neighbour_ids`.
@@ -300,15 +336,17 @@ class CellDetector:
300
336
  """
301
337
  for i, neighbour_id in enumerate(neighbour_ids):
302
338
  # walk up the chain of obsolescence
303
- neighbour_id = int(traverse_dict(self.obsolete_ids, neighbour_id))
339
+ neighbour_id = traverse_dict(self.obsolete_ids, neighbour_id)
304
340
  neighbour_ids[i] = neighbour_id
305
341
 
306
342
  # Get minimum of all non-obsolete IDs
307
343
  updated_id = get_non_zero_dtype_min(neighbour_ids)
308
- return int(updated_id)
344
+ return updated_id
309
345
 
310
346
  def merge_structures(
311
- self, updated_id: int, neighbour_ids: npt.NDArray[sid_np_type]
347
+ self,
348
+ updated_id: sid_numba_type,
349
+ neighbour_ids: npt.NDArray[sid_np_type],
312
350
  ) -> None:
313
351
  """
314
352
  For all the neighbours, reassign all the points of neighbour to
@@ -1,9 +1,14 @@
1
- from typing import List, Tuple
1
+ from typing import List, Tuple, Type
2
2
 
3
3
  import numpy as np
4
+ import torch
4
5
 
5
6
  from cellfinder.core import logger
6
- from cellfinder.core.detect.filters.volume.ball_filter import BallFilter
7
+ from cellfinder.core.detect.filters.setup_filters import DetectionSettings
8
+ from cellfinder.core.detect.filters.volume.ball_filter import (
9
+ BallFilter,
10
+ InvalidVolume,
11
+ )
7
12
  from cellfinder.core.detect.filters.volume.structure_detection import (
8
13
  CellDetector,
9
14
  get_structure_centre,
@@ -14,41 +19,62 @@ class StructureSplitException(Exception):
14
19
  pass
15
20
 
16
21
 
17
- def get_shape(xs: np.ndarray, ys: np.ndarray, zs: np.ndarray) -> List[int]:
22
+ def get_shape(
23
+ xs: np.ndarray, ys: np.ndarray, zs: np.ndarray
24
+ ) -> Tuple[int, int, int]:
25
+ """
26
+ Takes a list of x, y, z coordinates and returns a volume size such that
27
+ all the points will fit into it. With axis order = x, y, z.
28
+ """
18
29
  # +1 because difference. TEST:
19
- shape = [int((dim.max() - dim.min()) + 1) for dim in (xs, ys, zs)]
30
+ shape = tuple(int((dim.max() - dim.min()) + 1) for dim in (xs, ys, zs))
20
31
  return shape
21
32
 
22
33
 
23
34
  def coords_to_volume(
24
- xs: np.ndarray, ys: np.ndarray, zs: np.ndarray, ball_radius: int = 1
25
- ) -> np.ndarray:
35
+ xs: np.ndarray,
36
+ ys: np.ndarray,
37
+ zs: np.ndarray,
38
+ volume_shape: Tuple[int, int, int],
39
+ ball_radius: int,
40
+ dtype: Type[np.number],
41
+ threshold_value: int,
42
+ ) -> torch.Tensor:
43
+ """
44
+ Takes the series of x, y, z points along with the shape of the volume
45
+ that fully enclose them (also x, y, z order). It than expands the
46
+ shape by the ball diameter in each axis. Then, each point, shifted
47
+ by the radius internally is set to the threshold value.
48
+
49
+ The volume is then transposed and returned in the Z, Y, X order.
50
+ """
51
+ # it's faster doing the work in numpy and then returning as torch array,
52
+ # than doing the work in torch
26
53
  ball_diameter = ball_radius * 2
27
54
  # Expanded to ensure the ball fits even at the border
28
- expanded_shape = [
29
- dim_size + ball_diameter for dim_size in get_shape(xs, ys, zs)
30
- ]
31
- volume = np.zeros(expanded_shape, dtype=np.uint32)
55
+ expanded_shape = [dim_size + ball_diameter for dim_size in volume_shape]
56
+ # volume is now x, y, z order
57
+ volume = np.zeros(expanded_shape, dtype=dtype)
32
58
 
33
59
  x_min, y_min, z_min = xs.min(), ys.min(), zs.min()
34
60
 
61
+ # shift the points so any sphere centered on it would not have its
62
+ # radius expand beyond the volume
35
63
  relative_xs = np.array((xs - x_min + ball_radius), dtype=np.int64)
36
64
  relative_ys = np.array((ys - y_min + ball_radius), dtype=np.int64)
37
65
  relative_zs = np.array((zs - z_min + ball_radius), dtype=np.int64)
38
66
 
39
- # OPTIMISE: vectorize
67
+ # set each point as the center with a value of threshold
40
68
  for rel_x, rel_y, rel_z in zip(relative_xs, relative_ys, relative_zs):
41
- volume[rel_x, rel_y, rel_z] = np.iinfo(volume.dtype).max - 1
42
- return volume
69
+ volume[rel_x, rel_y, rel_z] = threshold_value
70
+
71
+ volume = volume.swapaxes(0, 2)
72
+ return torch.from_numpy(volume)
43
73
 
44
74
 
45
75
  def ball_filter_imgs(
46
- volume: np.ndarray,
47
- threshold_value: int,
48
- soma_centre_value: int,
49
- ball_xy_size: int = 3,
50
- ball_z_size: int = 3,
51
- ) -> Tuple[np.ndarray, np.ndarray]:
76
+ volume: torch.Tensor, settings: DetectionSettings
77
+ ) -> np.ndarray:
52
78
  """
53
79
  Apply ball filtering to a 3D volume and detect cell centres.
54
80
 
@@ -56,105 +82,118 @@ def ball_filter_imgs(
56
82
  and the `CellDetector` class to detect cell centres.
57
83
 
58
84
  Args:
59
- volume (np.ndarray): The 3D volume to be filtered.
60
- threshold_value (int): The threshold value for ball filtering.
61
- soma_centre_value (int): The value representing the soma centre.
62
- ball_xy_size (int, optional):
63
- The size of the ball filter in the XY plane. Defaults to 3.
64
- ball_z_size (int, optional):
65
- The size of the ball filter in the Z plane. Defaults to 3.
85
+ volume (torch.Tensor): The 3D volume to be filtered (Z, Y, X order).
86
+ settings (DetectionSettings):
87
+ The settings to use.
66
88
 
67
89
  Returns:
68
- Tuple[np.ndarray, np.ndarray]:
69
- A tuple containing the filtered volume and the cell centres.
90
+ The 2D array of cell centres (N, 3) - X, Y, Z order.
70
91
 
71
92
  """
72
- # OPTIMISE: reuse ball filter instance
73
-
74
- good_tiles_mask = np.ones((1, 1, volume.shape[2]), dtype=np.bool_)
75
-
76
- plane_width, plane_height = volume.shape[:2]
77
- current_z = ball_z_size // 2
78
-
79
- bf = BallFilter(
80
- plane_width,
81
- plane_height,
82
- ball_xy_size,
83
- ball_z_size,
84
- overlap_fraction=0.8,
85
- tile_step_width=plane_width,
86
- tile_step_height=plane_height,
87
- threshold_value=threshold_value,
88
- soma_centre_value=soma_centre_value,
93
+ detection_convert = settings.detection_data_converter_func
94
+ batch_size = settings.batch_size
95
+
96
+ # make sure volume is not less than kernel etc
97
+ try:
98
+ bf = BallFilter(
99
+ plane_height=settings.plane_height,
100
+ plane_width=settings.plane_width,
101
+ ball_xy_size=settings.ball_xy_size,
102
+ ball_z_size=settings.ball_z_size,
103
+ overlap_fraction=settings.ball_overlap_fraction,
104
+ threshold_value=settings.threshold_value,
105
+ soma_centre_value=settings.soma_centre_value,
106
+ tile_height=settings.tile_height,
107
+ tile_width=settings.tile_width,
108
+ dtype=settings.filtering_dtype.__name__,
109
+ batch_size=batch_size,
110
+ torch_device=settings.torch_device,
111
+ use_mask=False, # we don't need a mask here
112
+ )
113
+ except InvalidVolume:
114
+ return np.empty((0, 3))
115
+
116
+ start_z = bf.first_valid_plane
117
+ cell_detector = CellDetector(
118
+ settings.plane_height,
119
+ settings.plane_width,
120
+ start_z=start_z,
121
+ soma_centre_value=settings.detection_soma_centre_value,
89
122
  )
90
- cell_detector = CellDetector(plane_width, plane_height, start_z=current_z)
91
123
 
92
- # FIXME: hard coded type
93
- ball_filtered_volume = np.zeros(volume.shape, dtype=np.uint32)
94
124
  previous_plane = None
95
- for z in range(volume.shape[2]):
96
- bf.append(volume[:, :, z].astype(np.uint32), good_tiles_mask[:, :, z])
125
+ for z in range(0, volume.shape[0], batch_size):
126
+ bf.append(volume[z : z + batch_size, :, :])
127
+
97
128
  if bf.ready:
98
129
  bf.walk()
99
- middle_plane = bf.get_middle_plane()
100
130
 
101
- # first valid middle plane is the current_z, not z
102
- ball_filtered_volume[:, :, current_z] = middle_plane[:]
103
- current_z += 1
131
+ middle_planes = bf.get_processed_planes()
132
+ n = middle_planes.shape[0]
104
133
 
105
- # DEBUG: TEST: transpose
106
- previous_plane = cell_detector.process(
107
- middle_plane.copy(), previous_plane
134
+ # we edit volume, but only for planes already processed that won't
135
+ # be passed to the filter in this run
136
+ volume[start_z : start_z + n, :, :] = torch.from_numpy(
137
+ middle_planes
108
138
  )
109
- return ball_filtered_volume, cell_detector.get_cell_centres()
139
+ start_z += n
140
+
141
+ # convert to type needed for detection
142
+ middle_planes = detection_convert(middle_planes)
143
+ for plane in middle_planes:
144
+ previous_plane = cell_detector.process(plane, previous_plane)
145
+
146
+ return cell_detector.get_cell_centres()
110
147
 
111
148
 
112
149
  def iterative_ball_filter(
113
- volume: np.ndarray, n_iter: int = 10
150
+ volume: torch.Tensor, settings: DetectionSettings
114
151
  ) -> Tuple[List[int], List[np.ndarray]]:
115
152
  """
116
153
  Apply iterative ball filtering to the given volume.
117
154
  The volume is eroded at each iteration, by subtracting 1 from the volume.
118
155
 
119
156
  Parameters:
120
- volume (np.ndarray): The input volume.
121
- n_iter (int): The number of iterations to perform. Default is 10.
157
+ volume (torch.Tensor): The input volume. It is edited inplace.
158
+ Of shape Z, Y, X.
159
+ settings (DetectionSettings): The settings to use.
122
160
 
123
161
  Returns:
124
- Tuple[List[int], List[np.ndarray]]: A tuple containing two lists:
125
- The structures found in each iteration.
162
+ tuple: A tuple containing two lists:
163
+ The number of structures found in each iteration.
126
164
  The cell centres found in each iteration.
127
165
  """
128
166
  ns = []
129
167
  centres = []
130
168
 
131
- threshold_value = np.iinfo(volume.dtype).max - 1
132
- soma_centre_value = np.iinfo(volume.dtype).max
133
-
134
- vol = volume.copy() # TODO: check if required
135
-
136
- for i in range(n_iter):
137
- vol, cell_centres = ball_filter_imgs(
138
- vol, threshold_value, soma_centre_value
139
- )
140
-
141
- # vol is unsigned, so can't let zeros underflow to max value
142
- vol[:, :, :] = np.where(vol != 0, vol - 1, 0)
169
+ for i in range(settings.n_splitting_iter):
170
+ cell_centres = ball_filter_imgs(volume, settings)
171
+ volume.sub_(1)
143
172
 
144
173
  n_structures = len(cell_centres)
145
174
  ns.append(n_structures)
146
175
  centres.append(cell_centres)
147
176
  if n_structures == 0:
148
177
  break
178
+
149
179
  return ns, centres
150
180
 
151
181
 
152
182
  def check_centre_in_cuboid(centre: np.ndarray, max_coords: np.ndarray) -> bool:
153
183
  """
154
- Checks whether a coordinate is in a cuboid
155
- :param centre: x,y,z coordinate
156
- :param max_coords: far corner of cuboid
157
- :return: True if within cuboid, otherwise False
184
+ Checks whether a coordinate is in a cuboid.
185
+
186
+ Parameters
187
+ ----------
188
+
189
+ centre : np.ndarray
190
+ x, y, z coordinate.
191
+ max_coords : np.ndarray
192
+ Far corner of cuboid.
193
+
194
+ Returns
195
+ -------
196
+ True if within cuboid, otherwise False.
158
197
  """
159
198
  relative_coords = centre
160
199
  if (relative_coords > max_coords).all():
@@ -168,7 +207,7 @@ def check_centre_in_cuboid(centre: np.ndarray, max_coords: np.ndarray) -> bool:
168
207
 
169
208
 
170
209
  def split_cells(
171
- cell_points: np.ndarray, outlier_keep: bool = False
210
+ cell_points: np.ndarray, settings: DetectionSettings
172
211
  ) -> np.ndarray:
173
212
  """
174
213
  Split the given cell points into individual cell centres.
@@ -177,28 +216,24 @@ def split_cells(
177
216
  cell_points (np.ndarray): Array of cell points with shape (N, 3),
178
217
  where N is the number of cell points and each point is represented
179
218
  by its x, y, and z coordinates.
180
- outlier_keep (bool, optional): Flag indicating whether to keep outliers
181
- during the splitting process. Defaults to False.
219
+ settings (DetectionSettings) : The settings to use for splitting. It is
220
+ modified inplace.
182
221
 
183
222
  Returns:
184
223
  np.ndarray: Array of absolute cell centres with shape (M, 3),
185
224
  where M is the number of individual cells and each centre is
186
225
  represented by its x, y, and z coordinates.
187
226
  """
227
+ # these points are in x, y, z order columnwise, in absolute pixels
188
228
  orig_centre = get_structure_centre(cell_points)
189
229
 
190
230
  xs = cell_points[:, 0]
191
231
  ys = cell_points[:, 1]
192
232
  zs = cell_points[:, 2]
193
233
 
194
- orig_corner = np.array(
195
- [
196
- orig_centre[0] - (orig_centre[0] - xs.min()),
197
- orig_centre[1] - (orig_centre[1] - ys.min()),
198
- orig_centre[2] - (orig_centre[2] - zs.min()),
199
- ]
200
- )
201
-
234
+ # corner coordinates in absolute pixels
235
+ orig_corner = np.array([xs.min(), ys.min(), zs.min()])
236
+ # volume center relative to corner
202
237
  relative_orig_centre = np.array(
203
238
  [
204
239
  orig_centre[0] - orig_corner[0],
@@ -207,22 +242,51 @@ def split_cells(
207
242
  ]
208
243
  )
209
244
 
245
+ # total volume enclosing all points
210
246
  original_bounding_cuboid_shape = get_shape(xs, ys, zs)
211
247
 
212
- ball_radius = 1
213
- vol = coords_to_volume(xs, ys, zs, ball_radius=ball_radius)
248
+ ball_radius = settings.ball_xy_size // 2
249
+ # they should be the same dtype so as to not need a conversion before
250
+ # passing the input data with marked cells to the filters (we currently
251
+ # set both to float32)
252
+ assert settings.filtering_dtype == settings.plane_original_np_dtype
253
+ # volume will now be z, y, x order
254
+ vol = coords_to_volume(
255
+ xs,
256
+ ys,
257
+ zs,
258
+ volume_shape=original_bounding_cuboid_shape,
259
+ ball_radius=ball_radius,
260
+ dtype=settings.filtering_dtype,
261
+ threshold_value=settings.threshold_value,
262
+ )
263
+
264
+ # get an estimate of how much memory processing a single batch of original
265
+ # input planes takes. For this much smaller volume, our batch will be such
266
+ # that it uses at most that much memory
267
+ total_vol_size = (
268
+ settings.plane_height * settings.plane_width * settings.batch_size
269
+ )
270
+ batch_size = total_vol_size // (vol.shape[1] * vol.shape[2])
271
+ batch_size = min(batch_size, vol.shape[0])
272
+
273
+ # update settings with our volume data
274
+ settings.plane_shape = vol.shape[1:]
275
+ settings.start_plane = 0
276
+ settings.end_plane = vol.shape[0]
277
+ settings.batch_size = batch_size
214
278
 
215
279
  # centres is a list of arrays of centres (1 array of centres per ball run)
216
- ns, centres = iterative_ball_filter(vol)
280
+ # in x, y, z order
281
+ ns, centres = iterative_ball_filter(vol, settings)
217
282
  ns.insert(0, 1)
218
283
  centres.insert(0, np.array([relative_orig_centre]))
219
284
 
220
285
  best_iteration = ns.index(max(ns))
221
-
222
286
  # TODO: put constraint on minimum centres distance ?
223
287
  relative_centres = centres[best_iteration]
224
288
 
225
- if not outlier_keep:
289
+ if not settings.outlier_keep:
226
290
  # TODO: change to checking whether in original cluster shape
227
291
  original_max_coords = np.array(original_bounding_cuboid_shape)
228
292
  relative_centres = np.array(
@@ -234,7 +298,7 @@ def split_cells(
234
298
  )
235
299
 
236
300
  absolute_centres = np.empty((len(relative_centres), 3))
237
- # FIXME: extract functionality
301
+ # convert centers to absolute pixels
238
302
  absolute_centres[:, 0] = orig_corner[0] + relative_centres[:, 0]
239
303
  absolute_centres[:, 1] = orig_corner[1] + relative_centres[:, 1]
240
304
  absolute_centres[:, 2] = orig_corner[2] + relative_centres[:, 2]