nettracer3d 0.7.9__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

@@ -492,48 +492,39 @@ def generate_distinct_colors(n_colors: int) -> List[Tuple[int, int, int]]:
492
492
  return colors
493
493
 
494
494
  def assign_node_colors(node_list: List[int], labeled_array: np.ndarray) -> Tuple[np.ndarray, Dict[int, str]]:
495
- """
496
- Assign distinct colors to nodes and create an RGBA image.
497
-
498
- Args:
499
- node_list: List of node IDs
500
- labeled_array: 3D numpy array with labels corresponding to node IDs
495
+ """fast version using lookup table approach."""
501
496
 
502
- Returns:
503
- Tuple of (RGBA-coded numpy array (H, W, D, 4), dictionary mapping nodes to color names)
504
- """
505
-
506
- # Sort communities by size (descending)
507
- sorted_nodes= sorted(node_list, reverse=True)
497
+ # Sort nodes by size (descending)
498
+ sorted_nodes = sorted(node_list, reverse=True)
508
499
 
509
500
  # Generate distinct colors
510
501
  colors = generate_distinct_colors(len(node_list))
511
- random.shuffle(colors) #Randomly sorted to make adjacent structures likely stand out
502
+ random.shuffle(colors) # Randomly sorted to make adjacent structures likely stand out
512
503
 
513
504
  # Convert RGB colors to RGBA by adding alpha channel
514
- colors_rgba = [(r, g, b, 255) for r, g, b in colors] # Full opacity for colored regions
505
+ colors_rgba = np.array([(r, g, b, 255) for r, g, b in colors], dtype=np.uint8)
515
506
 
516
- # Create mapping from community to color
507
+ # Create mapping from node to color
517
508
  node_to_color = {node: colors_rgba[i] for i, node in enumerate(sorted_nodes)}
518
509
 
519
- # Create RGBA array (initialize with transparent background)
520
- rgba_array = np.zeros((*labeled_array.shape, 4), dtype=np.uint8)
510
+ # Create lookup table
511
+ max_label = max(max(labeled_array.flat), max(node_list) if node_list else 0)
512
+ color_lut = np.zeros((max_label + 1, 4), dtype=np.uint8) # Transparent by default
521
513
 
522
- # Assign colors to each voxel based on its label
523
- for label in np.unique(labeled_array):
524
- if label in node_to_color: # Skip background (usually label 0)
525
- mask = labeled_array == label
526
- for i in range(4): # RGBA channels
527
- rgba_array[mask, i] = node_to_color[label][i]
528
-
529
- # Convert the RGB portion of community_to_color back to RGB for color naming
514
+ for node_id, color in node_to_color.items():
515
+ color_lut[node_id] = color
516
+
517
+ # Single vectorized operation - eliminates all loops!
518
+ rgba_array = color_lut[labeled_array]
519
+
520
+ # Convert colors for naming
530
521
  node_to_color_rgb = {k: tuple(v[:3]) for k, v in node_to_color.items()}
531
522
  node_to_color_names = convert_node_colors_to_names(node_to_color_rgb)
532
523
 
533
524
  return rgba_array, node_to_color_names
534
525
 
535
526
  def assign_community_colors(community_dict: Dict[int, int], labeled_array: np.ndarray) -> Tuple[np.ndarray, Dict[int, str]]:
536
- """Ultra-fast version using lookup table approach."""
527
+ """fast version using lookup table approach."""
537
528
 
538
529
  # Same setup as before
539
530
  communities = set(community_dict.values())
@@ -202,10 +202,12 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
202
202
 
203
203
  return embedding
204
204
 
205
- def create_community_heatmap(community_intensity, node_community, node_centroids, is_3d=True,
206
- figsize=(12, 8), point_size=50, alpha=0.7, colorbar_label="Community Intensity"):
205
+ def create_community_heatmap(community_intensity, node_community, node_centroids, shape=None, is_3d=True,
206
+ labeled_array=None, figsize=(12, 8), point_size=50, alpha=0.7,
207
+ colorbar_label="Community Intensity", title="Community Intensity Heatmap"):
207
208
  """
208
209
  Create a 2D or 3D heatmap showing nodes colored by their community intensities.
210
+ Can return either matplotlib plot or numpy RGB array for overlay purposes.
209
211
 
210
212
  Parameters:
211
213
  -----------
@@ -220,25 +222,39 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
220
222
  Dictionary mapping node IDs to centroids
221
223
  Centroids should be [Z, Y, X] for 3D or [1, Y, X] for pseudo-3D
222
224
 
225
+ shape : tuple, optional
226
+ Shape of the output array in [Z, Y, X] format
227
+ If None, will be inferred from node_centroids
228
+
223
229
  is_3d : bool, default=True
224
- If True, create 3D plot. If False, create 2D plot.
230
+ If True, create 3D plot/array. If False, create 2D plot/array.
231
+
232
+ labeled_array : np.ndarray, optional
233
+ If provided, returns numpy RGB array overlay using this labeled array template
234
+ instead of matplotlib plot. Uses lookup table approach for efficiency.
225
235
 
226
236
  figsize : tuple, default=(12, 8)
227
- Figure size (width, height)
237
+ Figure size (width, height) - only used for matplotlib
228
238
 
229
239
  point_size : int, default=50
230
- Size of scatter plot points
240
+ Size of scatter plot points - only used for matplotlib
231
241
 
232
242
  alpha : float, default=0.7
233
- Transparency of points (0-1)
243
+ Transparency of points (0-1) - only used for matplotlib
234
244
 
235
245
  colorbar_label : str, default="Community Intensity"
236
- Label for the colorbar
246
+ Label for the colorbar - only used for matplotlib
247
+
248
+ title : str, default="Community Intensity Heatmap"
249
+ Title for the plot
237
250
 
238
251
  Returns:
239
252
  --------
240
- fig, ax : matplotlib figure and axis objects
253
+ If labeled_array is None: fig, ax (matplotlib figure and axis objects)
254
+ If labeled_array is provided: np.ndarray (RGB heatmap array with community intensity colors)
241
255
  """
256
+ import numpy as np
257
+ import matplotlib.pyplot as plt
242
258
 
243
259
  # Convert numpy int64 keys to regular ints for consistency
244
260
  community_intensity_clean = {}
@@ -254,6 +270,10 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
254
270
 
255
271
  for node_id, centroid in node_centroids.items():
256
272
  try:
273
+ # Convert node_id to regular int if it's numpy
274
+ if hasattr(node_id, 'item'):
275
+ node_id = node_id.item()
276
+
257
277
  # Get community for this node
258
278
  community_id = node_community[node_id]
259
279
 
@@ -266,74 +286,392 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
266
286
 
267
287
  node_positions.append(centroid)
268
288
  node_intensities.append(intensity)
269
- except:
289
+ except KeyError:
290
+ # Skip nodes that don't have community assignments or community intensities
270
291
  pass
271
292
 
272
293
  # Convert to numpy arrays
273
294
  positions = np.array(node_positions)
274
295
  intensities = np.array(node_intensities)
275
296
 
276
- # Determine min and max intensities for color scaling
277
- min_intensity = np.min(intensities)
278
- max_intensity = np.max(intensities)
297
+ # Determine shape if not provided
298
+ if shape is None:
299
+ if len(positions) > 0:
300
+ max_coords = np.max(positions, axis=0).astype(int)
301
+ shape = tuple(max_coords + 1)
302
+ else:
303
+ shape = (100, 100, 100) if is_3d else (1, 100, 100)
279
304
 
280
- # Create figure
281
- fig = plt.figure(figsize=figsize)
305
+ # Determine min and max intensities for scaling
306
+ if len(intensities) > 0:
307
+ min_intensity = np.min(intensities)
308
+ max_intensity = np.max(intensities)
309
+ else:
310
+ min_intensity, max_intensity = 0, 1
282
311
 
283
- if is_3d:
284
- # 3D plot
285
- ax = fig.add_subplot(111, projection='3d')
312
+ if labeled_array is not None:
313
+ # Create numpy RGB array output using labeled array and lookup table approach
314
+
315
+ # Create mapping from node ID to community intensity value
316
+ node_to_community_intensity = {}
317
+ for node_id, centroid in node_centroids.items():
318
+ # Convert node_id to regular int if it's numpy
319
+ if hasattr(node_id, 'item'):
320
+ node_id = node_id.item()
321
+
322
+ try:
323
+ # Get community for this node
324
+ community_id = node_community[node_id]
325
+
326
+ # Convert community_id to regular int if it's numpy
327
+ if hasattr(community_id, 'item'):
328
+ community_id = community_id.item()
329
+
330
+ # Get intensity for this community
331
+ if community_id in community_intensity_clean:
332
+ node_to_community_intensity[node_id] = community_intensity_clean[community_id]
333
+ except KeyError:
334
+ # Skip nodes that don't have community assignments
335
+ pass
336
+
337
+ # Create colormap function (RdBu_r - red for high, blue for low, yellow/white for middle)
338
+ def intensity_to_rgb(intensity, min_val, max_val):
339
+ """Convert intensity value to RGB using RdBu_r colormap logic"""
340
+ if max_val == min_val:
341
+ # All same value, use neutral color
342
+ return np.array([255, 255, 255], dtype=np.uint8) # White
343
+
344
+ # Normalize to -1 to 1 range (like RdBu_r colormap)
345
+ normalized = 2 * (intensity - min_val) / (max_val - min_val) - 1
346
+ normalized = np.clip(normalized, -1, 1)
347
+
348
+ if normalized > 0:
349
+ # Positive values: white to red
350
+ r = 255
351
+ g = int(255 * (1 - normalized))
352
+ b = int(255 * (1 - normalized))
353
+ else:
354
+ # Negative values: white to blue
355
+ r = int(255 * (1 + normalized))
356
+ g = int(255 * (1 + normalized))
357
+ b = 255
358
+
359
+ return np.array([r, g, b], dtype=np.uint8)
286
360
 
287
- # Extract coordinates (assuming [Z, Y, X] format)
288
- z_coords = positions[:, 0]
289
- y_coords = positions[:, 1]
290
- x_coords = positions[:, 2]
361
+ # Create lookup table for RGB colors
362
+ max_label = max(max(labeled_array.flat), max(node_to_community_intensity.keys()) if node_to_community_intensity else 0)
363
+ color_lut = np.zeros((max_label + 1, 3), dtype=np.uint8) # Default to black (0,0,0)
291
364
 
292
- # Create scatter plot
293
- scatter = ax.scatter(x_coords, y_coords, z_coords,
294
- c=intensities, s=point_size, alpha=alpha,
295
- cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
365
+ # Fill lookup table with RGB colors based on community intensity
366
+ for node_id, intensity in node_to_community_intensity.items():
367
+ rgb_color = intensity_to_rgb(intensity, min_intensity, max_intensity)
368
+ color_lut[int(node_id)] = rgb_color
296
369
 
297
- ax.set_xlabel('X')
298
- ax.set_ylabel('Y')
299
- ax.set_zlabel('Z')
300
- ax.set_title('3D Community Intensity Heatmap')
370
+ # Apply lookup table to labeled array - single vectorized operation
371
+ if is_3d:
372
+ # Return full 3D RGB array [Z, Y, X, 3]
373
+ heatmap_array = color_lut[labeled_array]
374
+ else:
375
+ # Return 2D RGB array
376
+ if labeled_array.ndim == 3:
377
+ # Take middle slice for 2D representation
378
+ middle_slice = labeled_array.shape[0] // 2
379
+ heatmap_array = color_lut[labeled_array[middle_slice]]
380
+ else:
381
+ # Already 2D
382
+ heatmap_array = color_lut[labeled_array]
301
383
 
384
+ return heatmap_array
385
+
302
386
  else:
303
- # 2D plot (using Y, X coordinates, ignoring Z/first dimension)
304
- ax = fig.add_subplot(111)
387
+ # Create matplotlib plot
388
+ fig = plt.figure(figsize=figsize)
305
389
 
306
- # Extract Y, X coordinates
307
- y_coords = positions[:, 1]
308
- x_coords = positions[:, 2]
390
+ if is_3d:
391
+ # 3D plot
392
+ ax = fig.add_subplot(111, projection='3d')
393
+
394
+ # Extract coordinates (assuming [Z, Y, X] format)
395
+ z_coords = positions[:, 0]
396
+ y_coords = positions[:, 1]
397
+ x_coords = positions[:, 2]
398
+
399
+ # Create scatter plot
400
+ scatter = ax.scatter(x_coords, y_coords, z_coords,
401
+ c=intensities, s=point_size, alpha=alpha,
402
+ cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
403
+
404
+ ax.set_xlabel('X')
405
+ ax.set_ylabel('Y')
406
+ ax.set_zlabel('Z')
407
+ ax.set_title(f'{title}')
408
+
409
+ # Set axis limits based on shape
410
+ ax.set_xlim(0, shape[2])
411
+ ax.set_ylim(0, shape[1])
412
+ ax.set_zlim(0, shape[0])
413
+
414
+ else:
415
+ # 2D plot (using Y, X coordinates, ignoring Z/first dimension)
416
+ ax = fig.add_subplot(111)
417
+
418
+ # Extract Y, X coordinates
419
+ y_coords = positions[:, 1]
420
+ x_coords = positions[:, 2]
421
+
422
+ # Create scatter plot
423
+ scatter = ax.scatter(x_coords, y_coords,
424
+ c=intensities, s=point_size, alpha=alpha,
425
+ cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
426
+
427
+ ax.set_xlabel('X')
428
+ ax.set_ylabel('Y')
429
+ ax.set_title(f'{title}')
430
+ ax.grid(True, alpha=0.3)
431
+
432
+ # Set axis limits based on shape
433
+ ax.set_xlim(0, shape[2])
434
+ ax.set_ylim(0, shape[1])
435
+
436
+ # Set origin to top-left (invert Y-axis)
437
+ ax.invert_yaxis()
309
438
 
310
- # Create scatter plot
311
- scatter = ax.scatter(x_coords, y_coords,
312
- c=intensities, s=point_size, alpha=alpha,
313
- cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
439
+ # Add colorbar
440
+ cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
441
+ cbar.set_label(colorbar_label)
314
442
 
315
- ax.set_xlabel('X')
316
- ax.set_ylabel('Y')
317
- ax.set_title('2D Community Intensity Heatmap')
318
- ax.grid(True, alpha=0.3)
443
+ # Add text annotations for min/max values
444
+ cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
445
+ transform=cbar.ax.transAxes, va='bottom')
446
+ cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
447
+ transform=cbar.ax.transAxes, va='top')
319
448
 
320
- # Set origin to top-left (invert Y-axis)
321
- ax.invert_yaxis()
449
+ plt.tight_layout()
450
+ plt.show()
451
+
452
+
453
+ def create_node_heatmap(node_intensity, node_centroids, shape=None, is_3d=True,
454
+ labeled_array=None, figsize=(12, 8), point_size=50, alpha=0.7,
455
+ colorbar_label="Node Intensity", title="Node Clustering Intensity Heatmap"):
456
+ """
457
+ Create a 2D or 3D heatmap showing nodes colored by their individual intensities.
458
+ Can return either matplotlib plot or numpy array for overlay purposes.
322
459
 
323
- # Add colorbar
324
- cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
325
- cbar.set_label(colorbar_label)
460
+ Parameters:
461
+ -----------
462
+ node_intensity : dict
463
+ Dictionary mapping node IDs to intensity values
464
+ Keys can be np.int64 or regular ints
465
+
466
+ node_centroids : dict
467
+ Dictionary mapping node IDs to centroids
468
+ Centroids should be [Z, Y, X] for 3D or [1, Y, X] for pseudo-3D
469
+
470
+ shape : tuple, optional
471
+ Shape of the output array in [Z, Y, X] format
472
+ If None, will be inferred from node_centroids
473
+
474
+ is_3d : bool, default=True
475
+ If True, create 3D plot/array. If False, create 2D plot/array.
476
+
477
+ labeled_array : np.ndarray, optional
478
+ If provided, returns numpy array overlay using this labeled array template
479
+ instead of matplotlib plot. Uses lookup table approach for efficiency.
480
+
481
+ figsize : tuple, default=(12, 8)
482
+ Figure size (width, height) - only used for matplotlib
483
+
484
+ point_size : int, default=50
485
+ Size of scatter plot points - only used for matplotlib
486
+
487
+ alpha : float, default=0.7
488
+ Transparency of points (0-1) - only used for matplotlib
489
+
490
+ colorbar_label : str, default="Node Intensity"
491
+ Label for the colorbar - only used for matplotlib
492
+
493
+ Returns:
494
+ --------
495
+ If labeled_array is None: fig, ax (matplotlib figure and axis objects)
496
+ If labeled_array is provided: np.ndarray (heatmap array with intensity values)
497
+ """
498
+ import numpy as np
499
+ import matplotlib.pyplot as plt
500
+
501
+ # Convert numpy int64 keys to regular ints for consistency
502
+ node_intensity_clean = {}
503
+ for k, v in node_intensity.items():
504
+ if hasattr(k, 'item'): # numpy scalar
505
+ node_intensity_clean[k.item()] = v
506
+ else:
507
+ node_intensity_clean[k] = v
326
508
 
327
- # Add text annotations for min/max values
328
- cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
329
- transform=cbar.ax.transAxes, va='bottom')
330
- cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
331
- transform=cbar.ax.transAxes, va='top')
509
+ # Prepare data for plotting/array creation
510
+ node_positions = []
511
+ node_intensities = []
332
512
 
333
- plt.tight_layout()
334
- plt.show()
335
-
336
-
513
+ for node_id, centroid in node_centroids.items():
514
+ try:
515
+ # Convert node_id to regular int if it's numpy
516
+ if hasattr(node_id, 'item'):
517
+ node_id = node_id.item()
518
+
519
+ # Get intensity for this node
520
+ intensity = node_intensity_clean[node_id]
521
+
522
+ node_positions.append(centroid)
523
+ node_intensities.append(intensity)
524
+ except KeyError:
525
+ # Skip nodes that don't have intensity values
526
+ pass
527
+
528
+ # Convert to numpy arrays
529
+ positions = np.array(node_positions)
530
+ intensities = np.array(node_intensities)
531
+
532
+ # Determine shape if not provided
533
+ if shape is None:
534
+ if len(positions) > 0:
535
+ max_coords = np.max(positions, axis=0).astype(int)
536
+ shape = tuple(max_coords + 1)
537
+ else:
538
+ shape = (100, 100, 100) if is_3d else (1, 100, 100)
539
+
540
+ # Determine min and max intensities for scaling
541
+ if len(intensities) > 0:
542
+ min_intensity = np.min(intensities)
543
+ max_intensity = np.max(intensities)
544
+ else:
545
+ min_intensity, max_intensity = 0, 1
546
+
547
+ if labeled_array is not None:
548
+ # Create numpy RGB array output using labeled array and lookup table approach
549
+
550
+ # Create mapping from node ID to intensity value (keep original float values)
551
+ node_to_intensity = {}
552
+ for node_id, centroid in node_centroids.items():
553
+ # Convert node_id to regular int if it's numpy
554
+ if hasattr(node_id, 'item'):
555
+ node_id = node_id.item()
556
+
557
+ # Only include nodes that have intensity values
558
+ if node_id in node_intensity_clean:
559
+ node_to_intensity[node_id] = node_intensity_clean[node_id]
560
+
561
+ # Create colormap function (RdBu_r - red for high, blue for low, yellow/white for middle)
562
+ def intensity_to_rgb(intensity, min_val, max_val):
563
+ """Convert intensity value to RGB using RdBu_r colormap logic"""
564
+ if max_val == min_val:
565
+ # All same value, use neutral color
566
+ return np.array([255, 255, 255], dtype=np.uint8) # White
567
+
568
+ # Normalize to -1 to 1 range (like RdBu_r colormap)
569
+ normalized = 2 * (intensity - min_val) / (max_val - min_val) - 1
570
+ normalized = np.clip(normalized, -1, 1)
571
+
572
+ if normalized > 0:
573
+ # Positive values: white to red
574
+ r = 255
575
+ g = int(255 * (1 - normalized))
576
+ b = int(255 * (1 - normalized))
577
+ else:
578
+ # Negative values: white to blue
579
+ r = int(255 * (1 + normalized))
580
+ g = int(255 * (1 + normalized))
581
+ b = 255
582
+
583
+ return np.array([r, g, b], dtype=np.uint8)
584
+
585
+ # Create lookup table for RGB colors
586
+ max_label = max(max(labeled_array.flat), max(node_to_intensity.keys()) if node_to_intensity else 0)
587
+ color_lut = np.zeros((max_label + 1, 3), dtype=np.uint8) # Default to black (0,0,0)
588
+
589
+ # Fill lookup table with RGB colors based on intensity
590
+ for node_id, intensity in node_to_intensity.items():
591
+ rgb_color = intensity_to_rgb(intensity, min_intensity, max_intensity)
592
+ color_lut[int(node_id)] = rgb_color
593
+
594
+ # Apply lookup table to labeled array - single vectorized operation
595
+ if is_3d:
596
+ # Return full 3D RGB array [Z, Y, X, 3]
597
+ heatmap_array = color_lut[labeled_array]
598
+ else:
599
+ # Return 2D RGB array
600
+ if labeled_array.ndim == 3:
601
+ # Take middle slice for 2D representation
602
+ middle_slice = labeled_array.shape[0] // 2
603
+ heatmap_array = color_lut[labeled_array[middle_slice]]
604
+ else:
605
+ # Already 2D
606
+ heatmap_array = color_lut[labeled_array]
607
+
608
+ return heatmap_array
609
+
610
+ else:
611
+ # Create matplotlib plot
612
+ fig = plt.figure(figsize=figsize)
613
+
614
+ if is_3d:
615
+ # 3D plot
616
+ ax = fig.add_subplot(111, projection='3d')
617
+
618
+ # Extract coordinates (assuming [Z, Y, X] format)
619
+ z_coords = positions[:, 0]
620
+ y_coords = positions[:, 1]
621
+ x_coords = positions[:, 2]
622
+
623
+ # Create scatter plot
624
+ scatter = ax.scatter(x_coords, y_coords, z_coords,
625
+ c=intensities, s=point_size, alpha=alpha,
626
+ cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
627
+
628
+ ax.set_xlabel('X')
629
+ ax.set_ylabel('Y')
630
+ ax.set_zlabel('Z')
631
+ ax.set_title(f'{title}')
632
+
633
+ # Set axis limits based on shape
634
+ ax.set_xlim(0, shape[2])
635
+ ax.set_ylim(0, shape[1])
636
+ ax.set_zlim(0, shape[0])
637
+
638
+ else:
639
+ # 2D plot (using Y, X coordinates, ignoring Z/first dimension)
640
+ ax = fig.add_subplot(111)
641
+
642
+ # Extract Y, X coordinates
643
+ y_coords = positions[:, 1]
644
+ x_coords = positions[:, 2]
645
+
646
+ # Create scatter plot
647
+ scatter = ax.scatter(x_coords, y_coords,
648
+ c=intensities, s=point_size, alpha=alpha,
649
+ cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
650
+
651
+ ax.set_xlabel('X')
652
+ ax.set_ylabel('Y')
653
+ ax.set_title(f'{title}')
654
+ ax.grid(True, alpha=0.3)
655
+
656
+ # Set axis limits based on shape
657
+ ax.set_xlim(0, shape[2])
658
+ ax.set_ylim(0, shape[1])
659
+
660
+ # Set origin to top-left (invert Y-axis)
661
+ ax.invert_yaxis()
662
+
663
+ # Add colorbar
664
+ cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
665
+ cbar.set_label(colorbar_label)
666
+
667
+ # Add text annotations for min/max values
668
+ cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
669
+ transform=cbar.ax.transAxes, va='bottom')
670
+ cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
671
+ transform=cbar.ax.transAxes, va='top')
672
+
673
+ plt.tight_layout()
674
+ plt.show()
337
675
 
338
676
  # Example usage:
339
677
  if __name__ == "__main__":
@@ -350,5 +688,4 @@ if __name__ == "__main__":
350
688
  fig, ax = plot_dict_heatmap(sample_dict, sample_id_set,
351
689
  title="Sample Heatmap Visualization")
352
690
 
353
- plt.show()
354
-
691
+ plt.show()