nettracer3d 0.8.1__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

@@ -0,0 +1,161 @@
1
+ import subprocess
2
+ import sys
3
+ import threading
4
+ from pathlib import Path
5
+ from PyQt6.QtWidgets import QMessageBox, QWidget
6
+
7
+ class CellposeGUILauncher:
8
+ """Simple launcher for cellpose GUI in PyQt6 applications."""
9
+
10
+ def __init__(self, parent_widget=None):
11
+ """
12
+ Initialize the launcher.
13
+
14
+ Args:
15
+ parent_widget: PyQt6 widget for showing message boxes (optional)
16
+ """
17
+ self.parent_widget = parent_widget
18
+ self.cellpose_process = None
19
+
20
+ def launch_cellpose_gui(self, image_path=None, working_directory=None):
21
+ """
22
+ Launch cellpose GUI in a separate thread.
23
+
24
+ Args:
25
+ image_path (str, optional): Path to image file to load automatically
26
+ working_directory (str, optional): Directory to start cellpose in
27
+
28
+ Returns:
29
+ bool: True if launch was initiated successfully
30
+ """
31
+ def run_cellpose():
32
+ """Function to run in separate thread."""
33
+ try:
34
+ # Build command
35
+ cmd = [sys.executable, "-m", "cellpose"]
36
+
37
+ # Add image path if provided
38
+ if image_path and Path(image_path).exists():
39
+ cmd.extend(["--image_path", str(image_path)])
40
+
41
+ # Set working directory
42
+ cwd = working_directory if working_directory else None
43
+
44
+ # Launch cellpose GUI
45
+ self.cellpose_process = subprocess.Popen(
46
+ cmd,
47
+ cwd=cwd,
48
+ stdout=subprocess.PIPE,
49
+ stderr=subprocess.PIPE
50
+ )
51
+
52
+ # Optional: wait for process to complete
53
+ # self.cellpose_process.wait()
54
+
55
+ except Exception as e:
56
+ if self.parent_widget:
57
+ # Show error in main thread
58
+ self.show_error(f"Failed to launch cellpose GUI: {str(e)}")
59
+ else:
60
+ print(f"Failed to launch cellpose GUI: {str(e)}")
61
+
62
+ try:
63
+ # Start cellpose in separate thread
64
+ thread = threading.Thread(target=run_cellpose, daemon=True)
65
+ thread.start()
66
+
67
+ if self.parent_widget:
68
+ self.show_info("Cellpose GUI launched!")
69
+ else:
70
+ print("Cellpose GUI launched!")
71
+
72
+ return True
73
+
74
+ except Exception as e:
75
+ if self.parent_widget:
76
+ self.show_error(f"Failed to start cellpose thread: {str(e)}")
77
+ else:
78
+ print(f"Failed to start cellpose thread: {str(e)}")
79
+ return False
80
+
81
+ def launch_with_directory(self, directory_path):
82
+ """
83
+ Launch cellpose GUI with a specific directory.
84
+
85
+ Args:
86
+ directory_path (str): Directory containing images
87
+ """
88
+ cmd_args = ["--dir", str(directory_path)]
89
+ return self.launch_cellpose_gui_with_args(cmd_args, working_directory=directory_path)
90
+
91
+ def launch_cellpose_gui_with_args(self, additional_args=None, working_directory=None):
92
+ """
93
+ Launch cellpose GUI with custom arguments.
94
+
95
+ Args:
96
+ additional_args (list): List of additional command line arguments
97
+ working_directory (str): Working directory for cellpose
98
+ """
99
+ def run_cellpose_custom():
100
+ try:
101
+ cmd = [sys.executable, "-m", "cellpose"]
102
+
103
+ if additional_args:
104
+ cmd.extend(additional_args)
105
+
106
+ cwd = working_directory if working_directory else None
107
+
108
+ self.cellpose_process = subprocess.Popen(
109
+ cmd,
110
+ cwd=cwd,
111
+ stdout=subprocess.PIPE,
112
+ stderr=subprocess.PIPE
113
+ )
114
+
115
+ except Exception as e:
116
+ if self.parent_widget:
117
+ self.show_error(f"Failed to launch cellpose GUI: {str(e)}")
118
+ else:
119
+ print(f"Failed to launch cellpose GUI: {str(e)}")
120
+
121
+ try:
122
+ thread = threading.Thread(target=run_cellpose_custom, daemon=True)
123
+ thread.start()
124
+ return True
125
+ except Exception as e:
126
+ if self.parent_widget:
127
+ self.show_error(f"Failed to start cellpose: {str(e)}")
128
+ return False
129
+
130
+ def is_cellpose_running(self):
131
+ """
132
+ Check if cellpose process is still running.
133
+
134
+ Returns:
135
+ bool: True if cellpose is still running
136
+ """
137
+ if self.cellpose_process is None:
138
+ return False
139
+
140
+ return self.cellpose_process.poll() is None
141
+
142
+ def close_cellpose(self):
143
+ """Terminate the cellpose process if running."""
144
+ if self.cellpose_process and self.is_cellpose_running():
145
+ try:
146
+ self.cellpose_process.terminate()
147
+ self.cellpose_process.wait(timeout=5) # Wait up to 5 seconds
148
+ except subprocess.TimeoutExpired:
149
+ self.cellpose_process.kill() # Force kill if it doesn't terminate
150
+ except Exception as e:
151
+ print(f"Error closing cellpose: {e}")
152
+
153
+ def show_info(self, message):
154
+ """Show info message if parent widget available."""
155
+ if self.parent_widget:
156
+ QMessageBox.information(self.parent_widget, "Cellpose Launcher", message)
157
+
158
+ def show_error(self, message):
159
+ """Show error message if parent widget available."""
160
+ if self.parent_widget:
161
+ QMessageBox.critical(self.parent_widget, "Cellpose Error", message)
@@ -393,28 +393,105 @@ def find_hub_nodes(G: nx.Graph, proportion: float = 0.1) -> List:
393
393
  return output
394
394
 
395
395
  def get_color_name_mapping():
396
- """Return a dictionary of common colors and their RGB values."""
396
+ """Return a dictionary of descriptive color names and their RGB values."""
397
397
  return {
398
- 'red': (255, 0, 0),
399
- 'green': (0, 255, 0),
400
- 'blue': (0, 0, 255),
401
- 'yellow': (255, 255, 0),
402
- 'cyan': (0, 255, 255),
398
+ # Reds
399
+ 'crimson_red': (220, 20, 60),
400
+ 'bright_red': (255, 0, 0),
401
+ 'dark_red': (139, 0, 0),
402
+ 'coral_red': (255, 127, 80),
403
+ 'rose_red': (255, 102, 102),
404
+ 'burgundy': (128, 0, 32),
405
+ 'cherry_red': (222, 49, 99),
406
+
407
+ # Greens
408
+ 'forest_green': (34, 139, 34),
409
+ 'lime_green': (50, 205, 50),
410
+ 'bright_green': (0, 255, 0),
411
+ 'dark_green': (0, 100, 0),
412
+ 'mint_green': (152, 255, 152),
413
+ 'sage_green': (159, 183, 121),
414
+ 'emerald_green': (80, 200, 120),
415
+ 'olive_green': (128, 128, 0),
416
+
417
+ # Blues
418
+ 'royal_blue': (65, 105, 225),
419
+ 'bright_blue': (0, 0, 255),
420
+ 'navy_blue': (0, 0, 128),
421
+ 'sky_blue': (135, 206, 235),
422
+ 'steel_blue': (70, 130, 180),
423
+ 'powder_blue': (176, 224, 230),
424
+ 'midnight_blue': (25, 25, 112),
425
+ 'cobalt_blue': (0, 71, 171),
426
+
427
+ # Purples
428
+ 'deep_purple': (75, 0, 130),
429
+ 'royal_purple': (120, 81, 169),
430
+ 'lavender': (230, 230, 250),
431
+ 'plum_purple': (221, 160, 221),
432
+ 'violet_purple': (238, 130, 238),
403
433
  'magenta': (255, 0, 255),
404
- 'purple': (128, 0, 128),
405
- 'orange': (255, 165, 0),
406
- 'brown': (165, 42, 42),
407
- 'pink': (255, 192, 203),
408
- 'navy': (0, 0, 128),
409
- 'teal': (0, 128, 128),
410
- 'olive': (128, 128, 0),
411
- 'maroon': (128, 0, 0),
412
- 'lime': (50, 205, 50),
413
- 'indigo': (75, 0, 130),
414
- 'violet': (238, 130, 238),
415
- 'coral': (255, 127, 80),
434
+ 'orchid': (218, 112, 214),
435
+
436
+ # Yellows & Golds
437
+ 'bright_yellow': (255, 255, 0),
438
+ 'golden_yellow': (255, 215, 0),
439
+ 'lemon_yellow': (255, 247, 0),
440
+ 'amber': (255, 191, 0),
441
+ 'mustard_yellow': (255, 219, 88),
442
+ 'cream': (255, 253, 208),
443
+ 'wheat': (245, 222, 179),
444
+
445
+ # Oranges
446
+ 'bright_orange': (255, 165, 0),
447
+ 'burnt_orange': (204, 85, 0),
448
+ 'peach': (255, 218, 185),
449
+ 'tangerine': (255, 163, 67),
450
+ 'pumpkin_orange': (255, 117, 24),
451
+ 'apricot': (251, 206, 177),
452
+
453
+ # Pinks
454
+ 'hot_pink': (255, 105, 180),
455
+ 'light_pink': (255, 192, 203),
456
+ 'deep_pink': (255, 20, 147),
457
+ 'salmon_pink': (250, 128, 114),
458
+ 'blush_pink': (255, 182, 193),
459
+ 'fuchsia': (255, 0, 255),
460
+
461
+ # Cyans & Teals
462
+ 'bright_cyan': (0, 255, 255),
463
+ 'dark_teal': (0, 128, 128),
416
464
  'turquoise': (64, 224, 208),
417
- 'gold': (255, 215, 0)
465
+ 'aqua': (0, 255, 255),
466
+ 'seafoam': (159, 226, 191),
467
+ 'teal_blue': (54, 117, 136),
468
+
469
+ # Browns & Earth Tones
470
+ 'chocolate_brown': (210, 105, 30),
471
+ 'saddle_brown': (139, 69, 19),
472
+ 'light_brown': (205, 133, 63),
473
+ 'tan': (210, 180, 140),
474
+ 'beige': (245, 245, 220),
475
+ 'coffee_brown': (111, 78, 55),
476
+ 'rust_brown': (183, 65, 14),
477
+
478
+ # Grays & Neutrals
479
+ 'charcoal_gray': (54, 69, 79),
480
+ 'light_gray': (211, 211, 211),
481
+ 'silver': (192, 192, 192),
482
+ 'slate_gray': (112, 128, 144),
483
+ 'ash_gray': (178, 190, 181),
484
+ 'smoke_gray': (152, 152, 152),
485
+
486
+ # Additional Distinctive Colors
487
+ 'lime_yellow': (191, 255, 0),
488
+ 'electric_blue': (125, 249, 255),
489
+ 'neon_green': (57, 255, 20),
490
+ 'wine_red': (114, 47, 55),
491
+ 'copper': (184, 115, 51),
492
+ 'ivory': (255, 255, 240),
493
+ 'periwinkle': (204, 204, 255),
494
+ 'mint': (189, 252, 201)
418
495
  }
419
496
 
420
497
  def rgb_to_color_name(rgb: Tuple[int, int, int]) -> str:
@@ -440,7 +517,7 @@ def rgb_to_color_name(rgb: Tuple[int, int, int]) -> str:
440
517
  distance = np.sqrt(np.sum((rgb_array - np.array(color_rgb)) ** 2))
441
518
  if distance < min_distance:
442
519
  min_distance = distance
443
- closest_color = color_name
520
+ closest_color = color_name + f" {str(rgb_array)}"
444
521
 
445
522
  return closest_color
446
523
 
@@ -1,52 +1,176 @@
1
1
  import numpy as np
2
2
  from sklearn.cluster import KMeans
3
+ from sklearn.metrics import calinski_harabasz_score
3
4
  import matplotlib.pyplot as plt
4
5
  from typing import Dict, Set
5
6
  import umap
6
-
7
+ from matplotlib.colors import LinearSegmentedColormap
8
+ from sklearn.cluster import DBSCAN
9
+ from sklearn.neighbors import NearestNeighbors
7
10
 
8
11
 
9
12
  import os
10
13
  os.environ['LOKY_MAX_CPU_COUNT'] = '4'
11
14
 
12
- def cluster_arrays(data_input, n_clusters, seed = 42):
15
+ def cluster_arrays_dbscan(data_input, seed=42):
13
16
  """
14
- Simple clustering of 1D arrays with key tracking.
17
+ Simple DBSCAN clustering of 1D arrays with sensible defaults.
15
18
 
16
19
  Parameters:
17
20
  -----------
18
21
  data_input : dict or List[List[float]]
19
22
  Dictionary {key: array} or list of arrays to cluster
20
- n_clusters : int
21
- How many groups you want
23
+ seed : int
24
+ Random seed for reproducibility (used for parameter estimation)
22
25
 
23
26
  Returns:
24
27
  --------
25
- dict: {cluster_id: {'keys': [keys], 'arrays': [arrays]}}
28
+ list: [[key1, key2], [key3, key4, key5]] - List of clusters, each containing keys/indices
29
+ Note: Outliers are excluded from the output
26
30
  """
27
31
 
28
32
  # Handle both dict and list inputs
29
33
  if isinstance(data_input, dict):
30
34
  keys = list(data_input.keys())
31
- array_values = list(data_input.values()) # Use .values() to get the arrays
35
+ array_values = list(data_input.values())
32
36
  else:
33
- keys = list(range(len(data_input))) # Use indices as keys for lists
37
+ keys = list(range(len(data_input)))
34
38
  array_values = data_input
35
39
 
36
- # Convert to numpy and cluster
40
+ # Convert to numpy
37
41
  data = np.array(array_values)
38
- kmeans = KMeans(n_clusters=n_clusters, random_state=seed)
39
- labels = kmeans.fit_predict(data)
42
+ n_samples = len(data)
43
+
44
+ # Simple heuristics for DBSCAN parameters
45
+ min_samples = max(3, int(np.sqrt(n_samples) * 0.2)) # Roughly sqrt(n)/5, minimum 3
46
+
47
+ # Estimate eps using 4th nearest neighbor distance (common heuristic)
48
+ k = min(4, n_samples - 1)
49
+ if k > 0:
50
+ nbrs = NearestNeighbors(n_neighbors=k + 1)
51
+ nbrs.fit(data)
52
+ distances, _ = nbrs.kneighbors(data)
53
+ # Use 80th percentile of k-nearest distances as eps
54
+ eps = np.percentile(distances[:, k], 80)
55
+ else:
56
+ eps = 0.1 # fallback
57
+
58
+ print(f"Using DBSCAN with eps={eps:.4f}, min_samples={min_samples}")
59
+
60
+ # Perform DBSCAN clustering
61
+ dbscan = DBSCAN(eps=eps, min_samples=min_samples)
62
+ labels = dbscan.fit_predict(data)
63
+
64
+ # Organize results into clusters (excluding outliers)
65
+ clusters = []
66
+
67
+ # Add only the main clusters (non-noise points)
68
+ unique_labels = np.unique(labels)
69
+ main_clusters = [label for label in unique_labels if label != -1]
70
+
71
+ for label in main_clusters:
72
+ cluster_indices = np.where(labels == label)[0]
73
+ cluster_keys = [keys[i] for i in cluster_indices]
74
+ clusters.append(cluster_keys)
75
+
76
+ n_main_clusters = len(main_clusters)
77
+ n_outliers = np.sum(labels == -1)
40
78
 
79
+ print(f"Found {n_main_clusters} main clusters and {n_outliers} outliers (go in neighborhood 0)")
80
+
81
+ return clusters
41
82
 
83
+ def cluster_arrays(data_input, n_clusters=None, seed=42):
84
+ """
85
+ Simple clustering of 1D arrays with key tracking and automatic cluster count detection.
86
+
87
+ Parameters:
88
+ -----------
89
+ data_input : dict or List[List[float]]
90
+ Dictionary {key: array} or list of arrays to cluster
91
+ n_clusters : int or None
92
+ How many groups you want. If None, will automatically determine optimal k
93
+ seed : int
94
+ Random seed for reproducibility
95
+ max_k : int
96
+ Maximum number of clusters to consider when auto-detecting (default: 10)
97
+
98
+ Returns:
99
+ --------
100
+ list: [[key1, key2], [key3, key4, key5]] - List of clusters, each containing keys/indices
101
+ """
102
+
103
+ # Handle both dict and list inputs
104
+ if isinstance(data_input, dict):
105
+ keys = list(data_input.keys())
106
+ array_values = list(data_input.values())
107
+ else:
108
+ keys = list(range(len(data_input)))
109
+ array_values = data_input
110
+
111
+ # Convert to numpy
112
+ data = np.array(array_values)
113
+
114
+ # Auto-detect optimal number of clusters if not specified
115
+ if n_clusters is None:
116
+ n_clusters = _find_optimal_clusters(data, seed, max_k = (len(keys) // 3))
117
+ print(f"Auto-detected optimal number of clusters: {n_clusters}")
118
+
119
+ # Perform clustering
120
+ kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10)
121
+ labels = kmeans.fit_predict(data)
122
+
123
+ # Organize results into clusters - simple list of lists with keys only
42
124
  clusters = [[] for _ in range(n_clusters)]
43
-
44
125
  for i, label in enumerate(labels):
45
126
  clusters[label].append(keys[i])
46
-
127
+
47
128
  return clusters
129
+
130
+ def _find_optimal_clusters(data, seed, max_k):
131
+ """
132
+ Find optimal number of clusters using Calinski-Harabasz index.
133
+ """
134
+ n_samples = len(data)
135
+
136
+ # Need at least 2 samples to cluster
137
+ if n_samples < 2:
138
+ return 1
139
+
140
+ # Limit max_k to reasonable bounds
141
+ max_k = min(max_k, n_samples - 1, 20)
142
+ print(f"Max_k: {max_k}, n_samples: {n_samples}")
143
+
144
+ if max_k < 2:
145
+ return 1
146
+
147
+ # Use Calinski-Harabasz index to find optimal k
148
+ ch_scores = []
149
+ k_range = range(2, max_k + 1)
150
+
151
+ for k in k_range:
152
+ try:
153
+ kmeans = KMeans(n_clusters=k, random_state=seed, n_init=10)
154
+ labels = kmeans.fit_predict(data)
155
+
156
+ # Check if we got the expected number of clusters
157
+ if len(np.unique(labels)) == k:
158
+ score = calinski_harabasz_score(data, labels)
159
+ ch_scores.append(score)
160
+ else:
161
+ ch_scores.append(0) # Penalize solutions that didn't achieve k clusters
162
+
163
+ except Exception:
164
+ ch_scores.append(0)
48
165
 
49
- def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighborhood Heatmap"):
166
+ # Find k with highest Calinski-Harabasz score
167
+ if ch_scores and max(ch_scores) > 0:
168
+ optimal_k = k_range[np.argmax(ch_scores)]
169
+ print(f"Using {optimal_k} neighborhoods")
170
+ return optimal_k
171
+
172
+ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighborhood Heatmap",
173
+ center_at_one=False):
50
174
  """
51
175
  Create a heatmap from a dictionary of numpy arrays.
52
176
 
@@ -60,6 +184,11 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
60
184
  Figure size (width, height)
61
185
  title : str, optional
62
186
  Title for the heatmap
187
+ center_at_one : bool, optional
188
+ If True, uses a diverging colormap centered at 1 with nonlinear scaling:
189
+ - 0 to 1: blue to white (underrepresentation to normal)
190
+ - 1+: white to red (overrepresentation)
191
+ If False (default), uses standard white-to-red scaling from 0 to 1
63
192
 
64
193
  Returns:
65
194
  --------
@@ -67,7 +196,6 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
67
196
  """
68
197
 
69
198
  data_dict = {k: unsorted_data_dict[k] for k in sorted(unsorted_data_dict.keys())}
70
-
71
199
  # Convert dict to 2D array for heatmap
72
200
  # Each row represents one key from the dict
73
201
  keys = list(data_dict.keys())
@@ -76,8 +204,70 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
76
204
  # Create the plot
77
205
  fig, ax = plt.subplots(figsize=figsize)
78
206
 
79
- # Create heatmap with white-to-red colormap
80
- im = ax.imshow(data_matrix, cmap='Reds', aspect='auto', vmin=0, vmax=1)
207
+ if center_at_one:
208
+ # Custom colormap and scaling for center_at_one mode
209
+ # Find the actual data range
210
+ data_min = np.min(data_matrix)
211
+ data_max = np.max(data_matrix)
212
+
213
+ # Create a custom colormap: blue -> white -> red
214
+ colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7',
215
+ '#fddbc7', '#f4a582', '#d6604d', '#b2182b']
216
+ n_bins = 256
217
+ cmap = LinearSegmentedColormap.from_list('custom_diverging', colors, N=n_bins)
218
+
219
+ # Create nonlinear transformation
220
+ # Map 0->1 with more resolution, 1+ with less resolution
221
+ def transform_data(data):
222
+ transformed = np.zeros_like(data)
223
+
224
+ # For values 0 to 1: use square root for faster approach to middle
225
+ mask_low = data <= 1
226
+ transformed[mask_low] = 0.5 * np.sqrt(data[mask_low])
227
+
228
+ # For values > 1: use slower logarithmic scaling
229
+ mask_high = data > 1
230
+ if np.any(mask_high):
231
+ # Scale from 0.5 to 1.0 based on log of excess above 1
232
+ max_excess = np.max(data[mask_high] - 1) if np.any(mask_high) else 0
233
+ if max_excess > 0:
234
+ excess_normalized = np.log1p(data[mask_high] - 1) / np.log1p(max_excess)
235
+ transformed[mask_high] = 0.5 + 0.5 * excess_normalized
236
+ else:
237
+ transformed[mask_high] = 0.5
238
+
239
+ return transformed
240
+
241
+ # Transform the data for visualization
242
+ transformed_matrix = transform_data(data_matrix)
243
+
244
+ # Create heatmap with custom colormap
245
+ im = ax.imshow(transformed_matrix, cmap=cmap, aspect='auto', vmin=0, vmax=1)
246
+
247
+ # Create custom colorbar with original values
248
+ cbar = ax.figure.colorbar(im, ax=ax)
249
+
250
+ # Set colorbar ticks to show meaningful values
251
+ if data_max > 1:
252
+ tick_values = [0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0]
253
+ tick_values = [v for v in tick_values if data_min <= v <= data_max]
254
+ else:
255
+ tick_values = [0, 0.25, 0.5, 0.75, 1.0]
256
+ tick_values = [v for v in tick_values if data_min <= v <= data_max]
257
+
258
+ # Transform tick values for colorbar positioning
259
+ transformed_ticks = transform_data(np.array(tick_values))
260
+ cbar.set_ticks(transformed_ticks)
261
+ cbar.set_ticklabels([f'{v:.2f}' for v in tick_values])
262
+ cbar.ax.set_ylabel('Representation Ratio', rotation=-90, va="bottom")
263
+
264
+ else:
265
+ # Default behavior: white-to-red colormap
266
+ im = ax.imshow(data_matrix, cmap='Reds', aspect='auto', vmin=0, vmax=1)
267
+
268
+ # Add standard colorbar
269
+ cbar = ax.figure.colorbar(im, ax=ax)
270
+ cbar.ax.set_ylabel('Intensity', rotation=-90, va="bottom")
81
271
 
82
272
  # Set ticks and labels
83
273
  ax.set_xticks(np.arange(len(id_set)))
@@ -91,23 +281,32 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
91
281
  # Add text annotations showing the actual values
92
282
  for i in range(len(keys)):
93
283
  for j in range(len(id_set)):
284
+ # Use original data values for annotations
94
285
  text = ax.text(j, i, f'{data_matrix[i, j]:.3f}',
95
286
  ha="center", va="center", color="black", fontsize=8)
96
-
97
- # Add colorbar
98
- cbar = ax.figure.colorbar(im, ax=ax)
99
- cbar.ax.set_ylabel('Intensity', rotation=-90, va="bottom")
287
+
288
+
289
+ ret_dict = {}
290
+
291
+ for i, row in enumerate(data_matrix):
292
+ ret_dict[keys[i]] = row
100
293
 
101
294
  # Set labels and title
102
- ax.set_xlabel('Proportion of Node Type')
295
+ if center_at_one:
296
+ ax.set_xlabel('Representation Factor of Node Type')
297
+ else:
298
+ ax.set_xlabel('Proportion of Node Type')
299
+
103
300
  ax.set_ylabel('Neighborhood')
104
301
  ax.set_title(title)
105
302
 
106
303
  # Adjust layout to prevent label cutoff
107
304
  plt.tight_layout()
108
-
109
305
  plt.show()
110
306
 
307
+ return ret_dict
308
+
309
+
111
310
 
112
311
  def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
113
312
  class_names: Set[str],