nettracer3d 0.9.8__py3-none-any.whl → 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -733,17 +733,33 @@ def assign_node_colors(node_list: List[int], labeled_array: np.ndarray) -> Tuple
733
733
  return rgba_array, node_to_color_names
734
734
 
735
735
  def assign_community_colors(community_dict: Dict[int, int], labeled_array: np.ndarray) -> Tuple[np.ndarray, Dict[int, str]]:
736
- """fast version using lookup table approach."""
736
+ """Fast version using lookup table approach with brown outliers for community 0."""
737
+
738
+ # Separate outliers (community 0) from regular communities
739
+ outliers = {node: comm for node, comm in community_dict.items() if comm == 0}
740
+ non_outlier_dict = {node: comm for node, comm in community_dict.items() if comm != 0}
737
741
 
738
- # Same setup as before
739
- communities = set(community_dict.values())
740
- community_sizes = Counter(community_dict.values())
741
- sorted_communities = sorted(communities, key=lambda x: community_sizes[x], reverse=True)
742
+ # Get communities excluding outliers
743
+ communities = set(non_outlier_dict.values()) if non_outlier_dict else set()
742
744
 
743
- colors = generate_distinct_colors(len(communities))
745
+ # Generate colors for non-outlier communities only
746
+ colors = generate_distinct_colors(len(communities)) if communities else []
744
747
  colors_rgba = np.array([(r, g, b, 255) for r, g, b in colors], dtype=np.uint8)
745
748
 
746
- community_to_color = {comm: colors_rgba[i] for i, comm in enumerate(sorted_communities)}
749
+ # Sort communities by size for consistent color assignment
750
+ if non_outlier_dict:
751
+ community_sizes = Counter(non_outlier_dict.values())
752
+ sorted_communities = sorted(communities, key=lambda x: (-community_sizes[x], x))
753
+ community_to_color = {comm: colors_rgba[i] for i, comm in enumerate(sorted_communities)}
754
+ else:
755
+ community_to_color = {}
756
+
757
+ # Add brown color for outliers (community 0)
758
+ brown_rgba = np.array([139, 69, 19, 255], dtype=np.uint8) # Brown color
759
+ if outliers:
760
+ community_to_color[0] = brown_rgba
761
+
762
+ # Create node to color mapping using original community_dict
747
763
  node_to_color = {node: community_to_color[comm] for node, comm in community_dict.items()}
748
764
 
749
765
  # Create lookup table - this is the key optimization
@@ -756,7 +772,7 @@ def assign_community_colors(community_dict: Dict[int, int], labeled_array: np.nd
756
772
  # Single vectorized operation - this is much faster!
757
773
  rgba_array = color_lut[labeled_array]
758
774
 
759
- # Rest remains the same
775
+ # Convert to RGB for color names (including brown for outliers)
760
776
  community_to_color_rgb = {k: tuple(v[:3]) for k, v in community_to_color.items()}
761
777
  node_to_color_names = convert_node_colors_to_names(community_to_color_rgb)
762
778
 
nettracer3d/morphology.py CHANGED
@@ -65,7 +65,7 @@ def reslice_3d_array(args):
65
65
  return resliced_array
66
66
 
67
67
 
68
- def _get_node_edge_dict(label_array, edge_array, label, dilate_xy, dilate_z, cores = 0, search = 0, fastdil = False, xy_scale = 1, z_scale = 1):
68
+ def _get_node_edge_dict(label_array, edge_array, label, dilate_xy, dilate_z, cores = 0, search = 0, fastdil = False, length = False, xy_scale = 1, z_scale = 1):
69
69
  """Internal method used for the secondary algorithm to find pixel involvement of nodes around an edge."""
70
70
 
71
71
  # Create a boolean mask where elements with the specified label are True
@@ -74,24 +74,25 @@ def _get_node_edge_dict(label_array, edge_array, label, dilate_xy, dilate_z, cor
74
74
 
75
75
  if cores == 0: #For getting the volume of objects. Cores presumes you want the 'core' included in the interaction.
76
76
  edge_array = edge_array * dil_array # Filter the edges by the label in question
77
- label_array = np.count_nonzero(dil_array)
78
- edge_array = np.count_nonzero(edge_array) # For getting the interacting skeleton
79
-
80
77
  elif cores == 1: #Cores being 1 presumes you do not want to 'core' included in the interaction
81
78
  label_array = dil_array - label_array
82
79
  edge_array = edge_array * label_array
83
- label_array = np.count_nonzero(label_array)
84
- edge_array = np.count_nonzero(edge_array) # For getting the interacting skeleton
85
-
86
80
  elif cores == 2: #Presumes you want skeleton within the core but to only 'count' the stuff around the core for volumes... because of imaging artifacts, perhaps
87
81
  edge_array = edge_array * dil_array
88
82
  label_array = dil_array - label_array
89
- label_array = np.count_nonzero(label_array)
90
- edge_array = np.count_nonzero(edge_array) # For getting the interacting skeleton
91
83
 
84
+ label_count = np.count_nonzero(label_array) * xy_scale * xy_scale * z_scale
92
85
 
93
-
94
- args = [edge_array, label_array]
86
+ if not length:
87
+ edge_count = np.count_nonzero(edge_array) * xy_scale * xy_scale * z_scale # For getting the interacting skeleton
88
+ else:
89
+ edge_count = calculate_skeleton_lengths(
90
+ edge_array,
91
+ xy_scale=xy_scale,
92
+ z_scale=z_scale
93
+ )
94
+
95
+ args = [edge_count, label_count]
95
96
 
96
97
  return args
97
98
 
@@ -115,7 +116,7 @@ def process_label(args):
115
116
 
116
117
 
117
118
 
118
- def create_node_dictionary(nodes, edges, num_nodes, dilate_xy, dilate_z, cores=0, search = 0, fastdil = False, xy_scale = 1, z_scale = 1):
119
+ def create_node_dictionary(nodes, edges, num_nodes, dilate_xy, dilate_z, cores=0, search = 0, fastdil = False, length = False, xy_scale = 1, z_scale = 1):
119
120
  """Modified to pre-compute all bounding boxes using find_objects"""
120
121
  node_dict = {}
121
122
  array_shape = nodes.shape
@@ -135,20 +136,20 @@ def create_node_dictionary(nodes, edges, num_nodes, dilate_xy, dilate_z, cores=0
135
136
  # Process results in parallel
136
137
  for label, sub_nodes, sub_edges in results:
137
138
  executor.submit(create_dict_entry, node_dict, label, sub_nodes, sub_edges,
138
- dilate_xy, dilate_z, cores, search, fastdil, xy_scale, z_scale)
139
+ dilate_xy, dilate_z, cores, search, fastdil, length, xy_scale, z_scale)
139
140
 
140
141
  return node_dict
141
142
 
142
- def create_dict_entry(node_dict, label, sub_nodes, sub_edges, dilate_xy, dilate_z, cores = 0, search = 0, fastdil = False, xy_scale = 1, z_scale = 1):
143
+ def create_dict_entry(node_dict, label, sub_nodes, sub_edges, dilate_xy, dilate_z, cores = 0, search = 0, fastdil = False, length = False, xy_scale = 1, z_scale = 1):
143
144
  """Internal method used for the secondary algorithm to pass around args in parallel."""
144
145
 
145
146
  if label is None:
146
147
  pass
147
148
  else:
148
- node_dict[label] = _get_node_edge_dict(sub_nodes, sub_edges, label, dilate_xy, dilate_z, cores = cores, search = search, fastdil = fastdil, xy_scale = xy_scale, z_scale = z_scale)
149
+ node_dict[label] = _get_node_edge_dict(sub_nodes, sub_edges, label, dilate_xy, dilate_z, cores = cores, search = search, fastdil = fastdil, length = length, xy_scale = xy_scale, z_scale = z_scale)
149
150
 
150
151
 
151
- def quantify_edge_node(nodes, edges, search = 0, xy_scale = 1, z_scale = 1, cores = 0, resize = None, save = True, skele = False, fastdil = False):
152
+ def quantify_edge_node(nodes, edges, search = 0, xy_scale = 1, z_scale = 1, cores = 0, resize = None, save = True, skele = False, length = False, auto = True, fastdil = False):
152
153
 
153
154
  def save_dubval_dict(dict, index_name, val1name, val2name, filename):
154
155
 
@@ -168,6 +169,9 @@ def quantify_edge_node(nodes, edges, search = 0, xy_scale = 1, z_scale = 1, core
168
169
  edges = tifffile.imread(edges)
169
170
 
170
171
  if skele:
172
+ if auto:
173
+ edges = nettracer.skeletonize(edges)
174
+ edges = nettracer.fill_holes_3d(edges)
171
175
  edges = nettracer.skeletonize(edges)
172
176
  else:
173
177
  edges = nettracer.binarize(edges)
@@ -188,7 +192,7 @@ def quantify_edge_node(nodes, edges, search = 0, xy_scale = 1, z_scale = 1, core
188
192
  dilate_xy, dilate_z = 0, 0
189
193
 
190
194
 
191
- edge_quants = create_node_dictionary(nodes, edges, num_nodes, dilate_xy, dilate_z, cores = cores, search = search, fastdil = fastdil, xy_scale = xy_scale, z_scale = z_scale) #Find which edges connect which nodes and put them in a dictionary.
195
+ edge_quants = create_node_dictionary(nodes, edges, num_nodes, dilate_xy, dilate_z, cores = cores, search = search, fastdil = fastdil, length = length, xy_scale = xy_scale, z_scale = z_scale) #Find which edges connect which nodes and put them in a dictionary.
192
196
 
193
197
  if save:
194
198
 
@@ -199,6 +203,98 @@ def quantify_edge_node(nodes, edges, search = 0, xy_scale = 1, z_scale = 1, core
199
203
  return edge_quants
200
204
 
201
205
 
206
+ # Helper methods for counting the lens of skeletons:
207
+
208
+ def calculate_skeleton_lengths(skeleton_binary, xy_scale=1.0, z_scale=1.0, skeleton_coords = None):
209
+ """
210
+ Calculate total length of all skeletons in a 3D binary image.
211
+
212
+ skeleton_binary: 3D boolean array where True = skeleton voxel
213
+ xy_scale, z_scale: physical units per voxel
214
+ """
215
+
216
+ if skeleton_coords is None:
217
+ # Find all skeleton voxels
218
+ skeleton_coords = np.argwhere(skeleton_binary)
219
+ shape = skeleton_binary.shape
220
+ else:
221
+ shape = skeleton_binary #Very professional stuff
222
+
223
+ if len(skeleton_coords) == 0:
224
+ return 0.0
225
+
226
+ # Create a mapping from coordinates to indices for fast lookup
227
+ coord_to_idx = {tuple(coord): idx for idx, coord in enumerate(skeleton_coords)}
228
+
229
+ # Build adjacency graph
230
+ adjacency_list = build_adjacency_graph(skeleton_coords, coord_to_idx, shape)
231
+
232
+ # Calculate lengths using scaled distances
233
+ total_length = calculate_graph_length(skeleton_coords, adjacency_list, xy_scale, z_scale)
234
+
235
+ return total_length
236
+
237
+ def build_adjacency_graph(skeleton_coords, coord_to_idx, shape):
238
+ """Build adjacency list for skeleton voxels using 26-connectivity."""
239
+ adjacency_list = [[] for _ in range(len(skeleton_coords))]
240
+
241
+ # 26-connectivity offsets (all combinations of -1,0,1 except 0,0,0)
242
+ offsets = []
243
+ for dz in [-1, 0, 1]:
244
+ for dy in [-1, 0, 1]:
245
+ for dx in [-1, 0, 1]:
246
+ if not (dx == 0 and dy == 0 and dz == 0):
247
+ offsets.append((dz, dy, dx))
248
+
249
+ for idx, coord in enumerate(skeleton_coords):
250
+ z, y, x = coord
251
+
252
+ # Check all 26 neighbors
253
+ for dz, dy, dx in offsets:
254
+ nz, ny, nx = z + dz, y + dy, x + dx
255
+
256
+ # Check bounds
257
+ if (0 <= nz < shape[0] and
258
+ 0 <= ny < shape[1] and
259
+ 0 <= nx < shape[2]):
260
+
261
+ neighbor_coord = (nz, ny, nx)
262
+ if neighbor_coord in coord_to_idx:
263
+ neighbor_idx = coord_to_idx[neighbor_coord]
264
+ adjacency_list[idx].append(neighbor_idx)
265
+
266
+ return adjacency_list
267
+
268
+ def calculate_graph_length(skeleton_coords, adjacency_list, xy_scale, z_scale):
269
+ """Calculate total length by summing distances between adjacent voxels."""
270
+ total_length = 0.0
271
+ processed_edges = set()
272
+
273
+ for idx, neighbors in enumerate(adjacency_list):
274
+ coord = skeleton_coords[idx]
275
+
276
+ for neighbor_idx in neighbors:
277
+ # Avoid double-counting edges
278
+ edge = tuple(sorted([idx, neighbor_idx]))
279
+ if edge in processed_edges:
280
+ continue
281
+ processed_edges.add(edge)
282
+
283
+ neighbor_coord = skeleton_coords[neighbor_idx]
284
+
285
+ # Calculate scaled distance
286
+ dz = (coord[0] - neighbor_coord[0]) * z_scale
287
+ dy = (coord[1] - neighbor_coord[1]) * xy_scale
288
+ dx = (coord[2] - neighbor_coord[2]) * xy_scale
289
+
290
+ distance = np.sqrt(dx*dx + dy*dy + dz*dz)
291
+ total_length += distance
292
+
293
+ return total_length
294
+
295
+ # End helper methods
296
+
297
+
202
298
 
203
299
  def calculate_voxel_volumes(array, xy_scale=1, z_scale=1):
204
300
  """
@@ -8,7 +8,8 @@ from matplotlib.colors import LinearSegmentedColormap
8
8
  from sklearn.cluster import DBSCAN
9
9
  from sklearn.neighbors import NearestNeighbors
10
10
  import matplotlib.colors as mcolors
11
-
11
+ from collections import Counter
12
+ from . import community_extractor
12
13
 
13
14
 
14
15
  import os
@@ -347,7 +348,8 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
347
348
  id_dictionary: Optional[Dict[int, str]] = None,
348
349
  graph_label = "Community ID",
349
350
  title = 'UMAP Visualization of Community Compositions',
350
- neighborhoods: Optional[Dict[int, int]] = None):
351
+ neighborhoods: Optional[Dict[int, int]] = None,
352
+ original_communities = None):
351
353
  """
352
354
  Convert cluster composition data to UMAP visualization.
353
355
 
@@ -394,37 +396,50 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
394
396
  embedding = reducer.fit_transform(compositions)
395
397
 
396
398
  # Determine coloring scheme based on parameters
397
- if neighborhoods is not None:
399
+ if neighborhoods is not None and original_communities is not None:
398
400
  # Use neighborhood coloring - import the community extractor methods
399
401
  from . import community_extractor
402
+ from collections import Counter
403
+
404
+ # Use original_communities (which is {node: neighborhood}) for color generation
405
+ # This ensures we use the proper node counts for sorting
400
406
 
401
- # Filter neighborhoods to only include cluster_ids that exist in our data
402
- filtered_neighborhoods = {node_id: neighborhood_id
403
- for node_id, neighborhood_id in neighborhoods.items()
404
- if node_id in cluster_ids}
407
+ # Separate outliers (neighborhood 0) from regular neighborhoods in ORIGINAL structure
408
+ outlier_neighborhoods = {node: neighborhood for node, neighborhood in original_communities.items() if neighborhood == 0}
409
+ non_outlier_neighborhoods = {node: neighborhood for node, neighborhood in original_communities.items() if neighborhood != 0}
405
410
 
406
- # Create a dummy labeled array just for the coloring function
407
- # We only need the coloring logic, not actual clustering
408
- dummy_array = np.array(cluster_ids)
411
+ # Get neighborhoods excluding outliers
412
+ unique_neighborhoods = set(non_outlier_neighborhoods.values()) if non_outlier_neighborhoods else set()
409
413
 
410
- # Get colors using the community coloration method
411
- _, neighborhood_color_names = community_extractor.assign_community_colors(
412
- filtered_neighborhoods, dummy_array
413
- )
414
+ # Generate colors for non-outlier neighborhoods only (same as assign_community_colors)
415
+ colors = community_extractor.generate_distinct_colors(len(unique_neighborhoods)) if unique_neighborhoods else []
416
+
417
+ # Sort neighborhoods by size for consistent color assignment (same logic as assign_community_colors)
418
+ # Use the ORIGINAL node counts from original_communities
419
+ if non_outlier_neighborhoods:
420
+ neighborhood_sizes = Counter(non_outlier_neighborhoods.values())
421
+ sorted_neighborhoods = sorted(unique_neighborhoods, key=lambda x: (-neighborhood_sizes[x], x))
422
+ neighborhood_to_color = {neighborhood: colors[i] for i, neighborhood in enumerate(sorted_neighborhoods)}
423
+ else:
424
+ neighborhood_to_color = {}
414
425
 
415
- # Create color mapping for our points
416
- unique_neighborhoods = sorted(list(set(filtered_neighborhoods.values())))
417
- colors = community_extractor.generate_distinct_colors(len(unique_neighborhoods))
418
- neighborhood_to_color = {neighborhood: colors[i] for i, neighborhood in enumerate(unique_neighborhoods)}
426
+ # Add brown color for outliers (neighborhood 0) - same as assign_community_colors
427
+ if outlier_neighborhoods:
428
+ neighborhood_to_color[0] = (139, 69, 19) # Brown color (RGB, not RGBA here)
419
429
 
420
- # Map each cluster to its neighborhood color
430
+ # Map each cluster to its neighborhood color using 'neighborhoods' ({community: neighborhood}) for assignment
421
431
  point_colors = []
422
432
  neighborhood_labels = []
423
433
  for cluster_id in cluster_ids:
424
- if cluster_id in filtered_neighborhoods:
425
- neighborhood_id = filtered_neighborhoods[cluster_id]
426
- point_colors.append(neighborhood_to_color[neighborhood_id])
427
- neighborhood_labels.append(neighborhood_id)
434
+ if cluster_id in neighborhoods:
435
+ neighborhood_id = neighborhoods[cluster_id] # This is {community: neighborhood}
436
+ if neighborhood_id in neighborhood_to_color:
437
+ point_colors.append(neighborhood_to_color[neighborhood_id])
438
+ neighborhood_labels.append(neighborhood_id)
439
+ else:
440
+ # Default color for neighborhoods not found
441
+ point_colors.append((128, 128, 128)) # Gray
442
+ neighborhood_labels.append("Unknown")
428
443
  else:
429
444
  # Default color for nodes not in any neighborhood
430
445
  point_colors.append((128, 128, 128)) # Gray
@@ -432,6 +447,10 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
432
447
 
433
448
  # Normalize RGB values for matplotlib (0-1 range)
434
449
  point_colors = [(r/255.0, g/255.0, b/255.0) for r, g, b in point_colors]
450
+
451
+ # Get unique neighborhoods for legend
452
+ unique_neighborhoods_for_legend = sorted(list(set(neighborhood_to_color.keys())))
453
+
435
454
  use_neighborhood_coloring = True
436
455
 
437
456
  elif id_dictionary is not None:
@@ -467,8 +486,8 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
467
486
  # Add cluster ID labels
468
487
  for i, cluster_id in enumerate(cluster_ids):
469
488
  display_label = f'{cluster_id}'
470
- if use_neighborhood_coloring and cluster_id in filtered_neighborhoods:
471
- neighborhood_id = filtered_neighborhoods[cluster_id]
489
+ if use_neighborhood_coloring and cluster_id in neighborhoods:
490
+ neighborhood_id = neighborhoods[cluster_id]
472
491
  display_label = f'{cluster_id}\n(N{neighborhood_id})'
473
492
  elif id_dictionary is not None:
474
493
  identity = id_dictionary.get(cluster_id, "Unknown")
@@ -483,7 +502,7 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
483
502
  if use_neighborhood_coloring:
484
503
  # Create custom legend for neighborhoods
485
504
  legend_elements = []
486
- for neighborhood_id in unique_neighborhoods:
505
+ for neighborhood_id in unique_neighborhoods_for_legend:
487
506
  color = neighborhood_to_color[neighborhood_id]
488
507
  norm_color = (color[0]/255.0, color[1]/255.0, color[2]/255.0)
489
508
  legend_elements.append(
@@ -530,8 +549,8 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
530
549
  # Add cluster ID labels
531
550
  for i, cluster_id in enumerate(cluster_ids):
532
551
  display_label = f'C{cluster_id}'
533
- if use_neighborhood_coloring and cluster_id in filtered_neighborhoods:
534
- neighborhood_id = filtered_neighborhoods[cluster_id]
552
+ if use_neighborhood_coloring and cluster_id in neighborhoods:
553
+ neighborhood_id = neighborhoods[cluster_id]
535
554
  display_label = f'C{cluster_id}\n(N{neighborhood_id})'
536
555
  elif id_dictionary is not None:
537
556
  identity = id_dictionary.get(cluster_id, "Unknown")
@@ -554,7 +573,7 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
554
573
  if use_neighborhood_coloring:
555
574
  # Create custom legend for neighborhoods
556
575
  legend_elements = []
557
- for neighborhood_id in unique_neighborhoods:
576
+ for neighborhood_id in unique_neighborhoods_for_legend:
558
577
  color = neighborhood_to_color[neighborhood_id]
559
578
  norm_color = (color[0]/255.0, color[1]/255.0, color[2]/255.0)
560
579
  legend_elements.append(
@@ -585,8 +604,8 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
585
604
  for i, cluster_id in enumerate(cluster_ids):
586
605
  composition = compositions[i]
587
606
  additional_info = ""
588
- if use_neighborhood_coloring and cluster_id in filtered_neighborhoods:
589
- neighborhood_id = filtered_neighborhoods[cluster_id]
607
+ if use_neighborhood_coloring and cluster_id in neighborhoods:
608
+ neighborhood_id = neighborhoods[cluster_id]
590
609
  additional_info = f" (Neighborhood: {neighborhood_id})"
591
610
  elif id_dictionary is not None:
592
611
  identity = id_dictionary.get(cluster_id, "Unknown")
@@ -974,60 +993,63 @@ def create_node_heatmap(node_intensity, node_centroids, shape=None, is_3d=True,
974
993
  node_to_intensity[node_id] = node_intensity_clean[node_id]
975
994
 
976
995
  # Create colormap function (RdBu_r - red for high, blue for low, yellow/white for middle)
977
- def intensity_to_rgb(intensity, min_val, max_val):
978
- """Convert intensity value to RGB using RdBu_r colormap logic, centered at 0"""
996
+ def intensity_to_rgba(intensity, min_val, max_val):
997
+ """Convert intensity value to RGBA using RdBu_r colormap logic, centered at 0"""
979
998
 
980
999
  # Handle edge case where all values are the same
981
1000
  if max_val == min_val:
982
1001
  if intensity == 0:
983
- return np.array([255, 255, 255], dtype=np.uint8) # White for 0
1002
+ return np.array([255, 255, 255, 0], dtype=np.uint8) # Transparent white for 0
984
1003
  elif intensity > 0:
985
- return np.array([255, 200, 200], dtype=np.uint8) # Light red for positive
1004
+ return np.array([255, 200, 200, 255], dtype=np.uint8) # Opaque light red for positive
986
1005
  else:
987
- return np.array([200, 200, 255], dtype=np.uint8) # Light blue for negative
1006
+ return np.array([200, 200, 255, 255], dtype=np.uint8) # Opaque light blue for negative
988
1007
 
989
1008
  # Find the maximum absolute value for symmetric scaling around 0
990
1009
  max_abs = max(abs(min_val), abs(max_val))
991
1010
 
992
- # If max_abs is 0, everything is 0, so return white
1011
+ # If max_abs is 0, everything is 0, so return transparent
993
1012
  if max_abs == 0:
994
- return np.array([255, 255, 255], dtype=np.uint8) # White
1013
+ return np.array([255, 255, 255, 0], dtype=np.uint8) # Transparent white
995
1014
 
996
1015
  # Normalize intensity to -1 to 1 range, centered at 0
997
1016
  normalized = intensity / max_abs
998
1017
  normalized = np.clip(normalized, -1, 1)
999
1018
 
1000
1019
  if normalized > 0:
1001
- # Positive values: white to red (intensity 0 = white, max positive = red)
1020
+ # Positive values: white to red (intensity 0 = transparent, max positive = red)
1002
1021
  r = 255
1003
1022
  g = int(255 * (1 - normalized))
1004
1023
  b = int(255 * (1 - normalized))
1024
+ alpha = 255 # Fully opaque for all non-zero values
1005
1025
  elif normalized < 0:
1006
- # Negative values: white to blue (intensity 0 = white, max negative = blue)
1026
+ # Negative values: white to blue (intensity 0 = transparent, max negative = blue)
1007
1027
  r = int(255 * (1 + normalized))
1008
1028
  g = int(255 * (1 + normalized))
1009
1029
  b = 255
1030
+ alpha = 255 # Fully opaque for all non-zero values
1010
1031
  else:
1011
- # Exactly 0: white
1012
- r, g, b = 255, 255, 255
1032
+ # Exactly 0: transparent
1033
+ r, g, b, alpha = 255, 255, 255, 0
1013
1034
 
1014
- return np.array([r, g, b], dtype=np.uint8)
1015
-
1016
- # Create lookup table for RGB colors
1035
+ return np.array([r, g, b, alpha], dtype=np.uint8)
1036
+
1037
+ # Modified usage in your main function:
1038
+ # Create lookup table for RGBA colors (note the 4 channels now)
1017
1039
  max_label = max(max(labeled_array.flat), max(node_to_intensity.keys()) if node_to_intensity else 0)
1018
- color_lut = np.zeros((max_label + 1, 3), dtype=np.uint8) # Default to black (0,0,0)
1019
-
1020
- # Fill lookup table with RGB colors based on intensity
1040
+ color_lut = np.zeros((max_label + 1, 4), dtype=np.uint8) # Default to transparent (0,0,0,0)
1041
+
1042
+ # Fill lookup table with RGBA colors based on intensity
1021
1043
  for node_id, intensity in node_to_intensity.items():
1022
- rgb_color = intensity_to_rgb(intensity, min_intensity, max_intensity)
1023
- color_lut[int(node_id)] = rgb_color
1024
-
1044
+ rgba_color = intensity_to_rgba(intensity, min_intensity, max_intensity)
1045
+ color_lut[int(node_id)] = rgba_color
1046
+
1025
1047
  # Apply lookup table to labeled array - single vectorized operation
1026
1048
  if is_3d:
1027
- # Return full 3D RGB array [Z, Y, X, 3]
1049
+ # Return full 3D RGBA array [Z, Y, X, 4]
1028
1050
  heatmap_array = color_lut[labeled_array]
1029
1051
  else:
1030
- # Return 2D RGB array
1052
+ # Return 2D RGBA array
1031
1053
  if labeled_array.ndim == 3:
1032
1054
  # Take middle slice for 2D representation
1033
1055
  middle_slice = labeled_array.shape[0] // 2
@@ -1035,7 +1057,7 @@ def create_node_heatmap(node_intensity, node_centroids, shape=None, is_3d=True,
1035
1057
  else:
1036
1058
  # Already 2D
1037
1059
  heatmap_array = color_lut[labeled_array]
1038
-
1060
+
1039
1061
  return heatmap_array
1040
1062
 
1041
1063
  else:
@@ -1104,19 +1126,132 @@ def create_node_heatmap(node_intensity, node_centroids, shape=None, is_3d=True,
1104
1126
  plt.tight_layout()
1105
1127
  plt.show()
1106
1128
 
1107
- # Example usage:
1108
- if __name__ == "__main__":
1109
- # Sample data for demonstration
1110
- sample_dict = {
1111
- 'category_A': np.array([0.1, 0.5, 0.8, 0.3, 0.9]),
1112
- 'category_B': np.array([0.7, 0.2, 0.6, 0.4, 0.1]),
1113
- 'category_C': np.array([0.9, 0.8, 0.2, 0.7, 0.5])
1114
- }
1129
+ def create_violin_plots(data_dict, graph_title="Violin Plots"):
1130
+ """
1131
+ Create violin plots from dictionary data with distinct colors.
1115
1132
 
1116
- sample_id_set = ['feature_1', 'feature_2', 'feature_3', 'feature_4', 'feature_5']
1133
+ Parameters:
1134
+ data_dict (dict): Dictionary where keys are column headers (strings) and
1135
+ values are lists of floats
1136
+ graph_title (str): Title for the overall plot
1137
+ """
1138
+ if not data_dict:
1139
+ print("No data to plot")
1140
+ return
1117
1141
 
1118
- # Create the heatmap
1119
- fig, ax = plot_dict_heatmap(sample_dict, sample_id_set,
1120
- title="Sample Heatmap Visualization")
1142
+ # Prepare data
1143
+ data_dict = dict(sorted(data_dict.items()))
1144
+ labels = list(data_dict.keys())
1145
+ print(labels)
1146
+ data_lists = list(data_dict.values())
1121
1147
 
1148
+ # Generate colors using the community color strategy
1149
+ """
1150
+ try:
1151
+ # Create a mock community dict for color generation
1152
+ mock_community_dict = {i: i+1 for i in range(len(labels))} # No outliers for simplicity
1153
+
1154
+ # Get distinct colors
1155
+ n_colors = len(labels)
1156
+ colors_rgb = community_extractor.generate_distinct_colors(n_colors)
1157
+
1158
+ # Sort by data size for consistent color assignment (like community sizes)
1159
+ data_sizes = [(i, len(data_lists[i])) for i in range(len(data_lists))]
1160
+ sorted_indices = sorted(data_sizes, key=lambda x: (-x[1], x[0]))
1161
+
1162
+ # Create color mapping
1163
+ colors = []
1164
+ for i, _ in sorted_indices:
1165
+ color_idx = sorted_indices.index((i, _))
1166
+ if color_idx < len(colors_rgb):
1167
+ # Convert RGB (0-255) to matplotlib format (0-1)
1168
+ rgb_normalized = tuple(c/255.0 for c in colors_rgb[color_idx])
1169
+ colors.append(rgb_normalized)
1170
+ else:
1171
+ colors.append('gray') # Fallback color
1172
+
1173
+ # Reorder colors to match original label order
1174
+ final_colors = ['gray'] * len(labels)
1175
+ for idx, (original_idx, _) in enumerate(sorted_indices):
1176
+ final_colors[original_idx] = colors[idx]
1177
+
1178
+ """
1179
+
1180
+ try:
1181
+ final_colors = generate_distinct_colors(len(labels))
1182
+
1183
+ except Exception as e:
1184
+ print(f"Color generation failed, using default colors: {e}")
1185
+ # Fallback to default matplotlib colors
1186
+ final_colors = plt.cm.Set3(np.linspace(0, 1, len(labels)))
1187
+
1188
+ # Create the plot
1189
+ fig, ax = plt.subplots(figsize=(max(8, len(labels) * 1.5), 6))
1190
+
1191
+ # Create violin plots
1192
+ violin_parts = ax.violinplot(data_lists, positions=range(len(labels)),
1193
+ showmeans=False, showmedians=True, showextrema=True)
1194
+
1195
+ # Color the violins
1196
+ for i, pc in enumerate(violin_parts['bodies']):
1197
+ if i < len(final_colors):
1198
+ pc.set_facecolor(final_colors[i])
1199
+ pc.set_alpha(0.7)
1200
+
1201
+ # Color the other violin elements
1202
+ for partname in ('cbars', 'cmins', 'cmaxes', 'cmedians'):
1203
+ if partname in violin_parts:
1204
+ violin_parts[partname].set_edgecolor('black')
1205
+ violin_parts[partname].set_linewidth(1)
1206
+
1207
+ # Add data points as scatter plot overlay with much lower transparency
1208
+ """
1209
+ for i, data in enumerate(data_lists):
1210
+ y = data
1211
+ # Add some jitter to x positions for better visibility
1212
+ x = np.random.normal(i, 0.04, size=len(y))
1213
+ ax.scatter(x, y, alpha=0.2, s=15, color='black', edgecolors='none', zorder=3) # No borders, more transparent
1214
+ """
1215
+
1216
+ # Calculate reasonable y-axis limits to focus on the bulk of the data
1217
+ all_data = [val for sublist in data_lists for val in sublist]
1218
+ if all_data:
1219
+ # Use percentiles to exclude extreme outliers from the view
1220
+ y_min = np.percentile(all_data, 5) # 5th percentile
1221
+ y_max = np.percentile(all_data, 95) # 95th percentile
1222
+
1223
+ # Add some padding
1224
+ y_range = y_max - y_min
1225
+ y_padding = y_range * 0.15
1226
+ ax.set_ylim(y_min - y_padding, y_max + y_padding)
1227
+
1228
+ # Add IQR and median text annotations BELOW the violins
1229
+ for i, data in enumerate(data_lists):
1230
+ if len(data) > 0:
1231
+ q1, median, q3 = np.percentile(data, [25, 50, 75])
1232
+ iqr = q3 - q1
1233
+
1234
+ # Position text below the violin (using current y-axis limits)
1235
+ y_min_current = ax.get_ylim()[0]
1236
+ y_text = y_min_current - (ax.get_ylim()[1] - ax.get_ylim()[0]) * 0.15
1237
+
1238
+ ax.text(i, y_text, f'Median: {median:.2f}\nIQR: {iqr:.2f}',
1239
+ horizontalalignment='center', fontsize=8,
1240
+ bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
1241
+
1242
+ # Customize the plot
1243
+ ax.set_xticks(range(len(labels)))
1244
+ ax.set_xticklabels(labels, rotation=45, ha='right')
1245
+ ax.set_title(graph_title, fontsize=14, fontweight='bold')
1246
+ ax.set_ylabel('Normalized Values (Z-score-like)', fontsize=12)
1247
+ ax.grid(True, alpha=0.3)
1248
+
1249
+ # Add a horizontal line at y=0 (the identity centerpoint)
1250
+ ax.axhline(y=0, color='red', linestyle='--', alpha=0.5, linewidth=1,
1251
+ label='Identity Centerpoint')
1252
+ ax.legend(loc='upper right')
1253
+
1254
+ # Adjust layout to prevent label cutoff and accommodate bottom text
1255
+ plt.subplots_adjust(bottom=0.2) # Extra space for bottom text
1256
+ plt.tight_layout()
1122
1257
  plt.show()