nettracer3d 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

nettracer3d/segmenter.py CHANGED
@@ -52,7 +52,7 @@ class InteractiveSegmenter:
52
52
  self.windows = 10
53
53
  self.dogs = [(1, 2), (2, 4), (4, 8)]
54
54
  self.master_chunk = 49
55
- self.twod_chunk_size = 262144
55
+ self.twod_chunk_size = 117649
56
56
  self.batch_amplifier = 1
57
57
 
58
58
  #Data when loading prev model:
@@ -86,26 +86,81 @@ class InteractiveSegmenter:
86
86
  # Single chunk for entire Z slice
87
87
  needed_chunks[z] = [[z, 0, y_dim, 0, x_dim]]
88
88
  else:
89
- # Multiple chunks - find which ones contain our coordinates
90
- largest_dim = 'y' if y_dim >= x_dim else 'x'
91
- num_divisions = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
89
+ # Calculate optimal grid dimensions for square-ish chunks
90
+ num_chunks_needed = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
91
+
92
+ # Find factors that give us the most square-like grid
93
+ best_y_chunks = 1
94
+ best_x_chunks = num_chunks_needed
95
+ best_aspect_ratio = float('inf')
96
+
97
+ for y_chunks in range(1, num_chunks_needed + 1):
98
+ x_chunks = int(np.ceil(num_chunks_needed / y_chunks))
99
+
100
+ # Calculate actual chunk dimensions
101
+ chunk_y_size = int(np.ceil(y_dim / y_chunks))
102
+ chunk_x_size = int(np.ceil(x_dim / x_chunks))
103
+
104
+ # Check if chunk size constraint is satisfied
105
+ chunk_pixels = chunk_y_size * chunk_x_size
106
+ if chunk_pixels > MAX_CHUNK_SIZE:
107
+ continue
108
+
109
+ # Calculate aspect ratio of the chunk
110
+ aspect_ratio = max(chunk_y_size, chunk_x_size) / min(chunk_y_size, chunk_x_size)
111
+
112
+ # Prefer more square-like chunks (aspect ratio closer to 1)
113
+ if aspect_ratio < best_aspect_ratio:
114
+ best_aspect_ratio = aspect_ratio
115
+ best_y_chunks = y_chunks
116
+ best_x_chunks = x_chunks
92
117
 
93
118
  chunks_for_z = []
94
119
 
95
- if largest_dim == 'y':
96
- div_size = int(np.ceil(y_dim / num_divisions))
97
- for i in range(0, y_dim, div_size):
98
- end_i = min(i + div_size, y_dim)
99
- # Check if this chunk contains any of our coordinates
100
- if any(i <= y <= end_i-1 for y in y_coords):
101
- chunks_for_z.append([z, i, end_i, 0, x_dim])
120
+ # If no valid configuration found, fall back to single dimension division
121
+ if best_aspect_ratio == float('inf'):
122
+ # Fall back to original logic
123
+ largest_dim = 'y' if y_dim >= x_dim else 'x'
124
+ num_divisions = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
125
+
126
+ if largest_dim == 'y':
127
+ div_size = int(np.ceil(y_dim / num_divisions))
128
+ for i in range(0, y_dim, div_size):
129
+ end_i = min(i + div_size, y_dim)
130
+ # Check if this chunk contains any of our coordinates
131
+ if any(i <= y <= end_i-1 for y in y_coords):
132
+ chunks_for_z.append([z, i, end_i, 0, x_dim])
133
+ else:
134
+ div_size = int(np.ceil(x_dim / num_divisions))
135
+ for i in range(0, x_dim, div_size):
136
+ end_i = min(i + div_size, x_dim)
137
+ # Check if this chunk contains any of our coordinates
138
+ if any(i <= x <= end_i-1 for x in x_coords):
139
+ chunks_for_z.append([z, 0, y_dim, i, end_i])
102
140
  else:
103
- div_size = int(np.ceil(x_dim / num_divisions))
104
- for i in range(0, x_dim, div_size):
105
- end_i = min(i + div_size, x_dim)
106
- # Check if this chunk contains any of our coordinates
107
- if any(i <= x <= end_i-1 for x in x_coords):
108
- chunks_for_z.append([z, 0, y_dim, i, end_i])
141
+ # Create the 2D grid of chunks and check which ones contain coordinates
142
+ y_chunk_size = int(np.ceil(y_dim / best_y_chunks))
143
+ x_chunk_size = int(np.ceil(x_dim / best_x_chunks))
144
+
145
+ for y_idx in range(best_y_chunks):
146
+ for x_idx in range(best_x_chunks):
147
+ y_start = y_idx * y_chunk_size
148
+ y_end = min(y_start + y_chunk_size, y_dim)
149
+ x_start = x_idx * x_chunk_size
150
+ x_end = min(x_start + x_chunk_size, x_dim)
151
+
152
+ # Skip empty chunks (can happen at edges)
153
+ if y_start >= y_dim or x_start >= x_dim:
154
+ continue
155
+
156
+ # Check if this chunk contains any of our coordinates
157
+ chunk_contains_coords = any(
158
+ y_start <= y <= y_end-1 and x_start <= x <= x_end-1
159
+ for y, x in zip(y_coords, x_coords)
160
+ )
161
+
162
+ if chunk_contains_coords:
163
+ chunks_for_z.append([z, y_start, y_end, x_start, x_end])
109
164
 
110
165
  needed_chunks[z] = chunks_for_z
111
166
 
@@ -612,25 +667,32 @@ class InteractiveSegmenter:
612
667
  return dict(z_dict) # Convert back to regular dict
613
668
 
614
669
  def process_chunk(self, chunk_coords):
615
- """Updated process_chunk with chunked 2D feature computation"""
670
+ """Updated process_chunk with manual coordinate generation"""
616
671
 
617
- if self.realtimechunks is None:
618
- # 3D processing (unchanged)
619
- z_min, z_max = chunk_coords[0], chunk_coords[1]
620
- y_min, y_max = chunk_coords[2], chunk_coords[3]
621
- x_min, x_max = chunk_coords[4], chunk_coords[5]
672
+ chunk_info = self.realtimechunks[chunk_coords]
622
673
 
623
- z_range = np.arange(z_min, z_max)
624
- y_range = np.arange(y_min, y_max)
625
- x_range = np.arange(x_min, x_max)
626
-
627
- z_grid, y_grid, x_grid = np.meshgrid(z_range, y_range, x_range, indexing='ij')
628
- chunk_coords_array = np.column_stack([
629
- z_grid.ravel(), y_grid.ravel(), x_grid.ravel()
630
- ])
674
+ if not self.use_two:
675
+ chunk_bounds = chunk_info['bounds']
676
+ else:
677
+ # 2D chunk format: key is (z, y_start, x_start), coords are [y_start, y_end, x_start, x_end]
678
+ z = chunk_info['z']
679
+ y_start, y_end, x_start, x_end = chunk_info['coords']
680
+ chunk_bounds = (z, y_start, y_end, x_start, x_end) # Convert to tuple for consistency
681
+
682
+ if not self.use_two:
683
+ # 3D processing - generate coordinates manually
684
+ z_start, z_end, y_start, y_end, x_start, x_end = chunk_bounds
685
+
686
+ # Generate coordinate array manually
687
+ chunk_coords_array = np.stack(np.meshgrid(
688
+ np.arange(z_start, z_end),
689
+ np.arange(y_start, y_end),
690
+ np.arange(x_start, x_end),
691
+ indexing='ij'
692
+ )).reshape(3, -1).T
631
693
 
632
694
  # Extract subarray
633
- subarray = self.image_3d[z_min:z_max+1, y_min:y_max+1, x_min:x_max+1]
695
+ subarray = self.image_3d[z_start:z_end, y_start:y_end, x_start:x_end]
634
696
 
635
697
  # Compute features for entire subarray
636
698
  if self.speed:
@@ -638,130 +700,87 @@ class InteractiveSegmenter:
638
700
  else:
639
701
  feature_map = self.compute_deep_feature_maps_cpu(subarray)
640
702
 
641
- # Vectorized feature extraction
642
- local_coords = chunk_coords_array - np.array([z_min, y_min, x_min])
703
+ # Vectorized feature extraction using local coordinates
704
+ local_coords = chunk_coords_array - np.array([z_start, y_start, x_start])
643
705
  features = feature_map[local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]]
644
706
 
645
- # Vectorized predictions
646
- predictions = self.model.predict(features)
647
- predictions = np.array(predictions, dtype=bool)
648
-
649
- # Use boolean indexing to separate coordinates
650
- foreground_coords = chunk_coords_array[predictions]
651
- background_coords = chunk_coords_array[~predictions]
652
-
653
- # Convert to sets
654
- foreground = set(map(tuple, foreground_coords))
655
- background = set(map(tuple, background_coords))
656
-
657
- return foreground, background
658
-
659
707
  else:
660
- # 2D processing - compute features for chunk only
661
- chunk_coords_array = np.array(chunk_coords)
662
- z_coords = chunk_coords_array[:, 0]
663
- y_coords = chunk_coords_array[:, 1]
664
- x_coords = chunk_coords_array[:, 2]
665
-
666
- z = int(np.unique(z_coords)[0]) # All coordinates should have same Z
667
-
668
- # Get chunk bounds
669
- y_min, y_max = int(np.min(y_coords)), int(np.max(y_coords))
670
- x_min, x_max = int(np.min(x_coords)), int(np.max(x_coords))
708
+ # 2D processing - generate coordinates manually
709
+ z, y_start, y_end, x_start, x_end = chunk_bounds
671
710
 
672
- # Expand bounds slightly to ensure we capture the chunk properly
673
- y_min = max(0, y_min)
674
- x_min = max(0, x_min)
675
- y_max = min(self.image_3d.shape[1], y_max + 1)
676
- x_max = min(self.image_3d.shape[2], x_max + 1)
711
+ # Generate coordinate array for this Z slice
712
+ chunk_coords_array = np.stack(np.meshgrid(
713
+ [z],
714
+ np.arange(y_start, y_end),
715
+ np.arange(x_start, x_end),
716
+ indexing='ij'
717
+ )).reshape(3, -1).T
677
718
 
678
719
  # Extract 2D subarray for this chunk
679
- subarray_2d = self.image_3d[z, y_min:y_max, x_min:x_max]
720
+ subarray_2d = self.image_3d[z, y_start:y_end, x_start:x_end]
680
721
 
681
- # Compute features for just this chunk
722
+ # Compute 2D features for the subarray
682
723
  if self.speed:
683
724
  feature_map = self.compute_feature_maps_cpu_2d(image_2d=subarray_2d)
684
725
  else:
685
726
  feature_map = self.compute_deep_feature_maps_cpu_2d(image_2d=subarray_2d)
686
727
 
687
728
  # Convert global coordinates to local chunk coordinates
688
- local_y_coords = y_coords - y_min
689
- local_x_coords = x_coords - x_min
729
+ local_y_coords = chunk_coords_array[:, 1] - y_start
730
+ local_x_coords = chunk_coords_array[:, 2] - x_start
690
731
 
691
732
  # Extract features using local coordinates
692
733
  features = feature_map[local_y_coords, local_x_coords]
693
-
694
- # Vectorized predictions
695
- predictions = self.model.predict(features)
696
- predictions = np.array(predictions, dtype=bool)
697
-
698
- # Use boolean indexing to separate coordinates
699
- foreground_coords = chunk_coords_array[predictions]
700
- background_coords = chunk_coords_array[~predictions]
701
-
702
- # Convert to sets
703
- foreground = set(map(tuple, foreground_coords))
704
- background = set(map(tuple, background_coords))
705
-
706
- return foreground, background
707
-
708
- def twodim_coords(self, y_dim, x_dim, z, chunk_size = None, subrange = None):
709
-
710
- if subrange is None:
711
- y_coords, x_coords = np.meshgrid(
712
- np.arange(y_dim),
713
- np.arange(x_dim),
714
- indexing='ij'
715
- )
716
734
 
717
- slice_coords = np.column_stack((
718
- np.full(chunk_size, z),
719
- y_coords.ravel(),
720
- x_coords.ravel()
721
- ))
722
-
723
- elif subrange[0] == 'y':
724
-
725
- y_subrange = np.arange(subrange[1], subrange[2])
726
-
727
- # Create meshgrid for this subchunk
728
- y_sub, x_sub = np.meshgrid(
729
- y_subrange,
730
- np.arange(x_dim),
731
- indexing='ij'
732
- )
733
-
734
- # Create coordinates for this subchunk
735
- subchunk_size = len(y_subrange) * x_dim
736
- slice_coords = np.column_stack((
737
- np.full(subchunk_size, z),
738
- y_sub.ravel(),
739
- x_sub.ravel()
740
- ))
741
-
742
- elif subrange[0] == 'x':
743
-
744
- x_subrange = np.arange(subrange[1], subrange[2])
745
-
746
- # Create meshgrid for this subchunk
747
- y_sub, x_sub = np.meshgrid(
748
- np.arange(y_dim),
749
- x_subrange,
750
- indexing='ij'
751
- )
752
-
753
- # Create coordinates for this subchunk
754
- subchunk_size = y_dim * len(x_subrange)
755
- slice_coords = np.column_stack((
756
- np.full(subchunk_size, z),
757
- y_sub.ravel(),
758
- x_sub.ravel()
759
- ))
760
-
761
-
735
+ # Common prediction logic
736
+ predictions = self.model.predict(features)
737
+ predictions = np.array(predictions, dtype=bool)
738
+
739
+ # Use boolean indexing to separate coordinates
740
+ foreground_coords = chunk_coords_array[predictions]
741
+ background_coords = chunk_coords_array[~predictions]
742
+
743
+ # Convert to sets
744
+ foreground = set(map(tuple, foreground_coords))
745
+ background = set(map(tuple, background_coords))
746
+
747
+ return foreground, background
762
748
 
763
- return list(map(tuple, slice_coords))
749
+ def twodim_coords(self, z, y_start, y_end, x_start, x_end):
750
+ """
751
+ Generate 2D coordinates for a z-slice using NumPy for CPU.
752
+ Updated to match GPU version signature and work with square chunking.
753
+
754
+ Args:
755
+ z (int): Z-slice index
756
+ y_start (int): Start index for y dimension
757
+ y_end (int): End index for y dimension
758
+ x_start (int): Start index for x dimension
759
+ x_end (int): End index for x dimension
760
+
761
+ Returns:
762
+ NumPy array of coordinates in format (z, y, x)
763
+ """
764
764
 
765
+ # Create ranges for y and x dimensions
766
+ y_range = np.arange(y_start, y_end, dtype=int)
767
+ x_range = np.arange(x_start, x_end, dtype=int)
768
+
769
+ # Create meshgrid
770
+ y_coords, x_coords = np.meshgrid(y_range, x_range, indexing='ij')
771
+
772
+ # Calculate total size
773
+ total_size = len(y_range) * len(x_range)
774
+
775
+ # Stack coordinates with z values
776
+ slice_coords = np.column_stack((
777
+ np.full(total_size, z, dtype=int),
778
+ y_coords.ravel(),
779
+ x_coords.ravel()
780
+ ))
781
+
782
+ return slice_coords
783
+
765
784
 
766
785
 
767
786
  def segment_volume(self, array, chunk_size=None, gpu=False):
@@ -773,7 +792,7 @@ class InteractiveSegmenter:
773
792
  chunk_size = self.master_chunk
774
793
 
775
794
  def create_2d_chunks():
776
- """Same as your existing implementation"""
795
+ """Updated 2D chunking to create more square-like chunks"""
777
796
  MAX_CHUNK_SIZE = self.twod_chunk_size
778
797
  chunks = []
779
798
 
@@ -785,53 +804,77 @@ class InteractiveSegmenter:
785
804
  if total_pixels <= MAX_CHUNK_SIZE:
786
805
  chunks.append([y_dim, x_dim, z, total_pixels, None])
787
806
  else:
788
- largest_dim = 'y' if y_dim >= x_dim else 'x'
789
- num_divisions = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
807
+ # Calculate optimal grid dimensions for square-ish chunks
808
+ num_chunks_needed = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
790
809
 
791
- if largest_dim == 'y':
792
- div_size = int(np.ceil(y_dim / num_divisions))
793
- for i in range(0, y_dim, div_size):
794
- end_i = min(i + div_size, y_dim)
795
- chunks.append([y_dim, x_dim, z, None, ['y', i, end_i]])
810
+ # Find factors that give us the most square-like grid
811
+ best_y_chunks = 1
812
+ best_x_chunks = num_chunks_needed
813
+ best_aspect_ratio = float('inf')
814
+
815
+ for y_chunks in range(1, num_chunks_needed + 1):
816
+ x_chunks = int(np.ceil(num_chunks_needed / y_chunks))
817
+
818
+ # Calculate actual chunk dimensions
819
+ chunk_y_size = int(np.ceil(y_dim / y_chunks))
820
+ chunk_x_size = int(np.ceil(x_dim / x_chunks))
821
+
822
+ # Check if chunk size constraint is satisfied
823
+ chunk_pixels = chunk_y_size * chunk_x_size
824
+ if chunk_pixels > MAX_CHUNK_SIZE:
825
+ continue
826
+
827
+ # Calculate aspect ratio of the chunk
828
+ aspect_ratio = max(chunk_y_size, chunk_x_size) / min(chunk_y_size, chunk_x_size)
829
+
830
+ # Prefer more square-like chunks (aspect ratio closer to 1)
831
+ if aspect_ratio < best_aspect_ratio:
832
+ best_aspect_ratio = aspect_ratio
833
+ best_y_chunks = y_chunks
834
+ best_x_chunks = x_chunks
835
+
836
+ # If no valid configuration found, fall back to single dimension division
837
+ if best_aspect_ratio == float('inf'):
838
+ # Fall back to original logic
839
+ largest_dim = 'y' if y_dim >= x_dim else 'x'
840
+ num_divisions = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
841
+
842
+ if largest_dim == 'y':
843
+ div_size = int(np.ceil(y_dim / num_divisions))
844
+ for i in range(0, y_dim, div_size):
845
+ end_i = min(i + div_size, y_dim)
846
+ chunks.append([y_dim, x_dim, z, None, ['y', i, end_i]])
847
+ else:
848
+ div_size = int(np.ceil(x_dim / num_divisions))
849
+ for i in range(0, x_dim, div_size):
850
+ end_i = min(i + div_size, x_dim)
851
+ chunks.append([y_dim, x_dim, z, None, ['x', i, end_i]])
796
852
  else:
797
- div_size = int(np.ceil(x_dim / num_divisions))
798
- for i in range(0, x_dim, div_size):
799
- end_i = min(i + div_size, x_dim)
800
- chunks.append([y_dim, x_dim, z, None, ['x', i, end_i]])
801
-
853
+ # Create the 2D grid of chunks
854
+ y_chunk_size = int(np.ceil(y_dim / best_y_chunks))
855
+ x_chunk_size = int(np.ceil(x_dim / best_x_chunks))
856
+
857
+ for y_idx in range(best_y_chunks):
858
+ for x_idx in range(best_x_chunks):
859
+ y_start = y_idx * y_chunk_size
860
+ y_end = min(y_start + y_chunk_size, y_dim)
861
+ x_start = x_idx * x_chunk_size
862
+ x_end = min(x_start + x_chunk_size, x_dim)
863
+
864
+ # Skip empty chunks (can happen at edges)
865
+ if y_start >= y_dim or x_start >= x_dim:
866
+ continue
867
+
868
+ # For 2D chunks, we need to encode both y and x ranges
869
+ chunks.append([y_dim, x_dim, z, None, ['2d', y_start, y_end, x_start, x_end]])
870
+
802
871
  return chunks
803
872
 
873
+
804
874
  print("Chunking data...")
805
875
 
806
876
  if not self.use_two:
807
- # Create smaller chunks for better load balancing
808
- if chunk_size is None:
809
- total_cores = multiprocessing.cpu_count()
810
- total_volume = np.prod(self.image_3d.shape)
811
- target_volume_per_chunk = total_volume / (total_cores * 4) # 4x more chunks
812
-
813
- chunk_size = int(np.cbrt(target_volume_per_chunk))
814
- chunk_size = max(16, min(chunk_size, min(self.image_3d.shape) // 2))
815
- chunk_size = ((chunk_size + 7) // 16) * 16
816
-
817
- z_chunks = (self.image_3d.shape[0] + chunk_size - 1) // chunk_size
818
- y_chunks = (self.image_3d.shape[1] + chunk_size - 1) // chunk_size
819
- x_chunks = (self.image_3d.shape[2] + chunk_size - 1) // chunk_size
820
-
821
- chunk_starts = np.array(np.meshgrid(
822
- np.arange(z_chunks) * chunk_size,
823
- np.arange(y_chunks) * chunk_size,
824
- np.arange(x_chunks) * chunk_size,
825
- indexing='ij'
826
- )).reshape(3, -1).T
827
-
828
- chunks = []
829
- for z_start, y_start, x_start in chunk_starts:
830
- z_end = min(z_start + chunk_size, self.image_3d.shape[0])
831
- y_end = min(y_start + chunk_size, self.image_3d.shape[1])
832
- x_end = min(x_start + chunk_size, self.image_3d.shape[2])
833
- coords = [z_start, z_end, y_start, y_end, x_start, x_end]
834
- chunks.append(coords)
877
+ chunks = self.compute_3d_chunks(chunk_size)
835
878
  else:
836
879
  chunks = create_2d_chunks()
837
880
 
@@ -941,37 +984,45 @@ class InteractiveSegmenter:
941
984
  return features, chunk_coords_array
942
985
 
943
986
  else:
944
- # Handle 2D case
945
- chunk_coords_list = self.twodim_coords(chunk_coords[0], chunk_coords[1],
946
- chunk_coords[2], chunk_coords[3], chunk_coords[4])
947
- chunk_coords_by_z = self.organize_by_z(chunk_coords_list)
948
987
 
949
- all_features = []
950
- all_coords = []
988
+ y_dim, x_dim, z, chunk_size, subrange = chunk_coords
989
+
990
+ if subrange[0] == 'y':
991
+ y_start, y_end = subrange[1], subrange[2]
992
+ x_start, x_end = 0, x_dim
993
+ elif subrange[0] == 'x':
994
+ y_start, y_end = 0, y_dim
995
+ x_start, x_end = subrange[1], subrange[2]
996
+ elif subrange[0] == '2d':
997
+ y_start, y_end = subrange[1], subrange[2]
998
+ x_start, x_end = subrange[3], subrange[4]
999
+ else:
1000
+ raise ValueError(f"Unknown subrange format: {subrange}")
951
1001
 
952
- for z, coords in chunk_coords_by_z.items():
953
- coords_array = np.array(coords)
954
-
955
- # Get features for this z-slice
956
- features_slice = self.get_feature_map_slice(z, self.speed, self.cur_gpu)
957
- features = features_slice[coords_array[:, 0], coords_array[:, 1]]
958
-
959
-
960
- # Convert to 3D coordinates
961
- coords_3d = np.column_stack([
962
- np.full(len(coords_array), z),
963
- coords_array[:, 0],
964
- coords_array[:, 1]
965
- ])
966
-
967
- all_features.append(features)
968
- all_coords.append(coords_3d)
969
1002
 
970
- if all_features:
971
- return np.vstack(all_features), np.vstack(all_coords)
1003
+ # Generate coordinates for this chunk
1004
+ coords_array = self.twodim_coords(z, y_start, y_end, x_start, x_end)
1005
+
1006
+ # NEW: Compute features for just this chunk instead of full Z-slice
1007
+ # Extract 2D subarray for this chunk
1008
+ subarray_2d = self.image_3d[z, y_start:y_end, x_start:x_end]
1009
+
1010
+ # Compute features for this chunk only
1011
+ if self.speed:
1012
+ feature_map = self.compute_feature_maps_cpu_2d(image_2d=subarray_2d)
972
1013
  else:
973
- return np.array([]), np.array([])
974
-
1014
+ feature_map = self.compute_deep_feature_maps_cpu_2d(image_2d=subarray_2d)
1015
+
1016
+ # Convert global coordinates to local chunk coordinates
1017
+ y_indices = coords_array[:, 1] - y_start # Local Y coordinates
1018
+ x_indices = coords_array[:, 2] - x_start # Local X coordinates
1019
+
1020
+ # Extract features using local coordinates
1021
+ features = feature_map[y_indices, x_indices]
1022
+
1023
+ return features, coords_array
1024
+
1025
+
975
1026
  def update_position(self, z=None, x=None, y=None):
976
1027
  """Update current position for chunk prioritization with safeguards"""
977
1028
 
@@ -1006,57 +1057,34 @@ class InteractiveSegmenter:
1006
1057
  self.prev_z = z
1007
1058
 
1008
1059
 
1009
- def get_realtime_chunks(self, chunk_size = 49):
1010
-
1011
- # Determine if we need to chunk XY planes
1012
- small_dims = (self.image_3d.shape[1] <= chunk_size and
1013
- self.image_3d.shape[2] <= chunk_size)
1014
- few_z = self.image_3d.shape[0] <= 100 # arbitrary threshold
1015
-
1016
- # If small enough, each Z is one chunk
1017
- if small_dims and few_z:
1018
- chunk_size_xy = max(self.image_3d.shape[1], self.image_3d.shape[2])
1019
- else:
1020
- chunk_size_xy = chunk_size
1060
+ def get_realtime_chunks(self, chunk_size=None):
1061
+ if chunk_size is None:
1062
+ chunk_size = self.master_chunk
1021
1063
 
1022
- # Calculate chunks for XY plane
1023
- y_chunks = (self.image_3d.shape[1] + chunk_size_xy - 1) // chunk_size_xy
1024
- x_chunks = (self.image_3d.shape[2] + chunk_size_xy - 1) // chunk_size_xy
1064
+ all_chunks = self.compute_3d_chunks(chunk_size)
1025
1065
 
1026
- # Populate chunk dictionary
1027
- chunk_dict = {}
1028
-
1029
- # Create chunks for each Z plane
1030
- for z in range(self.image_3d.shape[0]):
1031
- if small_dims:
1032
-
1033
- chunk_dict[(z, 0, 0)] = {
1034
- 'coords': [0, self.image_3d.shape[1], 0, self.image_3d.shape[2]],
1035
- 'processed': False,
1036
- 'z': z
1037
- }
1038
- else:
1039
- # Multiple chunks per Z
1040
- for y_chunk in range(y_chunks):
1041
- for x_chunk in range(x_chunks):
1042
- y_start = y_chunk * chunk_size_xy
1043
- x_start = x_chunk * chunk_size_xy
1044
- y_end = min(y_start + chunk_size_xy, self.image_3d.shape[1])
1045
- x_end = min(x_start + chunk_size_xy, self.image_3d.shape[2])
1046
-
1047
- chunk_dict[(z, y_start, x_start)] = {
1048
- 'coords': [y_start, y_end, x_start, x_end],
1049
- 'processed': False,
1050
- 'z': z
1051
- }
1052
-
1053
- self.realtimechunks = chunk_dict
1054
-
1055
- print("Ready!")
1066
+ self.realtimechunks = {
1067
+ i: {
1068
+ 'bounds': chunk_coords, # Only store [z_start, z_end, y_start, y_end, x_start, x_end]
1069
+ 'processed': False,
1070
+ 'center': self._get_chunk_center(chunk_coords), # Small tuple for distance calc
1071
+ 'is_3d': True # Flag to indicate this is 3D chunking
1072
+ }
1073
+ for i, chunk_coords in enumerate(all_chunks)
1074
+ }
1075
+
1076
+ def _get_chunk_center(self, chunk_coords):
1077
+ """Get center coordinate of chunk for distance calculations"""
1078
+ z_start, z_end, y_start, y_end, x_start, x_end = chunk_coords
1079
+ return (
1080
+ (z_start + z_end) // 2,
1081
+ (y_start + y_end) // 2,
1082
+ (x_start + x_end) // 2
1083
+ )
1056
1084
 
1057
1085
  def get_realtime_chunks_2d(self, chunk_size=None):
1058
1086
  """
1059
- Updated 2D chunking to match create_2d_chunks logic
1087
+ Updated 2D chunking to create more square-like chunks
1060
1088
  """
1061
1089
 
1062
1090
  MAX_CHUNK_SIZE = self.twod_chunk_size
@@ -1064,7 +1092,7 @@ class InteractiveSegmenter:
1064
1092
  # Populate chunk dictionary
1065
1093
  chunk_dict = {}
1066
1094
 
1067
- # Create chunks for each Z plane using the same logic as create_2d_chunks
1095
+ # Create chunks for each Z plane
1068
1096
  for z in range(self.image_3d.shape[0]):
1069
1097
  y_dim = self.image_3d.shape[1]
1070
1098
  x_dim = self.image_3d.shape[2]
@@ -1078,32 +1106,81 @@ class InteractiveSegmenter:
1078
1106
  'z': z
1079
1107
  }
1080
1108
  else:
1081
- # Multiple chunks per Z plane - divide along largest dimension
1082
- largest_dim = 'y' if y_dim >= x_dim else 'x'
1083
- num_divisions = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
1109
+ # Calculate optimal grid dimensions for square-ish chunks
1110
+ # Start with square root of total chunks needed
1111
+ num_chunks_needed = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
1112
+
1113
+ # Find factors that give us the most square-like grid
1114
+ best_y_chunks = 1
1115
+ best_x_chunks = num_chunks_needed
1116
+ best_aspect_ratio = float('inf')
1084
1117
 
1085
- if largest_dim == 'y':
1086
- # Divide along Y dimension
1087
- div_size = int(np.ceil(y_dim / num_divisions))
1088
- for i in range(0, y_dim, div_size):
1089
- end_i = min(i + div_size, y_dim)
1090
- # Use (z, y_start, x_start) as key for consistency
1091
- chunk_dict[(z, i, 0)] = {
1092
- 'coords': [i, end_i, 0, x_dim], # [y_start, y_end, x_start, x_end]
1093
- 'processed': False,
1094
- 'z': z
1095
- }
1118
+ for y_chunks in range(1, num_chunks_needed + 1):
1119
+ x_chunks = int(np.ceil(num_chunks_needed / y_chunks))
1120
+
1121
+ # Calculate actual chunk dimensions
1122
+ chunk_y_size = int(np.ceil(y_dim / y_chunks))
1123
+ chunk_x_size = int(np.ceil(x_dim / x_chunks))
1124
+
1125
+ # Check if chunk size constraint is satisfied
1126
+ chunk_pixels = chunk_y_size * chunk_x_size
1127
+ if chunk_pixels > MAX_CHUNK_SIZE:
1128
+ continue
1129
+
1130
+ # Calculate aspect ratio of the chunk
1131
+ aspect_ratio = max(chunk_y_size, chunk_x_size) / min(chunk_y_size, chunk_x_size)
1132
+
1133
+ # Prefer more square-like chunks (aspect ratio closer to 1)
1134
+ if aspect_ratio < best_aspect_ratio:
1135
+ best_aspect_ratio = aspect_ratio
1136
+ best_y_chunks = y_chunks
1137
+ best_x_chunks = x_chunks
1138
+
1139
+ # If no valid configuration found, fall back to single dimension division
1140
+ if best_aspect_ratio == float('inf'):
1141
+ # Fall back to original logic
1142
+ largest_dim = 'y' if y_dim >= x_dim else 'x'
1143
+ num_divisions = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
1144
+
1145
+ if largest_dim == 'y':
1146
+ div_size = int(np.ceil(y_dim / num_divisions))
1147
+ for i in range(0, y_dim, div_size):
1148
+ end_i = min(i + div_size, y_dim)
1149
+ chunk_dict[(z, i, 0)] = {
1150
+ 'coords': [i, end_i, 0, x_dim],
1151
+ 'processed': False,
1152
+ 'z': z
1153
+ }
1154
+ else:
1155
+ div_size = int(np.ceil(x_dim / num_divisions))
1156
+ for i in range(0, x_dim, div_size):
1157
+ end_i = min(i + div_size, x_dim)
1158
+ chunk_dict[(z, 0, i)] = {
1159
+ 'coords': [0, y_dim, i, end_i],
1160
+ 'processed': False,
1161
+ 'z': z
1162
+ }
1096
1163
  else:
1097
- # Divide along X dimension
1098
- div_size = int(np.ceil(x_dim / num_divisions))
1099
- for i in range(0, x_dim, div_size):
1100
- end_i = min(i + div_size, x_dim)
1101
- # Use (z, y_start, x_start) as key for consistency
1102
- chunk_dict[(z, 0, i)] = {
1103
- 'coords': [0, y_dim, i, end_i], # [y_start, y_end, x_start, x_end]
1104
- 'processed': False,
1105
- 'z': z
1106
- }
1164
+ # Create the 2D grid of chunks
1165
+ y_chunk_size = int(np.ceil(y_dim / best_y_chunks))
1166
+ x_chunk_size = int(np.ceil(x_dim / best_x_chunks))
1167
+
1168
+ for y_idx in range(best_y_chunks):
1169
+ for x_idx in range(best_x_chunks):
1170
+ y_start = y_idx * y_chunk_size
1171
+ y_end = min(y_start + y_chunk_size, y_dim)
1172
+ x_start = x_idx * x_chunk_size
1173
+ x_end = min(x_start + x_chunk_size, x_dim)
1174
+
1175
+ # Skip empty chunks (can happen at edges)
1176
+ if y_start >= y_dim or x_start >= x_dim:
1177
+ continue
1178
+
1179
+ chunk_dict[(z, y_start, x_start)] = {
1180
+ 'coords': [y_start, y_end, x_start, x_end],
1181
+ 'processed': False,
1182
+ 'z': z
1183
+ }
1107
1184
 
1108
1185
  self.realtimechunks = chunk_dict
1109
1186
  print("Ready!")
@@ -1162,75 +1239,77 @@ class InteractiveSegmenter:
1162
1239
  return foreground_features, background_features
1163
1240
 
1164
1241
  def segment_volume_realtime(self, gpu=False):
1165
- """Updated realtime segmentation with chunked 2D processing"""
1166
-
1167
1242
  if self.realtimechunks is None:
1168
1243
  if not self.use_two:
1169
- self.get_realtime_chunks()
1244
+ self.get_realtime_chunks() # 3D chunks
1170
1245
  else:
1171
- self.get_realtime_chunks_2d()
1246
+ self.get_realtime_chunks_2d() # 2D chunks
1172
1247
  else:
1173
- for chunk_pos in self.realtimechunks:
1174
- self.realtimechunks[chunk_pos]['processed'] = False
1175
-
1176
- chunk_dict = self.realtimechunks
1248
+ for chunk_key in self.realtimechunks:
1249
+ self.realtimechunks[chunk_key]['processed'] = False
1177
1250
 
1178
- def get_nearest_unprocessed_chunk(self):
1179
- """Get nearest unprocessed chunk prioritizing current Z"""
1251
+ def get_nearest_unprocessed_chunk():
1180
1252
  curr_z = self.current_z if self.current_z is not None else self.image_3d.shape[0] // 2
1181
1253
  curr_y = self.current_y if self.current_y is not None else self.image_3d.shape[1] // 2
1182
1254
  curr_x = self.current_x if self.current_x is not None else self.image_3d.shape[2] // 2
1183
1255
 
1184
- # First try to find chunks at current Z
1185
- current_z_chunks = [(pos, info) for pos, info in chunk_dict.items()
1186
- if pos[0] == curr_z and not info['processed']]
1187
-
1188
- if current_z_chunks:
1189
- nearest = min(current_z_chunks,
1190
- key=lambda x: ((x[0][1] - curr_y) ** 2 +
1191
- (x[0][2] - curr_x) ** 2))
1192
- return nearest[0]
1193
-
1194
- # If no chunks at current Z, find nearest Z with available chunks
1195
- available_z = sorted(
1196
- [(pos[0], pos) for pos, info in chunk_dict.items()
1197
- if not info['processed']],
1198
- key=lambda x: abs(x[0] - curr_z)
1199
- )
1200
-
1201
- if available_z:
1202
- target_z = available_z[0][0]
1203
- z_chunks = [(pos, info) for pos, info in chunk_dict.items()
1204
- if pos[0] == target_z and not info['processed']]
1205
- nearest = min(z_chunks,
1206
- key=lambda x: ((x[0][1] - curr_y) ** 2 +
1207
- (x[0][2] - curr_x) ** 2))
1256
+ unprocessed_chunks = [
1257
+ (key, info) for key, info in self.realtimechunks.items()
1258
+ if not info['processed']
1259
+ ]
1260
+
1261
+ if not unprocessed_chunks:
1262
+ return None
1263
+
1264
+ if self.use_two:
1265
+ # 2D chunks: key format is (z, y_start, x_start)
1266
+ # First try to find chunks at current Z
1267
+ current_z_chunks = [
1268
+ (key, info) for key, info in unprocessed_chunks
1269
+ if key[0] == curr_z
1270
+ ]
1271
+
1272
+ if current_z_chunks:
1273
+ # Find nearest chunk at current Z by y,x distance
1274
+ nearest = min(current_z_chunks,
1275
+ key=lambda x: ((x[0][1] - curr_y) ** 2 +
1276
+ (x[0][2] - curr_x) ** 2))
1277
+ return nearest[0]
1278
+
1279
+ # If no chunks at current Z, find nearest Z with available chunks
1280
+ available_z_chunks = sorted(unprocessed_chunks,
1281
+ key=lambda x: abs(x[0][0] - curr_z))
1282
+
1283
+ if available_z_chunks:
1284
+ # Get the nearest Z that has unprocessed chunks
1285
+ target_z = available_z_chunks[0][0][0]
1286
+ z_chunks = [
1287
+ (key, info) for key, info in unprocessed_chunks
1288
+ if key[0] == target_z
1289
+ ]
1290
+ # Find nearest chunk in that Z by y,x distance
1291
+ nearest = min(z_chunks,
1292
+ key=lambda x: ((x[0][1] - curr_y) ** 2 +
1293
+ (x[0][2] - curr_x) ** 2))
1294
+ return nearest[0]
1295
+ else:
1296
+ # 3D chunks: use existing center-based distance calculation
1297
+ nearest = min(unprocessed_chunks,
1298
+ key=lambda x: sum((a - b) ** 2 for a, b in
1299
+ zip(x[1]['center'], (curr_z, curr_y, curr_x))))
1208
1300
  return nearest[0]
1209
1301
 
1210
1302
  return None
1211
1303
 
1212
1304
  while True:
1213
- chunk_idx = get_nearest_unprocessed_chunk(self)
1214
- if chunk_idx is None:
1305
+ chunk_key = get_nearest_unprocessed_chunk()
1306
+ if chunk_key is None:
1215
1307
  break
1216
1308
 
1217
- chunk = chunk_dict[chunk_idx]
1218
- chunk['processed'] = True
1219
- coords = chunk['coords'] # [y_start, y_end, x_start, x_end]
1220
- z = chunk['z']
1221
-
1222
- # Generate coordinates for this chunk
1223
- coords_array = np.stack(np.meshgrid(
1224
- [z],
1225
- np.arange(coords[0], coords[1]),
1226
- np.arange(coords[2], coords[3]),
1227
- indexing='ij'
1228
- )).reshape(3, -1).T
1229
-
1230
- coords_list = list(map(tuple, coords_array))
1309
+ self.realtimechunks[chunk_key]['processed'] = True
1231
1310
 
1232
- # Process the chunk with updated method
1233
- fore, back = self.process_chunk(coords_list)
1311
+ # Process the chunk - pass the key, process_chunk will handle the rest
1312
+ fore, back = self.process_chunk(chunk_key)
1234
1313
 
1235
1314
  yield fore, back
1236
1315
 
@@ -1243,30 +1322,23 @@ class InteractiveSegmenter:
1243
1322
  except:
1244
1323
  pass
1245
1324
 
1246
- def process_grid_cell(self, grid_cell_info):
1325
+ def process_grid_cell(self, chunk_info):
1247
1326
  """
1248
- Process a single grid cell and return foreground and background features.
1327
+ Process a single chunk and return foreground and background features.
1249
1328
 
1250
1329
  Args:
1251
- grid_cell_info: tuple of (grid_z, grid_y, grid_x, box_size, depth, height, width, foreground_array)
1330
+ chunk_info: tuple of (chunk_coords, foreground_array) where
1331
+ chunk_coords is [z_start, z_end, y_start, y_end, x_start, x_end]
1252
1332
 
1253
1333
  Returns:
1254
1334
  tuple: (foreground_features, background_features)
1255
1335
  """
1256
- grid_z, grid_y, grid_x, box_size, depth, height, width, foreground_array = grid_cell_info
1257
-
1258
- # Calculate the boundaries of this grid cell
1259
- z_min = grid_z * box_size
1260
- y_min = grid_y * box_size
1261
- x_min = grid_x * box_size
1336
+ chunk_coords, foreground_array = chunk_info
1337
+ z_start, z_end, y_start, y_end, x_start, x_end = chunk_coords
1262
1338
 
1263
- z_max = min(z_min + box_size, depth)
1264
- y_max = min(y_min + box_size, height)
1265
- x_max = min(x_min + box_size, width)
1266
-
1267
- # Extract the subarray
1268
- subarray = self.image_3d[z_min:z_max, y_min:y_max, x_min:x_max]
1269
- subarray2 = foreground_array[z_min:z_max, y_min:y_max, x_min:x_max]
1339
+ # Extract the subarray using chunk boundaries
1340
+ subarray = self.image_3d[z_start:z_end, y_start:y_end, x_start:x_end]
1341
+ subarray2 = foreground_array[z_start:z_end, y_start:y_end, x_start:x_end]
1270
1342
 
1271
1343
  # Compute features for this subarray
1272
1344
  if self.speed:
@@ -1290,34 +1362,31 @@ class InteractiveSegmenter:
1290
1362
 
1291
1363
  return foreground_features, background_features
1292
1364
 
1293
- # Modified main processing code
1294
- def process_grid_cells_parallel(self, grid_cells_with_scribbles, box_size, depth, height, width, foreground_array, max_workers=None):
1365
+ def process_grid_cells_parallel(self, chunks_with_scribbles, foreground_array, max_workers=None):
1295
1366
  """
1296
- Process grid cells in parallel using ThreadPoolExecutor.
1367
+ Process chunks in parallel using ThreadPoolExecutor.
1297
1368
 
1298
1369
  Args:
1299
- grid_cells_with_scribbles: List of grid cell coordinates
1300
- box_size: Size of each grid cell
1301
- depth, height, width: Dimensions of the 3D image
1370
+ chunks_with_scribbles: List of chunk coordinates [z_start, z_end, y_start, y_end, x_start, x_end]
1302
1371
  foreground_array: Array marking foreground/background points
1303
1372
  max_workers: Maximum number of threads (None for default)
1304
1373
 
1305
1374
  Returns:
1306
1375
  tuple: (foreground_features, background_features)
1307
1376
  """
1308
- # Prepare data for each grid cell
1309
- grid_cell_data = [
1310
- (grid_z, grid_y, grid_x, box_size, depth, height, width, foreground_array)
1311
- for grid_z, grid_y, grid_x in grid_cells_with_scribbles
1377
+ # Prepare data for each chunk
1378
+ chunk_data = [
1379
+ (chunk_coords, foreground_array)
1380
+ for chunk_coords in chunks_with_scribbles
1312
1381
  ]
1313
1382
 
1314
1383
  foreground_features = []
1315
1384
  background_features = []
1316
1385
 
1317
- # Process grid cells in parallel
1386
+ # Process chunks in parallel
1318
1387
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
1319
1388
  # Submit all tasks
1320
- futures = [executor.submit(self.process_grid_cell, cell_data) for cell_data in grid_cell_data]
1389
+ futures = [executor.submit(self.process_grid_cell, data) for data in chunk_data]
1321
1390
 
1322
1391
  # Collect results as they complete
1323
1392
  for future in futures:
@@ -1327,6 +1396,56 @@ class InteractiveSegmenter:
1327
1396
 
1328
1397
  return foreground_features, background_features
1329
1398
 
1399
+ def compute_3d_chunks(self, chunk_size=None):
1400
+ """
1401
+ Compute 3D chunks with consistent logic across all operations.
1402
+
1403
+ Args:
1404
+ chunk_size: Optional chunk size, otherwise uses dynamic calculation
1405
+
1406
+ Returns:
1407
+ list: List of chunk coordinates [z_start, z_end, y_start, y_end, x_start, x_end]
1408
+ """
1409
+ # Use consistent chunk size calculation
1410
+ if chunk_size is None:
1411
+ if hasattr(self, 'master_chunk') and self.master_chunk is not None:
1412
+ chunk_size = self.master_chunk
1413
+ else:
1414
+ # Dynamic calculation (same as segmentation)
1415
+ total_cores = multiprocessing.cpu_count()
1416
+ total_volume = np.prod(self.image_3d.shape)
1417
+ target_volume_per_chunk = total_volume / (total_cores * 4)
1418
+
1419
+ chunk_size = int(np.cbrt(target_volume_per_chunk))
1420
+ chunk_size = max(16, min(chunk_size, min(self.image_3d.shape) // 2))
1421
+ chunk_size = ((chunk_size + 7) // 16) * 16
1422
+
1423
+ depth, height, width = self.image_3d.shape
1424
+
1425
+ # Calculate chunk grid dimensions
1426
+ z_chunks = (depth + chunk_size - 1) // chunk_size
1427
+ y_chunks = (height + chunk_size - 1) // chunk_size
1428
+ x_chunks = (width + chunk_size - 1) // chunk_size
1429
+
1430
+ # Generate all chunk start positions
1431
+ chunk_starts = np.array(np.meshgrid(
1432
+ np.arange(z_chunks) * chunk_size,
1433
+ np.arange(y_chunks) * chunk_size,
1434
+ np.arange(x_chunks) * chunk_size,
1435
+ indexing='ij'
1436
+ )).reshape(3, -1).T
1437
+
1438
+ # Create chunk coordinate list
1439
+ chunks = []
1440
+ for z_start, y_start, x_start in chunk_starts:
1441
+ z_end = min(z_start + chunk_size, depth)
1442
+ y_end = min(y_start + chunk_size, height)
1443
+ x_end = min(x_start + chunk_size, width)
1444
+ coords = [z_start, z_end, y_start, y_end, x_start, x_end]
1445
+ chunks.append(coords)
1446
+
1447
+ return chunks
1448
+
1330
1449
  def train_batch(self, foreground_array, speed=True, use_gpu=False, use_two=False, mem_lock=False, saving=False):
1331
1450
  """Updated train_batch with chunked 2D processing"""
1332
1451
 
@@ -1389,8 +1508,9 @@ class InteractiveSegmenter:
1389
1508
  )
1390
1509
 
1391
1510
  else:
1392
- # 3D processing (unchanged - your existing code)
1393
- box_size = self.master_chunk
1511
+ # 3D processing - match segmentation chunking logic
1512
+ chunk_size = self.master_chunk
1513
+
1394
1514
  foreground_features = []
1395
1515
  background_features = []
1396
1516
 
@@ -1400,22 +1520,44 @@ class InteractiveSegmenter:
1400
1520
  if len(z_fore) == 0 and len(z_back) == 0:
1401
1521
  return foreground_features, background_features
1402
1522
 
1403
- depth, height, width = foreground_array.shape
1404
- z_grid_size = (depth + box_size - 1) // box_size
1405
- y_grid_size = (height + box_size - 1) // box_size
1406
- x_grid_size = (width + box_size - 1) // box_size
1407
-
1408
- grid_cells_with_scribbles = set()
1409
-
1410
- for z, y, x in np.vstack((z_fore, z_back)) if len(z_back) > 0 else z_fore:
1411
- grid_z = z // box_size
1412
- grid_y = y // box_size
1413
- grid_x = x // box_size
1414
- grid_cells_with_scribbles.add((grid_z, grid_y, grid_x))
1523
+ # Get all chunks using consistent method
1524
+ all_chunks = self.compute_3d_chunks(self.master_chunk)
1525
+
1526
+ # Convert chunks to numpy array for vectorized operations
1527
+ chunks_array = np.array(all_chunks) # Shape: (n_chunks, 6)
1528
+ # columns: [z_start, z_end, y_start, y_end, x_start, x_end]
1529
+
1530
+ # Combine all scribbles
1531
+ all_scribbles = np.vstack((z_fore, z_back)) if len(z_back) > 0 else z_fore
1532
+
1533
+ # For each scribble, find which chunk it belongs to using vectorized operations
1534
+ chunks_with_scribbles = set()
1535
+
1536
+ for z, y, x in all_scribbles:
1537
+ # Vectorized check: find chunks that contain this scribble
1538
+ # Check if scribble falls within each chunk's bounds
1539
+ z_in_chunk = (chunks_array[:, 0] <= z) & (z < chunks_array[:, 1])
1540
+ y_in_chunk = (chunks_array[:, 2] <= y) & (y < chunks_array[:, 3])
1541
+ x_in_chunk = (chunks_array[:, 4] <= x) & (x < chunks_array[:, 5])
1542
+
1543
+ # Find chunks where all conditions are true
1544
+ matching_chunks = z_in_chunk & y_in_chunk & x_in_chunk
1545
+
1546
+ # Get the chunk indices that match
1547
+ chunk_indices = np.where(matching_chunks)[0]
1548
+
1549
+ # Add matching chunks to set (set automatically handles duplicates)
1550
+ for idx in chunk_indices:
1551
+ chunk_coords = tuple(chunks_array[idx])
1552
+ chunks_with_scribbles.add(chunk_coords)
1553
+
1554
+ # Convert set to list
1555
+ chunks_with_scribbles = list(chunks_with_scribbles)
1415
1556
 
1557
+ # Process these chunks using the updated method
1416
1558
  foreground_features, background_features = self.process_grid_cells_parallel(
1417
- grid_cells_with_scribbles, box_size, depth, height, width, foreground_array)
1418
-
1559
+ chunks_with_scribbles, foreground_array)
1560
+
1419
1561
  # Rest of the method unchanged (combining with previous features, training, etc.)
1420
1562
  if self.previous_foreground is not None:
1421
1563
  failed = True