nettracer3d 0.8.8__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

nettracer3d/segmenter.py CHANGED
@@ -61,6 +61,74 @@ class InteractiveSegmenter:
61
61
  self.previous_z_fore = None
62
62
  self.previous_z_back = None
63
63
 
64
+ def get_minimal_chunks_for_coordinates_cpu(self, coordinates_by_z):
65
+ """
66
+ Get minimal set of 2D chunks needed to cover the given coordinates
67
+ Uses same chunking logic as create_2d_chunks() - CPU version
68
+ """
69
+ MAX_CHUNK_SIZE = self.twod_chunk_size
70
+ needed_chunks = {}
71
+
72
+ for z in coordinates_by_z:
73
+ y_coords = [coord[0] for coord in coordinates_by_z[z]]
74
+ x_coords = [coord[1] for coord in coordinates_by_z[z]]
75
+
76
+ # Find bounding box of coordinates in this Z-slice
77
+ y_min, y_max = min(y_coords), max(y_coords)
78
+ x_min, x_max = min(x_coords), max(x_coords)
79
+
80
+ # Create chunks using same logic as create_2d_chunks
81
+ y_dim = self.image_3d.shape[1]
82
+ x_dim = self.image_3d.shape[2]
83
+ total_pixels = y_dim * x_dim
84
+
85
+ if total_pixels <= MAX_CHUNK_SIZE:
86
+ # Single chunk for entire Z slice
87
+ needed_chunks[z] = [[z, 0, y_dim, 0, x_dim]]
88
+ else:
89
+ # Multiple chunks - find which ones contain our coordinates
90
+ largest_dim = 'y' if y_dim >= x_dim else 'x'
91
+ num_divisions = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
92
+
93
+ chunks_for_z = []
94
+
95
+ if largest_dim == 'y':
96
+ div_size = int(np.ceil(y_dim / num_divisions))
97
+ for i in range(0, y_dim, div_size):
98
+ end_i = min(i + div_size, y_dim)
99
+ # Check if this chunk contains any of our coordinates
100
+ if any(i <= y <= end_i-1 for y in y_coords):
101
+ chunks_for_z.append([z, i, end_i, 0, x_dim])
102
+ else:
103
+ div_size = int(np.ceil(x_dim / num_divisions))
104
+ for i in range(0, x_dim, div_size):
105
+ end_i = min(i + div_size, x_dim)
106
+ # Check if this chunk contains any of our coordinates
107
+ if any(i <= x <= end_i-1 for x in x_coords):
108
+ chunks_for_z.append([z, 0, y_dim, i, end_i])
109
+
110
+ needed_chunks[z] = chunks_for_z
111
+
112
+ return needed_chunks
113
+
114
+ def compute_features_for_chunk_2d_cpu(self, chunk_coords, speed):
115
+ """
116
+ Compute features for a 2D chunk - CPU version
117
+ chunk_coords: [z, y_start, y_end, x_start, x_end]
118
+ """
119
+ z, y_start, y_end, x_start, x_end = chunk_coords
120
+
121
+ # Extract 2D subarray for this chunk
122
+ subarray_2d = self.image_3d[z, y_start:y_end, x_start:x_end]
123
+
124
+ # Compute features for this chunk
125
+ if speed:
126
+ feature_map = self.compute_feature_maps_cpu_2d(image_2d=subarray_2d)
127
+ else:
128
+ feature_map = self.compute_deep_feature_maps_cpu_2d(image_2d=subarray_2d)
129
+
130
+ return feature_map, (y_start, x_start) # Return offset for coordinate mapping
131
+
64
132
  def compute_deep_feature_maps_cpu_2d(self, z=None, image_2d = None):
65
133
  """Vectorized detailed version with Gaussian gradient magnitudes, Laplacians, and largest Hessian eigenvalue for 2D images"""
66
134
  if z is None:
@@ -131,7 +199,7 @@ class InteractiveSegmenter:
131
199
  features[..., feature_idx] = ndimage.laplace(gaussian_img, mode='reflect')
132
200
  feature_idx += 1
133
201
 
134
- # Largest Hessian eigenvalue for each sigma (fully vectorized, 2D version)
202
+ # Largest Hessian eigenvalue for each sigma (analytical 2D version)
135
203
  for sigma in self.sigmas:
136
204
  gaussian_img = gaussian_cache[sigma]
137
205
 
@@ -140,29 +208,25 @@ class InteractiveSegmenter:
140
208
  hyy = ndimage.gaussian_filter(gaussian_img, sigma=0, order=[2, 0], mode='reflect')
141
209
  hxy = ndimage.gaussian_filter(gaussian_img, sigma=0, order=[1, 1], mode='reflect')
142
210
 
143
- # Vectorized eigenvalue computation using numpy broadcasting
144
- # Create arrays with shape (d0, d1, 2, 2) for all 2D Hessian matrices
145
- shape = image_2d.shape
146
- hessian_matrices = np.zeros(shape + (2, 2))
211
+ # Analytical eigenvalue computation for 2x2 symmetric matrices
212
+ # For matrix [[hxx, hxy], [hxy, hyy]], eigenvalues are:
213
+ # λ = (trace ± sqrt(trace² - 4*det)) / 2
147
214
 
148
- # Fill the symmetric 2D Hessian matrices
149
- hessian_matrices[..., 0, 0] = hxx
150
- hessian_matrices[..., 1, 1] = hyy
151
- hessian_matrices[..., 0, 1] = hessian_matrices[..., 1, 0] = hxy
215
+ trace = hxx + hyy
216
+ det = hxx * hyy - hxy * hxy
152
217
 
153
- # Reshape for batch eigenvalue computation
154
- original_shape = hessian_matrices.shape[:-2] # (d0, d1)
155
- batch_size = np.prod(original_shape)
156
- hessian_batch = hessian_matrices.reshape(batch_size, 2, 2)
218
+ # Calculate discriminant and ensure it's non-negative
219
+ discriminant = trace * trace - 4 * det
220
+ discriminant = np.maximum(discriminant, 0) # Handle numerical errors
157
221
 
158
- # Compute eigenvalues for all matrices at once
159
- eigenvalues_batch = np.real(np.linalg.eigvals(hessian_batch))
222
+ sqrt_discriminant = np.sqrt(discriminant)
160
223
 
161
- # Get only the largest eigenvalue for each matrix
162
- largest_eigenvalues = np.max(eigenvalues_batch, axis=1)
224
+ # Calculate both eigenvalues
225
+ eigenval1 = (trace + sqrt_discriminant) / 2
226
+ eigenval2 = (trace - sqrt_discriminant) / 2
163
227
 
164
- # Reshape back to original spatial dimensions
165
- largest_eigenvalues = largest_eigenvalues.reshape(original_shape)
228
+ # Take the larger eigenvalue (most positive/least negative)
229
+ largest_eigenvalues = np.maximum(eigenval1, eigenval2)
166
230
 
167
231
  # Add the largest eigenvalue as a feature
168
232
  features[..., feature_idx] = largest_eigenvalues
@@ -548,73 +612,98 @@ class InteractiveSegmenter:
548
612
  return dict(z_dict) # Convert back to regular dict
549
613
 
550
614
  def process_chunk(self, chunk_coords):
551
- """
552
- Vectorized process_chunk that releases GIL more effectively
553
- """
615
+ """Updated process_chunk with chunked 2D feature computation"""
616
+
554
617
  if self.realtimechunks is None:
555
- # Generate coordinates using vectorized operations
618
+ # 3D processing (unchanged)
556
619
  z_min, z_max = chunk_coords[0], chunk_coords[1]
557
620
  y_min, y_max = chunk_coords[2], chunk_coords[3]
558
621
  x_min, x_max = chunk_coords[4], chunk_coords[5]
559
622
 
560
- # More efficient coordinate generation
561
623
  z_range = np.arange(z_min, z_max)
562
624
  y_range = np.arange(y_min, y_max)
563
625
  x_range = np.arange(x_min, x_max)
564
626
 
565
- # Create coordinate grid efficiently
566
627
  z_grid, y_grid, x_grid = np.meshgrid(z_range, y_range, x_range, indexing='ij')
567
628
  chunk_coords_array = np.column_stack([
568
- z_grid.ravel(),
569
- y_grid.ravel(),
570
- x_grid.ravel()
629
+ z_grid.ravel(), y_grid.ravel(), x_grid.ravel()
571
630
  ])
572
- else:
573
- # Convert to numpy array for vectorized operations
574
- chunk_coords_array = np.array(chunk_coords)
575
- z_coords, y_coords, x_coords = chunk_coords_array[:, 0], chunk_coords_array[:, 1], chunk_coords_array[:, 2]
576
- z_min, z_max = z_coords.min(), z_coords.max()
577
- y_min, y_max = y_coords.min(), y_coords.max()
578
- x_min, x_max = x_coords.min(), x_coords.max()
579
-
580
- # Extract subarray
581
- subarray = self.image_3d[z_min:z_max+1, y_min:y_max+1, x_min:x_max+1]
631
+
632
+ # Extract subarray
633
+ subarray = self.image_3d[z_min:z_max+1, y_min:y_max+1, x_min:x_max+1]
634
+
635
+ # Compute features for entire subarray
636
+ if self.speed:
637
+ feature_map = self.compute_feature_maps_cpu(subarray)
638
+ else:
639
+ feature_map = self.compute_deep_feature_maps_cpu(subarray)
640
+
641
+ # Vectorized feature extraction
642
+ local_coords = chunk_coords_array - np.array([z_min, y_min, x_min])
643
+ features = feature_map[local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]]
644
+
645
+ # Vectorized predictions
646
+ predictions = self.model.predict(features)
647
+ predictions = np.array(predictions, dtype=bool)
648
+
649
+ # Use boolean indexing to separate coordinates
650
+ foreground_coords = chunk_coords_array[predictions]
651
+ background_coords = chunk_coords_array[~predictions]
652
+
653
+ # Convert to sets
654
+ foreground = set(map(tuple, foreground_coords))
655
+ background = set(map(tuple, background_coords))
656
+
657
+ return foreground, background
582
658
 
583
- # Compute features for entire subarray at once
584
- if self.speed:
585
- feature_map = self.compute_feature_maps_cpu(subarray)
586
659
  else:
587
- feature_map = self.compute_deep_feature_maps_cpu(subarray)
588
-
589
- # Vectorized feature extraction
590
- # Convert global coordinates to local coordinates in one operation
591
- local_coords = chunk_coords_array - np.array([z_min, y_min, x_min])
592
-
593
- # Extract all features at once using advanced indexing
594
- features = feature_map[local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]]
595
-
596
- # Vectorized predictions (assuming your model can handle batch predictions)
597
- if hasattr(self.model, 'predict_batch') or features.ndim > 1:
598
- # If model supports batch prediction
660
+ # 2D processing - compute features for chunk only
661
+ chunk_coords_array = np.array(chunk_coords)
662
+ z_coords = chunk_coords_array[:, 0]
663
+ y_coords = chunk_coords_array[:, 1]
664
+ x_coords = chunk_coords_array[:, 2]
665
+
666
+ z = int(np.unique(z_coords)[0]) # All coordinates should have same Z
667
+
668
+ # Get chunk bounds
669
+ y_min, y_max = int(np.min(y_coords)), int(np.max(y_coords))
670
+ x_min, x_max = int(np.min(x_coords)), int(np.max(x_coords))
671
+
672
+ # Expand bounds slightly to ensure we capture the chunk properly
673
+ y_min = max(0, y_min)
674
+ x_min = max(0, x_min)
675
+ y_max = min(self.image_3d.shape[1], y_max + 1)
676
+ x_max = min(self.image_3d.shape[2], x_max + 1)
677
+
678
+ # Extract 2D subarray for this chunk
679
+ subarray_2d = self.image_3d[z, y_min:y_max, x_min:x_max]
680
+
681
+ # Compute features for just this chunk
682
+ if self.speed:
683
+ feature_map = self.compute_feature_maps_cpu_2d(image_2d=subarray_2d)
684
+ else:
685
+ feature_map = self.compute_deep_feature_maps_cpu_2d(image_2d=subarray_2d)
686
+
687
+ # Convert global coordinates to local chunk coordinates
688
+ local_y_coords = y_coords - y_min
689
+ local_x_coords = x_coords - x_min
690
+
691
+ # Extract features using local coordinates
692
+ features = feature_map[local_y_coords, local_x_coords]
693
+
694
+ # Vectorized predictions
599
695
  predictions = self.model.predict(features)
600
- else:
601
- # Fallback to individual predictions but still vectorized preparation
602
- predictions = np.array([self.model.predict([feat]) for feat in features])
603
-
604
- # Vectorized coordinate assignment
605
- predictions = np.array(predictions, dtype=bool)
606
- foreground_mask = predictions
607
- background_mask = ~predictions
608
-
609
- # Use boolean indexing to separate coordinates
610
- foreground_coords = chunk_coords_array[foreground_mask]
611
- background_coords = chunk_coords_array[background_mask]
612
-
613
- # Convert to sets (still needed for your return format)
614
- foreground = set(map(tuple, foreground_coords))
615
- background = set(map(tuple, background_coords))
616
-
617
- return foreground, background
696
+ predictions = np.array(predictions, dtype=bool)
697
+
698
+ # Use boolean indexing to separate coordinates
699
+ foreground_coords = chunk_coords_array[predictions]
700
+ background_coords = chunk_coords_array[~predictions]
701
+
702
+ # Convert to sets
703
+ foreground = set(map(tuple, foreground_coords))
704
+ background = set(map(tuple, background_coords))
705
+
706
+ return foreground, background
618
707
 
619
708
  def twodim_coords(self, y_dim, x_dim, z, chunk_size = None, subrange = None):
620
709
 
@@ -967,51 +1056,53 @@ class InteractiveSegmenter:
967
1056
 
968
1057
  def get_realtime_chunks_2d(self, chunk_size=None):
969
1058
  """
970
- Create square chunks with 1 z-thickness (2D chunks across XY planes)
1059
+ Updated 2D chunking to match create_2d_chunks logic
971
1060
  """
972
1061
 
973
- if chunk_size is None:
974
- chunk_size = int(np.sqrt(self.twod_chunk_size))
975
-
976
- # Determine if we need to chunk XY planes
977
- small_dims = (self.image_3d.shape[1] <= chunk_size and
978
- self.image_3d.shape[2] <= chunk_size)
979
- few_z = self.image_3d.shape[0] <= 100 # arbitrary threshold
980
-
981
- # If small enough, each Z is one chunk
982
- if small_dims and few_z:
983
- chunk_size_xy = max(self.image_3d.shape[1], self.image_3d.shape[2])
984
- else:
985
- chunk_size_xy = chunk_size
986
-
987
- # Calculate chunks for XY plane
988
- y_chunks = (self.image_3d.shape[1] + chunk_size_xy - 1) // chunk_size_xy
989
- x_chunks = (self.image_3d.shape[2] + chunk_size_xy - 1) // chunk_size_xy
1062
+ MAX_CHUNK_SIZE = self.twod_chunk_size
990
1063
 
991
1064
  # Populate chunk dictionary
992
1065
  chunk_dict = {}
993
1066
 
994
- # Create chunks for each Z plane (single Z thickness)
1067
+ # Create chunks for each Z plane using the same logic as create_2d_chunks
995
1068
  for z in range(self.image_3d.shape[0]):
996
- if small_dims:
1069
+ y_dim = self.image_3d.shape[1]
1070
+ x_dim = self.image_3d.shape[2]
1071
+ total_pixels = y_dim * x_dim
1072
+
1073
+ if total_pixels <= MAX_CHUNK_SIZE:
1074
+ # Single chunk for entire Z slice
997
1075
  chunk_dict[(z, 0, 0)] = {
998
- 'coords': [0, self.image_3d.shape[1], 0, self.image_3d.shape[2]],
1076
+ 'coords': [0, y_dim, 0, x_dim], # [y_start, y_end, x_start, x_end]
999
1077
  'processed': False,
1000
- 'z': z # Keep for backward compatibility
1078
+ 'z': z
1001
1079
  }
1002
1080
  else:
1003
- # Multiple chunks per Z plane
1004
- for y_chunk in range(y_chunks):
1005
- for x_chunk in range(x_chunks):
1006
- y_start = y_chunk * chunk_size_xy
1007
- x_start = x_chunk * chunk_size_xy
1008
- y_end = min(y_start + chunk_size_xy, self.image_3d.shape[1])
1009
- x_end = min(x_start + chunk_size_xy, self.image_3d.shape[2])
1010
-
1011
- chunk_dict[(z, y_start, x_start)] = {
1012
- 'coords': [y_start, y_end, x_start, x_end],
1081
+ # Multiple chunks per Z plane - divide along largest dimension
1082
+ largest_dim = 'y' if y_dim >= x_dim else 'x'
1083
+ num_divisions = int(np.ceil(total_pixels / MAX_CHUNK_SIZE))
1084
+
1085
+ if largest_dim == 'y':
1086
+ # Divide along Y dimension
1087
+ div_size = int(np.ceil(y_dim / num_divisions))
1088
+ for i in range(0, y_dim, div_size):
1089
+ end_i = min(i + div_size, y_dim)
1090
+ # Use (z, y_start, x_start) as key for consistency
1091
+ chunk_dict[(z, i, 0)] = {
1092
+ 'coords': [i, end_i, 0, x_dim], # [y_start, y_end, x_start, x_end]
1013
1093
  'processed': False,
1014
- 'z': z # Keep for backward compatibility
1094
+ 'z': z
1095
+ }
1096
+ else:
1097
+ # Divide along X dimension
1098
+ div_size = int(np.ceil(x_dim / num_divisions))
1099
+ for i in range(0, x_dim, div_size):
1100
+ end_i = min(i + div_size, x_dim)
1101
+ # Use (z, y_start, x_start) as key for consistency
1102
+ chunk_dict[(z, 0, i)] = {
1103
+ 'coords': [0, y_dim, i, end_i], # [y_start, y_end, x_start, x_end]
1104
+ 'processed': False,
1105
+ 'z': z
1015
1106
  }
1016
1107
 
1017
1108
  self.realtimechunks = chunk_dict
@@ -1041,33 +1132,37 @@ class InteractiveSegmenter:
1041
1132
 
1042
1133
  return slice_foreground_features, slice_background_features
1043
1134
 
1044
- def extract_features_parallel(self, slices: List[int], speed: Any, use_gpu: bool,
1045
- z_fores: Dict[int, List[Tuple[int, int]]],
1046
- z_backs: Dict[int, List[Tuple[int, int]]]) -> Tuple[List[Any], List[Any]]:
1135
+ def extract_features_parallel(self, needed_chunks, speed, use_gpu, z_fores, z_backs):
1047
1136
  """
1048
- Process feature extraction using ThreadPoolExecutor for parallel execution.
1137
+ Updated version that processes chunks instead of full Z-slices
1049
1138
  """
1050
1139
  max_cores = multiprocessing.cpu_count()
1051
1140
  foreground_features = []
1052
1141
  background_features = []
1053
1142
 
1143
+ # Flatten all chunks into a single list for parallel processing
1144
+ all_chunk_tasks = []
1145
+ for z in needed_chunks:
1146
+ for chunk_coords in needed_chunks[z]:
1147
+ all_chunk_tasks.append((chunk_coords, z_fores, z_backs, speed))
1148
+
1054
1149
  with ThreadPoolExecutor(max_workers=max_cores) as executor:
1055
- # Submit all slice processing tasks
1056
- future_to_slice = {
1057
- executor.submit(self.process_slice_features, z, speed, use_gpu, z_fores, z_backs): z
1058
- for z in slices
1150
+ # Submit all chunk processing tasks
1151
+ future_to_chunk = {
1152
+ executor.submit(self.process_chunk_features_for_training, chunk_task): chunk_task
1153
+ for chunk_task in all_chunk_tasks
1059
1154
  }
1060
1155
 
1061
1156
  # Collect results as they complete
1062
- for future in future_to_slice:
1063
- slice_foreground, slice_background = future.result()
1064
- foreground_features.extend(slice_foreground)
1065
- background_features.extend(slice_background)
1157
+ for future in future_to_chunk:
1158
+ chunk_foreground, chunk_background = future.result()
1159
+ foreground_features.extend(chunk_foreground)
1160
+ background_features.extend(chunk_background)
1066
1161
 
1067
1162
  return foreground_features, background_features
1068
1163
 
1069
- def segment_volume_realtime(self, gpu = False):
1070
-
1164
+ def segment_volume_realtime(self, gpu=False):
1165
+ """Updated realtime segmentation with chunked 2D processing"""
1071
1166
 
1072
1167
  if self.realtimechunks is None:
1073
1168
  if not self.use_two:
@@ -1075,11 +1170,10 @@ class InteractiveSegmenter:
1075
1170
  else:
1076
1171
  self.get_realtime_chunks_2d()
1077
1172
  else:
1078
- for chunk_pos in self.realtimechunks: # chunk_pos is the (z, y_start, x_start) tuple
1173
+ for chunk_pos in self.realtimechunks:
1079
1174
  self.realtimechunks[chunk_pos]['processed'] = False
1080
1175
 
1081
1176
  chunk_dict = self.realtimechunks
1082
-
1083
1177
 
1084
1178
  def get_nearest_unprocessed_chunk(self):
1085
1179
  """Get nearest unprocessed chunk prioritizing current Z"""
@@ -1092,7 +1186,6 @@ class InteractiveSegmenter:
1092
1186
  if pos[0] == curr_z and not info['processed']]
1093
1187
 
1094
1188
  if current_z_chunks:
1095
- # Find nearest chunk in current Z plane using the chunk positions from the key
1096
1189
  nearest = min(current_z_chunks,
1097
1190
  key=lambda x: ((x[0][1] - curr_y) ** 2 +
1098
1191
  (x[0][2] - curr_x) ** 2))
@@ -1107,7 +1200,6 @@ class InteractiveSegmenter:
1107
1200
 
1108
1201
  if available_z:
1109
1202
  target_z = available_z[0][0]
1110
- # Find nearest chunk in target Z plane
1111
1203
  z_chunks = [(pos, info) for pos, info in chunk_dict.items()
1112
1204
  if pos[0] == target_z and not info['processed']]
1113
1205
  nearest = min(z_chunks,
@@ -1117,41 +1209,31 @@ class InteractiveSegmenter:
1117
1209
 
1118
1210
  return None
1119
1211
 
1120
-
1121
1212
  while True:
1122
- # Find nearest unprocessed chunk using class attributes
1123
1213
  chunk_idx = get_nearest_unprocessed_chunk(self)
1124
1214
  if chunk_idx is None:
1125
1215
  break
1126
1216
 
1127
- # Process the chunk directly
1128
1217
  chunk = chunk_dict[chunk_idx]
1129
1218
  chunk['processed'] = True
1130
- coords = chunk['coords']
1219
+ coords = chunk['coords'] # [y_start, y_end, x_start, x_end]
1220
+ z = chunk['z']
1131
1221
 
1132
- coords = np.stack(np.meshgrid(
1133
- [chunk['z']],
1222
+ # Generate coordinates for this chunk
1223
+ coords_array = np.stack(np.meshgrid(
1224
+ [z],
1134
1225
  np.arange(coords[0], coords[1]),
1135
1226
  np.arange(coords[2], coords[3]),
1136
1227
  indexing='ij'
1137
1228
  )).reshape(3, -1).T
1138
1229
 
1139
- coords = list(map(tuple, coords))
1140
-
1230
+ coords_list = list(map(tuple, coords_array))
1141
1231
 
1142
- # Process the chunk directly based on whether GPU is available
1143
- if gpu:
1144
- try:
1145
- fore, back = self.process_chunk_GPU(coords)
1146
- except:
1147
- fore, back = self.process_chunk(coords)
1148
- else:
1149
- fore, back = self.process_chunk(coords)
1232
+ # Process the chunk with updated method
1233
+ fore, back = self.process_chunk(coords_list)
1150
1234
 
1151
- # Yield the results
1152
1235
  yield fore, back
1153
1236
 
1154
-
1155
1237
  def cleanup(self):
1156
1238
  """Clean up GPU memory"""
1157
1239
  if self.use_gpu:
@@ -1245,8 +1327,8 @@ class InteractiveSegmenter:
1245
1327
 
1246
1328
  return foreground_features, background_features
1247
1329
 
1248
- def train_batch(self, foreground_array, speed = True, use_gpu = False, use_two = False, mem_lock = False, saving = False):
1249
- """Train directly on foreground and background arrays"""
1330
+ def train_batch(self, foreground_array, speed=True, use_gpu=False, use_two=False, mem_lock=False, saving=False):
1331
+ """Updated train_batch with chunked 2D processing"""
1250
1332
 
1251
1333
  if not saving:
1252
1334
  print("Training model...")
@@ -1268,85 +1350,73 @@ class InteractiveSegmenter:
1268
1350
  self.mem_lock = mem_lock
1269
1351
 
1270
1352
  if use_two:
1271
-
1272
- #changed = [] #Track which slices need feature maps
1273
-
1274
- if not self.use_two: #Clarifies if we need to redo feature cache for 2D
1353
+ if not self.use_two:
1275
1354
  self.use_two = True
1276
-
1277
1355
  self.two_slices = []
1278
1356
 
1279
-
1280
1357
  # Get foreground coordinates and features
1281
1358
  z_fore, y_fore, x_fore = np.where(foreground_array == 1)
1282
-
1283
-
1284
1359
  fore_coords = list(zip(z_fore, y_fore, x_fore))
1285
1360
 
1286
1361
  # Get background coordinates and features
1287
1362
  z_back, y_back, x_back = np.where(foreground_array == 2)
1288
-
1289
1363
  back_coords = list(zip(z_back, y_back, x_back))
1290
-
1364
+
1291
1365
  foreground_features = []
1292
1366
  background_features = []
1293
-
1367
+
1368
+ # Organize coordinates by Z
1294
1369
  z_fores = self.organize_by_z(fore_coords)
1295
1370
  z_backs = self.organize_by_z(back_coords)
1296
- slices = set(list(z_fores.keys()) + list(z_backs.keys()))
1297
-
1371
+
1372
+ # Combine all Z-slices that have coordinates
1373
+ all_z_coords = {}
1374
+ for z in z_fores:
1375
+ if z not in all_z_coords:
1376
+ all_z_coords[z] = []
1377
+ all_z_coords[z].extend(z_fores[z])
1378
+ for z in z_backs:
1379
+ if z not in all_z_coords:
1380
+ all_z_coords[z] = []
1381
+ all_z_coords[z].extend(z_backs[z])
1382
+
1383
+ # Get minimal chunks needed to cover all coordinates
1384
+ needed_chunks = self.get_minimal_chunks_for_coordinates_cpu(all_z_coords)
1385
+
1386
+ # Use existing parallel infrastructure with chunked approach
1298
1387
  foreground_features, background_features = self.extract_features_parallel(
1299
- slices, speed, use_gpu, z_fores, z_backs
1388
+ needed_chunks, speed, use_gpu, z_fores, z_backs
1300
1389
  )
1301
1390
 
1302
-
1303
- else: #Forces ram efficiency
1304
-
1391
+ else:
1392
+ # 3D processing (unchanged - your existing code)
1305
1393
  box_size = self.master_chunk
1306
-
1307
- # Memory-efficient approach: compute features only for necessary subarrays
1308
1394
  foreground_features = []
1309
1395
  background_features = []
1310
1396
 
1311
- # Find coordinates of foreground and background scribbles
1312
1397
  z_fore = np.argwhere(foreground_array == 1)
1313
1398
  z_back = np.argwhere(foreground_array == 2)
1314
1399
 
1315
- # If no scribbles, return empty lists
1316
1400
  if len(z_fore) == 0 and len(z_back) == 0:
1317
1401
  return foreground_features, background_features
1318
1402
 
1319
- # Get dimensions of the input array
1320
1403
  depth, height, width = foreground_array.shape
1321
-
1322
- # Determine the minimum number of boxes needed to cover all scribbles
1323
- half_box = box_size // 2
1324
-
1325
- # Step 1: Find the minimum set of boxes that cover all scribbles
1326
- # We'll divide the volume into a grid of boxes of size box_size
1327
-
1328
- # Calculate how many boxes are needed in each dimension
1329
1404
  z_grid_size = (depth + box_size - 1) // box_size
1330
1405
  y_grid_size = (height + box_size - 1) // box_size
1331
1406
  x_grid_size = (width + box_size - 1) // box_size
1332
1407
 
1333
- # Track which grid cells contain scribbles
1334
1408
  grid_cells_with_scribbles = set()
1335
1409
 
1336
- # Map original coordinates to grid cells
1337
1410
  for z, y, x in np.vstack((z_fore, z_back)) if len(z_back) > 0 else z_fore:
1338
1411
  grid_z = z // box_size
1339
1412
  grid_y = y // box_size
1340
1413
  grid_x = x // box_size
1341
1414
  grid_cells_with_scribbles.add((grid_z, grid_y, grid_x))
1342
1415
 
1343
- # Create a mapping from original coordinates to their corresponding subarray and local coordinates
1344
- coord_mapping = {}
1345
-
1346
- # Step 2: Process each grid cell that contains scribbles
1347
-
1348
- foreground_features, background_features = self.process_grid_cells_parallel(grid_cells_with_scribbles, box_size, depth, height, width, foreground_array)
1416
+ foreground_features, background_features = self.process_grid_cells_parallel(
1417
+ grid_cells_with_scribbles, box_size, depth, height, width, foreground_array)
1349
1418
 
1419
+ # Rest of the method unchanged (combining with previous features, training, etc.)
1350
1420
  if self.previous_foreground is not None:
1351
1421
  failed = True
1352
1422
  try:
@@ -1371,7 +1441,6 @@ class InteractiveSegmenter:
1371
1441
  print("Could not combine new model with old loaded model. Perhaps you are trying to combine a quick model with a deep model? I cannot combine these...")
1372
1442
 
1373
1443
  if saving:
1374
-
1375
1444
  return foreground_features, background_features, z_fore, z_back
1376
1445
 
1377
1446
  # Combine features and labels
@@ -1386,13 +1455,8 @@ class InteractiveSegmenter:
1386
1455
  print(y)
1387
1456
 
1388
1457
  self.current_speed = speed
1389
-
1390
-
1391
-
1392
-
1393
1458
  print("Done")
1394
1459
 
1395
-
1396
1460
  def save_model(self, file_name, foreground_array):
1397
1461
 
1398
1462
  print("Saving model data")
@@ -1453,3 +1517,47 @@ class InteractiveSegmenter:
1453
1517
 
1454
1518
  return output
1455
1519
 
1520
+
1521
+ def process_chunk_features_for_training(self, chunk_task):
1522
+ """
1523
+ Process a single chunk for training feature extraction
1524
+ chunk_task: (chunk_coords, z_fores, z_backs, speed)
1525
+ """
1526
+ chunk_coords, z_fores, z_backs, speed = chunk_task
1527
+ z = chunk_coords[0]
1528
+
1529
+ # Compute features for this chunk
1530
+ feature_map, (y_offset, x_offset) = self.compute_features_for_chunk_2d_cpu(chunk_coords, speed)
1531
+
1532
+ chunk_foreground_features = []
1533
+ chunk_background_features = []
1534
+
1535
+ # Extract foreground features from this chunk
1536
+ if z in z_fores:
1537
+ for y, x in z_fores[z]:
1538
+ # Check if this coordinate is in the current chunk
1539
+ y_start, y_end = chunk_coords[1], chunk_coords[2]
1540
+ x_start, x_end = chunk_coords[3], chunk_coords[4]
1541
+
1542
+ if y_start <= y < y_end and x_start <= x < x_end:
1543
+ # Convert global coordinates to local chunk coordinates
1544
+ local_y = y - y_offset
1545
+ local_x = x - x_offset
1546
+ feature_vector = feature_map[local_y, local_x]
1547
+ chunk_foreground_features.append(feature_vector)
1548
+
1549
+ # Extract background features from this chunk
1550
+ if z in z_backs:
1551
+ for y, x in z_backs[z]:
1552
+ # Check if this coordinate is in the current chunk
1553
+ y_start, y_end = chunk_coords[1], chunk_coords[2]
1554
+ x_start, x_end = chunk_coords[3], chunk_coords[4]
1555
+
1556
+ if y_start <= y < y_end and x_start <= x < x_end:
1557
+ # Convert global coordinates to local chunk coordinates
1558
+ local_y = y - y_offset
1559
+ local_x = x - x_offset
1560
+ feature_vector = feature_map[local_y, local_x]
1561
+ chunk_background_features.append(feature_vector)
1562
+
1563
+ return chunk_foreground_features, chunk_background_features