nettracer3d 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

@@ -60,8 +60,8 @@ class InteractiveSegmenter:
60
60
  self.sigmas = [1,2,4,8]
61
61
  self.windows = 10
62
62
  self.dogs = [(1, 2), (2, 4), (4, 8)]
63
- self.master_chunk = 49
64
- self.twod_chunk_size = 262144
63
+ self.master_chunk = 64
64
+ self.twod_chunk_size = 117649
65
65
 
66
66
  #Data when loading prev model:
67
67
  self.previous_foreground = None
@@ -72,7 +72,7 @@ class InteractiveSegmenter:
72
72
 
73
73
  def get_minimal_chunks_for_coordinates(self, coordinates_by_z):
74
74
  """
75
- Get minimal set of 2D chunks needed to cover the given coordinates
75
+ GPU version - Get minimal set of 2D chunks needed to cover the given coordinates
76
76
  Uses same chunking logic as create_2d_chunks()
77
77
  """
78
78
  MAX_CHUNK_SIZE = self.twod_chunk_size
@@ -95,26 +95,81 @@ class InteractiveSegmenter:
95
95
  # Single chunk for entire Z slice
96
96
  needed_chunks[z] = [[z, 0, y_dim, 0, x_dim]]
97
97
  else:
98
- # Multiple chunks - find which ones contain our coordinates
99
- largest_dim = 'y' if y_dim >= x_dim else 'x'
100
- num_divisions = int(cp.ceil(total_pixels / MAX_CHUNK_SIZE))
98
+ # Calculate optimal grid dimensions for square-ish chunks
99
+ num_chunks_needed = int(cp.ceil(total_pixels / MAX_CHUNK_SIZE))
100
+
101
+ # Find factors that give us the most square-like grid
102
+ best_y_chunks = 1
103
+ best_x_chunks = num_chunks_needed
104
+ best_aspect_ratio = float('inf')
105
+
106
+ for y_chunks in range(1, num_chunks_needed + 1):
107
+ x_chunks = int(cp.ceil(num_chunks_needed / y_chunks))
108
+
109
+ # Calculate actual chunk dimensions
110
+ chunk_y_size = int(cp.ceil(y_dim / y_chunks))
111
+ chunk_x_size = int(cp.ceil(x_dim / x_chunks))
112
+
113
+ # Check if chunk size constraint is satisfied
114
+ chunk_pixels = chunk_y_size * chunk_x_size
115
+ if chunk_pixels > MAX_CHUNK_SIZE:
116
+ continue
117
+
118
+ # Calculate aspect ratio of the chunk
119
+ aspect_ratio = max(chunk_y_size, chunk_x_size) / min(chunk_y_size, chunk_x_size)
120
+
121
+ # Prefer more square-like chunks (aspect ratio closer to 1)
122
+ if aspect_ratio < best_aspect_ratio:
123
+ best_aspect_ratio = aspect_ratio
124
+ best_y_chunks = y_chunks
125
+ best_x_chunks = x_chunks
101
126
 
102
127
  chunks_for_z = []
103
128
 
104
- if largest_dim == 'y':
105
- div_size = int(cp.ceil(y_dim / num_divisions))
106
- for i in range(0, y_dim, div_size):
107
- end_i = min(i + div_size, y_dim)
108
- # Check if this chunk contains any of our coordinates
109
- if any(i <= y <= end_i-1 for y in y_coords):
110
- chunks_for_z.append([z, i, end_i, 0, x_dim])
129
+ # If no valid configuration found, fall back to single dimension division
130
+ if best_aspect_ratio == float('inf'):
131
+ # Fall back to original logic
132
+ largest_dim = 'y' if y_dim >= x_dim else 'x'
133
+ num_divisions = int(cp.ceil(total_pixels / MAX_CHUNK_SIZE))
134
+
135
+ if largest_dim == 'y':
136
+ div_size = int(cp.ceil(y_dim / num_divisions))
137
+ for i in range(0, y_dim, div_size):
138
+ end_i = min(i + div_size, y_dim)
139
+ # Check if this chunk contains any of our coordinates
140
+ if any(i <= y <= end_i-1 for y in y_coords):
141
+ chunks_for_z.append([z, i, end_i, 0, x_dim])
142
+ else:
143
+ div_size = int(cp.ceil(x_dim / num_divisions))
144
+ for i in range(0, x_dim, div_size):
145
+ end_i = min(i + div_size, x_dim)
146
+ # Check if this chunk contains any of our coordinates
147
+ if any(i <= x <= end_i-1 for x in x_coords):
148
+ chunks_for_z.append([z, 0, y_dim, i, end_i])
111
149
  else:
112
- div_size = int(cp.ceil(x_dim / num_divisions))
113
- for i in range(0, x_dim, div_size):
114
- end_i = min(i + div_size, x_dim)
115
- # Check if this chunk contains any of our coordinates
116
- if any(i <= x <= end_i-1 for x in x_coords):
117
- chunks_for_z.append([z, 0, y_dim, i, end_i])
150
+ # Create the 2D grid of chunks and check which ones contain coordinates
151
+ y_chunk_size = int(cp.ceil(y_dim / best_y_chunks))
152
+ x_chunk_size = int(cp.ceil(x_dim / best_x_chunks))
153
+
154
+ for y_idx in range(best_y_chunks):
155
+ for x_idx in range(best_x_chunks):
156
+ y_start = y_idx * y_chunk_size
157
+ y_end = min(y_start + y_chunk_size, y_dim)
158
+ x_start = x_idx * x_chunk_size
159
+ x_end = min(x_start + x_chunk_size, x_dim)
160
+
161
+ # Skip empty chunks (can happen at edges)
162
+ if y_start >= y_dim or x_start >= x_dim:
163
+ continue
164
+
165
+ # Check if this chunk contains any of our coordinates
166
+ chunk_contains_coords = any(
167
+ y_start <= y <= y_end-1 and x_start <= x <= x_end-1
168
+ for y, x in zip(y_coords, x_coords)
169
+ )
170
+
171
+ if chunk_contains_coords:
172
+ chunks_for_z.append([z, y_start, y_end, x_start, x_end])
118
173
 
119
174
  needed_chunks[z] = chunks_for_z
120
175
 
@@ -139,102 +194,87 @@ class InteractiveSegmenter:
139
194
  return feature_map, (y_start, x_start) # Return offset for coordinate mapping
140
195
 
141
196
 
142
- def process_chunk_updated(self, chunk_coords):
143
- """Updated process_chunk with proper 2D chunking"""
197
+ def process_chunk(self, chunk_coords):
198
+ """Updated GPU process_chunk with manual coordinate generation (faster)"""
199
+ import cupy as cp
144
200
 
145
- foreground_coords = []
146
- background_coords = []
147
-
148
- if self.realtimechunks is None:
149
- # 3D processing (original logic unchanged)
150
- z_min, z_max = chunk_coords[0], chunk_coords[1]
151
- y_min, y_max = chunk_coords[2], chunk_coords[3]
152
- x_min, x_max = chunk_coords[4], chunk_coords[5]
153
-
154
- z_range = cp.arange(z_min, z_max)
155
- y_range = cp.arange(y_min, y_max)
156
- x_range = cp.arange(x_min, x_max)
201
+ chunk_info = self.realtimechunks[chunk_coords]
202
+
203
+ if not self.use_two:
204
+ chunk_bounds = chunk_info['bounds']
205
+ else:
206
+ # 2D chunk format: key is (z, y_start, x_start), coords are [y_start, y_end, x_start, x_end]
207
+ z = chunk_info['z']
208
+ y_start, y_end, x_start, x_end = chunk_info['coords']
209
+ chunk_bounds = (z, y_start, y_end, x_start, x_end) # Convert to tuple for consistency
210
+
211
+ if not self.use_two:
212
+ # 3D processing - generate coordinates manually
213
+ z_start, z_end, y_start, y_end, x_start, x_end = chunk_bounds
157
214
 
215
+ # Generate coordinate array manually using CuPy
158
216
  chunk_coords_array = cp.stack(cp.meshgrid(
159
- z_range, y_range, x_range, indexing='ij'
217
+ cp.arange(z_start, z_end),
218
+ cp.arange(y_start, y_end),
219
+ cp.arange(x_start, x_end),
220
+ indexing='ij'
160
221
  )).reshape(3, -1).T
161
222
 
162
- chunk_coords_gpu = chunk_coords_array
223
+ # Extract subarray
224
+ subarray = self.image_3d[z_start:z_end, y_start:y_end, x_start:x_end]
163
225
 
164
- subarray = self.image_3d[z_min:z_max+1, y_min:y_max+1, x_min:x_max+1]
165
-
166
- if self.use_two:
167
- subarray = cp.squeeze(subarray)
168
-
169
- if self.use_two and self.speed:
170
- feature_map = self.compute_feature_maps_gpu_2d(image_2d=subarray)
171
- elif self.use_two:
172
- feature_map = self.compute_deep_feature_maps_gpu_2d(image_2d=subarray)
173
- elif self.speed:
226
+ # Compute features for entire subarray
227
+ if self.speed:
174
228
  feature_map = self.compute_feature_maps_gpu(subarray)
175
229
  else:
176
230
  feature_map = self.compute_deep_feature_maps_gpu(subarray)
177
231
 
178
- if self.use_two:
179
- feature_map = cp.expand_dims(feature_map, axis=0)
180
-
181
- local_coords = chunk_coords_gpu.copy()
182
- local_coords[:, 0] -= z_min
183
- local_coords[:, 1] -= y_min
184
- local_coords[:, 2] -= x_min
185
-
232
+ # Vectorized feature extraction using local coordinates
233
+ local_coords = chunk_coords_array - cp.array([z_start, y_start, x_start])
186
234
  features_gpu = feature_map[local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]]
187
235
 
188
- features_cpu = cp.asnumpy(features_gpu)
189
- predictions = self.model.predict(features_cpu)
190
-
191
- pred_mask = cp.array(predictions, dtype=bool)
192
- foreground_coords = chunk_coords_gpu[pred_mask]
193
- background_coords = chunk_coords_gpu[~pred_mask]
194
-
195
236
  else:
196
- # 2D processing - compute features for chunk only (not full Z-slice)
197
- chunk_coords_gpu = cp.array(chunk_coords)
198
- z_coords = chunk_coords_gpu[:, 0]
199
- y_coords = chunk_coords_gpu[:, 1]
200
- x_coords = chunk_coords_gpu[:, 2]
201
-
202
- z = int(cp.unique(z_coords)[0]) # All coordinates should have same Z
237
+ # 2D processing - generate coordinates manually
238
+ z, y_start, y_end, x_start, x_end = chunk_bounds
203
239
 
204
- # Get chunk bounds
205
- y_min, y_max = int(cp.min(y_coords)), int(cp.max(y_coords))
206
- x_min, x_max = int(cp.min(x_coords)), int(cp.max(x_coords))
207
-
208
- # Expand bounds slightly to ensure we capture the chunk properly
209
- y_min = max(0, y_min)
210
- x_min = max(0, x_min)
211
- y_max = min(self.image_3d.shape[1], y_max + 1)
212
- x_max = min(self.image_3d.shape[2], x_max + 1)
240
+ # Generate coordinate array for this Z slice using CuPy
241
+ chunk_coords_array = cp.stack(cp.meshgrid(
242
+ cp.array([z]),
243
+ cp.arange(y_start, y_end),
244
+ cp.arange(x_start, x_end),
245
+ indexing='ij'
246
+ )).reshape(3, -1).T
213
247
 
214
248
  # Extract 2D subarray for this chunk
215
- subarray_2d = self.image_3d[z, y_min:y_max, x_min:x_max]
249
+ subarray_2d = self.image_3d[z, y_start:y_end, x_start:x_end]
216
250
 
217
- # Compute features for just this chunk
251
+ # Compute 2D features for the subarray
218
252
  if self.speed:
219
253
  feature_map = self.compute_feature_maps_gpu_2d(image_2d=subarray_2d)
220
254
  else:
221
255
  feature_map = self.compute_deep_feature_maps_gpu_2d(image_2d=subarray_2d)
222
256
 
223
257
  # Convert global coordinates to local chunk coordinates
224
- local_y_coords = y_coords - y_min
225
- local_x_coords = x_coords - x_min
258
+ local_y_coords = chunk_coords_array[:, 1] - y_start
259
+ local_x_coords = chunk_coords_array[:, 2] - x_start
226
260
 
227
261
  # Extract features using local coordinates
228
262
  features_gpu = feature_map[local_y_coords, local_x_coords]
229
-
230
- features_cpu = cp.asnumpy(features_gpu)
231
- predictions = self.model.predict(features_cpu)
232
-
233
- pred_mask = cp.array(predictions, dtype=bool)
234
- foreground_coords = chunk_coords_gpu[pred_mask]
235
- background_coords = chunk_coords_gpu[~pred_mask]
236
-
237
- return foreground_coords, background_coords
263
+
264
+ # Common prediction logic - convert to CPU for model prediction
265
+ features_cpu = cp.asnumpy(features_gpu)
266
+ predictions = self.model.predict(features_cpu)
267
+ predictions_gpu = cp.array(predictions, dtype=bool)
268
+
269
+ # Use boolean indexing to separate coordinates
270
+ foreground_coords = chunk_coords_array[predictions_gpu]
271
+ background_coords = chunk_coords_array[~predictions_gpu]
272
+
273
+ # Convert to sets (convert back to CPU for set operations)
274
+ foreground = set(map(tuple, cp.asnumpy(foreground_coords)))
275
+ background = set(map(tuple, cp.asnumpy(background_coords)))
276
+
277
+ return foreground, background
238
278
 
239
279
  def twodim_coords(self, z, y_start, y_end, x_start, x_end):
240
280
  """
@@ -638,7 +678,7 @@ class InteractiveSegmenter:
638
678
  return features
639
679
 
640
680
  def create_2d_chunks(self):
641
- """Same 2D chunking logic"""
681
+ """GPU version - Updated 2D chunking to create more square-like chunks"""
642
682
  MAX_CHUNK_SIZE = self.twod_chunk_size
643
683
  chunks = []
644
684
 
@@ -650,21 +690,71 @@ class InteractiveSegmenter:
650
690
  if total_pixels <= MAX_CHUNK_SIZE:
651
691
  chunks.append([z, 0, y_dim, 0, x_dim])
652
692
  else:
653
- largest_dim = 'y' if y_dim >= x_dim else 'x'
654
- num_divisions = int(cp.ceil(total_pixels / MAX_CHUNK_SIZE))
693
+ # Calculate optimal grid dimensions for square-ish chunks
694
+ num_chunks_needed = int(cp.ceil(total_pixels / MAX_CHUNK_SIZE))
695
+
696
+ # Find factors that give us the most square-like grid
697
+ best_y_chunks = 1
698
+ best_x_chunks = num_chunks_needed
699
+ best_aspect_ratio = float('inf')
655
700
 
656
- if largest_dim == 'y':
657
- div_size = int(cp.ceil(y_dim / num_divisions))
658
- for i in range(0, y_dim, div_size):
659
- end_i = min(i + div_size, y_dim)
660
- chunks.append([z, i, end_i, 0, x_dim])
701
+ for y_chunks in range(1, num_chunks_needed + 1):
702
+ x_chunks = int(cp.ceil(num_chunks_needed / y_chunks))
703
+
704
+ # Calculate actual chunk dimensions
705
+ chunk_y_size = int(cp.ceil(y_dim / y_chunks))
706
+ chunk_x_size = int(cp.ceil(x_dim / x_chunks))
707
+
708
+ # Check if chunk size constraint is satisfied
709
+ chunk_pixels = chunk_y_size * chunk_x_size
710
+ if chunk_pixels > MAX_CHUNK_SIZE:
711
+ continue
712
+
713
+ # Calculate aspect ratio of the chunk
714
+ aspect_ratio = max(chunk_y_size, chunk_x_size) / min(chunk_y_size, chunk_x_size)
715
+
716
+ # Prefer more square-like chunks (aspect ratio closer to 1)
717
+ if aspect_ratio < best_aspect_ratio:
718
+ best_aspect_ratio = aspect_ratio
719
+ best_y_chunks = y_chunks
720
+ best_x_chunks = x_chunks
721
+
722
+ # If no valid configuration found, fall back to single dimension division
723
+ if best_aspect_ratio == float('inf'):
724
+ # Fall back to original logic
725
+ largest_dim = 'y' if y_dim >= x_dim else 'x'
726
+ num_divisions = int(cp.ceil(total_pixels / MAX_CHUNK_SIZE))
727
+
728
+ if largest_dim == 'y':
729
+ div_size = int(cp.ceil(y_dim / num_divisions))
730
+ for i in range(0, y_dim, div_size):
731
+ end_i = min(i + div_size, y_dim)
732
+ chunks.append([z, i, end_i, 0, x_dim])
733
+ else:
734
+ div_size = int(cp.ceil(x_dim / num_divisions))
735
+ for i in range(0, x_dim, div_size):
736
+ end_i = min(i + div_size, x_dim)
737
+ chunks.append([z, 0, y_dim, i, end_i])
661
738
  else:
662
- div_size = int(cp.ceil(x_dim / num_divisions))
663
- for i in range(0, x_dim, div_size):
664
- end_i = min(i + div_size, x_dim)
665
- chunks.append([z, 0, y_dim, i, end_i])
666
-
739
+ # Create the 2D grid of chunks
740
+ y_chunk_size = int(cp.ceil(y_dim / best_y_chunks))
741
+ x_chunk_size = int(cp.ceil(x_dim / best_x_chunks))
742
+
743
+ for y_idx in range(best_y_chunks):
744
+ for x_idx in range(best_x_chunks):
745
+ y_start = y_idx * y_chunk_size
746
+ y_end = min(y_start + y_chunk_size, y_dim)
747
+ x_start = x_idx * x_chunk_size
748
+ x_end = min(x_start + x_chunk_size, x_dim)
749
+
750
+ # Skip empty chunks (can happen at edges)
751
+ if y_start >= y_dim or x_start >= x_dim:
752
+ continue
753
+
754
+ chunks.append([z, y_start, y_end, x_start, x_end])
755
+
667
756
  return chunks
757
+
668
758
 
669
759
  def segment_volume(self, array, chunk_size=None, gpu=True):
670
760
  """Optimized GPU version with sequential GPU processing and batched sklearn prediction"""
@@ -677,32 +767,8 @@ class InteractiveSegmenter:
677
767
  print("Chunking data...")
678
768
 
679
769
  if not self.use_two:
680
- # 3D Processing
681
- chunk_size = ((chunk_size + 15) // 32) * 32
682
-
683
- z_chunks = (self.image_3d.shape[0] + chunk_size - 1) // chunk_size
684
- y_chunks = (self.image_3d.shape[1] + chunk_size - 1) // chunk_size
685
- x_chunks = (self.image_3d.shape[2] + chunk_size - 1) // chunk_size
686
-
687
- chunk_starts = cp.array(cp.meshgrid(
688
- cp.arange(z_chunks) * chunk_size,
689
- cp.arange(y_chunks) * chunk_size,
690
- cp.arange(x_chunks) * chunk_size,
691
- indexing='ij'
692
- )).reshape(3, -1).T
693
-
694
- chunks = []
695
- for chunk_start_gpu in chunk_starts:
696
- z_start = int(chunk_start_gpu[0])
697
- y_start = int(chunk_start_gpu[1])
698
- x_start = int(chunk_start_gpu[2])
699
-
700
- z_end = min(z_start + chunk_size, self.image_3d.shape[0])
701
- y_end = min(y_start + chunk_size, self.image_3d.shape[1])
702
- x_end = min(x_start + chunk_size, self.image_3d.shape[2])
703
-
704
- coords = [z_start, z_end, y_start, y_end, x_start, x_end]
705
- chunks.append(coords)
770
+ chunks = self.compute_3d_chunks(chunk_size)
771
+
706
772
  else:
707
773
  chunks = self.create_2d_chunks()
708
774
 
@@ -889,8 +955,7 @@ class InteractiveSegmenter:
889
955
 
890
956
  def get_realtime_chunks_2d(self, chunk_size=None):
891
957
  """
892
- Create chunks with 1 z-thickness (2D chunks across XY planes)
893
- Now uses the same logic as create_2d_chunks for consistency
958
+ GPU version - Updated 2D chunking to match create_2d_chunks logic
894
959
  """
895
960
 
896
961
  MAX_CHUNK_SIZE = self.twod_chunk_size
@@ -912,154 +977,241 @@ class InteractiveSegmenter:
912
977
  'z': z
913
978
  }
914
979
  else:
915
- # Multiple chunks per Z plane - divide along largest dimension
916
- largest_dim = 'y' if y_dim >= x_dim else 'x'
917
- num_divisions = int(cp.ceil(total_pixels / MAX_CHUNK_SIZE))
980
+ # Calculate optimal grid dimensions for square-ish chunks
981
+ num_chunks_needed = int(cp.ceil(total_pixels / MAX_CHUNK_SIZE))
982
+
983
+ # Find factors that give us the most square-like grid
984
+ best_y_chunks = 1
985
+ best_x_chunks = num_chunks_needed
986
+ best_aspect_ratio = float('inf')
987
+
988
+ for y_chunks in range(1, num_chunks_needed + 1):
989
+ x_chunks = int(cp.ceil(num_chunks_needed / y_chunks))
990
+
991
+ # Calculate actual chunk dimensions
992
+ chunk_y_size = int(cp.ceil(y_dim / y_chunks))
993
+ chunk_x_size = int(cp.ceil(x_dim / x_chunks))
994
+
995
+ # Check if chunk size constraint is satisfied
996
+ chunk_pixels = chunk_y_size * chunk_x_size
997
+ if chunk_pixels > MAX_CHUNK_SIZE:
998
+ continue
999
+
1000
+ # Calculate aspect ratio of the chunk
1001
+ aspect_ratio = max(chunk_y_size, chunk_x_size) / min(chunk_y_size, chunk_x_size)
1002
+
1003
+ # Prefer more square-like chunks (aspect ratio closer to 1)
1004
+ if aspect_ratio < best_aspect_ratio:
1005
+ best_aspect_ratio = aspect_ratio
1006
+ best_y_chunks = y_chunks
1007
+ best_x_chunks = x_chunks
918
1008
 
919
- if largest_dim == 'y':
920
- # Divide along Y dimension
921
- div_size = int(cp.ceil(y_dim / num_divisions))
922
- for i in range(0, y_dim, div_size):
923
- end_i = min(i + div_size, y_dim)
924
- # Use (z, y_start, x_start) as key for consistency
925
- chunk_dict[(z, i, 0)] = {
926
- 'coords': [i, end_i, 0, x_dim], # [y_start, y_end, x_start, x_end]
927
- 'processed': False,
928
- 'z': z
929
- }
1009
+ # If no valid configuration found, fall back to single dimension division
1010
+ if best_aspect_ratio == float('inf'):
1011
+ # Fall back to original logic
1012
+ largest_dim = 'y' if y_dim >= x_dim else 'x'
1013
+ num_divisions = int(cp.ceil(total_pixels / MAX_CHUNK_SIZE))
1014
+
1015
+ if largest_dim == 'y':
1016
+ div_size = int(cp.ceil(y_dim / num_divisions))
1017
+ for i in range(0, y_dim, div_size):
1018
+ end_i = min(i + div_size, y_dim)
1019
+ chunk_dict[(z, i, 0)] = {
1020
+ 'coords': [i, end_i, 0, x_dim],
1021
+ 'processed': False,
1022
+ 'z': z
1023
+ }
1024
+ else:
1025
+ div_size = int(cp.ceil(x_dim / num_divisions))
1026
+ for i in range(0, x_dim, div_size):
1027
+ end_i = min(i + div_size, x_dim)
1028
+ chunk_dict[(z, 0, i)] = {
1029
+ 'coords': [0, y_dim, i, end_i],
1030
+ 'processed': False,
1031
+ 'z': z
1032
+ }
930
1033
  else:
931
- # Divide along X dimension
932
- div_size = int(cp.ceil(x_dim / num_divisions))
933
- for i in range(0, x_dim, div_size):
934
- end_i = min(i + div_size, x_dim)
935
- # Use (z, y_start, x_start) as key for consistency
936
- chunk_dict[(z, 0, i)] = {
937
- 'coords': [0, y_dim, i, end_i], # [y_start, y_end, x_start, x_end]
938
- 'processed': False,
939
- 'z': z
940
- }
1034
+ # Create the 2D grid of chunks
1035
+ y_chunk_size = int(cp.ceil(y_dim / best_y_chunks))
1036
+ x_chunk_size = int(cp.ceil(x_dim / best_x_chunks))
1037
+
1038
+ for y_idx in range(best_y_chunks):
1039
+ for x_idx in range(best_x_chunks):
1040
+ y_start = y_idx * y_chunk_size
1041
+ y_end = min(y_start + y_chunk_size, y_dim)
1042
+ x_start = x_idx * x_chunk_size
1043
+ x_end = min(x_start + x_chunk_size, x_dim)
1044
+
1045
+ # Skip empty chunks (can happen at edges)
1046
+ if y_start >= y_dim or x_start >= x_dim:
1047
+ continue
1048
+
1049
+ chunk_dict[(z, y_start, x_start)] = {
1050
+ 'coords': [y_start, y_end, x_start, x_end],
1051
+ 'processed': False,
1052
+ 'z': z
1053
+ }
941
1054
 
942
1055
  self.realtimechunks = chunk_dict
943
1056
  print("Ready!")
944
1057
 
945
- def get_realtime_chunks(self, chunk_size=49):
1058
+ def compute_3d_chunks(self, chunk_size=None):
1059
+ """
1060
+ Compute 3D chunks with consistent logic across all operations (GPU version).
946
1061
 
947
- # Determine if we need to chunk XY planes
948
- small_dims = (self.image_3d.shape[1] <= chunk_size and
949
- self.image_3d.shape[2] <= chunk_size)
950
- few_z = self.image_3d.shape[0] <= 100 # arbitrary threshold
1062
+ Args:
1063
+ chunk_size: Optional chunk size, otherwise uses dynamic calculation
1064
+
1065
+ Returns:
1066
+ list: List of chunk coordinates [z_start, z_end, y_start, y_end, x_start, x_end]
1067
+ """
1068
+ import cupy as cp
1069
+ import multiprocessing
951
1070
 
952
- # If small enough, each Z is one chunk
953
- if small_dims and few_z:
954
- chunk_size_xy = max(self.image_3d.shape[1], self.image_3d.shape[2])
955
- else:
956
- chunk_size_xy = chunk_size
1071
+ # Use consistent chunk size calculation
1072
+ if chunk_size is None:
1073
+ if hasattr(self, 'master_chunk') and self.master_chunk is not None:
1074
+ chunk_size = self.master_chunk
1075
+ else:
1076
+ # Dynamic calculation (same as segmentation)
1077
+ total_cores = multiprocessing.cpu_count()
1078
+ total_volume = cp.prod(cp.array(self.image_3d.shape))
1079
+ target_volume_per_chunk = total_volume / (total_cores * 4)
1080
+
1081
+ chunk_size = int(cp.cbrt(target_volume_per_chunk))
1082
+ chunk_size = max(16, min(chunk_size, min(self.image_3d.shape) // 2))
1083
+ chunk_size = ((chunk_size + 7) // 16) * 16
957
1084
 
958
- # Calculate chunks for XY plane
959
- y_chunks = (self.image_3d.shape[1] + chunk_size_xy - 1) // chunk_size_xy
960
- x_chunks = (self.image_3d.shape[2] + chunk_size_xy - 1) // chunk_size_xy
1085
+ depth, height, width = self.image_3d.shape
961
1086
 
962
- # Populate chunk dictionary
963
- chunk_dict = {}
1087
+ # Calculate chunk grid dimensions
1088
+ z_chunks = (depth + chunk_size - 1) // chunk_size
1089
+ y_chunks = (height + chunk_size - 1) // chunk_size
1090
+ x_chunks = (width + chunk_size - 1) // chunk_size
964
1091
 
965
- # Create chunks for each Z plane
966
- for z in range(self.image_3d.shape[0]):
967
- if small_dims:
968
-
969
- chunk_dict[(z, 0, 0)] = {
970
- 'coords': [0, self.image_3d.shape[1], 0, self.image_3d.shape[2]],
971
- 'processed': False,
972
- 'z': z
973
- }
974
- else:
975
- # Multiple chunks per Z
976
- for y_chunk in range(y_chunks):
977
- for x_chunk in range(x_chunks):
978
- y_start = y_chunk * chunk_size_xy
979
- x_start = x_chunk * chunk_size_xy
980
- y_end = min(y_start + chunk_size_xy, self.image_3d.shape[1])
981
- x_end = min(x_start + chunk_size_xy, self.image_3d.shape[2])
982
-
983
- chunk_dict[(z, y_start, x_start)] = {
984
- 'coords': [y_start, y_end, x_start, x_end],
985
- 'processed': False,
986
- 'z': z
987
- }
1092
+ # Generate all chunk start positions using CuPy
1093
+ chunk_starts = cp.array(cp.meshgrid(
1094
+ cp.arange(z_chunks) * chunk_size,
1095
+ cp.arange(y_chunks) * chunk_size,
1096
+ cp.arange(x_chunks) * chunk_size,
1097
+ indexing='ij'
1098
+ )).reshape(3, -1).T
1099
+
1100
+
1101
+ # Create chunk coordinate list
1102
+ chunks = []
1103
+ for z_start, y_start, x_start in chunk_starts:
1104
+ z_end = min(z_start + chunk_size, depth)
1105
+ y_end = min(y_start + chunk_size, height)
1106
+ x_end = min(x_start + chunk_size, width)
1107
+ coords = [int(z_start), int(z_end), int(y_start), int(y_end), int(x_start), int(x_end)]
1108
+ chunks.append(coords)
1109
+
1110
+ return chunks
988
1111
 
989
- self.realtimechunks = chunk_dict
990
1112
 
1113
+
1114
+ def get_realtime_chunks(self, chunk_size=None):
1115
+ if chunk_size is None:
1116
+ chunk_size = self.master_chunk
1117
+
1118
+ all_chunks = self.compute_3d_chunks(chunk_size)
1119
+
1120
+ self.realtimechunks = {
1121
+ i: {
1122
+ 'bounds': chunk_coords, # Only store [z_start, z_end, y_start, y_end, x_start, x_end]
1123
+ 'processed': False,
1124
+ 'center': self._get_chunk_center(chunk_coords), # Small tuple for distance calc
1125
+ 'is_3d': True # Flag to indicate this is 3D chunking
1126
+ }
1127
+ for i, chunk_coords in enumerate(all_chunks)
1128
+ }
991
1129
  print("Ready!")
992
1130
 
1131
+ def _get_chunk_center(self, chunk_coords):
1132
+ """Get center coordinate of chunk for distance calculations"""
1133
+ z_start, z_end, y_start, y_end, x_start, x_end = chunk_coords
1134
+ return (
1135
+ (z_start + z_end) // 2,
1136
+ (y_start + y_end) // 2,
1137
+ (x_start + x_end) // 2
1138
+ )
993
1139
 
994
- def segment_volume_realtime(self, gpu=True):
995
- """Updated realtime segmentation - no more feature map caching needed"""
996
- import cupy as cp
997
-
1140
+
1141
+ def segment_volume_realtime(self, gpu=False):
998
1142
  if self.realtimechunks is None:
999
1143
  if not self.use_two:
1000
- self.get_realtime_chunks()
1144
+ self.get_realtime_chunks() # 3D chunks
1001
1145
  else:
1002
- self.get_realtime_chunks_2d()
1146
+ self.get_realtime_chunks_2d() # 2D chunks
1003
1147
  else:
1004
- for chunk_pos in self.realtimechunks:
1005
- self.realtimechunks[chunk_pos]['processed'] = False
1006
-
1007
- chunk_dict = self.realtimechunks
1148
+ for chunk_key in self.realtimechunks:
1149
+ self.realtimechunks[chunk_key]['processed'] = False
1008
1150
 
1009
- def get_nearest_unprocessed_chunk(self):
1010
- """Get nearest unprocessed chunk prioritizing current Z"""
1151
+ def get_nearest_unprocessed_chunk():
1011
1152
  curr_z = self.current_z if self.current_z is not None else self.image_3d.shape[0] // 2
1012
1153
  curr_y = self.current_y if self.current_y is not None else self.image_3d.shape[1] // 2
1013
1154
  curr_x = self.current_x if self.current_x is not None else self.image_3d.shape[2] // 2
1014
1155
 
1015
- # First try to find chunks at current Z
1016
- current_z_chunks = [(pos, info) for pos, info in chunk_dict.items()
1017
- if pos[0] == curr_z and not info['processed']]
1156
+ unprocessed_chunks = [
1157
+ (key, info) for key, info in self.realtimechunks.items()
1158
+ if not info['processed']
1159
+ ]
1018
1160
 
1019
- if current_z_chunks:
1020
- nearest = min(current_z_chunks,
1021
- key=lambda x: ((x[0][1] - curr_y) ** 2 +
1022
- (x[0][2] - curr_x) ** 2))
1023
- return nearest[0]
1024
-
1025
- # If no chunks at current Z, find nearest Z with available chunks
1026
- available_z = sorted(
1027
- [(pos[0], pos) for pos, info in chunk_dict.items()
1028
- if not info['processed']],
1029
- key=lambda x: abs(x[0] - curr_z)
1030
- )
1161
+ if not unprocessed_chunks:
1162
+ return None
1031
1163
 
1032
- if available_z:
1033
- target_z = available_z[0][0]
1034
- z_chunks = [(pos, info) for pos, info in chunk_dict.items()
1035
- if pos[0] == target_z and not info['processed']]
1036
- nearest = min(z_chunks,
1037
- key=lambda x: ((x[0][1] - curr_y) ** 2 +
1038
- (x[0][2] - curr_x) ** 2))
1164
+ if self.use_two:
1165
+ # 2D chunks: key format is (z, y_start, x_start)
1166
+ # First try to find chunks at current Z
1167
+ current_z_chunks = [
1168
+ (key, info) for key, info in unprocessed_chunks
1169
+ if key[0] == curr_z
1170
+ ]
1171
+
1172
+ if current_z_chunks:
1173
+ # Find nearest chunk at current Z by y,x distance
1174
+ nearest = min(current_z_chunks,
1175
+ key=lambda x: ((x[0][1] - curr_y) ** 2 +
1176
+ (x[0][2] - curr_x) ** 2))
1177
+ return nearest[0]
1178
+
1179
+ # If no chunks at current Z, find nearest Z with available chunks
1180
+ available_z_chunks = sorted(unprocessed_chunks,
1181
+ key=lambda x: abs(x[0][0] - curr_z))
1182
+
1183
+ if available_z_chunks:
1184
+ # Get the nearest Z that has unprocessed chunks
1185
+ target_z = available_z_chunks[0][0][0]
1186
+ z_chunks = [
1187
+ (key, info) for key, info in unprocessed_chunks
1188
+ if key[0] == target_z
1189
+ ]
1190
+ # Find nearest chunk in that Z by y,x distance
1191
+ nearest = min(z_chunks,
1192
+ key=lambda x: ((x[0][1] - curr_y) ** 2 +
1193
+ (x[0][2] - curr_x) ** 2))
1194
+ return nearest[0]
1195
+ else:
1196
+ # 3D chunks: use existing center-based distance calculation
1197
+ nearest = min(unprocessed_chunks,
1198
+ key=lambda x: sum((a - b) ** 2 for a, b in
1199
+ zip(x[1]['center'], (curr_z, curr_y, curr_x))))
1039
1200
  return nearest[0]
1040
1201
 
1041
1202
  return None
1042
1203
 
1043
1204
  while True:
1044
- chunk_idx = get_nearest_unprocessed_chunk(self)
1045
- if chunk_idx is None:
1205
+ chunk_key = get_nearest_unprocessed_chunk()
1206
+ if chunk_key is None:
1046
1207
  break
1047
1208
 
1048
- chunk = chunk_dict[chunk_idx]
1049
- chunk['processed'] = True
1050
- coords = chunk['coords'] # [y_start, y_end, x_start, x_end]
1051
- z = chunk['z']
1052
-
1053
- # Generate coordinates for this chunk
1054
- coords_array = self.twodim_coords(z, coords[0], coords[1], coords[2], coords[3])
1055
-
1056
- # Convert to CPU for processing
1057
- coords_list = list(map(tuple, cp.asnumpy(coords_array)))
1209
+ self.realtimechunks[chunk_key]['processed'] = True
1058
1210
 
1059
- # Process the chunk - now computes features only for this chunk
1060
- fore, back = self.process_chunk_updated(coords_list)
1211
+ # Process the chunk - pass the key, process_chunk will handle the rest
1212
+ fore, back = self.process_chunk(chunk_key)
1061
1213
 
1062
- yield cp.asnumpy(fore), cp.asnumpy(back)
1214
+ yield fore, back
1063
1215
 
1064
1216
 
1065
1217
  def cleanup(self):
@@ -1180,9 +1332,10 @@ class InteractiveSegmenter:
1180
1332
  local_x = x - x_offset
1181
1333
  feature_vector = feature_map[local_y, local_x]
1182
1334
  background_features.append(cp.asnumpy(feature_vector))
1335
+
1183
1336
  else:
1184
-
1185
- box_size = self.master_chunk
1337
+ # 3D processing - match segmentation chunking logic using compute_3d_chunks
1338
+ chunk_size = self.master_chunk
1186
1339
 
1187
1340
  # Memory-efficient approach: compute features only for necessary subarrays
1188
1341
  foreground_features = []
@@ -1195,48 +1348,48 @@ class InteractiveSegmenter:
1195
1348
  z_fore = cp.argwhere(foreground_array_gpu == 1)
1196
1349
  z_back = cp.argwhere(foreground_array_gpu == 2)
1197
1350
 
1198
- # Convert back to NumPy for compatibility with the rest of the code
1199
- #z_fore_cpu = cp.asnumpy(z_fore)
1200
- #z_back_cpu = cp.asnumpy(z_back)
1201
-
1202
1351
  # If no scribbles, return empty lists
1203
1352
  if len(z_fore) == 0 and len(z_back) == 0:
1204
1353
  return foreground_features, background_features
1205
1354
 
1206
- # Get dimensions of the input array
1207
- depth, height, width = foreground_array.shape
1208
-
1209
- # Determine the minimum number of boxes needed to cover all scribbles
1210
- half_box = box_size // 2
1211
-
1212
- # Step 1: Find the minimum set of boxes that cover all scribbles
1213
- # We'll divide the volume into a grid of boxes of size box_size
1214
-
1215
- # Calculate how many boxes are needed in each dimension
1216
- z_grid_size = (depth + box_size - 1) // box_size
1217
- y_grid_size = (height + box_size - 1) // box_size
1218
- x_grid_size = (width + box_size - 1) // box_size
1219
-
1220
- # Track which grid cells contain scribbles
1221
- grid_cells_with_scribbles = set()
1222
-
1223
- # Map original coordinates to grid cells
1224
- for z, y, x in cp.vstack((z_fore, z_back)) if len(z_back) > 0 else z_fore:
1225
- grid_z = int(z // box_size)
1226
- grid_y = int(y // box_size)
1227
- grid_x = int(x // box_size)
1228
- grid_cells_with_scribbles.add((grid_z, grid_y, grid_x))
1229
-
1230
- # Step 2: Process each grid cell that contains scribbles
1231
- for grid_z, grid_y, grid_x in grid_cells_with_scribbles:
1232
- # Calculate the boundaries of this grid cell
1233
- z_min = grid_z * box_size
1234
- y_min = grid_y * box_size
1235
- x_min = grid_x * box_size
1355
+ # Get all chunks using consistent method
1356
+ all_chunks = self.compute_3d_chunks(self.master_chunk)
1357
+ # Convert chunks to cupy array for vectorized operations
1358
+ chunks_array = cp.array(all_chunks) # Shape: (n_chunks, 6)
1359
+ # columns: [z_start, z_end, y_start, y_end, x_start, x_end]
1360
+
1361
+ # Combine all scribbles
1362
+ all_scribbles = cp.vstack((z_fore, z_back)) if len(z_back) > 0 else z_fore
1363
+
1364
+ # For each scribble, find which chunk it belongs to using vectorized operations
1365
+ chunks_with_scribbles = set()
1366
+ all_scribbles_cpu = cp.asnumpy(all_scribbles) # Convert once for iteration
1367
+
1368
+ for z, y, x in all_scribbles_cpu:
1369
+ # Vectorized check: find chunks that contain this scribble
1370
+ # Check if scribble falls within each chunk's bounds
1371
+ z_in_chunk = (chunks_array[:, 0] <= z) & (z < chunks_array[:, 1])
1372
+ y_in_chunk = (chunks_array[:, 2] <= y) & (y < chunks_array[:, 3])
1373
+ x_in_chunk = (chunks_array[:, 4] <= x) & (x < chunks_array[:, 5])
1374
+
1375
+ # Find chunks where all conditions are true
1376
+ matching_chunks = z_in_chunk & y_in_chunk & x_in_chunk
1377
+
1378
+ # Get the chunk indices that match
1379
+ chunk_indices = cp.where(matching_chunks)[0]
1236
1380
 
1237
- z_max = min(z_min + box_size, depth)
1238
- y_max = min(y_min + box_size, height)
1239
- x_max = min(x_min + box_size, width)
1381
+ # Add matching chunks to set (set automatically handles duplicates)
1382
+ for idx in cp.asnumpy(chunk_indices):
1383
+ chunk_coords = tuple(cp.asnumpy(chunks_array[idx]))
1384
+ chunks_with_scribbles.add(chunk_coords)
1385
+
1386
+ # Convert set to list
1387
+ chunks_with_scribbles = list(chunks_with_scribbles)
1388
+
1389
+ # Step 2: Process each chunk that contains scribbles (manual coordinate extraction)
1390
+ for chunk_coords in chunks_with_scribbles:
1391
+ # Extract chunk boundaries
1392
+ z_min, z_max, y_min, y_max, x_min, x_max = chunk_coords
1240
1393
 
1241
1394
  # Extract the subarray (assuming image_3d is already a CuPy array)
1242
1395
  subarray = self.image_3d[z_min:z_max, y_min:y_max, x_min:x_max]