nettracer3d 1.2.7__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

nettracer3d/nettracer.py CHANGED
@@ -785,6 +785,147 @@ def get_surface_areas(labeled, xy_scale=1, z_scale=1):
785
785
  result = {int(label): float(surface_areas[label]) for label in labels}
786
786
  return result
787
787
 
788
+ def get_background_surface_areas(labeled, xy_scale=1, z_scale=1):
789
+ """Calculate surface area exposed to background (value 0) for each object."""
790
+ labels = np.unique(labeled)
791
+ labels = labels[labels > 0]
792
+ max_label = int(np.max(labeled))
793
+
794
+ surface_areas = np.zeros(max_label + 1, dtype=np.float64)
795
+
796
+ for axis in range(3):
797
+ if axis == 2:
798
+ face_area = xy_scale * xy_scale
799
+ else:
800
+ face_area = xy_scale * z_scale
801
+
802
+ for direction in [-1, 1]:
803
+ # Pad with zeros only on the axis we're checking
804
+ pad_width = [(1, 1) if i == axis else (0, 0) for i in range(3)]
805
+ padded = np.pad(labeled, pad_width, mode='constant', constant_values=0)
806
+
807
+ # Roll the padded array
808
+ shifted = np.roll(padded, direction, axis=axis)
809
+
810
+ # Extract the center region (original size) from shifted
811
+ slices = [slice(1, -1) if i == axis else slice(None) for i in range(3)]
812
+ shifted_cropped = shifted[tuple(slices)]
813
+
814
+ # Find faces exposed to background (neighbor is 0)
815
+ exposed_faces = (shifted_cropped == 0) & (labeled > 0)
816
+
817
+ face_counts = np.bincount(labeled[exposed_faces],
818
+ minlength=max_label + 1)
819
+ surface_areas += face_counts * face_area
820
+
821
+ result = {int(label): float(surface_areas[label]) for label in labels}
822
+ return result
823
+
824
+
825
+ def get_background_proportion(labeled, xy_scale=1, z_scale=1):
826
+ """Calculate proportion of surface area exposed to background for each object."""
827
+ total_areas = get_surface_areas(labeled, xy_scale, z_scale)
828
+ background_areas = get_background_surface_areas(labeled, xy_scale, z_scale)
829
+
830
+ proportions = {}
831
+ for label in total_areas:
832
+ if total_areas[label] > 0:
833
+ proportions[label] = background_areas[label] / total_areas[label]
834
+ else:
835
+ proportions[label] = 0.0
836
+
837
+ return proportions
838
+
839
+ def get_perimeters(labeled, xy_scale=1):
840
+ """Calculate total perimeter for each object in a 2D array (pseudo-3D with z=1)."""
841
+ # Squeeze to 2D without modifying the original array reference
842
+ labeled_2d = np.squeeze(labeled)
843
+
844
+ labels = np.unique(labeled_2d)
845
+ labels = labels[labels > 0]
846
+ max_label = int(np.max(labeled_2d))
847
+
848
+ perimeters = np.zeros(max_label + 1, dtype=np.float64)
849
+
850
+ # Only check 2 axes for 2D
851
+ for axis in range(2):
852
+ edge_length = xy_scale
853
+
854
+ for direction in [-1, 1]:
855
+ # Pad with zeros only on the axis we're checking
856
+ pad_width = [(1, 1) if i == axis else (0, 0) for i in range(2)]
857
+ padded = np.pad(labeled_2d, pad_width, mode='constant', constant_values=0)
858
+
859
+ # Roll the padded array
860
+ shifted = np.roll(padded, direction, axis=axis)
861
+
862
+ # Extract the center region (original size) from shifted
863
+ slices = [slice(1, -1) if i == axis else slice(None) for i in range(2)]
864
+ shifted_cropped = shifted[tuple(slices)]
865
+
866
+ # Find exposed edges
867
+ exposed_edges = (labeled_2d != shifted_cropped) & (labeled_2d > 0)
868
+
869
+ edge_counts = np.bincount(labeled_2d[exposed_edges],
870
+ minlength=max_label + 1)
871
+ perimeters += edge_counts * edge_length
872
+
873
+ result = {int(label): float(perimeters[label]) for label in labels}
874
+ return result
875
+
876
+
877
+ def get_background_perimeters(labeled, xy_scale=1):
878
+ """Calculate perimeter exposed to background (value 0) for each object in a 2D array."""
879
+ # Squeeze to 2D without modifying the original array reference
880
+ labeled_2d = np.squeeze(labeled)
881
+
882
+ labels = np.unique(labeled_2d)
883
+ labels = labels[labels > 0]
884
+ max_label = int(np.max(labeled_2d))
885
+
886
+ perimeters = np.zeros(max_label + 1, dtype=np.float64)
887
+
888
+ # Only check 2 axes for 2D
889
+ for axis in range(2):
890
+ edge_length = xy_scale
891
+
892
+ for direction in [-1, 1]:
893
+ # Pad with zeros only on the axis we're checking
894
+ pad_width = [(1, 1) if i == axis else (0, 0) for i in range(2)]
895
+ padded = np.pad(labeled_2d, pad_width, mode='constant', constant_values=0)
896
+
897
+ # Roll the padded array
898
+ shifted = np.roll(padded, direction, axis=axis)
899
+
900
+ # Extract the center region (original size) from shifted
901
+ slices = [slice(1, -1) if i == axis else slice(None) for i in range(2)]
902
+ shifted_cropped = shifted[tuple(slices)]
903
+
904
+ # Find edges exposed to background (neighbor is 0)
905
+ exposed_edges = (shifted_cropped == 0) & (labeled_2d > 0)
906
+
907
+ edge_counts = np.bincount(labeled_2d[exposed_edges],
908
+ minlength=max_label + 1)
909
+ perimeters += edge_counts * edge_length
910
+
911
+ result = {int(label): float(perimeters[label]) for label in labels}
912
+ return result
913
+
914
+
915
+ def get_background_perimeter_proportion(labeled, xy_scale=1):
916
+ """Calculate proportion of perimeter exposed to background for each object in a 2D array."""
917
+ total_perimeters = get_perimeters(labeled, xy_scale)
918
+ background_perimeters = get_background_perimeters(labeled, xy_scale)
919
+
920
+ proportions = {}
921
+ for label in total_perimeters:
922
+ if total_perimeters[label] > 0:
923
+ proportions[label] = background_perimeters[label] / total_perimeters[label]
924
+ else:
925
+ proportions[label] = 0.0
926
+
927
+ return proportions
928
+
788
929
  def break_and_label_skeleton(skeleton, peaks = 1, branch_removal = 0, comp_dil = 0, max_vol = 0, directory = None, return_skele = False, nodes = None, compute = True, unify = False, xy_scale = 1, z_scale = 1):
789
930
  """Internal method to break open a skeleton at its branchpoints and label the remaining components, for an 8bit binary array"""
790
931
 
@@ -2392,7 +2533,7 @@ def fix_branches_network(array, G, communities, fix_val = None):
2392
2533
 
2393
2534
  return targs
2394
2535
 
2395
- def fix_branches(array, G, max_val):
2536
+ def fix_branches(array, G, max_val, consider_prop = True):
2396
2537
  """
2397
2538
  Parameters:
2398
2539
  array: numpy array containing the labeled regions
@@ -2416,8 +2557,29 @@ def fix_branches(array, G, max_val):
2416
2557
 
2417
2558
  # Find all neighbors of not_safe nodes in one pass
2418
2559
  neighbors_of_not_safe = set()
2419
- for node in not_safe_initial:
2420
- neighbors_of_not_safe.update(adj[node])
2560
+ if consider_prop:
2561
+ if array.shape[0] != 1:
2562
+ areas = get_background_proportion(array, xy_scale=1, z_scale=1)
2563
+ else:
2564
+ areas = get_background_perimeter_proportion(array, xy_scale=1)
2565
+ valid_areas = {label: proportion for label, proportion in areas.items() if proportion < 0.4}
2566
+
2567
+ for node in not_safe_initial:
2568
+ # Filter neighbors based on whether they're in the valid areas dict
2569
+ valid_neighbors = [neighbor for neighbor in adj[node] if neighbor in valid_areas]
2570
+
2571
+ # If no valid neighbors, fall back to the one with lowest proportion
2572
+ if not valid_neighbors:
2573
+ node_neighbors = list(adj[node])
2574
+ if node_neighbors:
2575
+ # Find neighbor with minimum background proportion
2576
+ min_neighbor = min(node_neighbors, key=lambda n: areas.get(n, float('inf')))
2577
+ valid_neighbors = [min_neighbor]
2578
+
2579
+ neighbors_of_not_safe.update(valid_neighbors)
2580
+ else:
2581
+ for node in not_safe_initial:
2582
+ neighbors_of_not_safe.update(adj[node])
2421
2583
 
2422
2584
  # Remove max_val if present
2423
2585
  neighbors_of_not_safe.discard(max_val)
@@ -2428,7 +2590,7 @@ def fix_branches(array, G, max_val):
2428
2590
  # Update sets
2429
2591
  not_safe = not_safe_initial | nodes_to_move
2430
2592
 
2431
- # The rest of the function - FIX STARTS HERE
2593
+ # The rest of the function
2432
2594
  targs = np.array(list(not_safe))
2433
2595
 
2434
2596
  if len(targs) == 0:
@@ -2441,18 +2603,12 @@ def fix_branches(array, G, max_val):
2441
2603
  # Get the current maximum label in the array to avoid collisions
2442
2604
  current_max = np.max(array)
2443
2605
 
2444
- # Assign new unique labels to each connected component
2445
- for component_id in range(1, num_components + 1):
2446
- component_mask = labeled == component_id
2447
- array[component_mask] = current_max + component_id
2606
+ # Vectorized relabeling - single operation instead of loop
2607
+ array[mask] = labeled[mask] + current_max
2448
2608
 
2449
2609
  return array
2450
2610
 
2451
2611
 
2452
-
2453
-
2454
-
2455
-
2456
2612
  def label_vertices(array, peaks = 0, branch_removal = 0, comp_dil = 0, max_vol = 0, down_factor = 0, directory = None, return_skele = False, order = 0, fastdil = True):
2457
2613
  """
2458
2614
  Can be used to label vertices (where multiple branches connect) a binary image. Labelled output will be saved to the active directory if none is specified. Note this works better on already thin filaments and may over-divide larger trunkish objects.
@@ -2487,7 +2643,7 @@ def label_vertices(array, peaks = 0, branch_removal = 0, comp_dil = 0, max_vol =
2487
2643
  old_skeleton = copy.deepcopy(array) # The skeleton might get modified in label_vertices so we can make a preserved copy of it to use later
2488
2644
 
2489
2645
  if branch_removal > 0:
2490
- array = remove_branches(array, branch_removal)
2646
+ array = remove_branches_new(array, branch_removal)
2491
2647
 
2492
2648
  array = np.pad(array, pad_width=1, mode='constant', constant_values=0)
2493
2649
 
@@ -4135,21 +4291,37 @@ class Network_3D:
4135
4291
  self._nodes, num_nodes = label_objects(nodes, structure_3d)
4136
4292
 
4137
4293
  def combine_nodes(self, root_nodes, other_nodes, other_ID, identity_dict, root_ID = None, centroids = False, down_factor = None):
4138
-
4139
4294
  """Internal method to merge two labelled node arrays into one"""
4140
-
4141
4295
  print("Combining node arrays")
4142
-
4296
+
4297
+ # Calculate the maximum value that will exist in the output
4298
+ max_root = np.max(root_nodes)
4299
+ max_other = np.max(other_nodes)
4300
+ max_output = max_root + max_other # Worst case: all other_nodes shifted by max_root
4301
+
4302
+ # Determine the minimum dtype needed
4303
+ if max_output <= 255:
4304
+ target_dtype = np.uint8
4305
+ elif max_output <= 65535:
4306
+ target_dtype = np.uint16
4307
+ else:
4308
+ target_dtype = np.uint32
4309
+
4310
+ # Convert arrays to appropriate dtype
4311
+ root_nodes = root_nodes.astype(target_dtype)
4312
+ other_nodes = other_nodes.astype(target_dtype)
4313
+
4314
+ # Now perform the merge
4143
4315
  mask = (root_nodes == 0) & (other_nodes > 0)
4144
4316
  if np.any(mask):
4145
- max_val = np.max(root_nodes)
4146
- other_nodes[:] = np.where(mask, other_nodes + max_val, 0)
4147
- if centroids:
4148
- new_dict = network_analysis._find_centroids(other_nodes, down_factor = down_factor)
4149
- if down_factor is not None:
4150
- for item in new_dict:
4151
- new_dict[item] = down_factor * new_dict[item]
4152
- self.node_centroids.update(new_dict)
4317
+ other_nodes_shifted = np.where(other_nodes > 0, other_nodes + max_root, 0)
4318
+ if centroids:
4319
+ new_dict = network_analysis._find_centroids(other_nodes_shifted, down_factor = down_factor)
4320
+ if down_factor is not None:
4321
+ for item in new_dict:
4322
+ new_dict[item] = down_factor * new_dict[item]
4323
+ self.node_centroids.update(new_dict)
4324
+ other_nodes = np.where(mask, other_nodes_shifted, 0)
4153
4325
 
4154
4326
  if root_ID is not None:
4155
4327
  rootIDs = list(np.unique(root_nodes)) #Sets up adding these vals to the identitiy dictionary. Gets skipped if this has already been done.
@@ -4188,7 +4360,7 @@ class Network_3D:
4188
4360
 
4189
4361
  return nodes, identity_dict
4190
4362
 
4191
- def merge_nodes(self, addn_nodes_name, label_nodes = True, root_id = "Root_Nodes", centroids = False, down_factor = None):
4363
+ def merge_nodes(self, addn_nodes_name, label_nodes = True, root_id = "Root_Nodes", centroids = False, down_factor = None, is_array = False):
4192
4364
  """
4193
4365
  Merges the self._nodes attribute with alternate labelled node images. The alternate nodes can be inputted as a string for a filepath to a tif,
4194
4366
  or as a directory address containing only tif images, which will merge the _nodes attribute with all tifs in the folder. The _node_identities attribute
@@ -4215,7 +4387,11 @@ class Network_3D:
4215
4387
  self.node_centroids[item] = down_factor * self.node_centroids[item]
4216
4388
 
4217
4389
  try: #Try presumes the input is a tif
4218
- addn_nodes = tifffile.imread(addn_nodes_name) #If not this will fail and activate the except block
4390
+ if not is_array:
4391
+ addn_nodes = tifffile.imread(addn_nodes_name) #If not this will fail and activate the except block
4392
+ else:
4393
+ addn_nodes = addn_nodes_name # Passing it an array directly
4394
+ addn_nodes_name = "Node"
4219
4395
 
4220
4396
  if label_nodes is True:
4221
4397
  addn_nodes, num_nodes2 = label_objects(addn_nodes) # Label the node objects. Note this presumes no overlap between node masks.
@@ -4227,7 +4403,6 @@ class Network_3D:
4227
4403
  num_nodes = int(np.max(node_labels))
4228
4404
 
4229
4405
  except: #Exception presumes the input is a directory containing multiple tifs, to allow multi-node stackage.
4230
-
4231
4406
  addn_nodes_list = directory_info(addn_nodes_name)
4232
4407
 
4233
4408
  for i, addn_nodes in enumerate(addn_nodes_list):
@@ -4882,14 +5057,14 @@ class Network_3D:
4882
5057
 
4883
5058
 
4884
5059
 
4885
- def prune_samenode_connections(self):
5060
+ def prune_samenode_connections(self, target = None):
4886
5061
  """
4887
5062
  If working with a network that has multiple node identities (from merging nodes or otherwise manipulating this property),
4888
5063
  this method will remove from the network and network_lists properties any connections that exist between the same node identity,
4889
5064
  in case we want to investigate only connections between differing objects.
4890
5065
  """
4891
5066
 
4892
- self._network_lists, self._node_identities = network_analysis.prune_samenode_connections(self._network_lists, self._node_identities)
5067
+ self._network_lists, self._node_identities = network_analysis.prune_samenode_connections(self._network_lists, self._node_identities, target = target)
4893
5068
  self._network, num_weights = network_analysis.weighted_network(self._network_lists)
4894
5069
 
4895
5070
 
@@ -5608,7 +5783,7 @@ class Network_3D:
5608
5783
  print(f"Using {volume} for the volume measurement (Volume of provided mask as scaled by xy and z scaling)")
5609
5784
 
5610
5785
  # Compute distance transform on padded array
5611
- legal = smart_dilate.compute_distance_transform_distance(legal, sampling = [self.z_scale, self.xy_scale, self.xy_scale])
5786
+ legal = smart_dilate.compute_distance_transform_distance(legal, sampling = [self.z_scale, self.xy_scale, self.xy_scale], fast_dil = True)
5612
5787
 
5613
5788
  # Remove padding after distance transform
5614
5789
  if dim == 2:
@@ -5702,7 +5877,6 @@ class Network_3D:
5702
5877
 
5703
5878
 
5704
5879
  def morph_proximity(self, search = 0, targets = None, fastdil = False):
5705
-
5706
5880
  if type(search) == list:
5707
5881
  search_x, search_z = search #Suppose we just want to directly pass these params
5708
5882
  else:
@@ -5711,7 +5885,6 @@ class Network_3D:
5711
5885
  num_nodes = int(np.max(self._nodes))
5712
5886
 
5713
5887
  my_dict = proximity.create_node_dictionary(self._nodes, num_nodes, search_x, search_z, targets = targets, fastdil = fastdil, xy_scale = self._xy_scale, z_scale = self._z_scale, search = search)
5714
-
5715
5888
  my_dict = proximity.find_shared_value_pairs(my_dict)
5716
5889
 
5717
5890
  my_dict = create_and_save_dataframe(my_dict)