nettracer3d 1.2.5__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

nettracer3d/nettracer.py CHANGED
@@ -90,6 +90,7 @@ def reslice_3d_array(args):
90
90
  def _get_node_edge_dict(label_array, edge_array, label, dilate_xy, dilate_z):
91
91
  """Internal method used for the secondary algorithm to find which nodes interact with which edges."""
92
92
 
93
+ import tifffile
93
94
  # Create a boolean mask where elements with the specified label are True
94
95
  label_array = label_array == label
95
96
  label_array = dilate_3D(label_array, dilate_xy, dilate_xy, dilate_z) #Dilate the label to see where the dilated label overlaps
@@ -104,7 +105,7 @@ def _get_node_edge_dict(label_array, edge_array, label, dilate_xy, dilate_z):
104
105
  def process_label(args):
105
106
  """Modified to use pre-computed bounding boxes instead of argwhere"""
106
107
  nodes, edges, label, dilate_xy, dilate_z, array_shape, bounding_boxes = args
107
- print(f"Processing node {label}")
108
+ #print(f"Processing node {label}")
108
109
 
109
110
  # Get the pre-computed bounding box for this label
110
111
  slice_obj = bounding_boxes[int(label)-1] # -1 because label numbers start at 1
@@ -122,6 +123,7 @@ def process_label(args):
122
123
 
123
124
  def create_node_dictionary(nodes, edges, num_nodes, dilate_xy, dilate_z):
124
125
  """Modified to pre-compute all bounding boxes using find_objects"""
126
+ print("Calculating network...")
125
127
  node_dict = {}
126
128
  array_shape = nodes.shape
127
129
 
@@ -783,6 +785,147 @@ def get_surface_areas(labeled, xy_scale=1, z_scale=1):
783
785
  result = {int(label): float(surface_areas[label]) for label in labels}
784
786
  return result
785
787
 
788
+ def get_background_surface_areas(labeled, xy_scale=1, z_scale=1):
789
+ """Calculate surface area exposed to background (value 0) for each object."""
790
+ labels = np.unique(labeled)
791
+ labels = labels[labels > 0]
792
+ max_label = int(np.max(labeled))
793
+
794
+ surface_areas = np.zeros(max_label + 1, dtype=np.float64)
795
+
796
+ for axis in range(3):
797
+ if axis == 2:
798
+ face_area = xy_scale * xy_scale
799
+ else:
800
+ face_area = xy_scale * z_scale
801
+
802
+ for direction in [-1, 1]:
803
+ # Pad with zeros only on the axis we're checking
804
+ pad_width = [(1, 1) if i == axis else (0, 0) for i in range(3)]
805
+ padded = np.pad(labeled, pad_width, mode='constant', constant_values=0)
806
+
807
+ # Roll the padded array
808
+ shifted = np.roll(padded, direction, axis=axis)
809
+
810
+ # Extract the center region (original size) from shifted
811
+ slices = [slice(1, -1) if i == axis else slice(None) for i in range(3)]
812
+ shifted_cropped = shifted[tuple(slices)]
813
+
814
+ # Find faces exposed to background (neighbor is 0)
815
+ exposed_faces = (shifted_cropped == 0) & (labeled > 0)
816
+
817
+ face_counts = np.bincount(labeled[exposed_faces],
818
+ minlength=max_label + 1)
819
+ surface_areas += face_counts * face_area
820
+
821
+ result = {int(label): float(surface_areas[label]) for label in labels}
822
+ return result
823
+
824
+
825
+ def get_background_proportion(labeled, xy_scale=1, z_scale=1):
826
+ """Calculate proportion of surface area exposed to background for each object."""
827
+ total_areas = get_surface_areas(labeled, xy_scale, z_scale)
828
+ background_areas = get_background_surface_areas(labeled, xy_scale, z_scale)
829
+
830
+ proportions = {}
831
+ for label in total_areas:
832
+ if total_areas[label] > 0:
833
+ proportions[label] = background_areas[label] / total_areas[label]
834
+ else:
835
+ proportions[label] = 0.0
836
+
837
+ return proportions
838
+
839
+ def get_perimeters(labeled, xy_scale=1):
840
+ """Calculate total perimeter for each object in a 2D array (pseudo-3D with z=1)."""
841
+ # Squeeze to 2D without modifying the original array reference
842
+ labeled_2d = np.squeeze(labeled)
843
+
844
+ labels = np.unique(labeled_2d)
845
+ labels = labels[labels > 0]
846
+ max_label = int(np.max(labeled_2d))
847
+
848
+ perimeters = np.zeros(max_label + 1, dtype=np.float64)
849
+
850
+ # Only check 2 axes for 2D
851
+ for axis in range(2):
852
+ edge_length = xy_scale
853
+
854
+ for direction in [-1, 1]:
855
+ # Pad with zeros only on the axis we're checking
856
+ pad_width = [(1, 1) if i == axis else (0, 0) for i in range(2)]
857
+ padded = np.pad(labeled_2d, pad_width, mode='constant', constant_values=0)
858
+
859
+ # Roll the padded array
860
+ shifted = np.roll(padded, direction, axis=axis)
861
+
862
+ # Extract the center region (original size) from shifted
863
+ slices = [slice(1, -1) if i == axis else slice(None) for i in range(2)]
864
+ shifted_cropped = shifted[tuple(slices)]
865
+
866
+ # Find exposed edges
867
+ exposed_edges = (labeled_2d != shifted_cropped) & (labeled_2d > 0)
868
+
869
+ edge_counts = np.bincount(labeled_2d[exposed_edges],
870
+ minlength=max_label + 1)
871
+ perimeters += edge_counts * edge_length
872
+
873
+ result = {int(label): float(perimeters[label]) for label in labels}
874
+ return result
875
+
876
+
877
+ def get_background_perimeters(labeled, xy_scale=1):
878
+ """Calculate perimeter exposed to background (value 0) for each object in a 2D array."""
879
+ # Squeeze to 2D without modifying the original array reference
880
+ labeled_2d = np.squeeze(labeled)
881
+
882
+ labels = np.unique(labeled_2d)
883
+ labels = labels[labels > 0]
884
+ max_label = int(np.max(labeled_2d))
885
+
886
+ perimeters = np.zeros(max_label + 1, dtype=np.float64)
887
+
888
+ # Only check 2 axes for 2D
889
+ for axis in range(2):
890
+ edge_length = xy_scale
891
+
892
+ for direction in [-1, 1]:
893
+ # Pad with zeros only on the axis we're checking
894
+ pad_width = [(1, 1) if i == axis else (0, 0) for i in range(2)]
895
+ padded = np.pad(labeled_2d, pad_width, mode='constant', constant_values=0)
896
+
897
+ # Roll the padded array
898
+ shifted = np.roll(padded, direction, axis=axis)
899
+
900
+ # Extract the center region (original size) from shifted
901
+ slices = [slice(1, -1) if i == axis else slice(None) for i in range(2)]
902
+ shifted_cropped = shifted[tuple(slices)]
903
+
904
+ # Find edges exposed to background (neighbor is 0)
905
+ exposed_edges = (shifted_cropped == 0) & (labeled_2d > 0)
906
+
907
+ edge_counts = np.bincount(labeled_2d[exposed_edges],
908
+ minlength=max_label + 1)
909
+ perimeters += edge_counts * edge_length
910
+
911
+ result = {int(label): float(perimeters[label]) for label in labels}
912
+ return result
913
+
914
+
915
+ def get_background_perimeter_proportion(labeled, xy_scale=1):
916
+ """Calculate proportion of perimeter exposed to background for each object in a 2D array."""
917
+ total_perimeters = get_perimeters(labeled, xy_scale)
918
+ background_perimeters = get_background_perimeters(labeled, xy_scale)
919
+
920
+ proportions = {}
921
+ for label in total_perimeters:
922
+ if total_perimeters[label] > 0:
923
+ proportions[label] = background_perimeters[label] / total_perimeters[label]
924
+ else:
925
+ proportions[label] = 0.0
926
+
927
+ return proportions
928
+
786
929
  def break_and_label_skeleton(skeleton, peaks = 1, branch_removal = 0, comp_dil = 0, max_vol = 0, directory = None, return_skele = False, nodes = None, compute = True, unify = False, xy_scale = 1, z_scale = 1):
787
930
  """Internal method to break open a skeleton at its branchpoints and label the remaining components, for an 8bit binary array"""
788
931
 
@@ -1520,7 +1663,7 @@ def dilate_2D(array, search, scaling = 1):
1520
1663
  return inv
1521
1664
 
1522
1665
 
1523
- def dilate_3D_dt(array, search_distance, xy_scaling=1.0, z_scaling=1.0):
1666
+ def dilate_3D_dt(array, search_distance, xy_scaling=1.0, z_scaling=1.0, fast_dil = False):
1524
1667
  """
1525
1668
  Dilate a 3D array using distance transform method. Dt dilation produces perfect results but only works in euclidean geometry and lags in big arrays.
1526
1669
 
@@ -1568,7 +1711,7 @@ def dilate_3D_dt(array, search_distance, xy_scaling=1.0, z_scaling=1.0):
1568
1711
  """
1569
1712
 
1570
1713
  # Compute distance transform (Euclidean)
1571
- inv = smart_dilate.compute_distance_transform_distance(inv, sampling = [z_scaling, xy_scaling, xy_scaling])
1714
+ inv = smart_dilate.compute_distance_transform_distance(inv, sampling = [z_scaling, xy_scaling, xy_scaling], fast_dil = fast_dil)
1572
1715
 
1573
1716
  #inv = inv * cardinal
1574
1717
 
@@ -1615,7 +1758,7 @@ def erode_2D(array, search, scaling=1, preserve_labels = False):
1615
1758
 
1616
1759
  return array
1617
1760
 
1618
- def erode_3D_dt(array, search_distance, xy_scaling=1.0, z_scaling=1.0, preserve_labels = False):
1761
+ def erode_3D_dt(array, search_distance, xy_scaling=1.0, z_scaling=1.0, fast_dil = False, preserve_labels = False):
1619
1762
  """
1620
1763
  Erode a 3D array using distance transform method. DT erosion produces perfect results
1621
1764
  with Euclidean geometry, but may be slower for large arrays.
@@ -1643,34 +1786,30 @@ def erode_3D_dt(array, search_distance, xy_scaling=1.0, z_scaling=1.0, preserve_
1643
1786
 
1644
1787
  borders = find_boundaries(array, mode='thick')
1645
1788
  mask = array * invert_array(borders)
1646
- mask = smart_dilate.compute_distance_transform_distance(mask, sampling = [z_scaling, xy_scaling, xy_scaling])
1789
+ mask = smart_dilate.compute_distance_transform_distance(mask, sampling = [z_scaling, xy_scaling, xy_scaling], fast_dil = fast_dil)
1647
1790
  mask = mask >= search_distance
1648
1791
  array = mask * array
1649
1792
  else:
1650
- array = smart_dilate.compute_distance_transform_distance(array, sampling = [z_scaling, xy_scaling, xy_scaling])
1793
+ array = smart_dilate.compute_distance_transform_distance(array, sampling = [z_scaling, xy_scaling, xy_scaling], fast_dil = fast_dil)
1651
1794
  # Threshold the distance transform to get eroded result
1652
1795
  # For erosion, we keep only the points that are at least search_distance from the boundary
1653
1796
  array = array > search_distance
1654
1797
 
1655
- # Resample back to original dimensions if needed
1656
- #if rev_factor:
1657
- #array = ndimage.zoom(array, rev_factor, order=0) # Use order=0 for binary masks
1658
-
1659
1798
  return array.astype(np.uint8)
1660
1799
 
1661
1800
 
1662
1801
  def dilate_3D(tiff_array, dilated_x, dilated_y, dilated_z):
1663
- """Internal method to dilate an array in 3D. Dilation this way is much faster than using a distance transform although the latter is theoretically more accurate.
1802
+ """Internal method to dilate an array in 3D. Dilation this way is much faster than using a distance transform although the latter is more accurate.
1664
1803
  Arguments are an array, and the desired pixel dilation amounts in X, Y, Z. Uses psuedo-3D kernels (imagine a 3D + sign rather than a cube) to approximate 3D neighborhoods but will miss diagonally located things with larger kernels, if those are needed use the distance transform version.
1665
1804
  """
1666
1805
 
1667
- if tiff_array.shape[0] == 1:
1668
- return dilate_2D(tiff_array, ((dilated_x - 1) / 2))
1669
-
1670
1806
  if dilated_x == 3 and dilated_y == 3 and dilated_z == 3:
1671
1807
 
1672
1808
  return dilate_3D_old(tiff_array, dilated_x, dilated_y, dilated_z)
1673
1809
 
1810
+ if tiff_array.shape[0] == 1:
1811
+ return dilate_2D(tiff_array, ((dilated_x - 1) / 2))
1812
+
1674
1813
  def create_circular_kernel(diameter):
1675
1814
  """Create a 2D circular kernel with a given radius.
1676
1815
 
@@ -1802,11 +1941,6 @@ def dilate_3D_old(tiff_array, dilated_x=3, dilated_y=3, dilated_z=3):
1802
1941
  Dilated 3D array
1803
1942
  """
1804
1943
 
1805
- # Handle special case for 2D arrays
1806
- if tiff_array.shape[0] == 1:
1807
- # Call 2D dilation function if needed
1808
- return dilate_2D(tiff_array, 1) # For a 3x3 kernel, radius is 1
1809
-
1810
1944
  # Create a simple 3x3x3 cubic kernel (all ones)
1811
1945
  kernel = np.ones((3, 3, 3), dtype=bool)
1812
1946
 
@@ -1816,107 +1950,6 @@ def dilate_3D_old(tiff_array, dilated_x=3, dilated_y=3, dilated_z=3):
1816
1950
  return dilated_array.astype(np.uint8)
1817
1951
 
1818
1952
 
1819
- def erode_3D(tiff_array, eroded_x, eroded_y, eroded_z):
1820
- """Internal method to erode an array in 3D. Erosion this way is faster than using a distance transform although the latter is theoretically more accurate.
1821
- Arguments are an array, and the desired pixel erosion amounts in X, Y, Z."""
1822
-
1823
- if tiff_array.shape[0] == 1:
1824
- return erode_2D(tiff_array, ((eroded_x - 1) / 2))
1825
-
1826
- def create_circular_kernel(diameter):
1827
- """Create a 2D circular kernel with a given radius.
1828
- Parameters:
1829
- radius (int or float): The radius of the circle.
1830
- Returns:
1831
- numpy.ndarray: A 2D numpy array representing the circular kernel.
1832
- """
1833
- # Determine the size of the kernel
1834
- radius = diameter/2
1835
- size = radius # Diameter of the circle
1836
- size = int(np.ceil(size)) # Ensure size is an integer
1837
-
1838
- # Create a grid of (x, y) coordinates
1839
- y, x = np.ogrid[-radius:radius+1, -radius:radius+1]
1840
-
1841
- # Calculate the distance from the center (0,0)
1842
- distance = np.sqrt(x**2 + y**2)
1843
-
1844
- # Create the circular kernel: points within the radius are 1, others are 0
1845
- kernel = distance <= radius
1846
-
1847
- # Convert the boolean array to integer (0 and 1)
1848
- return kernel.astype(np.uint8)
1849
-
1850
- def create_ellipsoidal_kernel(long_axis, short_axis):
1851
- """Create a 2D ellipsoidal kernel with specified axis lengths and orientation.
1852
- Parameters:
1853
- long_axis (int or float): The length of the long axis.
1854
- short_axis (int or float): The length of the short axis.
1855
- Returns:
1856
- numpy.ndarray: A 2D numpy array representing the ellipsoidal kernel.
1857
- """
1858
- semi_major, semi_minor = long_axis / 2, short_axis / 2
1859
- # Determine the size of the kernel
1860
- size_y = int(np.ceil(semi_minor))
1861
- size_x = int(np.ceil(semi_major))
1862
-
1863
- # Create a grid of (x, y) coordinates centered at (0,0)
1864
- y, x = np.ogrid[-semi_minor:semi_minor+1, -semi_major:semi_major+1]
1865
-
1866
- # Ellipsoid equation: (x/a)^2 + (y/b)^2 <= 1
1867
- ellipse = (x**2 / semi_major**2) + (y**2 / semi_minor**2) <= 1
1868
-
1869
- return ellipse.astype(np.uint8)
1870
-
1871
- z_depth = tiff_array.shape[0]
1872
-
1873
- # Function to process each slice
1874
- def process_slice(z):
1875
- tiff_slice = tiff_array[z].astype(np.uint8)
1876
- eroded_slice = cv2.erode(tiff_slice, kernel, iterations=1)
1877
- return z, eroded_slice
1878
-
1879
- def process_slice_other(y):
1880
- tiff_slice = tiff_array[:, y, :].astype(np.uint8)
1881
- eroded_slice = cv2.erode(tiff_slice, kernel, iterations=1)
1882
- return y, eroded_slice
1883
-
1884
- # Create empty arrays to store the eroded results for the XY and XZ planes
1885
- eroded_xy = np.zeros_like(tiff_array, dtype=np.uint8)
1886
- eroded_xz = np.zeros_like(tiff_array, dtype=np.uint8)
1887
-
1888
- kernel_x = int(eroded_x)
1889
- kernel = create_circular_kernel(kernel_x)
1890
-
1891
- num_cores = mp.cpu_count()
1892
- with ThreadPoolExecutor(max_workers=num_cores) as executor:
1893
- futures = {executor.submit(process_slice, z): z for z in range(tiff_array.shape[0])}
1894
- for future in as_completed(futures):
1895
- z, eroded_slice = future.result()
1896
- eroded_xy[z] = eroded_slice
1897
-
1898
- kernel_x = int(eroded_x)
1899
- kernel_z = int(eroded_z)
1900
- kernel = create_ellipsoidal_kernel(kernel_x, kernel_z)
1901
-
1902
- if z_depth != 2:
1903
-
1904
- with ThreadPoolExecutor(max_workers=num_cores) as executor:
1905
- futures = {executor.submit(process_slice_other, y): y for y in range(tiff_array.shape[1])}
1906
-
1907
- for future in as_completed(futures):
1908
- y, eroded_slice = future.result()
1909
- eroded_xz[:, y, :] = eroded_slice
1910
-
1911
- # Overlay the results using AND operation instead of OR for erosion
1912
- if z_depth != 2:
1913
- final_result = eroded_xy & eroded_xz
1914
- else:
1915
- return eroded_xy
1916
-
1917
- return final_result
1918
-
1919
-
1920
1953
  def dilation_length_to_pixels(xy_scaling, z_scaling, micronx, micronz):
1921
1954
  """Internal method to find XY and Z dilation parameters based on voxel micron scaling"""
1922
1955
  dilate_xy = 2 * int(round(micronx/xy_scaling))
@@ -2321,12 +2354,13 @@ def dilate(arrayimage, amount, xy_scale = 1, z_scale = 1, directory = None, fast
2321
2354
  def erode(arrayimage, amount, xy_scale = 1, z_scale = 1, mode = 0, preserve_labels = False):
2322
2355
  if not preserve_labels and len(np.unique(arrayimage)) > 2: #binarize
2323
2356
  arrayimage = binarize(arrayimage)
2324
- erode_xy, erode_z = dilation_length_to_pixels(xy_scale, z_scale, amount, amount)
2325
2357
 
2326
- if mode == 2:
2327
- arrayimage = (erode_3D(arrayimage, erode_xy, erode_xy, erode_z)) * 255
2358
+ if mode == 0 or mode == 2:
2359
+ fast_dil = True
2328
2360
  else:
2329
- arrayimage = erode_3D_dt(arrayimage, amount, xy_scaling=xy_scale, z_scaling=z_scale, preserve_labels = preserve_labels)
2361
+ fast_dil = False
2362
+
2363
+ arrayimage = erode_3D_dt(arrayimage, amount, xy_scaling=xy_scale, z_scaling=z_scale, fast_dil = fast_dil, preserve_labels = preserve_labels)
2330
2364
 
2331
2365
  if np.max(arrayimage) == 1:
2332
2366
  arrayimage = arrayimage * 255
@@ -2378,7 +2412,7 @@ def skeletonize(arrayimage, directory = None):
2378
2412
 
2379
2413
  return arrayimage
2380
2414
 
2381
- def label_branches(array, peaks = 0, branch_removal = 0, comp_dil = 0, max_vol = 0, down_factor = None, directory = None, nodes = None, bonus_array = None, GPU = True, arrayshape = None, compute = False, unify = False, union_val = 10, xy_scale = 1, z_scale = 1):
2415
+ def label_branches(array, peaks = 0, branch_removal = 0, comp_dil = 0, max_vol = 0, down_factor = None, directory = None, nodes = None, bonus_array = None, GPU = True, arrayshape = None, compute = False, unify = False, union_val = 10, mode = 0, xy_scale = 1, z_scale = 1):
2382
2416
  """
2383
2417
  Can be used to label branches a binary image. Labelled output will be saved to the active directory if none is specified. Note this works better on already thin filaments and may over-divide larger trunkish objects.
2384
2418
  :param array: (Mandatory, string or ndarray) - If string, a path to a tif file to label. Note that the ndarray alternative is for internal use mainly and will not save its output.
@@ -2419,24 +2453,25 @@ def label_branches(array, peaks = 0, branch_removal = 0, comp_dil = 0, max_vol =
2419
2453
  from . import branch_stitcher
2420
2454
  verts = dilate_3D_old(verts, 3, 3, 3,)
2421
2455
  verts, _ = label_objects(verts)
2422
- array = branch_stitcher.trace(bonus_array, array, verts, score_thresh = union_val)
2456
+ print("Merging branches...")
2457
+ array = branch_stitcher.trace(bonus_array, array, verts, score_thresh = union_val, xy_scale = xy_scale, z_scale = z_scale)
2423
2458
  verts = None
2424
2459
 
2425
2460
 
2426
2461
  if nodes is None:
2427
2462
 
2428
- array = smart_dilate.smart_label(array, other_array, GPU = GPU, remove_template = True)
2463
+ array = smart_dilate.smart_label(array, other_array, GPU = GPU, remove_template = True, mode = mode)
2429
2464
  #distance = smart_dilate.compute_distance_transform_distance(array)
2430
2465
  #array = water(-distance, other_array, mask=array) #Tried out skimage watershed as shown and found it did not label branches as well as smart_label (esp combined combined with post-processing label splitting if needed)
2431
2466
 
2432
2467
  else:
2433
2468
  if down_factor is not None:
2434
- array = smart_dilate.smart_label(bonus_array, array, GPU = GPU, predownsample = down_factor, remove_template = True)
2469
+ array = smart_dilate.smart_label(bonus_array, array, GPU = GPU, predownsample = down_factor, remove_template = True, mode = mode)
2435
2470
  #distance = smart_dilate.compute_distance_transform_distance(bonus_array)
2436
2471
  #array = water(-distance, array, mask=bonus_array)
2437
2472
  else:
2438
2473
 
2439
- array = smart_dilate.smart_label(bonus_array, array, GPU = GPU, remove_template = True)
2474
+ array = smart_dilate.smart_label(bonus_array, array, GPU = GPU, remove_template = True, mode = mode)
2440
2475
  #distance = smart_dilate.compute_distance_transform_distance(bonus_array)
2441
2476
  #array = water(-distance, array, mask=bonus_array)
2442
2477
 
@@ -2498,7 +2533,7 @@ def fix_branches_network(array, G, communities, fix_val = None):
2498
2533
 
2499
2534
  return targs
2500
2535
 
2501
- def fix_branches(array, G, max_val):
2536
+ def fix_branches(array, G, max_val, consider_prop = True):
2502
2537
  """
2503
2538
  Parameters:
2504
2539
  array: numpy array containing the labeled regions
@@ -2522,8 +2557,29 @@ def fix_branches(array, G, max_val):
2522
2557
 
2523
2558
  # Find all neighbors of not_safe nodes in one pass
2524
2559
  neighbors_of_not_safe = set()
2525
- for node in not_safe_initial:
2526
- neighbors_of_not_safe.update(adj[node])
2560
+ if consider_prop:
2561
+ if array.shape[0] != 1:
2562
+ areas = get_background_proportion(array, xy_scale=1, z_scale=1)
2563
+ else:
2564
+ areas = get_background_perimeter_proportion(array, xy_scale=1)
2565
+ valid_areas = {label: proportion for label, proportion in areas.items() if proportion < 0.4}
2566
+
2567
+ for node in not_safe_initial:
2568
+ # Filter neighbors based on whether they're in the valid areas dict
2569
+ valid_neighbors = [neighbor for neighbor in adj[node] if neighbor in valid_areas]
2570
+
2571
+ # If no valid neighbors, fall back to the one with lowest proportion
2572
+ if not valid_neighbors:
2573
+ node_neighbors = list(adj[node])
2574
+ if node_neighbors:
2575
+ # Find neighbor with minimum background proportion
2576
+ min_neighbor = min(node_neighbors, key=lambda n: areas.get(n, float('inf')))
2577
+ valid_neighbors = [min_neighbor]
2578
+
2579
+ neighbors_of_not_safe.update(valid_neighbors)
2580
+ else:
2581
+ for node in not_safe_initial:
2582
+ neighbors_of_not_safe.update(adj[node])
2527
2583
 
2528
2584
  # Remove max_val if present
2529
2585
  neighbors_of_not_safe.discard(max_val)
@@ -2534,7 +2590,7 @@ def fix_branches(array, G, max_val):
2534
2590
  # Update sets
2535
2591
  not_safe = not_safe_initial | nodes_to_move
2536
2592
 
2537
- # The rest of the function - FIX STARTS HERE
2593
+ # The rest of the function
2538
2594
  targs = np.array(list(not_safe))
2539
2595
 
2540
2596
  if len(targs) == 0:
@@ -2547,18 +2603,12 @@ def fix_branches(array, G, max_val):
2547
2603
  # Get the current maximum label in the array to avoid collisions
2548
2604
  current_max = np.max(array)
2549
2605
 
2550
- # Assign new unique labels to each connected component
2551
- for component_id in range(1, num_components + 1):
2552
- component_mask = labeled == component_id
2553
- array[component_mask] = current_max + component_id
2606
+ # Vectorized relabeling - single operation instead of loop
2607
+ array[mask] = labeled[mask] + current_max
2554
2608
 
2555
2609
  return array
2556
2610
 
2557
2611
 
2558
-
2559
-
2560
-
2561
-
2562
2612
  def label_vertices(array, peaks = 0, branch_removal = 0, comp_dil = 0, max_vol = 0, down_factor = 0, directory = None, return_skele = False, order = 0, fastdil = True):
2563
2613
  """
2564
2614
  Can be used to label vertices (where multiple branches connect) a binary image. Labelled output will be saved to the active directory if none is specified. Note this works better on already thin filaments and may over-divide larger trunkish objects.
@@ -2593,7 +2643,7 @@ def label_vertices(array, peaks = 0, branch_removal = 0, comp_dil = 0, max_vol =
2593
2643
  old_skeleton = copy.deepcopy(array) # The skeleton might get modified in label_vertices so we can make a preserved copy of it to use later
2594
2644
 
2595
2645
  if branch_removal > 0:
2596
- array = remove_branches(array, branch_removal)
2646
+ array = remove_branches_new(array, branch_removal)
2597
2647
 
2598
2648
  array = np.pad(array, pad_width=1, mode='constant', constant_values=0)
2599
2649
 
@@ -2628,19 +2678,18 @@ def label_vertices(array, peaks = 0, branch_removal = 0, comp_dil = 0, max_vol =
2628
2678
  if peaks > 0:
2629
2679
  image_copy = filter_size_by_peaks(image_copy, peaks)
2630
2680
  if comp_dil > 0:
2631
- image_copy = dilate(image_copy, comp_dil, fast_dil = fastdil)
2681
+ image_copy = dilate_3D_dt(image_copy, comp_dil, fast_dil = fastdil)
2632
2682
 
2633
2683
  labeled_image, num_labels = label_objects(image_copy)
2634
2684
  elif max_vol > 0:
2635
2685
  image_copy = filter_size_by_vol(image_copy, max_vol)
2636
2686
  if comp_dil > 0:
2637
- image_copy = dilate(image_copy, comp_dil, fast_dil = fastdil)
2687
+ image_copy = dilate_3D_dt(image_copy, comp_dil, fast_dil = fastdil)
2638
2688
 
2639
2689
  labeled_image, num_labels = label_objects(image_copy)
2640
2690
  else:
2641
-
2642
2691
  if comp_dil > 0:
2643
- image_copy = dilate(image_copy, comp_dil, fast_dil = fastdil)
2692
+ image_copy = dilate_3D_dt(image_copy, comp_dil, fast_dil = fastdil)
2644
2693
  labeled_image, num_labels = label_objects(image_copy)
2645
2694
 
2646
2695
  #if down_factor > 0:
@@ -2775,7 +2824,7 @@ def gray_watershed(image, min_distance = 1, threshold_abs = None):
2775
2824
  return image
2776
2825
 
2777
2826
 
2778
- def watershed(image, directory = None, proportion = 0.1, GPU = True, smallest_rad = None, predownsample = None, predownsample2 = None):
2827
+ def watershed(image, directory = None, proportion = 0.1, GPU = True, smallest_rad = None, fast_dil = False, predownsample = None, predownsample2 = None):
2779
2828
  """
2780
2829
  Can be used to 3D watershed a binary image. Watershedding attempts to use an algorithm to split touching objects into seperate labelled components. Labelled output will be saved to the active directory if none is specified.
2781
2830
  This watershed algo essentially uses the distance transform to decide where peaks are and then after thresholding out the non-peaks, uses the peaks as labelling kernels for a smart label. It runs semi slow without GPU accel since it requires two dts to be computed.
@@ -2838,7 +2887,7 @@ def watershed(image, directory = None, proportion = 0.1, GPU = True, smallest_ra
2838
2887
  if GPU:
2839
2888
  print("GPU dt failed or did not detect GPU (cupy must be installed with a CUDA toolkit setup...). Computing CPU distance transform instead.")
2840
2889
  print(f"Error message: {str(e)}")
2841
- distance = smart_dilate.compute_distance_transform_distance(image)
2890
+ distance = smart_dilate.compute_distance_transform_distance(image, fast_dil = fast_dil)
2842
2891
 
2843
2892
 
2844
2893
  distance = threshold(distance, proportion, custom_rad = smallest_rad)
@@ -4189,10 +4238,7 @@ class Network_3D:
4189
4238
  if search is not None and hasattr(self, '_nodes') and self._nodes is not None and self._search_region is None:
4190
4239
  search_region = binarize(self._nodes)
4191
4240
  dilate_xy, dilate_z = dilation_length_to_pixels(self._xy_scale, self._z_scale, search, search)
4192
- if fast_dil:
4193
- search_region = dilate_3D(search_region, dilate_xy, dilate_xy, dilate_z)
4194
- else:
4195
- search_region = dilate_3D_dt(search_region, diledge, self._xy_scale, self._z_scale)
4241
+ search_region = dilate_3D_dt(search_region, diledge, self._xy_scale, self._z_scale, fast_dil = fast_dil)
4196
4242
  else:
4197
4243
  search_region = binarize(self._search_region)
4198
4244
 
@@ -4210,10 +4256,8 @@ class Network_3D:
4210
4256
 
4211
4257
  if dilate_xy <= 3 and dilate_z <= 3:
4212
4258
  outer_edges = dilate_3D_old(outer_edges, dilate_xy, dilate_xy, dilate_z)
4213
- elif fast_dil:
4214
- outer_edges = dilate_3D(outer_edges, dilate_xy, dilate_xy, dilate_z)
4215
4259
  else:
4216
- outer_edges = dilate_3D_dt(outer_edges, diledge, self._xy_scale, self._z_scale)
4260
+ outer_edges = dilate_3D_dt(outer_edges, diledge, self._xy_scale, self._z_scale, fast_dil = fast_dil)
4217
4261
  else:
4218
4262
  outer_edges = dilate_3D_old(outer_edges)
4219
4263
 
@@ -4247,21 +4291,37 @@ class Network_3D:
4247
4291
  self._nodes, num_nodes = label_objects(nodes, structure_3d)
4248
4292
 
4249
4293
  def combine_nodes(self, root_nodes, other_nodes, other_ID, identity_dict, root_ID = None, centroids = False, down_factor = None):
4250
-
4251
4294
  """Internal method to merge two labelled node arrays into one"""
4252
-
4253
4295
  print("Combining node arrays")
4254
-
4296
+
4297
+ # Calculate the maximum value that will exist in the output
4298
+ max_root = np.max(root_nodes)
4299
+ max_other = np.max(other_nodes)
4300
+ max_output = max_root + max_other # Worst case: all other_nodes shifted by max_root
4301
+
4302
+ # Determine the minimum dtype needed
4303
+ if max_output <= 255:
4304
+ target_dtype = np.uint8
4305
+ elif max_output <= 65535:
4306
+ target_dtype = np.uint16
4307
+ else:
4308
+ target_dtype = np.uint32
4309
+
4310
+ # Convert arrays to appropriate dtype
4311
+ root_nodes = root_nodes.astype(target_dtype)
4312
+ other_nodes = other_nodes.astype(target_dtype)
4313
+
4314
+ # Now perform the merge
4255
4315
  mask = (root_nodes == 0) & (other_nodes > 0)
4256
4316
  if np.any(mask):
4257
- max_val = np.max(root_nodes)
4258
- other_nodes[:] = np.where(mask, other_nodes + max_val, 0)
4259
- if centroids:
4260
- new_dict = network_analysis._find_centroids(other_nodes, down_factor = down_factor)
4261
- if down_factor is not None:
4262
- for item in new_dict:
4263
- new_dict[item] = down_factor * new_dict[item]
4264
- self.node_centroids.update(new_dict)
4317
+ other_nodes_shifted = np.where(other_nodes > 0, other_nodes + max_root, 0)
4318
+ if centroids:
4319
+ new_dict = network_analysis._find_centroids(other_nodes_shifted, down_factor = down_factor)
4320
+ if down_factor is not None:
4321
+ for item in new_dict:
4322
+ new_dict[item] = down_factor * new_dict[item]
4323
+ self.node_centroids.update(new_dict)
4324
+ other_nodes = np.where(mask, other_nodes_shifted, 0)
4265
4325
 
4266
4326
  if root_ID is not None:
4267
4327
  rootIDs = list(np.unique(root_nodes)) #Sets up adding these vals to the identitiy dictionary. Gets skipped if this has already been done.
@@ -4300,7 +4360,7 @@ class Network_3D:
4300
4360
 
4301
4361
  return nodes, identity_dict
4302
4362
 
4303
- def merge_nodes(self, addn_nodes_name, label_nodes = True, root_id = "Root_Nodes", centroids = False, down_factor = None):
4363
+ def merge_nodes(self, addn_nodes_name, label_nodes = True, root_id = "Root_Nodes", centroids = False, down_factor = None, is_array = False):
4304
4364
  """
4305
4365
  Merges the self._nodes attribute with alternate labelled node images. The alternate nodes can be inputted as a string for a filepath to a tif,
4306
4366
  or as a directory address containing only tif images, which will merge the _nodes attribute with all tifs in the folder. The _node_identities attribute
@@ -4327,7 +4387,11 @@ class Network_3D:
4327
4387
  self.node_centroids[item] = down_factor * self.node_centroids[item]
4328
4388
 
4329
4389
  try: #Try presumes the input is a tif
4330
- addn_nodes = tifffile.imread(addn_nodes_name) #If not this will fail and activate the except block
4390
+ if not is_array:
4391
+ addn_nodes = tifffile.imread(addn_nodes_name) #If not this will fail and activate the except block
4392
+ else:
4393
+ addn_nodes = addn_nodes_name # Passing it an array directly
4394
+ addn_nodes_name = "Node"
4331
4395
 
4332
4396
  if label_nodes is True:
4333
4397
  addn_nodes, num_nodes2 = label_objects(addn_nodes) # Label the node objects. Note this presumes no overlap between node masks.
@@ -4339,7 +4403,6 @@ class Network_3D:
4339
4403
  num_nodes = int(np.max(node_labels))
4340
4404
 
4341
4405
  except: #Exception presumes the input is a directory containing multiple tifs, to allow multi-node stackage.
4342
-
4343
4406
  addn_nodes_list = directory_info(addn_nodes_name)
4344
4407
 
4345
4408
  for i, addn_nodes in enumerate(addn_nodes_list):
@@ -4539,9 +4602,6 @@ class Network_3D:
4539
4602
  self._nodes = nodes
4540
4603
  del nodes
4541
4604
 
4542
- if self._nodes.shape[0] == 1:
4543
- fast_dil = True #Set this to true because the 2D algo always uses the distance transform and doesnt need this special ver
4544
-
4545
4605
  if label_nodes:
4546
4606
  self._nodes, num_nodes = label_objects(self._nodes)
4547
4607
  if other_nodes is not None:
@@ -4937,7 +4997,17 @@ class Network_3D:
4937
4997
  nodeb = []
4938
4998
  edgec = []
4939
4999
 
4940
- trunk = stats.mode(edgesc)
5000
+ from collections import Counter
5001
+ counts = Counter(edgesc)
5002
+ if 0 not in edgesc:
5003
+ trunk = stats.mode(edgesc)
5004
+ else:
5005
+ try:
5006
+ sorted_edges = counts.most_common()
5007
+ trunk = sorted_edges[1][0]
5008
+ except:
5009
+ return
5010
+
4941
5011
  addtrunk = max(set(nodesa + nodesb)) + 1
4942
5012
 
4943
5013
  for i in range(len(nodesa)):
@@ -4955,7 +5025,10 @@ class Network_3D:
4955
5025
 
4956
5026
  self.network_lists = [nodea, nodeb, edgec]
4957
5027
 
4958
- self._node_centroids[addtrunk] = self._edge_centroids[trunk]
5028
+ try:
5029
+ self._node_centroids[addtrunk] = self._edge_centroids[trunk]
5030
+ except:
5031
+ pass
4959
5032
 
4960
5033
  if self._node_identities is None:
4961
5034
  self._node_identities = {}
@@ -4984,14 +5057,14 @@ class Network_3D:
4984
5057
 
4985
5058
 
4986
5059
 
4987
- def prune_samenode_connections(self):
5060
+ def prune_samenode_connections(self, target = None):
4988
5061
  """
4989
5062
  If working with a network that has multiple node identities (from merging nodes or otherwise manipulating this property),
4990
5063
  this method will remove from the network and network_lists properties any connections that exist between the same node identity,
4991
5064
  in case we want to investigate only connections between differing objects.
4992
5065
  """
4993
5066
 
4994
- self._network_lists, self._node_identities = network_analysis.prune_samenode_connections(self._network_lists, self._node_identities)
5067
+ self._network_lists, self._node_identities = network_analysis.prune_samenode_connections(self._network_lists, self._node_identities, target = target)
4995
5068
  self._network, num_weights = network_analysis.weighted_network(self._network_lists)
4996
5069
 
4997
5070
 
@@ -5193,7 +5266,7 @@ class Network_3D:
5193
5266
 
5194
5267
  #Methods related to visualizing the network using networkX and matplotlib
5195
5268
 
5196
- def show_network(self, geometric = False, directory = None):
5269
+ def show_network(self, geometric = False, directory = None, show_labels = True):
5197
5270
  """
5198
5271
  Shows the network property as a simplistic graph, and some basic stats. Does not support viewing edge weights.
5199
5272
  :param geoemtric: (Optional - Val = False; boolean). If False, node placement in the graph will be random. If True, nodes
@@ -5204,19 +5277,19 @@ class Network_3D:
5204
5277
 
5205
5278
  if not geometric:
5206
5279
 
5207
- simple_network.show_simple_network(self._network_lists, directory = directory)
5280
+ simple_network.show_simple_network(self._network_lists, directory = directory, show_labels = show_labels)
5208
5281
 
5209
5282
  else:
5210
- simple_network.show_simple_network(self._network_lists, geometric = True, geo_info = [self._node_centroids, self._nodes.shape], directory = directory)
5283
+ simple_network.show_simple_network(self._network_lists, geometric = True, geo_info = [self._node_centroids, self._nodes.shape], directory = directory, show_labels = show_labels)
5211
5284
 
5212
- def show_communities_flex(self, geometric = False, directory = None, weighted = True, partition = False, style = 0):
5285
+ def show_communities_flex(self, geometric = False, directory = None, weighted = True, partition = False, style = 0, show_labels = True):
5213
5286
 
5214
5287
 
5215
- self._communities, self.normalized_weights = modularity.show_communities_flex(self._network, self._network_lists, self.normalized_weights, geo_info = [self._node_centroids, self._nodes.shape], geometric = geometric, directory = directory, weighted = weighted, partition = partition, style = style)
5288
+ self._communities, self.normalized_weights = modularity.show_communities_flex(self._network, self._network_lists, self.normalized_weights, geo_info = [self._node_centroids, self._nodes.shape], geometric = geometric, directory = directory, weighted = weighted, partition = partition, style = style, show_labels = show_labels)
5216
5289
 
5217
5290
 
5218
5291
 
5219
- def show_identity_network(self, geometric = False, directory = None):
5292
+ def show_identity_network(self, geometric = False, directory = None, show_labels = True):
5220
5293
  """
5221
5294
  Shows the network property, and some basic stats, as a graph where nodes are labelled by colors representing the identity of the node as detailed in the node_identities property. Does not support viewing edge weights.
5222
5295
  :param geoemtric: (Optional - Val = False; boolean). If False, node placement in the graph will be random. If True, nodes
@@ -5225,9 +5298,9 @@ class Network_3D:
5225
5298
  :param directory: (Optional – Val = None; string). An optional string path to a directory to save the network plot image to. If not set, nothing will be saved.
5226
5299
  """
5227
5300
  if not geometric:
5228
- simple_network.show_identity_network(self._network_lists, self._node_identities, geometric = False, directory = directory)
5301
+ simple_network.show_identity_network(self._network_lists, self._node_identities, geometric = False, directory = directory, show_labels = show_labels)
5229
5302
  else:
5230
- simple_network.show_identity_network(self._network_lists, self._node_identities, geometric = True, geo_info = [self._node_centroids, self._nodes.shape], directory = directory)
5303
+ simple_network.show_identity_network(self._network_lists, self._node_identities, geometric = True, geo_info = [self._node_centroids, self._nodes.shape], directory = directory, show_labels = show_labels)
5231
5304
 
5232
5305
 
5233
5306
 
@@ -5710,7 +5783,7 @@ class Network_3D:
5710
5783
  print(f"Using {volume} for the volume measurement (Volume of provided mask as scaled by xy and z scaling)")
5711
5784
 
5712
5785
  # Compute distance transform on padded array
5713
- legal = smart_dilate.compute_distance_transform_distance(legal, sampling = [self.z_scale, self.xy_scale, self.xy_scale])
5786
+ legal = smart_dilate.compute_distance_transform_distance(legal, sampling = [self.z_scale, self.xy_scale, self.xy_scale], fast_dil = True)
5714
5787
 
5715
5788
  # Remove padding after distance transform
5716
5789
  if dim == 2:
@@ -5804,7 +5877,6 @@ class Network_3D:
5804
5877
 
5805
5878
 
5806
5879
  def morph_proximity(self, search = 0, targets = None, fastdil = False):
5807
-
5808
5880
  if type(search) == list:
5809
5881
  search_x, search_z = search #Suppose we just want to directly pass these params
5810
5882
  else:
@@ -5813,7 +5885,6 @@ class Network_3D:
5813
5885
  num_nodes = int(np.max(self._nodes))
5814
5886
 
5815
5887
  my_dict = proximity.create_node_dictionary(self._nodes, num_nodes, search_x, search_z, targets = targets, fastdil = fastdil, xy_scale = self._xy_scale, z_scale = self._z_scale, search = search)
5816
-
5817
5888
  my_dict = proximity.find_shared_value_pairs(my_dict)
5818
5889
 
5819
5890
  my_dict = create_and_save_dataframe(my_dict)