nettracer3d 1.0.6__py3-none-any.whl → 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

nettracer3d/nettracer.py CHANGED
@@ -804,7 +804,7 @@ def threshold(arr, proportion, custom_rad = None):
804
804
 
805
805
  threshold_index = int(len(sorted_values) * proportion)
806
806
  threshold_value = sorted_values[threshold_index]
807
- print(f"Thresholding as if smallest_radius as assigned {threshold_value}")
807
+ print(f"Thresholding as if smallest_radius was assigned {threshold_value}")
808
808
 
809
809
 
810
810
  mask = arr > threshold_value
@@ -36,6 +36,7 @@ from threading import Lock
36
36
  from scipy import ndimage
37
37
  import os
38
38
  from . import painting
39
+ from . import stats as net_stats
39
40
 
40
41
 
41
42
 
@@ -425,8 +426,6 @@ class ImageViewerWindow(QMainWindow):
425
426
  self.canvas.mpl_connect('button_release_event', self.on_mouse_release)
426
427
  self.canvas.mpl_connect('motion_notify_event', self.on_mouse_move)
427
428
 
428
- #self.canvas.mpl_connect('button_press_event', self.on_mouse_click)
429
-
430
429
  # Initialize measurement tracking
431
430
  self.measurement_points = [] # List to store point pairs
432
431
  self.angle_measurements = [] # NEW: List to store angle trios
@@ -879,6 +878,24 @@ class ImageViewerWindow(QMainWindow):
879
878
  elif self.scroll_direction > 0 and new_value <= self.slice_slider.maximum():
880
879
  self.slice_slider.setValue(new_value)
881
880
 
881
+ def evaluate_mini(self, mode = 'nodes'):
882
+ if mode == 'nodes':
883
+ if self.channel_data[0].shape[0] * self.channel_data[0].shape[1] * self.channel_data[0].shape[2] > self.mini_thresh:
884
+ self.mini_overlay = True
885
+ self.create_mini_overlay(node_indices = self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
886
+ else:
887
+ self.create_highlight_overlay(node_indices=self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
888
+ elif mode == 'edges':
889
+
890
+ if self.channel_data[1].shape[0] * self.channel_data[1].shape[1] * self.channel_data[1].shape[2] > self.mini_thresh:
891
+ self.mini_overlay = True
892
+ self.create_mini_overlay(node_indices = self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
893
+ else:
894
+ self.create_highlight_overlay(
895
+ node_indices=self.clicked_values['nodes'],
896
+ edge_indices=self.clicked_values['edges']
897
+ )
898
+
882
899
  def create_highlight_overlay(self, node_indices=None, edge_indices=None, overlay1_indices = None, overlay2_indices = None, bounds = False):
883
900
  """
884
901
  Create a binary overlay highlighting specific nodes and/or edges using parallel processing.
@@ -1705,28 +1722,12 @@ class ImageViewerWindow(QMainWindow):
1705
1722
  if edges:
1706
1723
  edge_indices = filtered_df.iloc[:, 2].unique().tolist()
1707
1724
  self.clicked_values['edges'] = edge_indices
1708
-
1709
- if self.channel_data[1].shape[0] * self.channel_data[1].shape[1] * self.channel_data[1].shape[2] > self.mini_thresh:
1710
- self.mini_overlay = True
1711
- self.create_mini_overlay(node_indices = self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
1712
- else:
1713
- self.create_highlight_overlay(
1714
- node_indices=self.clicked_values['nodes'],
1715
- edge_indices=self.clicked_values['edges']
1716
- )
1725
+ self.evaluate_mini(mode = 'edges')
1717
1726
  else:
1718
- if self.channel_data[0].shape[0] * self.channel_data[0].shape[1] * self.channel_data[0].shape[2] > self.mini_thresh:
1719
- self.mini_overlay = True
1720
- self.create_mini_overlay(node_indices = self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
1721
- else:
1722
- self.create_highlight_overlay(
1723
- node_indices=self.clicked_values['nodes'],
1724
- edge_indices = self.clicked_values['edges']
1725
- )
1726
-
1727
+ self.evaluate_mini()
1727
1728
 
1728
1729
  except Exception as e:
1729
- print(f"Error processing neighbors: {e}")
1730
+ print(f"Error showing neighbors: {e}")
1730
1731
 
1731
1732
 
1732
1733
  def handle_show_component(self, edges = False, nodes = True):
@@ -1797,23 +1798,10 @@ class ImageViewerWindow(QMainWindow):
1797
1798
  if edges:
1798
1799
  edge_indices = filtered_df.iloc[:, 2].unique().tolist()
1799
1800
  self.clicked_values['edges'] = edge_indices
1800
- if self.channel_data[1].shape[0] * self.channel_data[1].shape[1] * self.channel_data[1].shape[2] > self.mini_thresh:
1801
- self.mini_overlay = True
1802
- self.create_mini_overlay(node_indices = self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
1803
- else:
1804
- self.create_highlight_overlay(
1805
- node_indices=self.clicked_values['nodes'],
1806
- edge_indices=edge_indices
1807
- )
1801
+ self.evaluate_mini(mode = 'edges')
1808
1802
  else:
1809
- if self.channel_data[0].shape[0] * self.channel_data[0].shape[1] * self.channel_data[0].shape[2] > self.mini_thresh:
1810
- self.mini_overlay = True
1811
- self.create_mini_overlay(node_indices = self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
1812
- else:
1813
- self.create_highlight_overlay(
1814
- node_indices = self.clicked_values['nodes'],
1815
- edge_indices = self.clicked_values['edges']
1816
- )
1803
+ self.evaluate_mini()
1804
+
1817
1805
 
1818
1806
  except Exception as e:
1819
1807
 
@@ -2640,9 +2628,9 @@ class ImageViewerWindow(QMainWindow):
2640
2628
  self.create_mini_overlay(node_indices = self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
2641
2629
  self.needs_mini = False
2642
2630
  else:
2643
- self.create_highlight_overlay(node_indices = self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
2631
+ self.evaluate_mini()
2644
2632
  else:
2645
- self.create_highlight_overlay(node_indices = self.clicked_values['nodes'], edge_indices = self.clicked_values['edges'])
2633
+ self.evaluate_mini()
2646
2634
 
2647
2635
 
2648
2636
  self.update_display(preserve_zoom=(current_xlim, current_ylim))
@@ -4569,6 +4557,8 @@ class ImageViewerWindow(QMainWindow):
4569
4557
  allstats_action.triggered.connect(self.stats)
4570
4558
  histos_action = stats_menu.addAction("Network Statistic Histograms")
4571
4559
  histos_action.triggered.connect(self.histos)
4560
+ sig_action = stats_menu.addAction("Significance Testing")
4561
+ sig_action.triggered.connect(self.sig_test)
4572
4562
  radial_action = stats_menu.addAction("Radial Distribution Analysis")
4573
4563
  radial_action.triggered.connect(self.show_radial_dialog)
4574
4564
  neighbor_id_action = stats_menu.addAction("Identity Distribution of Neighbors")
@@ -4913,6 +4903,16 @@ class ImageViewerWindow(QMainWindow):
4913
4903
  except Exception as e:
4914
4904
  print(f"Error creating histogram selector: {e}")
4915
4905
 
4906
+ def sig_test(self):
4907
+ # Get the existing QApplication instance
4908
+ app = QApplication.instance()
4909
+
4910
+ # Create the statistical GUI window without starting a new event loop
4911
+ stats_window = net_stats.main(app)
4912
+
4913
+ # Keep a reference so it doesn't get garbage collected
4914
+ self.stats_window = stats_window
4915
+
4916
4916
  def volumes(self):
4917
4917
 
4918
4918
 
@@ -5060,6 +5060,10 @@ class ImageViewerWindow(QMainWindow):
5060
5060
  dialog = MergeNodeIdDialog(self)
5061
5061
  dialog.exec()
5062
5062
 
5063
+ def show_multichan_dialog(self, data):
5064
+ dialog = MultiChanDialog(self, data)
5065
+ dialog.show()
5066
+
5063
5067
  def show_gray_water_dialog(self):
5064
5068
  """Show the gray watershed parameter dialog."""
5065
5069
  dialog = GrayWaterDialog(self)
@@ -5162,7 +5166,7 @@ class ImageViewerWindow(QMainWindow):
5162
5166
 
5163
5167
  my_network.edges = (my_network.nodes == 0) * my_network.edges
5164
5168
 
5165
- my_network.calculate_all(my_network.nodes, my_network.edges, xy_scale = my_network.xy_scale, z_scale = my_network.z_scale, search = None, diledge = None, inners = False, hash_inners = False, remove_trunk = 0, ignore_search_region = True, other_nodes = None, label_nodes = True, directory = None, GPU = False, fast_dil = False, skeletonize = False, GPU_downsample = None)
5169
+ my_network.calculate_all(my_network.nodes, my_network.edges, xy_scale = my_network.xy_scale, z_scale = my_network.z_scale, search = None, diledge = None, inners = False, remove_trunk = 0, ignore_search_region = True, other_nodes = None, label_nodes = True, directory = None, GPU = False, fast_dil = False, skeletonize = False, GPU_downsample = None)
5166
5170
 
5167
5171
  self.load_channel(1, my_network.edges, data = True)
5168
5172
  self.load_channel(0, my_network.nodes, data = True)
@@ -5916,6 +5920,16 @@ class ImageViewerWindow(QMainWindow):
5916
5920
  msg.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
5917
5921
  return msg.exec() == QMessageBox.StandardButton.Yes
5918
5922
 
5923
+ def confirm_multichan_dialog(self):
5924
+ """Shows a dialog asking user to confirm if image is multichan"""
5925
+ msg = QMessageBox()
5926
+ msg.setIcon(QMessageBox.Icon.Question)
5927
+ msg.setText("Image Format Alert")
5928
+ msg.setInformativeText("Is this a Multi-Channel (4D) image?")
5929
+ msg.setWindowTitle("Confirm Image Format")
5930
+ msg.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
5931
+ return msg.exec() == QMessageBox.StandardButton.Yes
5932
+
5919
5933
  def confirm_resize_dialog(self):
5920
5934
  """Shows a dialog asking user to resize image"""
5921
5935
  msg = QMessageBox()
@@ -5999,7 +6013,7 @@ class ImageViewerWindow(QMainWindow):
5999
6013
  try:
6000
6014
  if len(self.channel_data[channel_index].shape) == 3: # potentially 2D RGB
6001
6015
  if self.channel_data[channel_index].shape[-1] in (3, 4): # last dim is 3 or 4
6002
- if not data and self.shape is None:
6016
+ if not data:
6003
6017
  if self.confirm_rgb_dialog():
6004
6018
  # User confirmed it's 2D RGB, expand to 4D
6005
6019
  self.channel_data[channel_index] = np.expand_dims(self.channel_data[channel_index], axis=0)
@@ -6009,12 +6023,18 @@ class ImageViewerWindow(QMainWindow):
6009
6023
  except:
6010
6024
  pass
6011
6025
 
6012
- if not color:
6013
- try:
6014
- if len(self.channel_data[channel_index].shape) == 4 and (channel_index == 0 or channel_index == 1):
6026
+ if len(self.channel_data[channel_index].shape) == 4:
6027
+ if not self.channel_data[channel_index].shape[-1] in (3, 4):
6028
+ if self.confirm_multichan_dialog(): # User is trying to load 4D channel stack:
6029
+ my_data = copy.deepcopy(self.channel_data[channel_index])
6030
+ self.channel_data[channel_index] = None
6031
+ self.show_multichan_dialog(data = my_data)
6032
+ return
6033
+ elif not color and (channel_index == 0 or channel_index == 1):
6034
+ try:
6015
6035
  self.channel_data[channel_index] = self.reduce_rgb_dimension(self.channel_data[channel_index], 'weight')
6016
- except:
6017
- pass
6036
+ except:
6037
+ pass
6018
6038
 
6019
6039
  reset_resize = False
6020
6040
 
@@ -8431,6 +8451,89 @@ class MergeNodeIdDialog(QDialog):
8431
8451
  print(traceback.format_exc())
8432
8452
  #print(f"Error: {e}")
8433
8453
 
8454
+ class MultiChanDialog(QDialog):
8455
+
8456
+ def __init__(self, parent=None, data = None):
8457
+
8458
+ super().__init__(parent)
8459
+ self.setWindowTitle("Channel Loading")
8460
+ self.setModal(False)
8461
+
8462
+ layout = QFormLayout(self)
8463
+
8464
+ self.data = data
8465
+
8466
+ self.nodes = QComboBox()
8467
+ self.edges = QComboBox()
8468
+ self.overlay1 = QComboBox()
8469
+ self.overlay2 = QComboBox()
8470
+ options = ["None"]
8471
+ for i in range(self.data.shape[0]):
8472
+ options.append(str(i))
8473
+ self.nodes.addItems(options)
8474
+ self.edges.addItems(options)
8475
+ self.overlay1.addItems(options)
8476
+ self.overlay2.addItems(options)
8477
+ self.nodes.setCurrentIndex(0)
8478
+ self.edges.setCurrentIndex(0)
8479
+ self.overlay1.setCurrentIndex(0)
8480
+ self.overlay2.setCurrentIndex(0)
8481
+ layout.addRow("Load this channel into nodes?", self.nodes)
8482
+ layout.addRow("Load this channel into edges?", self.edges)
8483
+ layout.addRow("Load this channel into overlay1?", self.overlay1)
8484
+ layout.addRow("Load this channel into overlay2?", self.overlay2)
8485
+
8486
+ run_button = QPushButton("Load Channels")
8487
+ run_button.clicked.connect(self.run)
8488
+ layout.addWidget(run_button)
8489
+
8490
+ run_button2 = QPushButton("Save Channels to Directory")
8491
+ run_button2.clicked.connect(self.run2)
8492
+ layout.addWidget(run_button2)
8493
+
8494
+
8495
+ def run(self):
8496
+
8497
+ try:
8498
+ node_chan = int(self.nodes.currentText())
8499
+ self.parent().load_channel(0, self.data[node_chan, :, :, :], data = True)
8500
+ except:
8501
+ pass
8502
+ try:
8503
+ edge_chan = int(self.edges.currentText())
8504
+ self.parent().load_channel(1, self.data[edge_chan, :, :, :], data = True)
8505
+ except:
8506
+ pass
8507
+ try:
8508
+ overlay1_chan = int(self.overlay1.currentText())
8509
+ self.parent().load_channel(2, self.data[overlay1_chan, :, :, :], data = True)
8510
+ except:
8511
+ pass
8512
+ try:
8513
+ overlay2_chan = int(self.overlay2.currentText())
8514
+ self.parent().load_channel(3, self.data[overlay2_chan, :, :, :], data = True)
8515
+ except:
8516
+ pass
8517
+
8518
+ def run2(self):
8519
+
8520
+ try:
8521
+ # First let user select parent directory
8522
+ parent_dir = QFileDialog.getExistingDirectory(
8523
+ self,
8524
+ "Select Location to Save Channels",
8525
+ "",
8526
+ QFileDialog.Option.ShowDirsOnly
8527
+ )
8528
+
8529
+ for i in range(self.data.shape[0]):
8530
+ try:
8531
+ tifffile.imwrite(f'{parent_dir}/C{i}.tif', self.data[i, :, :, :])
8532
+ except:
8533
+ continue
8534
+ except:
8535
+ pass
8536
+
8434
8537
 
8435
8538
  class Show3dDialog(QDialog):
8436
8539
  def __init__(self, parent=None):
@@ -10513,6 +10616,9 @@ class MotherDialog(QDialog):
10513
10616
 
10514
10617
  except Exception as e:
10515
10618
 
10619
+ import traceback
10620
+ print(traceback.format_exc())
10621
+
10516
10622
  print(f"Error finding mothers: {e}")
10517
10623
 
10518
10624
 
@@ -12257,7 +12363,7 @@ class ThresholdWindow(QMainWindow):
12257
12363
  button_layout.addWidget(run_button)
12258
12364
 
12259
12365
  # Add Cancel button for external dialog use
12260
- cancel_button = QPushButton("Cancel/Skip")
12366
+ cancel_button = QPushButton("Cancel/Skip (Retains Selection)")
12261
12367
  cancel_button.clicked.connect(self.cancel_processing)
12262
12368
  button_layout.addWidget(cancel_button)
12263
12369
 
nettracer3d/segmenter.py CHANGED
@@ -1181,12 +1181,22 @@ class InteractiveSegmenter:
1181
1181
  (x[0][2] - curr_x) ** 2))
1182
1182
  return nearest[0]
1183
1183
  else:
1184
- # 3D chunks: use existing center-based distance calculation
1185
- nearest = min(unprocessed_chunks,
1184
+ # 3D chunks: find chunks on nearest Z-plane, then closest in X/Y
1185
+ # First find the nearest Z-plane among available chunks
1186
+ nearest_z = min(unprocessed_chunks,
1187
+ key=lambda x: abs(x[1]['center'][0] - curr_z))[1]['center'][0]
1188
+
1189
+ # Get all chunks on that nearest Z-plane
1190
+ nearest_z_chunks = [chunk for chunk in unprocessed_chunks
1191
+ if chunk[1]['center'][0] == nearest_z]
1192
+
1193
+ # From those chunks, find closest in X/Y
1194
+ nearest = min(nearest_z_chunks,
1186
1195
  key=lambda x: sum((a - b) ** 2 for a, b in
1187
- zip(x[1]['center'], (curr_z, curr_y, curr_x))))
1196
+ zip(x[1]['center'][1:], (curr_y, curr_x))))
1197
+
1188
1198
  return nearest[0]
1189
-
1199
+
1190
1200
  return None
1191
1201
 
1192
1202
  while True:
@@ -1284,12 +1294,14 @@ class InteractiveSegmenter:
1284
1294
 
1285
1295
  return foreground_features, background_features
1286
1296
 
1287
- def compute_3d_chunks(self, chunk_size=None):
1297
+ def compute_3d_chunks(self, chunk_size=None, thickness=49):
1288
1298
  """
1289
- Compute 3D chunks with consistent logic across all operations.
1299
+ Compute 3D chunks as rectangular prisms with consistent logic across all operations.
1300
+ Creates chunks that are thin in Z and square-like in X/Y dimensions.
1290
1301
 
1291
1302
  Args:
1292
- chunk_size: Optional chunk size, otherwise uses dynamic calculation
1303
+ chunk_size: Optional chunk size for volume calculation, otherwise uses dynamic calculation
1304
+ thickness: Z-dimension thickness of chunks (default: 9)
1293
1305
 
1294
1306
  Returns:
1295
1307
  list: List of chunk coordinates [z_start, z_end, y_start, y_end, x_start, x_end]
@@ -1313,27 +1325,57 @@ class InteractiveSegmenter:
1313
1325
  except:
1314
1326
  depth, height, width, rgb = self.image_3d.shape
1315
1327
 
1316
- # Calculate chunk grid dimensions
1317
- z_chunks = (depth + chunk_size - 1) // chunk_size
1318
- y_chunks = (height + chunk_size - 1) // chunk_size
1319
- x_chunks = (width + chunk_size - 1) // chunk_size
1328
+ # Calculate target volume per chunk (same as original cube)
1329
+ target_volume = chunk_size ** 3
1330
+
1331
+ # Calculate XY side length based on thickness and target volume
1332
+ # Volume = thickness * xy_side * xy_side
1333
+ # So xy_side = sqrt(volume / thickness)
1334
+ xy_side = int(np.sqrt(target_volume / thickness))
1335
+ xy_side = max(1, xy_side) # Ensure minimum size of 1
1320
1336
 
1321
- # Generate all chunk start positions
1322
- chunk_starts = np.array(np.meshgrid(
1323
- np.arange(z_chunks) * chunk_size,
1324
- np.arange(y_chunks) * chunk_size,
1325
- np.arange(x_chunks) * chunk_size,
1326
- indexing='ij'
1327
- )).reshape(3, -1).T
1337
+ # Calculate actual chunk dimensions for grid calculation
1338
+ z_chunk_size = thickness
1339
+ xy_chunk_size = xy_side
1328
1340
 
1329
- # Create chunk coordinate list
1341
+ # Calculate number of chunks in each dimension
1342
+ z_chunks = (depth + z_chunk_size - 1) // z_chunk_size
1343
+ y_chunks = (height + xy_chunk_size - 1) // xy_chunk_size
1344
+ x_chunks = (width + xy_chunk_size - 1) // xy_chunk_size
1345
+
1346
+ # Calculate actual chunk sizes to distribute remainder evenly
1347
+ # This ensures all chunks are roughly the same size
1348
+ z_sizes = np.full(z_chunks, depth // z_chunks)
1349
+ z_remainder = depth % z_chunks
1350
+ z_sizes[:z_remainder] += 1
1351
+
1352
+ y_sizes = np.full(y_chunks, height // y_chunks)
1353
+ y_remainder = height % y_chunks
1354
+ y_sizes[:y_remainder] += 1
1355
+
1356
+ x_sizes = np.full(x_chunks, width // x_chunks)
1357
+ x_remainder = width % x_chunks
1358
+ x_sizes[:x_remainder] += 1
1359
+
1360
+ # Calculate cumulative positions
1361
+ z_positions = np.concatenate([[0], np.cumsum(z_sizes)])
1362
+ y_positions = np.concatenate([[0], np.cumsum(y_sizes)])
1363
+ x_positions = np.concatenate([[0], np.cumsum(x_sizes)])
1364
+
1365
+ # Generate all chunk coordinates
1330
1366
  chunks = []
1331
- for z_start, y_start, x_start in chunk_starts:
1332
- z_end = min(z_start + chunk_size, depth)
1333
- y_end = min(y_start + chunk_size, height)
1334
- x_end = min(x_start + chunk_size, width)
1335
- coords = [z_start, z_end, y_start, y_end, x_start, x_end]
1336
- chunks.append(coords)
1367
+ for z_idx in range(z_chunks):
1368
+ for y_idx in range(y_chunks):
1369
+ for x_idx in range(x_chunks):
1370
+ z_start = z_positions[z_idx]
1371
+ z_end = z_positions[z_idx + 1]
1372
+ y_start = y_positions[y_idx]
1373
+ y_end = y_positions[y_idx + 1]
1374
+ x_start = x_positions[x_idx]
1375
+ x_end = x_positions[x_idx + 1]
1376
+
1377
+ coords = [z_start, z_end, y_start, y_end, x_start, x_end]
1378
+ chunks.append(coords)
1337
1379
 
1338
1380
  return chunks
1339
1381
 
@@ -1055,19 +1055,18 @@ class InteractiveSegmenter:
1055
1055
  self.realtimechunks = chunk_dict
1056
1056
  print("Ready!")
1057
1057
 
1058
- def compute_3d_chunks(self, chunk_size=None):
1058
+ def compute_3d_chunks(self, chunk_size=None, thickness=49):
1059
1059
  """
1060
- Compute 3D chunks with consistent logic across all operations (GPU version).
1060
+ Compute 3D chunks as rectangular prisms with consistent logic across all operations.
1061
+ Creates chunks that are thin in Z and square-like in X/Y dimensions.
1061
1062
 
1062
1063
  Args:
1063
- chunk_size: Optional chunk size, otherwise uses dynamic calculation
1064
+ chunk_size: Optional chunk size for volume calculation, otherwise uses dynamic calculation
1065
+ thickness: Z-dimension thickness of chunks (default: 9)
1064
1066
 
1065
1067
  Returns:
1066
1068
  list: List of chunk coordinates [z_start, z_end, y_start, y_end, x_start, x_end]
1067
1069
  """
1068
- import cupy as cp
1069
- import multiprocessing
1070
-
1071
1070
  # Use consistent chunk size calculation
1072
1071
  if chunk_size is None:
1073
1072
  if hasattr(self, 'master_chunk') and self.master_chunk is not None:
@@ -1075,10 +1074,10 @@ class InteractiveSegmenter:
1075
1074
  else:
1076
1075
  # Dynamic calculation (same as segmentation)
1077
1076
  total_cores = multiprocessing.cpu_count()
1078
- total_volume = cp.prod(cp.array(self.image_3d.shape))
1077
+ total_volume = np.prod(self.image_3d.shape)
1079
1078
  target_volume_per_chunk = total_volume / (total_cores * 4)
1080
1079
 
1081
- chunk_size = int(cp.cbrt(target_volume_per_chunk))
1080
+ chunk_size = int(np.cbrt(target_volume_per_chunk))
1082
1081
  chunk_size = max(16, min(chunk_size, min(self.image_3d.shape) // 2))
1083
1082
  chunk_size = ((chunk_size + 7) // 16) * 16
1084
1083
 
@@ -1087,28 +1086,57 @@ class InteractiveSegmenter:
1087
1086
  except:
1088
1087
  depth, height, width, rgb = self.image_3d.shape
1089
1088
 
1090
- # Calculate chunk grid dimensions
1091
- z_chunks = (depth + chunk_size - 1) // chunk_size
1092
- y_chunks = (height + chunk_size - 1) // chunk_size
1093
- x_chunks = (width + chunk_size - 1) // chunk_size
1089
+ # Calculate target volume per chunk (same as original cube)
1090
+ target_volume = chunk_size ** 3
1091
+
1092
+ # Calculate XY side length based on thickness and target volume
1093
+ # Volume = thickness * xy_side * xy_side
1094
+ # So xy_side = sqrt(volume / thickness)
1095
+ xy_side = int(np.sqrt(target_volume / thickness))
1096
+ xy_side = max(1, xy_side) # Ensure minimum size of 1
1097
+
1098
+ # Calculate actual chunk dimensions for grid calculation
1099
+ z_chunk_size = thickness
1100
+ xy_chunk_size = xy_side
1101
+
1102
+ # Calculate number of chunks in each dimension
1103
+ z_chunks = (depth + z_chunk_size - 1) // z_chunk_size
1104
+ y_chunks = (height + xy_chunk_size - 1) // xy_chunk_size
1105
+ x_chunks = (width + xy_chunk_size - 1) // xy_chunk_size
1106
+
1107
+ # Calculate actual chunk sizes to distribute remainder evenly
1108
+ # This ensures all chunks are roughly the same size
1109
+ z_sizes = np.full(z_chunks, depth // z_chunks)
1110
+ z_remainder = depth % z_chunks
1111
+ z_sizes[:z_remainder] += 1
1094
1112
 
1095
- # Generate all chunk start positions using CuPy
1096
- chunk_starts = cp.array(cp.meshgrid(
1097
- cp.arange(z_chunks) * chunk_size,
1098
- cp.arange(y_chunks) * chunk_size,
1099
- cp.arange(x_chunks) * chunk_size,
1100
- indexing='ij'
1101
- )).reshape(3, -1).T
1113
+ y_sizes = np.full(y_chunks, height // y_chunks)
1114
+ y_remainder = height % y_chunks
1115
+ y_sizes[:y_remainder] += 1
1102
1116
 
1117
+ x_sizes = np.full(x_chunks, width // x_chunks)
1118
+ x_remainder = width % x_chunks
1119
+ x_sizes[:x_remainder] += 1
1103
1120
 
1104
- # Create chunk coordinate list
1121
+ # Calculate cumulative positions
1122
+ z_positions = np.concatenate([[0], np.cumsum(z_sizes)])
1123
+ y_positions = np.concatenate([[0], np.cumsum(y_sizes)])
1124
+ x_positions = np.concatenate([[0], np.cumsum(x_sizes)])
1125
+
1126
+ # Generate all chunk coordinates
1105
1127
  chunks = []
1106
- for z_start, y_start, x_start in chunk_starts:
1107
- z_end = min(z_start + chunk_size, depth)
1108
- y_end = min(y_start + chunk_size, height)
1109
- x_end = min(x_start + chunk_size, width)
1110
- coords = [int(z_start), int(z_end), int(y_start), int(y_end), int(x_start), int(x_end)]
1111
- chunks.append(coords)
1128
+ for z_idx in range(z_chunks):
1129
+ for y_idx in range(y_chunks):
1130
+ for x_idx in range(x_chunks):
1131
+ z_start = z_positions[z_idx]
1132
+ z_end = z_positions[z_idx + 1]
1133
+ y_start = y_positions[y_idx]
1134
+ y_end = y_positions[y_idx + 1]
1135
+ x_start = x_positions[x_idx]
1136
+ x_end = x_positions[x_idx + 1]
1137
+
1138
+ coords = [z_start, z_end, y_start, y_end, x_start, x_end]
1139
+ chunks.append(coords)
1112
1140
 
1113
1141
  return chunks
1114
1142
 
@@ -1196,10 +1224,20 @@ class InteractiveSegmenter:
1196
1224
  (x[0][2] - curr_x) ** 2))
1197
1225
  return nearest[0]
1198
1226
  else:
1199
- # 3D chunks: use existing center-based distance calculation
1200
- nearest = min(unprocessed_chunks,
1227
+ # 3D chunks: find chunks on nearest Z-plane, then closest in X/Y
1228
+ # First find the nearest Z-plane among available chunks
1229
+ nearest_z = min(unprocessed_chunks,
1230
+ key=lambda x: abs(x[1]['center'][0] - curr_z))[1]['center'][0]
1231
+
1232
+ # Get all chunks on that nearest Z-plane
1233
+ nearest_z_chunks = [chunk for chunk in unprocessed_chunks
1234
+ if chunk[1]['center'][0] == nearest_z]
1235
+
1236
+ # From those chunks, find closest in X/Y
1237
+ nearest = min(nearest_z_chunks,
1201
1238
  key=lambda x: sum((a - b) ** 2 for a, b in
1202
- zip(x[1]['center'], (curr_z, curr_y, curr_x))))
1239
+ zip(x[1]['center'][1:], (curr_y, curr_x))))
1240
+
1203
1241
  return nearest[0]
1204
1242
 
1205
1243
  return None