nettracer3d 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

@@ -1,52 +1,176 @@
1
1
  import numpy as np
2
2
  from sklearn.cluster import KMeans
3
+ from sklearn.metrics import calinski_harabasz_score
3
4
  import matplotlib.pyplot as plt
4
5
  from typing import Dict, Set
5
6
  import umap
6
-
7
+ from matplotlib.colors import LinearSegmentedColormap
8
+ from sklearn.cluster import DBSCAN
9
+ from sklearn.neighbors import NearestNeighbors
7
10
 
8
11
 
9
12
  import os
10
13
  os.environ['LOKY_MAX_CPU_COUNT'] = '4'
11
14
 
12
- def cluster_arrays(data_input, n_clusters, seed = 42):
15
+ def cluster_arrays_dbscan(data_input, seed=42):
13
16
  """
14
- Simple clustering of 1D arrays with key tracking.
17
+ Simple DBSCAN clustering of 1D arrays with sensible defaults.
15
18
 
16
19
  Parameters:
17
20
  -----------
18
21
  data_input : dict or List[List[float]]
19
22
  Dictionary {key: array} or list of arrays to cluster
20
- n_clusters : int
21
- How many groups you want
23
+ seed : int
24
+ Random seed for reproducibility (used for parameter estimation)
22
25
 
23
26
  Returns:
24
27
  --------
25
- dict: {cluster_id: {'keys': [keys], 'arrays': [arrays]}}
28
+ list: [[key1, key2], [key3, key4, key5]] - List of clusters, each containing keys/indices
29
+ Note: Outliers are excluded from the output
26
30
  """
27
31
 
28
32
  # Handle both dict and list inputs
29
33
  if isinstance(data_input, dict):
30
34
  keys = list(data_input.keys())
31
- array_values = list(data_input.values()) # Use .values() to get the arrays
35
+ array_values = list(data_input.values())
32
36
  else:
33
- keys = list(range(len(data_input))) # Use indices as keys for lists
37
+ keys = list(range(len(data_input)))
34
38
  array_values = data_input
35
39
 
36
- # Convert to numpy and cluster
40
+ # Convert to numpy
37
41
  data = np.array(array_values)
38
- kmeans = KMeans(n_clusters=n_clusters, random_state=seed)
39
- labels = kmeans.fit_predict(data)
42
+ n_samples = len(data)
43
+
44
+ # Simple heuristics for DBSCAN parameters
45
+ min_samples = max(3, int(np.sqrt(n_samples) * 0.2)) # Roughly sqrt(n)/5, minimum 3
46
+
47
+ # Estimate eps using 4th nearest neighbor distance (common heuristic)
48
+ k = min(4, n_samples - 1)
49
+ if k > 0:
50
+ nbrs = NearestNeighbors(n_neighbors=k + 1)
51
+ nbrs.fit(data)
52
+ distances, _ = nbrs.kneighbors(data)
53
+ # Use 80th percentile of k-nearest distances as eps
54
+ eps = np.percentile(distances[:, k], 80)
55
+ else:
56
+ eps = 0.1 # fallback
57
+
58
+ print(f"Using DBSCAN with eps={eps:.4f}, min_samples={min_samples}")
59
+
60
+ # Perform DBSCAN clustering
61
+ dbscan = DBSCAN(eps=eps, min_samples=min_samples)
62
+ labels = dbscan.fit_predict(data)
63
+
64
+ # Organize results into clusters (excluding outliers)
65
+ clusters = []
40
66
 
67
+ # Add only the main clusters (non-noise points)
68
+ unique_labels = np.unique(labels)
69
+ main_clusters = [label for label in unique_labels if label != -1]
70
+
71
+ for label in main_clusters:
72
+ cluster_indices = np.where(labels == label)[0]
73
+ cluster_keys = [keys[i] for i in cluster_indices]
74
+ clusters.append(cluster_keys)
75
+
76
+ n_main_clusters = len(main_clusters)
77
+ n_outliers = np.sum(labels == -1)
78
+
79
+ print(f"Found {n_main_clusters} main clusters and {n_outliers} outliers (go in neighborhood 0)")
80
+
81
+ return clusters
41
82
 
83
+ def cluster_arrays(data_input, n_clusters=None, seed=42):
84
+ """
85
+ Simple clustering of 1D arrays with key tracking and automatic cluster count detection.
86
+
87
+ Parameters:
88
+ -----------
89
+ data_input : dict or List[List[float]]
90
+ Dictionary {key: array} or list of arrays to cluster
91
+ n_clusters : int or None
92
+ How many groups you want. If None, will automatically determine optimal k
93
+ seed : int
94
+ Random seed for reproducibility
95
+ max_k : int
96
+ Maximum number of clusters to consider when auto-detecting (default: 10)
97
+
98
+ Returns:
99
+ --------
100
+ list: [[key1, key2], [key3, key4, key5]] - List of clusters, each containing keys/indices
101
+ """
102
+
103
+ # Handle both dict and list inputs
104
+ if isinstance(data_input, dict):
105
+ keys = list(data_input.keys())
106
+ array_values = list(data_input.values())
107
+ else:
108
+ keys = list(range(len(data_input)))
109
+ array_values = data_input
110
+
111
+ # Convert to numpy
112
+ data = np.array(array_values)
113
+
114
+ # Auto-detect optimal number of clusters if not specified
115
+ if n_clusters is None:
116
+ n_clusters = _find_optimal_clusters(data, seed, max_k = (len(keys) // 3))
117
+ print(f"Auto-detected optimal number of clusters: {n_clusters}")
118
+
119
+ # Perform clustering
120
+ kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10)
121
+ labels = kmeans.fit_predict(data)
122
+
123
+ # Organize results into clusters - simple list of lists with keys only
42
124
  clusters = [[] for _ in range(n_clusters)]
43
-
44
125
  for i, label in enumerate(labels):
45
126
  clusters[label].append(keys[i])
46
-
127
+
47
128
  return clusters
129
+
130
+ def _find_optimal_clusters(data, seed, max_k):
131
+ """
132
+ Find optimal number of clusters using Calinski-Harabasz index.
133
+ """
134
+ n_samples = len(data)
135
+
136
+ # Need at least 2 samples to cluster
137
+ if n_samples < 2:
138
+ return 1
139
+
140
+ # Limit max_k to reasonable bounds
141
+ max_k = min(max_k, n_samples - 1, 20)
142
+ print(f"Max_k: {max_k}, n_samples: {n_samples}")
48
143
 
49
- def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighborhood Heatmap"):
144
+ if max_k < 2:
145
+ return 1
146
+
147
+ # Use Calinski-Harabasz index to find optimal k
148
+ ch_scores = []
149
+ k_range = range(2, max_k + 1)
150
+
151
+ for k in k_range:
152
+ try:
153
+ kmeans = KMeans(n_clusters=k, random_state=seed, n_init=10)
154
+ labels = kmeans.fit_predict(data)
155
+
156
+ # Check if we got the expected number of clusters
157
+ if len(np.unique(labels)) == k:
158
+ score = calinski_harabasz_score(data, labels)
159
+ ch_scores.append(score)
160
+ else:
161
+ ch_scores.append(0) # Penalize solutions that didn't achieve k clusters
162
+
163
+ except Exception:
164
+ ch_scores.append(0)
165
+
166
+ # Find k with highest Calinski-Harabasz score
167
+ if ch_scores and max(ch_scores) > 0:
168
+ optimal_k = k_range[np.argmax(ch_scores)]
169
+ print(f"Using {optimal_k} neighborhoods")
170
+ return optimal_k
171
+
172
+ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighborhood Heatmap",
173
+ center_at_one=False):
50
174
  """
51
175
  Create a heatmap from a dictionary of numpy arrays.
52
176
 
@@ -60,6 +184,11 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
60
184
  Figure size (width, height)
61
185
  title : str, optional
62
186
  Title for the heatmap
187
+ center_at_one : bool, optional
188
+ If True, uses a diverging colormap centered at 1 with nonlinear scaling:
189
+ - 0 to 1: blue to white (underrepresentation to normal)
190
+ - 1+: white to red (overrepresentation)
191
+ If False (default), uses standard white-to-red scaling from 0 to 1
63
192
 
64
193
  Returns:
65
194
  --------
@@ -67,7 +196,6 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
67
196
  """
68
197
 
69
198
  data_dict = {k: unsorted_data_dict[k] for k in sorted(unsorted_data_dict.keys())}
70
-
71
199
  # Convert dict to 2D array for heatmap
72
200
  # Each row represents one key from the dict
73
201
  keys = list(data_dict.keys())
@@ -76,8 +204,70 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
76
204
  # Create the plot
77
205
  fig, ax = plt.subplots(figsize=figsize)
78
206
 
79
- # Create heatmap with white-to-red colormap
80
- im = ax.imshow(data_matrix, cmap='Reds', aspect='auto', vmin=0, vmax=1)
207
+ if center_at_one:
208
+ # Custom colormap and scaling for center_at_one mode
209
+ # Find the actual data range
210
+ data_min = np.min(data_matrix)
211
+ data_max = np.max(data_matrix)
212
+
213
+ # Create a custom colormap: blue -> white -> red
214
+ colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7',
215
+ '#fddbc7', '#f4a582', '#d6604d', '#b2182b']
216
+ n_bins = 256
217
+ cmap = LinearSegmentedColormap.from_list('custom_diverging', colors, N=n_bins)
218
+
219
+ # Create nonlinear transformation
220
+ # Map 0->1 with more resolution, 1+ with less resolution
221
+ def transform_data(data):
222
+ transformed = np.zeros_like(data)
223
+
224
+ # For values 0 to 1: use square root for faster approach to middle
225
+ mask_low = data <= 1
226
+ transformed[mask_low] = 0.5 * np.sqrt(data[mask_low])
227
+
228
+ # For values > 1: use slower logarithmic scaling
229
+ mask_high = data > 1
230
+ if np.any(mask_high):
231
+ # Scale from 0.5 to 1.0 based on log of excess above 1
232
+ max_excess = np.max(data[mask_high] - 1) if np.any(mask_high) else 0
233
+ if max_excess > 0:
234
+ excess_normalized = np.log1p(data[mask_high] - 1) / np.log1p(max_excess)
235
+ transformed[mask_high] = 0.5 + 0.5 * excess_normalized
236
+ else:
237
+ transformed[mask_high] = 0.5
238
+
239
+ return transformed
240
+
241
+ # Transform the data for visualization
242
+ transformed_matrix = transform_data(data_matrix)
243
+
244
+ # Create heatmap with custom colormap
245
+ im = ax.imshow(transformed_matrix, cmap=cmap, aspect='auto', vmin=0, vmax=1)
246
+
247
+ # Create custom colorbar with original values
248
+ cbar = ax.figure.colorbar(im, ax=ax)
249
+
250
+ # Set colorbar ticks to show meaningful values
251
+ if data_max > 1:
252
+ tick_values = [0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0]
253
+ tick_values = [v for v in tick_values if data_min <= v <= data_max]
254
+ else:
255
+ tick_values = [0, 0.25, 0.5, 0.75, 1.0]
256
+ tick_values = [v for v in tick_values if data_min <= v <= data_max]
257
+
258
+ # Transform tick values for colorbar positioning
259
+ transformed_ticks = transform_data(np.array(tick_values))
260
+ cbar.set_ticks(transformed_ticks)
261
+ cbar.set_ticklabels([f'{v:.2f}' for v in tick_values])
262
+ cbar.ax.set_ylabel('Representation Ratio', rotation=-90, va="bottom")
263
+
264
+ else:
265
+ # Default behavior: white-to-red colormap
266
+ im = ax.imshow(data_matrix, cmap='Reds', aspect='auto', vmin=0, vmax=1)
267
+
268
+ # Add standard colorbar
269
+ cbar = ax.figure.colorbar(im, ax=ax)
270
+ cbar.ax.set_ylabel('Intensity', rotation=-90, va="bottom")
81
271
 
82
272
  # Set ticks and labels
83
273
  ax.set_xticks(np.arange(len(id_set)))
@@ -91,23 +281,32 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
91
281
  # Add text annotations showing the actual values
92
282
  for i in range(len(keys)):
93
283
  for j in range(len(id_set)):
284
+ # Use original data values for annotations
94
285
  text = ax.text(j, i, f'{data_matrix[i, j]:.3f}',
95
286
  ha="center", va="center", color="black", fontsize=8)
96
-
97
- # Add colorbar
98
- cbar = ax.figure.colorbar(im, ax=ax)
99
- cbar.ax.set_ylabel('Intensity', rotation=-90, va="bottom")
287
+
288
+
289
+ ret_dict = {}
290
+
291
+ for i, row in enumerate(data_matrix):
292
+ ret_dict[keys[i]] = row
100
293
 
101
294
  # Set labels and title
102
- ax.set_xlabel('Proportion of Node Type')
295
+ if center_at_one:
296
+ ax.set_xlabel('Representation Factor of Node Type')
297
+ else:
298
+ ax.set_xlabel('Proportion of Node Type')
299
+
103
300
  ax.set_ylabel('Neighborhood')
104
301
  ax.set_title(title)
105
302
 
106
303
  # Adjust layout to prevent label cutoff
107
304
  plt.tight_layout()
108
-
109
305
  plt.show()
110
306
 
307
+ return ret_dict
308
+
309
+
111
310
 
112
311
  def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
113
312
  class_names: Set[str],
@@ -202,10 +401,12 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
202
401
 
203
402
  return embedding
204
403
 
205
- def create_community_heatmap(community_intensity, node_community, node_centroids, is_3d=True,
206
- figsize=(12, 8), point_size=50, alpha=0.7, colorbar_label="Community Intensity"):
404
+ def create_community_heatmap(community_intensity, node_community, node_centroids, shape=None, is_3d=True,
405
+ labeled_array=None, figsize=(12, 8), point_size=50, alpha=0.7,
406
+ colorbar_label="Community Intensity", title="Community Intensity Heatmap"):
207
407
  """
208
408
  Create a 2D or 3D heatmap showing nodes colored by their community intensities.
409
+ Can return either matplotlib plot or numpy RGB array for overlay purposes.
209
410
 
210
411
  Parameters:
211
412
  -----------
@@ -220,25 +421,39 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
220
421
  Dictionary mapping node IDs to centroids
221
422
  Centroids should be [Z, Y, X] for 3D or [1, Y, X] for pseudo-3D
222
423
 
424
+ shape : tuple, optional
425
+ Shape of the output array in [Z, Y, X] format
426
+ If None, will be inferred from node_centroids
427
+
223
428
  is_3d : bool, default=True
224
- If True, create 3D plot. If False, create 2D plot.
429
+ If True, create 3D plot/array. If False, create 2D plot/array.
430
+
431
+ labeled_array : np.ndarray, optional
432
+ If provided, returns numpy RGB array overlay using this labeled array template
433
+ instead of matplotlib plot. Uses lookup table approach for efficiency.
225
434
 
226
435
  figsize : tuple, default=(12, 8)
227
- Figure size (width, height)
436
+ Figure size (width, height) - only used for matplotlib
228
437
 
229
438
  point_size : int, default=50
230
- Size of scatter plot points
439
+ Size of scatter plot points - only used for matplotlib
231
440
 
232
441
  alpha : float, default=0.7
233
- Transparency of points (0-1)
442
+ Transparency of points (0-1) - only used for matplotlib
234
443
 
235
444
  colorbar_label : str, default="Community Intensity"
236
- Label for the colorbar
445
+ Label for the colorbar - only used for matplotlib
446
+
447
+ title : str, default="Community Intensity Heatmap"
448
+ Title for the plot
237
449
 
238
450
  Returns:
239
451
  --------
240
- fig, ax : matplotlib figure and axis objects
452
+ If labeled_array is None: fig, ax (matplotlib figure and axis objects)
453
+ If labeled_array is provided: np.ndarray (RGB heatmap array with community intensity colors)
241
454
  """
455
+ import numpy as np
456
+ import matplotlib.pyplot as plt
242
457
 
243
458
  # Convert numpy int64 keys to regular ints for consistency
244
459
  community_intensity_clean = {}
@@ -254,6 +469,10 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
254
469
 
255
470
  for node_id, centroid in node_centroids.items():
256
471
  try:
472
+ # Convert node_id to regular int if it's numpy
473
+ if hasattr(node_id, 'item'):
474
+ node_id = node_id.item()
475
+
257
476
  # Get community for this node
258
477
  community_id = node_community[node_id]
259
478
 
@@ -266,74 +485,392 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
266
485
 
267
486
  node_positions.append(centroid)
268
487
  node_intensities.append(intensity)
269
- except:
488
+ except KeyError:
489
+ # Skip nodes that don't have community assignments or community intensities
270
490
  pass
271
491
 
272
492
  # Convert to numpy arrays
273
493
  positions = np.array(node_positions)
274
494
  intensities = np.array(node_intensities)
275
495
 
276
- # Determine min and max intensities for color scaling
277
- min_intensity = np.min(intensities)
278
- max_intensity = np.max(intensities)
496
+ # Determine shape if not provided
497
+ if shape is None:
498
+ if len(positions) > 0:
499
+ max_coords = np.max(positions, axis=0).astype(int)
500
+ shape = tuple(max_coords + 1)
501
+ else:
502
+ shape = (100, 100, 100) if is_3d else (1, 100, 100)
279
503
 
280
- # Create figure
281
- fig = plt.figure(figsize=figsize)
504
+ # Determine min and max intensities for scaling
505
+ if len(intensities) > 0:
506
+ min_intensity = np.min(intensities)
507
+ max_intensity = np.max(intensities)
508
+ else:
509
+ min_intensity, max_intensity = 0, 1
282
510
 
283
- if is_3d:
284
- # 3D plot
285
- ax = fig.add_subplot(111, projection='3d')
511
+ if labeled_array is not None:
512
+ # Create numpy RGB array output using labeled array and lookup table approach
286
513
 
287
- # Extract coordinates (assuming [Z, Y, X] format)
288
- z_coords = positions[:, 0]
289
- y_coords = positions[:, 1]
290
- x_coords = positions[:, 2]
514
+ # Create mapping from node ID to community intensity value
515
+ node_to_community_intensity = {}
516
+ for node_id, centroid in node_centroids.items():
517
+ # Convert node_id to regular int if it's numpy
518
+ if hasattr(node_id, 'item'):
519
+ node_id = node_id.item()
520
+
521
+ try:
522
+ # Get community for this node
523
+ community_id = node_community[node_id]
524
+
525
+ # Convert community_id to regular int if it's numpy
526
+ if hasattr(community_id, 'item'):
527
+ community_id = community_id.item()
528
+
529
+ # Get intensity for this community
530
+ if community_id in community_intensity_clean:
531
+ node_to_community_intensity[node_id] = community_intensity_clean[community_id]
532
+ except KeyError:
533
+ # Skip nodes that don't have community assignments
534
+ pass
535
+
536
+ # Create colormap function (RdBu_r - red for high, blue for low, yellow/white for middle)
537
+ def intensity_to_rgb(intensity, min_val, max_val):
538
+ """Convert intensity value to RGB using RdBu_r colormap logic"""
539
+ if max_val == min_val:
540
+ # All same value, use neutral color
541
+ return np.array([255, 255, 255], dtype=np.uint8) # White
542
+
543
+ # Normalize to -1 to 1 range (like RdBu_r colormap)
544
+ normalized = 2 * (intensity - min_val) / (max_val - min_val) - 1
545
+ normalized = np.clip(normalized, -1, 1)
546
+
547
+ if normalized > 0:
548
+ # Positive values: white to red
549
+ r = 255
550
+ g = int(255 * (1 - normalized))
551
+ b = int(255 * (1 - normalized))
552
+ else:
553
+ # Negative values: white to blue
554
+ r = int(255 * (1 + normalized))
555
+ g = int(255 * (1 + normalized))
556
+ b = 255
557
+
558
+ return np.array([r, g, b], dtype=np.uint8)
291
559
 
292
- # Create scatter plot
293
- scatter = ax.scatter(x_coords, y_coords, z_coords,
294
- c=intensities, s=point_size, alpha=alpha,
295
- cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
560
+ # Create lookup table for RGB colors
561
+ max_label = max(max(labeled_array.flat), max(node_to_community_intensity.keys()) if node_to_community_intensity else 0)
562
+ color_lut = np.zeros((max_label + 1, 3), dtype=np.uint8) # Default to black (0,0,0)
296
563
 
297
- ax.set_xlabel('X')
298
- ax.set_ylabel('Y')
299
- ax.set_zlabel('Z')
300
- ax.set_title('3D Community Intensity Heatmap')
564
+ # Fill lookup table with RGB colors based on community intensity
565
+ for node_id, intensity in node_to_community_intensity.items():
566
+ rgb_color = intensity_to_rgb(intensity, min_intensity, max_intensity)
567
+ color_lut[int(node_id)] = rgb_color
301
568
 
569
+ # Apply lookup table to labeled array - single vectorized operation
570
+ if is_3d:
571
+ # Return full 3D RGB array [Z, Y, X, 3]
572
+ heatmap_array = color_lut[labeled_array]
573
+ else:
574
+ # Return 2D RGB array
575
+ if labeled_array.ndim == 3:
576
+ # Take middle slice for 2D representation
577
+ middle_slice = labeled_array.shape[0] // 2
578
+ heatmap_array = color_lut[labeled_array[middle_slice]]
579
+ else:
580
+ # Already 2D
581
+ heatmap_array = color_lut[labeled_array]
582
+
583
+ return heatmap_array
584
+
302
585
  else:
303
- # 2D plot (using Y, X coordinates, ignoring Z/first dimension)
304
- ax = fig.add_subplot(111)
586
+ # Create matplotlib plot
587
+ fig = plt.figure(figsize=figsize)
305
588
 
306
- # Extract Y, X coordinates
307
- y_coords = positions[:, 1]
308
- x_coords = positions[:, 2]
589
+ if is_3d:
590
+ # 3D plot
591
+ ax = fig.add_subplot(111, projection='3d')
592
+
593
+ # Extract coordinates (assuming [Z, Y, X] format)
594
+ z_coords = positions[:, 0]
595
+ y_coords = positions[:, 1]
596
+ x_coords = positions[:, 2]
597
+
598
+ # Create scatter plot
599
+ scatter = ax.scatter(x_coords, y_coords, z_coords,
600
+ c=intensities, s=point_size, alpha=alpha,
601
+ cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
602
+
603
+ ax.set_xlabel('X')
604
+ ax.set_ylabel('Y')
605
+ ax.set_zlabel('Z')
606
+ ax.set_title(f'{title}')
607
+
608
+ # Set axis limits based on shape
609
+ ax.set_xlim(0, shape[2])
610
+ ax.set_ylim(0, shape[1])
611
+ ax.set_zlim(0, shape[0])
612
+
613
+ else:
614
+ # 2D plot (using Y, X coordinates, ignoring Z/first dimension)
615
+ ax = fig.add_subplot(111)
616
+
617
+ # Extract Y, X coordinates
618
+ y_coords = positions[:, 1]
619
+ x_coords = positions[:, 2]
620
+
621
+ # Create scatter plot
622
+ scatter = ax.scatter(x_coords, y_coords,
623
+ c=intensities, s=point_size, alpha=alpha,
624
+ cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
625
+
626
+ ax.set_xlabel('X')
627
+ ax.set_ylabel('Y')
628
+ ax.set_title(f'{title}')
629
+ ax.grid(True, alpha=0.3)
630
+
631
+ # Set axis limits based on shape
632
+ ax.set_xlim(0, shape[2])
633
+ ax.set_ylim(0, shape[1])
634
+
635
+ # Set origin to top-left (invert Y-axis)
636
+ ax.invert_yaxis()
309
637
 
310
- # Create scatter plot
311
- scatter = ax.scatter(x_coords, y_coords,
312
- c=intensities, s=point_size, alpha=alpha,
313
- cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
638
+ # Add colorbar
639
+ cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
640
+ cbar.set_label(colorbar_label)
314
641
 
315
- ax.set_xlabel('X')
316
- ax.set_ylabel('Y')
317
- ax.set_title('2D Community Intensity Heatmap')
318
- ax.grid(True, alpha=0.3)
642
+ # Add text annotations for min/max values
643
+ cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
644
+ transform=cbar.ax.transAxes, va='bottom')
645
+ cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
646
+ transform=cbar.ax.transAxes, va='top')
319
647
 
320
- # Set origin to top-left (invert Y-axis)
321
- ax.invert_yaxis()
648
+ plt.tight_layout()
649
+ plt.show()
650
+
651
+
652
+ def create_node_heatmap(node_intensity, node_centroids, shape=None, is_3d=True,
653
+ labeled_array=None, figsize=(12, 8), point_size=50, alpha=0.7,
654
+ colorbar_label="Node Intensity", title="Node Clustering Intensity Heatmap"):
655
+ """
656
+ Create a 2D or 3D heatmap showing nodes colored by their individual intensities.
657
+ Can return either matplotlib plot or numpy array for overlay purposes.
322
658
 
323
- # Add colorbar
324
- cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
325
- cbar.set_label(colorbar_label)
659
+ Parameters:
660
+ -----------
661
+ node_intensity : dict
662
+ Dictionary mapping node IDs to intensity values
663
+ Keys can be np.int64 or regular ints
664
+
665
+ node_centroids : dict
666
+ Dictionary mapping node IDs to centroids
667
+ Centroids should be [Z, Y, X] for 3D or [1, Y, X] for pseudo-3D
668
+
669
+ shape : tuple, optional
670
+ Shape of the output array in [Z, Y, X] format
671
+ If None, will be inferred from node_centroids
672
+
673
+ is_3d : bool, default=True
674
+ If True, create 3D plot/array. If False, create 2D plot/array.
675
+
676
+ labeled_array : np.ndarray, optional
677
+ If provided, returns numpy array overlay using this labeled array template
678
+ instead of matplotlib plot. Uses lookup table approach for efficiency.
679
+
680
+ figsize : tuple, default=(12, 8)
681
+ Figure size (width, height) - only used for matplotlib
682
+
683
+ point_size : int, default=50
684
+ Size of scatter plot points - only used for matplotlib
685
+
686
+ alpha : float, default=0.7
687
+ Transparency of points (0-1) - only used for matplotlib
688
+
689
+ colorbar_label : str, default="Node Intensity"
690
+ Label for the colorbar - only used for matplotlib
691
+
692
+ Returns:
693
+ --------
694
+ If labeled_array is None: fig, ax (matplotlib figure and axis objects)
695
+ If labeled_array is provided: np.ndarray (heatmap array with intensity values)
696
+ """
697
+ import numpy as np
698
+ import matplotlib.pyplot as plt
326
699
 
327
- # Add text annotations for min/max values
328
- cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
329
- transform=cbar.ax.transAxes, va='bottom')
330
- cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
331
- transform=cbar.ax.transAxes, va='top')
700
+ # Convert numpy int64 keys to regular ints for consistency
701
+ node_intensity_clean = {}
702
+ for k, v in node_intensity.items():
703
+ if hasattr(k, 'item'): # numpy scalar
704
+ node_intensity_clean[k.item()] = v
705
+ else:
706
+ node_intensity_clean[k] = v
332
707
 
333
- plt.tight_layout()
334
- plt.show()
335
-
336
-
708
+ # Prepare data for plotting/array creation
709
+ node_positions = []
710
+ node_intensities = []
711
+
712
+ for node_id, centroid in node_centroids.items():
713
+ try:
714
+ # Convert node_id to regular int if it's numpy
715
+ if hasattr(node_id, 'item'):
716
+ node_id = node_id.item()
717
+
718
+ # Get intensity for this node
719
+ intensity = node_intensity_clean[node_id]
720
+
721
+ node_positions.append(centroid)
722
+ node_intensities.append(intensity)
723
+ except KeyError:
724
+ # Skip nodes that don't have intensity values
725
+ pass
726
+
727
+ # Convert to numpy arrays
728
+ positions = np.array(node_positions)
729
+ intensities = np.array(node_intensities)
730
+
731
+ # Determine shape if not provided
732
+ if shape is None:
733
+ if len(positions) > 0:
734
+ max_coords = np.max(positions, axis=0).astype(int)
735
+ shape = tuple(max_coords + 1)
736
+ else:
737
+ shape = (100, 100, 100) if is_3d else (1, 100, 100)
738
+
739
+ # Determine min and max intensities for scaling
740
+ if len(intensities) > 0:
741
+ min_intensity = np.min(intensities)
742
+ max_intensity = np.max(intensities)
743
+ else:
744
+ min_intensity, max_intensity = 0, 1
745
+
746
+ if labeled_array is not None:
747
+ # Create numpy RGB array output using labeled array and lookup table approach
748
+
749
+ # Create mapping from node ID to intensity value (keep original float values)
750
+ node_to_intensity = {}
751
+ for node_id, centroid in node_centroids.items():
752
+ # Convert node_id to regular int if it's numpy
753
+ if hasattr(node_id, 'item'):
754
+ node_id = node_id.item()
755
+
756
+ # Only include nodes that have intensity values
757
+ if node_id in node_intensity_clean:
758
+ node_to_intensity[node_id] = node_intensity_clean[node_id]
759
+
760
+ # Create colormap function (RdBu_r - red for high, blue for low, yellow/white for middle)
761
+ def intensity_to_rgb(intensity, min_val, max_val):
762
+ """Convert intensity value to RGB using RdBu_r colormap logic"""
763
+ if max_val == min_val:
764
+ # All same value, use neutral color
765
+ return np.array([255, 255, 255], dtype=np.uint8) # White
766
+
767
+ # Normalize to -1 to 1 range (like RdBu_r colormap)
768
+ normalized = 2 * (intensity - min_val) / (max_val - min_val) - 1
769
+ normalized = np.clip(normalized, -1, 1)
770
+
771
+ if normalized > 0:
772
+ # Positive values: white to red
773
+ r = 255
774
+ g = int(255 * (1 - normalized))
775
+ b = int(255 * (1 - normalized))
776
+ else:
777
+ # Negative values: white to blue
778
+ r = int(255 * (1 + normalized))
779
+ g = int(255 * (1 + normalized))
780
+ b = 255
781
+
782
+ return np.array([r, g, b], dtype=np.uint8)
783
+
784
+ # Create lookup table for RGB colors
785
+ max_label = max(max(labeled_array.flat), max(node_to_intensity.keys()) if node_to_intensity else 0)
786
+ color_lut = np.zeros((max_label + 1, 3), dtype=np.uint8) # Default to black (0,0,0)
787
+
788
+ # Fill lookup table with RGB colors based on intensity
789
+ for node_id, intensity in node_to_intensity.items():
790
+ rgb_color = intensity_to_rgb(intensity, min_intensity, max_intensity)
791
+ color_lut[int(node_id)] = rgb_color
792
+
793
+ # Apply lookup table to labeled array - single vectorized operation
794
+ if is_3d:
795
+ # Return full 3D RGB array [Z, Y, X, 3]
796
+ heatmap_array = color_lut[labeled_array]
797
+ else:
798
+ # Return 2D RGB array
799
+ if labeled_array.ndim == 3:
800
+ # Take middle slice for 2D representation
801
+ middle_slice = labeled_array.shape[0] // 2
802
+ heatmap_array = color_lut[labeled_array[middle_slice]]
803
+ else:
804
+ # Already 2D
805
+ heatmap_array = color_lut[labeled_array]
806
+
807
+ return heatmap_array
808
+
809
+ else:
810
+ # Create matplotlib plot
811
+ fig = plt.figure(figsize=figsize)
812
+
813
+ if is_3d:
814
+ # 3D plot
815
+ ax = fig.add_subplot(111, projection='3d')
816
+
817
+ # Extract coordinates (assuming [Z, Y, X] format)
818
+ z_coords = positions[:, 0]
819
+ y_coords = positions[:, 1]
820
+ x_coords = positions[:, 2]
821
+
822
+ # Create scatter plot
823
+ scatter = ax.scatter(x_coords, y_coords, z_coords,
824
+ c=intensities, s=point_size, alpha=alpha,
825
+ cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
826
+
827
+ ax.set_xlabel('X')
828
+ ax.set_ylabel('Y')
829
+ ax.set_zlabel('Z')
830
+ ax.set_title(f'{title}')
831
+
832
+ # Set axis limits based on shape
833
+ ax.set_xlim(0, shape[2])
834
+ ax.set_ylim(0, shape[1])
835
+ ax.set_zlim(0, shape[0])
836
+
837
+ else:
838
+ # 2D plot (using Y, X coordinates, ignoring Z/first dimension)
839
+ ax = fig.add_subplot(111)
840
+
841
+ # Extract Y, X coordinates
842
+ y_coords = positions[:, 1]
843
+ x_coords = positions[:, 2]
844
+
845
+ # Create scatter plot
846
+ scatter = ax.scatter(x_coords, y_coords,
847
+ c=intensities, s=point_size, alpha=alpha,
848
+ cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
849
+
850
+ ax.set_xlabel('X')
851
+ ax.set_ylabel('Y')
852
+ ax.set_title(f'{title}')
853
+ ax.grid(True, alpha=0.3)
854
+
855
+ # Set axis limits based on shape
856
+ ax.set_xlim(0, shape[2])
857
+ ax.set_ylim(0, shape[1])
858
+
859
+ # Set origin to top-left (invert Y-axis)
860
+ ax.invert_yaxis()
861
+
862
+ # Add colorbar
863
+ cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
864
+ cbar.set_label(colorbar_label)
865
+
866
+ # Add text annotations for min/max values
867
+ cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
868
+ transform=cbar.ax.transAxes, va='bottom')
869
+ cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
870
+ transform=cbar.ax.transAxes, va='top')
871
+
872
+ plt.tight_layout()
873
+ plt.show()
337
874
 
338
875
  # Example usage:
339
876
  if __name__ == "__main__":
@@ -350,5 +887,4 @@ if __name__ == "__main__":
350
887
  fig, ax = plot_dict_heatmap(sample_dict, sample_id_set,
351
888
  title="Sample Heatmap Visualization")
352
889
 
353
- plt.show()
354
-
890
+ plt.show()