nettracer3d 0.8.1__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nettracer3d might be problematic. Click here for more details.
- nettracer3d/cellpose_manager.py +161 -0
- nettracer3d/community_extractor.py +97 -20
- nettracer3d/neighborhoods.py +222 -23
- nettracer3d/nettracer.py +166 -68
- nettracer3d/nettracer_gui.py +584 -266
- nettracer3d/network_analysis.py +222 -230
- nettracer3d/proximity.py +191 -30
- nettracer3d-0.8.2.dist-info/METADATA +117 -0
- {nettracer3d-0.8.1.dist-info → nettracer3d-0.8.2.dist-info}/RECORD +13 -12
- nettracer3d-0.8.1.dist-info/METADATA +0 -80
- {nettracer3d-0.8.1.dist-info → nettracer3d-0.8.2.dist-info}/WHEEL +0 -0
- {nettracer3d-0.8.1.dist-info → nettracer3d-0.8.2.dist-info}/entry_points.txt +0 -0
- {nettracer3d-0.8.1.dist-info → nettracer3d-0.8.2.dist-info}/licenses/LICENSE +0 -0
- {nettracer3d-0.8.1.dist-info → nettracer3d-0.8.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
import sys
|
|
3
|
+
import threading
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from PyQt6.QtWidgets import QMessageBox, QWidget
|
|
6
|
+
|
|
7
|
+
class CellposeGUILauncher:
|
|
8
|
+
"""Simple launcher for cellpose GUI in PyQt6 applications."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, parent_widget=None):
|
|
11
|
+
"""
|
|
12
|
+
Initialize the launcher.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
parent_widget: PyQt6 widget for showing message boxes (optional)
|
|
16
|
+
"""
|
|
17
|
+
self.parent_widget = parent_widget
|
|
18
|
+
self.cellpose_process = None
|
|
19
|
+
|
|
20
|
+
def launch_cellpose_gui(self, image_path=None, working_directory=None):
|
|
21
|
+
"""
|
|
22
|
+
Launch cellpose GUI in a separate thread.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
image_path (str, optional): Path to image file to load automatically
|
|
26
|
+
working_directory (str, optional): Directory to start cellpose in
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
bool: True if launch was initiated successfully
|
|
30
|
+
"""
|
|
31
|
+
def run_cellpose():
|
|
32
|
+
"""Function to run in separate thread."""
|
|
33
|
+
try:
|
|
34
|
+
# Build command
|
|
35
|
+
cmd = [sys.executable, "-m", "cellpose"]
|
|
36
|
+
|
|
37
|
+
# Add image path if provided
|
|
38
|
+
if image_path and Path(image_path).exists():
|
|
39
|
+
cmd.extend(["--image_path", str(image_path)])
|
|
40
|
+
|
|
41
|
+
# Set working directory
|
|
42
|
+
cwd = working_directory if working_directory else None
|
|
43
|
+
|
|
44
|
+
# Launch cellpose GUI
|
|
45
|
+
self.cellpose_process = subprocess.Popen(
|
|
46
|
+
cmd,
|
|
47
|
+
cwd=cwd,
|
|
48
|
+
stdout=subprocess.PIPE,
|
|
49
|
+
stderr=subprocess.PIPE
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Optional: wait for process to complete
|
|
53
|
+
# self.cellpose_process.wait()
|
|
54
|
+
|
|
55
|
+
except Exception as e:
|
|
56
|
+
if self.parent_widget:
|
|
57
|
+
# Show error in main thread
|
|
58
|
+
self.show_error(f"Failed to launch cellpose GUI: {str(e)}")
|
|
59
|
+
else:
|
|
60
|
+
print(f"Failed to launch cellpose GUI: {str(e)}")
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
# Start cellpose in separate thread
|
|
64
|
+
thread = threading.Thread(target=run_cellpose, daemon=True)
|
|
65
|
+
thread.start()
|
|
66
|
+
|
|
67
|
+
if self.parent_widget:
|
|
68
|
+
self.show_info("Cellpose GUI launched!")
|
|
69
|
+
else:
|
|
70
|
+
print("Cellpose GUI launched!")
|
|
71
|
+
|
|
72
|
+
return True
|
|
73
|
+
|
|
74
|
+
except Exception as e:
|
|
75
|
+
if self.parent_widget:
|
|
76
|
+
self.show_error(f"Failed to start cellpose thread: {str(e)}")
|
|
77
|
+
else:
|
|
78
|
+
print(f"Failed to start cellpose thread: {str(e)}")
|
|
79
|
+
return False
|
|
80
|
+
|
|
81
|
+
def launch_with_directory(self, directory_path):
|
|
82
|
+
"""
|
|
83
|
+
Launch cellpose GUI with a specific directory.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
directory_path (str): Directory containing images
|
|
87
|
+
"""
|
|
88
|
+
cmd_args = ["--dir", str(directory_path)]
|
|
89
|
+
return self.launch_cellpose_gui_with_args(cmd_args, working_directory=directory_path)
|
|
90
|
+
|
|
91
|
+
def launch_cellpose_gui_with_args(self, additional_args=None, working_directory=None):
|
|
92
|
+
"""
|
|
93
|
+
Launch cellpose GUI with custom arguments.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
additional_args (list): List of additional command line arguments
|
|
97
|
+
working_directory (str): Working directory for cellpose
|
|
98
|
+
"""
|
|
99
|
+
def run_cellpose_custom():
|
|
100
|
+
try:
|
|
101
|
+
cmd = [sys.executable, "-m", "cellpose"]
|
|
102
|
+
|
|
103
|
+
if additional_args:
|
|
104
|
+
cmd.extend(additional_args)
|
|
105
|
+
|
|
106
|
+
cwd = working_directory if working_directory else None
|
|
107
|
+
|
|
108
|
+
self.cellpose_process = subprocess.Popen(
|
|
109
|
+
cmd,
|
|
110
|
+
cwd=cwd,
|
|
111
|
+
stdout=subprocess.PIPE,
|
|
112
|
+
stderr=subprocess.PIPE
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
except Exception as e:
|
|
116
|
+
if self.parent_widget:
|
|
117
|
+
self.show_error(f"Failed to launch cellpose GUI: {str(e)}")
|
|
118
|
+
else:
|
|
119
|
+
print(f"Failed to launch cellpose GUI: {str(e)}")
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
thread = threading.Thread(target=run_cellpose_custom, daemon=True)
|
|
123
|
+
thread.start()
|
|
124
|
+
return True
|
|
125
|
+
except Exception as e:
|
|
126
|
+
if self.parent_widget:
|
|
127
|
+
self.show_error(f"Failed to start cellpose: {str(e)}")
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
def is_cellpose_running(self):
|
|
131
|
+
"""
|
|
132
|
+
Check if cellpose process is still running.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
bool: True if cellpose is still running
|
|
136
|
+
"""
|
|
137
|
+
if self.cellpose_process is None:
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
return self.cellpose_process.poll() is None
|
|
141
|
+
|
|
142
|
+
def close_cellpose(self):
|
|
143
|
+
"""Terminate the cellpose process if running."""
|
|
144
|
+
if self.cellpose_process and self.is_cellpose_running():
|
|
145
|
+
try:
|
|
146
|
+
self.cellpose_process.terminate()
|
|
147
|
+
self.cellpose_process.wait(timeout=5) # Wait up to 5 seconds
|
|
148
|
+
except subprocess.TimeoutExpired:
|
|
149
|
+
self.cellpose_process.kill() # Force kill if it doesn't terminate
|
|
150
|
+
except Exception as e:
|
|
151
|
+
print(f"Error closing cellpose: {e}")
|
|
152
|
+
|
|
153
|
+
def show_info(self, message):
|
|
154
|
+
"""Show info message if parent widget available."""
|
|
155
|
+
if self.parent_widget:
|
|
156
|
+
QMessageBox.information(self.parent_widget, "Cellpose Launcher", message)
|
|
157
|
+
|
|
158
|
+
def show_error(self, message):
|
|
159
|
+
"""Show error message if parent widget available."""
|
|
160
|
+
if self.parent_widget:
|
|
161
|
+
QMessageBox.critical(self.parent_widget, "Cellpose Error", message)
|
|
@@ -393,28 +393,105 @@ def find_hub_nodes(G: nx.Graph, proportion: float = 0.1) -> List:
|
|
|
393
393
|
return output
|
|
394
394
|
|
|
395
395
|
def get_color_name_mapping():
|
|
396
|
-
"""Return a dictionary of
|
|
396
|
+
"""Return a dictionary of descriptive color names and their RGB values."""
|
|
397
397
|
return {
|
|
398
|
-
|
|
399
|
-
'
|
|
400
|
-
'
|
|
401
|
-
'
|
|
402
|
-
'
|
|
398
|
+
# Reds
|
|
399
|
+
'crimson_red': (220, 20, 60),
|
|
400
|
+
'bright_red': (255, 0, 0),
|
|
401
|
+
'dark_red': (139, 0, 0),
|
|
402
|
+
'coral_red': (255, 127, 80),
|
|
403
|
+
'rose_red': (255, 102, 102),
|
|
404
|
+
'burgundy': (128, 0, 32),
|
|
405
|
+
'cherry_red': (222, 49, 99),
|
|
406
|
+
|
|
407
|
+
# Greens
|
|
408
|
+
'forest_green': (34, 139, 34),
|
|
409
|
+
'lime_green': (50, 205, 50),
|
|
410
|
+
'bright_green': (0, 255, 0),
|
|
411
|
+
'dark_green': (0, 100, 0),
|
|
412
|
+
'mint_green': (152, 255, 152),
|
|
413
|
+
'sage_green': (159, 183, 121),
|
|
414
|
+
'emerald_green': (80, 200, 120),
|
|
415
|
+
'olive_green': (128, 128, 0),
|
|
416
|
+
|
|
417
|
+
# Blues
|
|
418
|
+
'royal_blue': (65, 105, 225),
|
|
419
|
+
'bright_blue': (0, 0, 255),
|
|
420
|
+
'navy_blue': (0, 0, 128),
|
|
421
|
+
'sky_blue': (135, 206, 235),
|
|
422
|
+
'steel_blue': (70, 130, 180),
|
|
423
|
+
'powder_blue': (176, 224, 230),
|
|
424
|
+
'midnight_blue': (25, 25, 112),
|
|
425
|
+
'cobalt_blue': (0, 71, 171),
|
|
426
|
+
|
|
427
|
+
# Purples
|
|
428
|
+
'deep_purple': (75, 0, 130),
|
|
429
|
+
'royal_purple': (120, 81, 169),
|
|
430
|
+
'lavender': (230, 230, 250),
|
|
431
|
+
'plum_purple': (221, 160, 221),
|
|
432
|
+
'violet_purple': (238, 130, 238),
|
|
403
433
|
'magenta': (255, 0, 255),
|
|
404
|
-
'
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
'
|
|
408
|
-
'
|
|
409
|
-
'
|
|
410
|
-
'
|
|
411
|
-
'
|
|
412
|
-
'
|
|
413
|
-
'
|
|
414
|
-
|
|
415
|
-
|
|
434
|
+
'orchid': (218, 112, 214),
|
|
435
|
+
|
|
436
|
+
# Yellows & Golds
|
|
437
|
+
'bright_yellow': (255, 255, 0),
|
|
438
|
+
'golden_yellow': (255, 215, 0),
|
|
439
|
+
'lemon_yellow': (255, 247, 0),
|
|
440
|
+
'amber': (255, 191, 0),
|
|
441
|
+
'mustard_yellow': (255, 219, 88),
|
|
442
|
+
'cream': (255, 253, 208),
|
|
443
|
+
'wheat': (245, 222, 179),
|
|
444
|
+
|
|
445
|
+
# Oranges
|
|
446
|
+
'bright_orange': (255, 165, 0),
|
|
447
|
+
'burnt_orange': (204, 85, 0),
|
|
448
|
+
'peach': (255, 218, 185),
|
|
449
|
+
'tangerine': (255, 163, 67),
|
|
450
|
+
'pumpkin_orange': (255, 117, 24),
|
|
451
|
+
'apricot': (251, 206, 177),
|
|
452
|
+
|
|
453
|
+
# Pinks
|
|
454
|
+
'hot_pink': (255, 105, 180),
|
|
455
|
+
'light_pink': (255, 192, 203),
|
|
456
|
+
'deep_pink': (255, 20, 147),
|
|
457
|
+
'salmon_pink': (250, 128, 114),
|
|
458
|
+
'blush_pink': (255, 182, 193),
|
|
459
|
+
'fuchsia': (255, 0, 255),
|
|
460
|
+
|
|
461
|
+
# Cyans & Teals
|
|
462
|
+
'bright_cyan': (0, 255, 255),
|
|
463
|
+
'dark_teal': (0, 128, 128),
|
|
416
464
|
'turquoise': (64, 224, 208),
|
|
417
|
-
'
|
|
465
|
+
'aqua': (0, 255, 255),
|
|
466
|
+
'seafoam': (159, 226, 191),
|
|
467
|
+
'teal_blue': (54, 117, 136),
|
|
468
|
+
|
|
469
|
+
# Browns & Earth Tones
|
|
470
|
+
'chocolate_brown': (210, 105, 30),
|
|
471
|
+
'saddle_brown': (139, 69, 19),
|
|
472
|
+
'light_brown': (205, 133, 63),
|
|
473
|
+
'tan': (210, 180, 140),
|
|
474
|
+
'beige': (245, 245, 220),
|
|
475
|
+
'coffee_brown': (111, 78, 55),
|
|
476
|
+
'rust_brown': (183, 65, 14),
|
|
477
|
+
|
|
478
|
+
# Grays & Neutrals
|
|
479
|
+
'charcoal_gray': (54, 69, 79),
|
|
480
|
+
'light_gray': (211, 211, 211),
|
|
481
|
+
'silver': (192, 192, 192),
|
|
482
|
+
'slate_gray': (112, 128, 144),
|
|
483
|
+
'ash_gray': (178, 190, 181),
|
|
484
|
+
'smoke_gray': (152, 152, 152),
|
|
485
|
+
|
|
486
|
+
# Additional Distinctive Colors
|
|
487
|
+
'lime_yellow': (191, 255, 0),
|
|
488
|
+
'electric_blue': (125, 249, 255),
|
|
489
|
+
'neon_green': (57, 255, 20),
|
|
490
|
+
'wine_red': (114, 47, 55),
|
|
491
|
+
'copper': (184, 115, 51),
|
|
492
|
+
'ivory': (255, 255, 240),
|
|
493
|
+
'periwinkle': (204, 204, 255),
|
|
494
|
+
'mint': (189, 252, 201)
|
|
418
495
|
}
|
|
419
496
|
|
|
420
497
|
def rgb_to_color_name(rgb: Tuple[int, int, int]) -> str:
|
|
@@ -440,7 +517,7 @@ def rgb_to_color_name(rgb: Tuple[int, int, int]) -> str:
|
|
|
440
517
|
distance = np.sqrt(np.sum((rgb_array - np.array(color_rgb)) ** 2))
|
|
441
518
|
if distance < min_distance:
|
|
442
519
|
min_distance = distance
|
|
443
|
-
closest_color = color_name
|
|
520
|
+
closest_color = color_name + f" {str(rgb_array)}"
|
|
444
521
|
|
|
445
522
|
return closest_color
|
|
446
523
|
|
nettracer3d/neighborhoods.py
CHANGED
|
@@ -1,52 +1,176 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from sklearn.cluster import KMeans
|
|
3
|
+
from sklearn.metrics import calinski_harabasz_score
|
|
3
4
|
import matplotlib.pyplot as plt
|
|
4
5
|
from typing import Dict, Set
|
|
5
6
|
import umap
|
|
6
|
-
|
|
7
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
8
|
+
from sklearn.cluster import DBSCAN
|
|
9
|
+
from sklearn.neighbors import NearestNeighbors
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
import os
|
|
10
13
|
os.environ['LOKY_MAX_CPU_COUNT'] = '4'
|
|
11
14
|
|
|
12
|
-
def
|
|
15
|
+
def cluster_arrays_dbscan(data_input, seed=42):
|
|
13
16
|
"""
|
|
14
|
-
Simple clustering of 1D arrays with
|
|
17
|
+
Simple DBSCAN clustering of 1D arrays with sensible defaults.
|
|
15
18
|
|
|
16
19
|
Parameters:
|
|
17
20
|
-----------
|
|
18
21
|
data_input : dict or List[List[float]]
|
|
19
22
|
Dictionary {key: array} or list of arrays to cluster
|
|
20
|
-
|
|
21
|
-
|
|
23
|
+
seed : int
|
|
24
|
+
Random seed for reproducibility (used for parameter estimation)
|
|
22
25
|
|
|
23
26
|
Returns:
|
|
24
27
|
--------
|
|
25
|
-
|
|
28
|
+
list: [[key1, key2], [key3, key4, key5]] - List of clusters, each containing keys/indices
|
|
29
|
+
Note: Outliers are excluded from the output
|
|
26
30
|
"""
|
|
27
31
|
|
|
28
32
|
# Handle both dict and list inputs
|
|
29
33
|
if isinstance(data_input, dict):
|
|
30
34
|
keys = list(data_input.keys())
|
|
31
|
-
array_values = list(data_input.values())
|
|
35
|
+
array_values = list(data_input.values())
|
|
32
36
|
else:
|
|
33
|
-
keys = list(range(len(data_input)))
|
|
37
|
+
keys = list(range(len(data_input)))
|
|
34
38
|
array_values = data_input
|
|
35
39
|
|
|
36
|
-
# Convert to numpy
|
|
40
|
+
# Convert to numpy
|
|
37
41
|
data = np.array(array_values)
|
|
38
|
-
|
|
39
|
-
|
|
42
|
+
n_samples = len(data)
|
|
43
|
+
|
|
44
|
+
# Simple heuristics for DBSCAN parameters
|
|
45
|
+
min_samples = max(3, int(np.sqrt(n_samples) * 0.2)) # Roughly sqrt(n)/5, minimum 3
|
|
46
|
+
|
|
47
|
+
# Estimate eps using 4th nearest neighbor distance (common heuristic)
|
|
48
|
+
k = min(4, n_samples - 1)
|
|
49
|
+
if k > 0:
|
|
50
|
+
nbrs = NearestNeighbors(n_neighbors=k + 1)
|
|
51
|
+
nbrs.fit(data)
|
|
52
|
+
distances, _ = nbrs.kneighbors(data)
|
|
53
|
+
# Use 80th percentile of k-nearest distances as eps
|
|
54
|
+
eps = np.percentile(distances[:, k], 80)
|
|
55
|
+
else:
|
|
56
|
+
eps = 0.1 # fallback
|
|
57
|
+
|
|
58
|
+
print(f"Using DBSCAN with eps={eps:.4f}, min_samples={min_samples}")
|
|
59
|
+
|
|
60
|
+
# Perform DBSCAN clustering
|
|
61
|
+
dbscan = DBSCAN(eps=eps, min_samples=min_samples)
|
|
62
|
+
labels = dbscan.fit_predict(data)
|
|
63
|
+
|
|
64
|
+
# Organize results into clusters (excluding outliers)
|
|
65
|
+
clusters = []
|
|
66
|
+
|
|
67
|
+
# Add only the main clusters (non-noise points)
|
|
68
|
+
unique_labels = np.unique(labels)
|
|
69
|
+
main_clusters = [label for label in unique_labels if label != -1]
|
|
70
|
+
|
|
71
|
+
for label in main_clusters:
|
|
72
|
+
cluster_indices = np.where(labels == label)[0]
|
|
73
|
+
cluster_keys = [keys[i] for i in cluster_indices]
|
|
74
|
+
clusters.append(cluster_keys)
|
|
75
|
+
|
|
76
|
+
n_main_clusters = len(main_clusters)
|
|
77
|
+
n_outliers = np.sum(labels == -1)
|
|
40
78
|
|
|
79
|
+
print(f"Found {n_main_clusters} main clusters and {n_outliers} outliers (go in neighborhood 0)")
|
|
80
|
+
|
|
81
|
+
return clusters
|
|
41
82
|
|
|
83
|
+
def cluster_arrays(data_input, n_clusters=None, seed=42):
|
|
84
|
+
"""
|
|
85
|
+
Simple clustering of 1D arrays with key tracking and automatic cluster count detection.
|
|
86
|
+
|
|
87
|
+
Parameters:
|
|
88
|
+
-----------
|
|
89
|
+
data_input : dict or List[List[float]]
|
|
90
|
+
Dictionary {key: array} or list of arrays to cluster
|
|
91
|
+
n_clusters : int or None
|
|
92
|
+
How many groups you want. If None, will automatically determine optimal k
|
|
93
|
+
seed : int
|
|
94
|
+
Random seed for reproducibility
|
|
95
|
+
max_k : int
|
|
96
|
+
Maximum number of clusters to consider when auto-detecting (default: 10)
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
--------
|
|
100
|
+
list: [[key1, key2], [key3, key4, key5]] - List of clusters, each containing keys/indices
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
# Handle both dict and list inputs
|
|
104
|
+
if isinstance(data_input, dict):
|
|
105
|
+
keys = list(data_input.keys())
|
|
106
|
+
array_values = list(data_input.values())
|
|
107
|
+
else:
|
|
108
|
+
keys = list(range(len(data_input)))
|
|
109
|
+
array_values = data_input
|
|
110
|
+
|
|
111
|
+
# Convert to numpy
|
|
112
|
+
data = np.array(array_values)
|
|
113
|
+
|
|
114
|
+
# Auto-detect optimal number of clusters if not specified
|
|
115
|
+
if n_clusters is None:
|
|
116
|
+
n_clusters = _find_optimal_clusters(data, seed, max_k = (len(keys) // 3))
|
|
117
|
+
print(f"Auto-detected optimal number of clusters: {n_clusters}")
|
|
118
|
+
|
|
119
|
+
# Perform clustering
|
|
120
|
+
kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10)
|
|
121
|
+
labels = kmeans.fit_predict(data)
|
|
122
|
+
|
|
123
|
+
# Organize results into clusters - simple list of lists with keys only
|
|
42
124
|
clusters = [[] for _ in range(n_clusters)]
|
|
43
|
-
|
|
44
125
|
for i, label in enumerate(labels):
|
|
45
126
|
clusters[label].append(keys[i])
|
|
46
|
-
|
|
127
|
+
|
|
47
128
|
return clusters
|
|
129
|
+
|
|
130
|
+
def _find_optimal_clusters(data, seed, max_k):
|
|
131
|
+
"""
|
|
132
|
+
Find optimal number of clusters using Calinski-Harabasz index.
|
|
133
|
+
"""
|
|
134
|
+
n_samples = len(data)
|
|
135
|
+
|
|
136
|
+
# Need at least 2 samples to cluster
|
|
137
|
+
if n_samples < 2:
|
|
138
|
+
return 1
|
|
139
|
+
|
|
140
|
+
# Limit max_k to reasonable bounds
|
|
141
|
+
max_k = min(max_k, n_samples - 1, 20)
|
|
142
|
+
print(f"Max_k: {max_k}, n_samples: {n_samples}")
|
|
143
|
+
|
|
144
|
+
if max_k < 2:
|
|
145
|
+
return 1
|
|
146
|
+
|
|
147
|
+
# Use Calinski-Harabasz index to find optimal k
|
|
148
|
+
ch_scores = []
|
|
149
|
+
k_range = range(2, max_k + 1)
|
|
150
|
+
|
|
151
|
+
for k in k_range:
|
|
152
|
+
try:
|
|
153
|
+
kmeans = KMeans(n_clusters=k, random_state=seed, n_init=10)
|
|
154
|
+
labels = kmeans.fit_predict(data)
|
|
155
|
+
|
|
156
|
+
# Check if we got the expected number of clusters
|
|
157
|
+
if len(np.unique(labels)) == k:
|
|
158
|
+
score = calinski_harabasz_score(data, labels)
|
|
159
|
+
ch_scores.append(score)
|
|
160
|
+
else:
|
|
161
|
+
ch_scores.append(0) # Penalize solutions that didn't achieve k clusters
|
|
162
|
+
|
|
163
|
+
except Exception:
|
|
164
|
+
ch_scores.append(0)
|
|
48
165
|
|
|
49
|
-
|
|
166
|
+
# Find k with highest Calinski-Harabasz score
|
|
167
|
+
if ch_scores and max(ch_scores) > 0:
|
|
168
|
+
optimal_k = k_range[np.argmax(ch_scores)]
|
|
169
|
+
print(f"Using {optimal_k} neighborhoods")
|
|
170
|
+
return optimal_k
|
|
171
|
+
|
|
172
|
+
def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighborhood Heatmap",
|
|
173
|
+
center_at_one=False):
|
|
50
174
|
"""
|
|
51
175
|
Create a heatmap from a dictionary of numpy arrays.
|
|
52
176
|
|
|
@@ -60,6 +184,11 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
|
|
|
60
184
|
Figure size (width, height)
|
|
61
185
|
title : str, optional
|
|
62
186
|
Title for the heatmap
|
|
187
|
+
center_at_one : bool, optional
|
|
188
|
+
If True, uses a diverging colormap centered at 1 with nonlinear scaling:
|
|
189
|
+
- 0 to 1: blue to white (underrepresentation to normal)
|
|
190
|
+
- 1+: white to red (overrepresentation)
|
|
191
|
+
If False (default), uses standard white-to-red scaling from 0 to 1
|
|
63
192
|
|
|
64
193
|
Returns:
|
|
65
194
|
--------
|
|
@@ -67,7 +196,6 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
|
|
|
67
196
|
"""
|
|
68
197
|
|
|
69
198
|
data_dict = {k: unsorted_data_dict[k] for k in sorted(unsorted_data_dict.keys())}
|
|
70
|
-
|
|
71
199
|
# Convert dict to 2D array for heatmap
|
|
72
200
|
# Each row represents one key from the dict
|
|
73
201
|
keys = list(data_dict.keys())
|
|
@@ -76,8 +204,70 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
|
|
|
76
204
|
# Create the plot
|
|
77
205
|
fig, ax = plt.subplots(figsize=figsize)
|
|
78
206
|
|
|
79
|
-
|
|
80
|
-
|
|
207
|
+
if center_at_one:
|
|
208
|
+
# Custom colormap and scaling for center_at_one mode
|
|
209
|
+
# Find the actual data range
|
|
210
|
+
data_min = np.min(data_matrix)
|
|
211
|
+
data_max = np.max(data_matrix)
|
|
212
|
+
|
|
213
|
+
# Create a custom colormap: blue -> white -> red
|
|
214
|
+
colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7',
|
|
215
|
+
'#fddbc7', '#f4a582', '#d6604d', '#b2182b']
|
|
216
|
+
n_bins = 256
|
|
217
|
+
cmap = LinearSegmentedColormap.from_list('custom_diverging', colors, N=n_bins)
|
|
218
|
+
|
|
219
|
+
# Create nonlinear transformation
|
|
220
|
+
# Map 0->1 with more resolution, 1+ with less resolution
|
|
221
|
+
def transform_data(data):
|
|
222
|
+
transformed = np.zeros_like(data)
|
|
223
|
+
|
|
224
|
+
# For values 0 to 1: use square root for faster approach to middle
|
|
225
|
+
mask_low = data <= 1
|
|
226
|
+
transformed[mask_low] = 0.5 * np.sqrt(data[mask_low])
|
|
227
|
+
|
|
228
|
+
# For values > 1: use slower logarithmic scaling
|
|
229
|
+
mask_high = data > 1
|
|
230
|
+
if np.any(mask_high):
|
|
231
|
+
# Scale from 0.5 to 1.0 based on log of excess above 1
|
|
232
|
+
max_excess = np.max(data[mask_high] - 1) if np.any(mask_high) else 0
|
|
233
|
+
if max_excess > 0:
|
|
234
|
+
excess_normalized = np.log1p(data[mask_high] - 1) / np.log1p(max_excess)
|
|
235
|
+
transformed[mask_high] = 0.5 + 0.5 * excess_normalized
|
|
236
|
+
else:
|
|
237
|
+
transformed[mask_high] = 0.5
|
|
238
|
+
|
|
239
|
+
return transformed
|
|
240
|
+
|
|
241
|
+
# Transform the data for visualization
|
|
242
|
+
transformed_matrix = transform_data(data_matrix)
|
|
243
|
+
|
|
244
|
+
# Create heatmap with custom colormap
|
|
245
|
+
im = ax.imshow(transformed_matrix, cmap=cmap, aspect='auto', vmin=0, vmax=1)
|
|
246
|
+
|
|
247
|
+
# Create custom colorbar with original values
|
|
248
|
+
cbar = ax.figure.colorbar(im, ax=ax)
|
|
249
|
+
|
|
250
|
+
# Set colorbar ticks to show meaningful values
|
|
251
|
+
if data_max > 1:
|
|
252
|
+
tick_values = [0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0]
|
|
253
|
+
tick_values = [v for v in tick_values if data_min <= v <= data_max]
|
|
254
|
+
else:
|
|
255
|
+
tick_values = [0, 0.25, 0.5, 0.75, 1.0]
|
|
256
|
+
tick_values = [v for v in tick_values if data_min <= v <= data_max]
|
|
257
|
+
|
|
258
|
+
# Transform tick values for colorbar positioning
|
|
259
|
+
transformed_ticks = transform_data(np.array(tick_values))
|
|
260
|
+
cbar.set_ticks(transformed_ticks)
|
|
261
|
+
cbar.set_ticklabels([f'{v:.2f}' for v in tick_values])
|
|
262
|
+
cbar.ax.set_ylabel('Representation Ratio', rotation=-90, va="bottom")
|
|
263
|
+
|
|
264
|
+
else:
|
|
265
|
+
# Default behavior: white-to-red colormap
|
|
266
|
+
im = ax.imshow(data_matrix, cmap='Reds', aspect='auto', vmin=0, vmax=1)
|
|
267
|
+
|
|
268
|
+
# Add standard colorbar
|
|
269
|
+
cbar = ax.figure.colorbar(im, ax=ax)
|
|
270
|
+
cbar.ax.set_ylabel('Intensity', rotation=-90, va="bottom")
|
|
81
271
|
|
|
82
272
|
# Set ticks and labels
|
|
83
273
|
ax.set_xticks(np.arange(len(id_set)))
|
|
@@ -91,23 +281,32 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
|
|
|
91
281
|
# Add text annotations showing the actual values
|
|
92
282
|
for i in range(len(keys)):
|
|
93
283
|
for j in range(len(id_set)):
|
|
284
|
+
# Use original data values for annotations
|
|
94
285
|
text = ax.text(j, i, f'{data_matrix[i, j]:.3f}',
|
|
95
286
|
ha="center", va="center", color="black", fontsize=8)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
ret_dict = {}
|
|
290
|
+
|
|
291
|
+
for i, row in enumerate(data_matrix):
|
|
292
|
+
ret_dict[keys[i]] = row
|
|
100
293
|
|
|
101
294
|
# Set labels and title
|
|
102
|
-
|
|
295
|
+
if center_at_one:
|
|
296
|
+
ax.set_xlabel('Representation Factor of Node Type')
|
|
297
|
+
else:
|
|
298
|
+
ax.set_xlabel('Proportion of Node Type')
|
|
299
|
+
|
|
103
300
|
ax.set_ylabel('Neighborhood')
|
|
104
301
|
ax.set_title(title)
|
|
105
302
|
|
|
106
303
|
# Adjust layout to prevent label cutoff
|
|
107
304
|
plt.tight_layout()
|
|
108
|
-
|
|
109
305
|
plt.show()
|
|
110
306
|
|
|
307
|
+
return ret_dict
|
|
308
|
+
|
|
309
|
+
|
|
111
310
|
|
|
112
311
|
def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
|
|
113
312
|
class_names: Set[str],
|