nettracer3d 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nettracer3d might be problematic. Click here for more details.
- nettracer3d/cellpose_manager.py +161 -0
- nettracer3d/community_extractor.py +97 -20
- nettracer3d/neighborhoods.py +617 -81
- nettracer3d/nettracer.py +282 -74
- nettracer3d/nettracer_gui.py +860 -281
- nettracer3d/network_analysis.py +222 -230
- nettracer3d/node_draw.py +22 -12
- nettracer3d/proximity.py +254 -30
- nettracer3d-0.8.2.dist-info/METADATA +117 -0
- nettracer3d-0.8.2.dist-info/RECORD +24 -0
- nettracer3d-0.8.0.dist-info/METADATA +0 -83
- nettracer3d-0.8.0.dist-info/RECORD +0 -23
- {nettracer3d-0.8.0.dist-info → nettracer3d-0.8.2.dist-info}/WHEEL +0 -0
- {nettracer3d-0.8.0.dist-info → nettracer3d-0.8.2.dist-info}/entry_points.txt +0 -0
- {nettracer3d-0.8.0.dist-info → nettracer3d-0.8.2.dist-info}/licenses/LICENSE +0 -0
- {nettracer3d-0.8.0.dist-info → nettracer3d-0.8.2.dist-info}/top_level.txt +0 -0
nettracer3d/neighborhoods.py
CHANGED
|
@@ -1,52 +1,176 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from sklearn.cluster import KMeans
|
|
3
|
+
from sklearn.metrics import calinski_harabasz_score
|
|
3
4
|
import matplotlib.pyplot as plt
|
|
4
5
|
from typing import Dict, Set
|
|
5
6
|
import umap
|
|
6
|
-
|
|
7
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
8
|
+
from sklearn.cluster import DBSCAN
|
|
9
|
+
from sklearn.neighbors import NearestNeighbors
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
import os
|
|
10
13
|
os.environ['LOKY_MAX_CPU_COUNT'] = '4'
|
|
11
14
|
|
|
12
|
-
def
|
|
15
|
+
def cluster_arrays_dbscan(data_input, seed=42):
|
|
13
16
|
"""
|
|
14
|
-
Simple clustering of 1D arrays with
|
|
17
|
+
Simple DBSCAN clustering of 1D arrays with sensible defaults.
|
|
15
18
|
|
|
16
19
|
Parameters:
|
|
17
20
|
-----------
|
|
18
21
|
data_input : dict or List[List[float]]
|
|
19
22
|
Dictionary {key: array} or list of arrays to cluster
|
|
20
|
-
|
|
21
|
-
|
|
23
|
+
seed : int
|
|
24
|
+
Random seed for reproducibility (used for parameter estimation)
|
|
22
25
|
|
|
23
26
|
Returns:
|
|
24
27
|
--------
|
|
25
|
-
|
|
28
|
+
list: [[key1, key2], [key3, key4, key5]] - List of clusters, each containing keys/indices
|
|
29
|
+
Note: Outliers are excluded from the output
|
|
26
30
|
"""
|
|
27
31
|
|
|
28
32
|
# Handle both dict and list inputs
|
|
29
33
|
if isinstance(data_input, dict):
|
|
30
34
|
keys = list(data_input.keys())
|
|
31
|
-
array_values = list(data_input.values())
|
|
35
|
+
array_values = list(data_input.values())
|
|
32
36
|
else:
|
|
33
|
-
keys = list(range(len(data_input)))
|
|
37
|
+
keys = list(range(len(data_input)))
|
|
34
38
|
array_values = data_input
|
|
35
39
|
|
|
36
|
-
# Convert to numpy
|
|
40
|
+
# Convert to numpy
|
|
37
41
|
data = np.array(array_values)
|
|
38
|
-
|
|
39
|
-
|
|
42
|
+
n_samples = len(data)
|
|
43
|
+
|
|
44
|
+
# Simple heuristics for DBSCAN parameters
|
|
45
|
+
min_samples = max(3, int(np.sqrt(n_samples) * 0.2)) # Roughly sqrt(n)/5, minimum 3
|
|
46
|
+
|
|
47
|
+
# Estimate eps using 4th nearest neighbor distance (common heuristic)
|
|
48
|
+
k = min(4, n_samples - 1)
|
|
49
|
+
if k > 0:
|
|
50
|
+
nbrs = NearestNeighbors(n_neighbors=k + 1)
|
|
51
|
+
nbrs.fit(data)
|
|
52
|
+
distances, _ = nbrs.kneighbors(data)
|
|
53
|
+
# Use 80th percentile of k-nearest distances as eps
|
|
54
|
+
eps = np.percentile(distances[:, k], 80)
|
|
55
|
+
else:
|
|
56
|
+
eps = 0.1 # fallback
|
|
57
|
+
|
|
58
|
+
print(f"Using DBSCAN with eps={eps:.4f}, min_samples={min_samples}")
|
|
59
|
+
|
|
60
|
+
# Perform DBSCAN clustering
|
|
61
|
+
dbscan = DBSCAN(eps=eps, min_samples=min_samples)
|
|
62
|
+
labels = dbscan.fit_predict(data)
|
|
63
|
+
|
|
64
|
+
# Organize results into clusters (excluding outliers)
|
|
65
|
+
clusters = []
|
|
40
66
|
|
|
67
|
+
# Add only the main clusters (non-noise points)
|
|
68
|
+
unique_labels = np.unique(labels)
|
|
69
|
+
main_clusters = [label for label in unique_labels if label != -1]
|
|
70
|
+
|
|
71
|
+
for label in main_clusters:
|
|
72
|
+
cluster_indices = np.where(labels == label)[0]
|
|
73
|
+
cluster_keys = [keys[i] for i in cluster_indices]
|
|
74
|
+
clusters.append(cluster_keys)
|
|
75
|
+
|
|
76
|
+
n_main_clusters = len(main_clusters)
|
|
77
|
+
n_outliers = np.sum(labels == -1)
|
|
78
|
+
|
|
79
|
+
print(f"Found {n_main_clusters} main clusters and {n_outliers} outliers (go in neighborhood 0)")
|
|
80
|
+
|
|
81
|
+
return clusters
|
|
41
82
|
|
|
83
|
+
def cluster_arrays(data_input, n_clusters=None, seed=42):
|
|
84
|
+
"""
|
|
85
|
+
Simple clustering of 1D arrays with key tracking and automatic cluster count detection.
|
|
86
|
+
|
|
87
|
+
Parameters:
|
|
88
|
+
-----------
|
|
89
|
+
data_input : dict or List[List[float]]
|
|
90
|
+
Dictionary {key: array} or list of arrays to cluster
|
|
91
|
+
n_clusters : int or None
|
|
92
|
+
How many groups you want. If None, will automatically determine optimal k
|
|
93
|
+
seed : int
|
|
94
|
+
Random seed for reproducibility
|
|
95
|
+
max_k : int
|
|
96
|
+
Maximum number of clusters to consider when auto-detecting (default: 10)
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
--------
|
|
100
|
+
list: [[key1, key2], [key3, key4, key5]] - List of clusters, each containing keys/indices
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
# Handle both dict and list inputs
|
|
104
|
+
if isinstance(data_input, dict):
|
|
105
|
+
keys = list(data_input.keys())
|
|
106
|
+
array_values = list(data_input.values())
|
|
107
|
+
else:
|
|
108
|
+
keys = list(range(len(data_input)))
|
|
109
|
+
array_values = data_input
|
|
110
|
+
|
|
111
|
+
# Convert to numpy
|
|
112
|
+
data = np.array(array_values)
|
|
113
|
+
|
|
114
|
+
# Auto-detect optimal number of clusters if not specified
|
|
115
|
+
if n_clusters is None:
|
|
116
|
+
n_clusters = _find_optimal_clusters(data, seed, max_k = (len(keys) // 3))
|
|
117
|
+
print(f"Auto-detected optimal number of clusters: {n_clusters}")
|
|
118
|
+
|
|
119
|
+
# Perform clustering
|
|
120
|
+
kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10)
|
|
121
|
+
labels = kmeans.fit_predict(data)
|
|
122
|
+
|
|
123
|
+
# Organize results into clusters - simple list of lists with keys only
|
|
42
124
|
clusters = [[] for _ in range(n_clusters)]
|
|
43
|
-
|
|
44
125
|
for i, label in enumerate(labels):
|
|
45
126
|
clusters[label].append(keys[i])
|
|
46
|
-
|
|
127
|
+
|
|
47
128
|
return clusters
|
|
129
|
+
|
|
130
|
+
def _find_optimal_clusters(data, seed, max_k):
|
|
131
|
+
"""
|
|
132
|
+
Find optimal number of clusters using Calinski-Harabasz index.
|
|
133
|
+
"""
|
|
134
|
+
n_samples = len(data)
|
|
135
|
+
|
|
136
|
+
# Need at least 2 samples to cluster
|
|
137
|
+
if n_samples < 2:
|
|
138
|
+
return 1
|
|
139
|
+
|
|
140
|
+
# Limit max_k to reasonable bounds
|
|
141
|
+
max_k = min(max_k, n_samples - 1, 20)
|
|
142
|
+
print(f"Max_k: {max_k}, n_samples: {n_samples}")
|
|
48
143
|
|
|
49
|
-
|
|
144
|
+
if max_k < 2:
|
|
145
|
+
return 1
|
|
146
|
+
|
|
147
|
+
# Use Calinski-Harabasz index to find optimal k
|
|
148
|
+
ch_scores = []
|
|
149
|
+
k_range = range(2, max_k + 1)
|
|
150
|
+
|
|
151
|
+
for k in k_range:
|
|
152
|
+
try:
|
|
153
|
+
kmeans = KMeans(n_clusters=k, random_state=seed, n_init=10)
|
|
154
|
+
labels = kmeans.fit_predict(data)
|
|
155
|
+
|
|
156
|
+
# Check if we got the expected number of clusters
|
|
157
|
+
if len(np.unique(labels)) == k:
|
|
158
|
+
score = calinski_harabasz_score(data, labels)
|
|
159
|
+
ch_scores.append(score)
|
|
160
|
+
else:
|
|
161
|
+
ch_scores.append(0) # Penalize solutions that didn't achieve k clusters
|
|
162
|
+
|
|
163
|
+
except Exception:
|
|
164
|
+
ch_scores.append(0)
|
|
165
|
+
|
|
166
|
+
# Find k with highest Calinski-Harabasz score
|
|
167
|
+
if ch_scores and max(ch_scores) > 0:
|
|
168
|
+
optimal_k = k_range[np.argmax(ch_scores)]
|
|
169
|
+
print(f"Using {optimal_k} neighborhoods")
|
|
170
|
+
return optimal_k
|
|
171
|
+
|
|
172
|
+
def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighborhood Heatmap",
|
|
173
|
+
center_at_one=False):
|
|
50
174
|
"""
|
|
51
175
|
Create a heatmap from a dictionary of numpy arrays.
|
|
52
176
|
|
|
@@ -60,6 +184,11 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
|
|
|
60
184
|
Figure size (width, height)
|
|
61
185
|
title : str, optional
|
|
62
186
|
Title for the heatmap
|
|
187
|
+
center_at_one : bool, optional
|
|
188
|
+
If True, uses a diverging colormap centered at 1 with nonlinear scaling:
|
|
189
|
+
- 0 to 1: blue to white (underrepresentation to normal)
|
|
190
|
+
- 1+: white to red (overrepresentation)
|
|
191
|
+
If False (default), uses standard white-to-red scaling from 0 to 1
|
|
63
192
|
|
|
64
193
|
Returns:
|
|
65
194
|
--------
|
|
@@ -67,7 +196,6 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
|
|
|
67
196
|
"""
|
|
68
197
|
|
|
69
198
|
data_dict = {k: unsorted_data_dict[k] for k in sorted(unsorted_data_dict.keys())}
|
|
70
|
-
|
|
71
199
|
# Convert dict to 2D array for heatmap
|
|
72
200
|
# Each row represents one key from the dict
|
|
73
201
|
keys = list(data_dict.keys())
|
|
@@ -76,8 +204,70 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
|
|
|
76
204
|
# Create the plot
|
|
77
205
|
fig, ax = plt.subplots(figsize=figsize)
|
|
78
206
|
|
|
79
|
-
|
|
80
|
-
|
|
207
|
+
if center_at_one:
|
|
208
|
+
# Custom colormap and scaling for center_at_one mode
|
|
209
|
+
# Find the actual data range
|
|
210
|
+
data_min = np.min(data_matrix)
|
|
211
|
+
data_max = np.max(data_matrix)
|
|
212
|
+
|
|
213
|
+
# Create a custom colormap: blue -> white -> red
|
|
214
|
+
colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7',
|
|
215
|
+
'#fddbc7', '#f4a582', '#d6604d', '#b2182b']
|
|
216
|
+
n_bins = 256
|
|
217
|
+
cmap = LinearSegmentedColormap.from_list('custom_diverging', colors, N=n_bins)
|
|
218
|
+
|
|
219
|
+
# Create nonlinear transformation
|
|
220
|
+
# Map 0->1 with more resolution, 1+ with less resolution
|
|
221
|
+
def transform_data(data):
|
|
222
|
+
transformed = np.zeros_like(data)
|
|
223
|
+
|
|
224
|
+
# For values 0 to 1: use square root for faster approach to middle
|
|
225
|
+
mask_low = data <= 1
|
|
226
|
+
transformed[mask_low] = 0.5 * np.sqrt(data[mask_low])
|
|
227
|
+
|
|
228
|
+
# For values > 1: use slower logarithmic scaling
|
|
229
|
+
mask_high = data > 1
|
|
230
|
+
if np.any(mask_high):
|
|
231
|
+
# Scale from 0.5 to 1.0 based on log of excess above 1
|
|
232
|
+
max_excess = np.max(data[mask_high] - 1) if np.any(mask_high) else 0
|
|
233
|
+
if max_excess > 0:
|
|
234
|
+
excess_normalized = np.log1p(data[mask_high] - 1) / np.log1p(max_excess)
|
|
235
|
+
transformed[mask_high] = 0.5 + 0.5 * excess_normalized
|
|
236
|
+
else:
|
|
237
|
+
transformed[mask_high] = 0.5
|
|
238
|
+
|
|
239
|
+
return transformed
|
|
240
|
+
|
|
241
|
+
# Transform the data for visualization
|
|
242
|
+
transformed_matrix = transform_data(data_matrix)
|
|
243
|
+
|
|
244
|
+
# Create heatmap with custom colormap
|
|
245
|
+
im = ax.imshow(transformed_matrix, cmap=cmap, aspect='auto', vmin=0, vmax=1)
|
|
246
|
+
|
|
247
|
+
# Create custom colorbar with original values
|
|
248
|
+
cbar = ax.figure.colorbar(im, ax=ax)
|
|
249
|
+
|
|
250
|
+
# Set colorbar ticks to show meaningful values
|
|
251
|
+
if data_max > 1:
|
|
252
|
+
tick_values = [0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0]
|
|
253
|
+
tick_values = [v for v in tick_values if data_min <= v <= data_max]
|
|
254
|
+
else:
|
|
255
|
+
tick_values = [0, 0.25, 0.5, 0.75, 1.0]
|
|
256
|
+
tick_values = [v for v in tick_values if data_min <= v <= data_max]
|
|
257
|
+
|
|
258
|
+
# Transform tick values for colorbar positioning
|
|
259
|
+
transformed_ticks = transform_data(np.array(tick_values))
|
|
260
|
+
cbar.set_ticks(transformed_ticks)
|
|
261
|
+
cbar.set_ticklabels([f'{v:.2f}' for v in tick_values])
|
|
262
|
+
cbar.ax.set_ylabel('Representation Ratio', rotation=-90, va="bottom")
|
|
263
|
+
|
|
264
|
+
else:
|
|
265
|
+
# Default behavior: white-to-red colormap
|
|
266
|
+
im = ax.imshow(data_matrix, cmap='Reds', aspect='auto', vmin=0, vmax=1)
|
|
267
|
+
|
|
268
|
+
# Add standard colorbar
|
|
269
|
+
cbar = ax.figure.colorbar(im, ax=ax)
|
|
270
|
+
cbar.ax.set_ylabel('Intensity', rotation=-90, va="bottom")
|
|
81
271
|
|
|
82
272
|
# Set ticks and labels
|
|
83
273
|
ax.set_xticks(np.arange(len(id_set)))
|
|
@@ -91,23 +281,32 @@ def plot_dict_heatmap(unsorted_data_dict, id_set, figsize=(12, 8), title="Neighb
|
|
|
91
281
|
# Add text annotations showing the actual values
|
|
92
282
|
for i in range(len(keys)):
|
|
93
283
|
for j in range(len(id_set)):
|
|
284
|
+
# Use original data values for annotations
|
|
94
285
|
text = ax.text(j, i, f'{data_matrix[i, j]:.3f}',
|
|
95
286
|
ha="center", va="center", color="black", fontsize=8)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
ret_dict = {}
|
|
290
|
+
|
|
291
|
+
for i, row in enumerate(data_matrix):
|
|
292
|
+
ret_dict[keys[i]] = row
|
|
100
293
|
|
|
101
294
|
# Set labels and title
|
|
102
|
-
|
|
295
|
+
if center_at_one:
|
|
296
|
+
ax.set_xlabel('Representation Factor of Node Type')
|
|
297
|
+
else:
|
|
298
|
+
ax.set_xlabel('Proportion of Node Type')
|
|
299
|
+
|
|
103
300
|
ax.set_ylabel('Neighborhood')
|
|
104
301
|
ax.set_title(title)
|
|
105
302
|
|
|
106
303
|
# Adjust layout to prevent label cutoff
|
|
107
304
|
plt.tight_layout()
|
|
108
|
-
|
|
109
305
|
plt.show()
|
|
110
306
|
|
|
307
|
+
return ret_dict
|
|
308
|
+
|
|
309
|
+
|
|
111
310
|
|
|
112
311
|
def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
|
|
113
312
|
class_names: Set[str],
|
|
@@ -202,10 +401,12 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
|
|
|
202
401
|
|
|
203
402
|
return embedding
|
|
204
403
|
|
|
205
|
-
def create_community_heatmap(community_intensity, node_community, node_centroids, is_3d=True,
|
|
206
|
-
figsize=(12, 8), point_size=50, alpha=0.7,
|
|
404
|
+
def create_community_heatmap(community_intensity, node_community, node_centroids, shape=None, is_3d=True,
|
|
405
|
+
labeled_array=None, figsize=(12, 8), point_size=50, alpha=0.7,
|
|
406
|
+
colorbar_label="Community Intensity", title="Community Intensity Heatmap"):
|
|
207
407
|
"""
|
|
208
408
|
Create a 2D or 3D heatmap showing nodes colored by their community intensities.
|
|
409
|
+
Can return either matplotlib plot or numpy RGB array for overlay purposes.
|
|
209
410
|
|
|
210
411
|
Parameters:
|
|
211
412
|
-----------
|
|
@@ -220,25 +421,39 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
|
|
|
220
421
|
Dictionary mapping node IDs to centroids
|
|
221
422
|
Centroids should be [Z, Y, X] for 3D or [1, Y, X] for pseudo-3D
|
|
222
423
|
|
|
424
|
+
shape : tuple, optional
|
|
425
|
+
Shape of the output array in [Z, Y, X] format
|
|
426
|
+
If None, will be inferred from node_centroids
|
|
427
|
+
|
|
223
428
|
is_3d : bool, default=True
|
|
224
|
-
If True, create 3D plot. If False, create 2D plot.
|
|
429
|
+
If True, create 3D plot/array. If False, create 2D plot/array.
|
|
430
|
+
|
|
431
|
+
labeled_array : np.ndarray, optional
|
|
432
|
+
If provided, returns numpy RGB array overlay using this labeled array template
|
|
433
|
+
instead of matplotlib plot. Uses lookup table approach for efficiency.
|
|
225
434
|
|
|
226
435
|
figsize : tuple, default=(12, 8)
|
|
227
|
-
Figure size (width, height)
|
|
436
|
+
Figure size (width, height) - only used for matplotlib
|
|
228
437
|
|
|
229
438
|
point_size : int, default=50
|
|
230
|
-
Size of scatter plot points
|
|
439
|
+
Size of scatter plot points - only used for matplotlib
|
|
231
440
|
|
|
232
441
|
alpha : float, default=0.7
|
|
233
|
-
Transparency of points (0-1)
|
|
442
|
+
Transparency of points (0-1) - only used for matplotlib
|
|
234
443
|
|
|
235
444
|
colorbar_label : str, default="Community Intensity"
|
|
236
|
-
Label for the colorbar
|
|
445
|
+
Label for the colorbar - only used for matplotlib
|
|
446
|
+
|
|
447
|
+
title : str, default="Community Intensity Heatmap"
|
|
448
|
+
Title for the plot
|
|
237
449
|
|
|
238
450
|
Returns:
|
|
239
451
|
--------
|
|
240
|
-
fig, ax
|
|
452
|
+
If labeled_array is None: fig, ax (matplotlib figure and axis objects)
|
|
453
|
+
If labeled_array is provided: np.ndarray (RGB heatmap array with community intensity colors)
|
|
241
454
|
"""
|
|
455
|
+
import numpy as np
|
|
456
|
+
import matplotlib.pyplot as plt
|
|
242
457
|
|
|
243
458
|
# Convert numpy int64 keys to regular ints for consistency
|
|
244
459
|
community_intensity_clean = {}
|
|
@@ -254,6 +469,10 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
|
|
|
254
469
|
|
|
255
470
|
for node_id, centroid in node_centroids.items():
|
|
256
471
|
try:
|
|
472
|
+
# Convert node_id to regular int if it's numpy
|
|
473
|
+
if hasattr(node_id, 'item'):
|
|
474
|
+
node_id = node_id.item()
|
|
475
|
+
|
|
257
476
|
# Get community for this node
|
|
258
477
|
community_id = node_community[node_id]
|
|
259
478
|
|
|
@@ -266,74 +485,392 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
|
|
|
266
485
|
|
|
267
486
|
node_positions.append(centroid)
|
|
268
487
|
node_intensities.append(intensity)
|
|
269
|
-
except:
|
|
488
|
+
except KeyError:
|
|
489
|
+
# Skip nodes that don't have community assignments or community intensities
|
|
270
490
|
pass
|
|
271
491
|
|
|
272
492
|
# Convert to numpy arrays
|
|
273
493
|
positions = np.array(node_positions)
|
|
274
494
|
intensities = np.array(node_intensities)
|
|
275
495
|
|
|
276
|
-
# Determine
|
|
277
|
-
|
|
278
|
-
|
|
496
|
+
# Determine shape if not provided
|
|
497
|
+
if shape is None:
|
|
498
|
+
if len(positions) > 0:
|
|
499
|
+
max_coords = np.max(positions, axis=0).astype(int)
|
|
500
|
+
shape = tuple(max_coords + 1)
|
|
501
|
+
else:
|
|
502
|
+
shape = (100, 100, 100) if is_3d else (1, 100, 100)
|
|
279
503
|
|
|
280
|
-
#
|
|
281
|
-
|
|
504
|
+
# Determine min and max intensities for scaling
|
|
505
|
+
if len(intensities) > 0:
|
|
506
|
+
min_intensity = np.min(intensities)
|
|
507
|
+
max_intensity = np.max(intensities)
|
|
508
|
+
else:
|
|
509
|
+
min_intensity, max_intensity = 0, 1
|
|
282
510
|
|
|
283
|
-
if
|
|
284
|
-
#
|
|
285
|
-
ax = fig.add_subplot(111, projection='3d')
|
|
511
|
+
if labeled_array is not None:
|
|
512
|
+
# Create numpy RGB array output using labeled array and lookup table approach
|
|
286
513
|
|
|
287
|
-
#
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
514
|
+
# Create mapping from node ID to community intensity value
|
|
515
|
+
node_to_community_intensity = {}
|
|
516
|
+
for node_id, centroid in node_centroids.items():
|
|
517
|
+
# Convert node_id to regular int if it's numpy
|
|
518
|
+
if hasattr(node_id, 'item'):
|
|
519
|
+
node_id = node_id.item()
|
|
520
|
+
|
|
521
|
+
try:
|
|
522
|
+
# Get community for this node
|
|
523
|
+
community_id = node_community[node_id]
|
|
524
|
+
|
|
525
|
+
# Convert community_id to regular int if it's numpy
|
|
526
|
+
if hasattr(community_id, 'item'):
|
|
527
|
+
community_id = community_id.item()
|
|
528
|
+
|
|
529
|
+
# Get intensity for this community
|
|
530
|
+
if community_id in community_intensity_clean:
|
|
531
|
+
node_to_community_intensity[node_id] = community_intensity_clean[community_id]
|
|
532
|
+
except KeyError:
|
|
533
|
+
# Skip nodes that don't have community assignments
|
|
534
|
+
pass
|
|
535
|
+
|
|
536
|
+
# Create colormap function (RdBu_r - red for high, blue for low, yellow/white for middle)
|
|
537
|
+
def intensity_to_rgb(intensity, min_val, max_val):
|
|
538
|
+
"""Convert intensity value to RGB using RdBu_r colormap logic"""
|
|
539
|
+
if max_val == min_val:
|
|
540
|
+
# All same value, use neutral color
|
|
541
|
+
return np.array([255, 255, 255], dtype=np.uint8) # White
|
|
542
|
+
|
|
543
|
+
# Normalize to -1 to 1 range (like RdBu_r colormap)
|
|
544
|
+
normalized = 2 * (intensity - min_val) / (max_val - min_val) - 1
|
|
545
|
+
normalized = np.clip(normalized, -1, 1)
|
|
546
|
+
|
|
547
|
+
if normalized > 0:
|
|
548
|
+
# Positive values: white to red
|
|
549
|
+
r = 255
|
|
550
|
+
g = int(255 * (1 - normalized))
|
|
551
|
+
b = int(255 * (1 - normalized))
|
|
552
|
+
else:
|
|
553
|
+
# Negative values: white to blue
|
|
554
|
+
r = int(255 * (1 + normalized))
|
|
555
|
+
g = int(255 * (1 + normalized))
|
|
556
|
+
b = 255
|
|
557
|
+
|
|
558
|
+
return np.array([r, g, b], dtype=np.uint8)
|
|
291
559
|
|
|
292
|
-
# Create
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
560
|
+
# Create lookup table for RGB colors
|
|
561
|
+
max_label = max(max(labeled_array.flat), max(node_to_community_intensity.keys()) if node_to_community_intensity else 0)
|
|
562
|
+
color_lut = np.zeros((max_label + 1, 3), dtype=np.uint8) # Default to black (0,0,0)
|
|
296
563
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
564
|
+
# Fill lookup table with RGB colors based on community intensity
|
|
565
|
+
for node_id, intensity in node_to_community_intensity.items():
|
|
566
|
+
rgb_color = intensity_to_rgb(intensity, min_intensity, max_intensity)
|
|
567
|
+
color_lut[int(node_id)] = rgb_color
|
|
301
568
|
|
|
569
|
+
# Apply lookup table to labeled array - single vectorized operation
|
|
570
|
+
if is_3d:
|
|
571
|
+
# Return full 3D RGB array [Z, Y, X, 3]
|
|
572
|
+
heatmap_array = color_lut[labeled_array]
|
|
573
|
+
else:
|
|
574
|
+
# Return 2D RGB array
|
|
575
|
+
if labeled_array.ndim == 3:
|
|
576
|
+
# Take middle slice for 2D representation
|
|
577
|
+
middle_slice = labeled_array.shape[0] // 2
|
|
578
|
+
heatmap_array = color_lut[labeled_array[middle_slice]]
|
|
579
|
+
else:
|
|
580
|
+
# Already 2D
|
|
581
|
+
heatmap_array = color_lut[labeled_array]
|
|
582
|
+
|
|
583
|
+
return heatmap_array
|
|
584
|
+
|
|
302
585
|
else:
|
|
303
|
-
#
|
|
304
|
-
|
|
586
|
+
# Create matplotlib plot
|
|
587
|
+
fig = plt.figure(figsize=figsize)
|
|
305
588
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
589
|
+
if is_3d:
|
|
590
|
+
# 3D plot
|
|
591
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
592
|
+
|
|
593
|
+
# Extract coordinates (assuming [Z, Y, X] format)
|
|
594
|
+
z_coords = positions[:, 0]
|
|
595
|
+
y_coords = positions[:, 1]
|
|
596
|
+
x_coords = positions[:, 2]
|
|
597
|
+
|
|
598
|
+
# Create scatter plot
|
|
599
|
+
scatter = ax.scatter(x_coords, y_coords, z_coords,
|
|
600
|
+
c=intensities, s=point_size, alpha=alpha,
|
|
601
|
+
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
602
|
+
|
|
603
|
+
ax.set_xlabel('X')
|
|
604
|
+
ax.set_ylabel('Y')
|
|
605
|
+
ax.set_zlabel('Z')
|
|
606
|
+
ax.set_title(f'{title}')
|
|
607
|
+
|
|
608
|
+
# Set axis limits based on shape
|
|
609
|
+
ax.set_xlim(0, shape[2])
|
|
610
|
+
ax.set_ylim(0, shape[1])
|
|
611
|
+
ax.set_zlim(0, shape[0])
|
|
612
|
+
|
|
613
|
+
else:
|
|
614
|
+
# 2D plot (using Y, X coordinates, ignoring Z/first dimension)
|
|
615
|
+
ax = fig.add_subplot(111)
|
|
616
|
+
|
|
617
|
+
# Extract Y, X coordinates
|
|
618
|
+
y_coords = positions[:, 1]
|
|
619
|
+
x_coords = positions[:, 2]
|
|
620
|
+
|
|
621
|
+
# Create scatter plot
|
|
622
|
+
scatter = ax.scatter(x_coords, y_coords,
|
|
623
|
+
c=intensities, s=point_size, alpha=alpha,
|
|
624
|
+
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
625
|
+
|
|
626
|
+
ax.set_xlabel('X')
|
|
627
|
+
ax.set_ylabel('Y')
|
|
628
|
+
ax.set_title(f'{title}')
|
|
629
|
+
ax.grid(True, alpha=0.3)
|
|
630
|
+
|
|
631
|
+
# Set axis limits based on shape
|
|
632
|
+
ax.set_xlim(0, shape[2])
|
|
633
|
+
ax.set_ylim(0, shape[1])
|
|
634
|
+
|
|
635
|
+
# Set origin to top-left (invert Y-axis)
|
|
636
|
+
ax.invert_yaxis()
|
|
309
637
|
|
|
310
|
-
#
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
638
|
+
# Add colorbar
|
|
639
|
+
cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
|
|
640
|
+
cbar.set_label(colorbar_label)
|
|
314
641
|
|
|
315
|
-
|
|
316
|
-
ax.
|
|
317
|
-
|
|
318
|
-
ax.
|
|
642
|
+
# Add text annotations for min/max values
|
|
643
|
+
cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
|
|
644
|
+
transform=cbar.ax.transAxes, va='bottom')
|
|
645
|
+
cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
|
|
646
|
+
transform=cbar.ax.transAxes, va='top')
|
|
319
647
|
|
|
320
|
-
|
|
321
|
-
|
|
648
|
+
plt.tight_layout()
|
|
649
|
+
plt.show()
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def create_node_heatmap(node_intensity, node_centroids, shape=None, is_3d=True,
|
|
653
|
+
labeled_array=None, figsize=(12, 8), point_size=50, alpha=0.7,
|
|
654
|
+
colorbar_label="Node Intensity", title="Node Clustering Intensity Heatmap"):
|
|
655
|
+
"""
|
|
656
|
+
Create a 2D or 3D heatmap showing nodes colored by their individual intensities.
|
|
657
|
+
Can return either matplotlib plot or numpy array for overlay purposes.
|
|
322
658
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
659
|
+
Parameters:
|
|
660
|
+
-----------
|
|
661
|
+
node_intensity : dict
|
|
662
|
+
Dictionary mapping node IDs to intensity values
|
|
663
|
+
Keys can be np.int64 or regular ints
|
|
664
|
+
|
|
665
|
+
node_centroids : dict
|
|
666
|
+
Dictionary mapping node IDs to centroids
|
|
667
|
+
Centroids should be [Z, Y, X] for 3D or [1, Y, X] for pseudo-3D
|
|
668
|
+
|
|
669
|
+
shape : tuple, optional
|
|
670
|
+
Shape of the output array in [Z, Y, X] format
|
|
671
|
+
If None, will be inferred from node_centroids
|
|
672
|
+
|
|
673
|
+
is_3d : bool, default=True
|
|
674
|
+
If True, create 3D plot/array. If False, create 2D plot/array.
|
|
675
|
+
|
|
676
|
+
labeled_array : np.ndarray, optional
|
|
677
|
+
If provided, returns numpy array overlay using this labeled array template
|
|
678
|
+
instead of matplotlib plot. Uses lookup table approach for efficiency.
|
|
679
|
+
|
|
680
|
+
figsize : tuple, default=(12, 8)
|
|
681
|
+
Figure size (width, height) - only used for matplotlib
|
|
682
|
+
|
|
683
|
+
point_size : int, default=50
|
|
684
|
+
Size of scatter plot points - only used for matplotlib
|
|
685
|
+
|
|
686
|
+
alpha : float, default=0.7
|
|
687
|
+
Transparency of points (0-1) - only used for matplotlib
|
|
688
|
+
|
|
689
|
+
colorbar_label : str, default="Node Intensity"
|
|
690
|
+
Label for the colorbar - only used for matplotlib
|
|
691
|
+
|
|
692
|
+
Returns:
|
|
693
|
+
--------
|
|
694
|
+
If labeled_array is None: fig, ax (matplotlib figure and axis objects)
|
|
695
|
+
If labeled_array is provided: np.ndarray (heatmap array with intensity values)
|
|
696
|
+
"""
|
|
697
|
+
import numpy as np
|
|
698
|
+
import matplotlib.pyplot as plt
|
|
326
699
|
|
|
327
|
-
#
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
700
|
+
# Convert numpy int64 keys to regular ints for consistency
|
|
701
|
+
node_intensity_clean = {}
|
|
702
|
+
for k, v in node_intensity.items():
|
|
703
|
+
if hasattr(k, 'item'): # numpy scalar
|
|
704
|
+
node_intensity_clean[k.item()] = v
|
|
705
|
+
else:
|
|
706
|
+
node_intensity_clean[k] = v
|
|
332
707
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
708
|
+
# Prepare data for plotting/array creation
|
|
709
|
+
node_positions = []
|
|
710
|
+
node_intensities = []
|
|
711
|
+
|
|
712
|
+
for node_id, centroid in node_centroids.items():
|
|
713
|
+
try:
|
|
714
|
+
# Convert node_id to regular int if it's numpy
|
|
715
|
+
if hasattr(node_id, 'item'):
|
|
716
|
+
node_id = node_id.item()
|
|
717
|
+
|
|
718
|
+
# Get intensity for this node
|
|
719
|
+
intensity = node_intensity_clean[node_id]
|
|
720
|
+
|
|
721
|
+
node_positions.append(centroid)
|
|
722
|
+
node_intensities.append(intensity)
|
|
723
|
+
except KeyError:
|
|
724
|
+
# Skip nodes that don't have intensity values
|
|
725
|
+
pass
|
|
726
|
+
|
|
727
|
+
# Convert to numpy arrays
|
|
728
|
+
positions = np.array(node_positions)
|
|
729
|
+
intensities = np.array(node_intensities)
|
|
730
|
+
|
|
731
|
+
# Determine shape if not provided
|
|
732
|
+
if shape is None:
|
|
733
|
+
if len(positions) > 0:
|
|
734
|
+
max_coords = np.max(positions, axis=0).astype(int)
|
|
735
|
+
shape = tuple(max_coords + 1)
|
|
736
|
+
else:
|
|
737
|
+
shape = (100, 100, 100) if is_3d else (1, 100, 100)
|
|
738
|
+
|
|
739
|
+
# Determine min and max intensities for scaling
|
|
740
|
+
if len(intensities) > 0:
|
|
741
|
+
min_intensity = np.min(intensities)
|
|
742
|
+
max_intensity = np.max(intensities)
|
|
743
|
+
else:
|
|
744
|
+
min_intensity, max_intensity = 0, 1
|
|
745
|
+
|
|
746
|
+
if labeled_array is not None:
|
|
747
|
+
# Create numpy RGB array output using labeled array and lookup table approach
|
|
748
|
+
|
|
749
|
+
# Create mapping from node ID to intensity value (keep original float values)
|
|
750
|
+
node_to_intensity = {}
|
|
751
|
+
for node_id, centroid in node_centroids.items():
|
|
752
|
+
# Convert node_id to regular int if it's numpy
|
|
753
|
+
if hasattr(node_id, 'item'):
|
|
754
|
+
node_id = node_id.item()
|
|
755
|
+
|
|
756
|
+
# Only include nodes that have intensity values
|
|
757
|
+
if node_id in node_intensity_clean:
|
|
758
|
+
node_to_intensity[node_id] = node_intensity_clean[node_id]
|
|
759
|
+
|
|
760
|
+
# Create colormap function (RdBu_r - red for high, blue for low, yellow/white for middle)
|
|
761
|
+
def intensity_to_rgb(intensity, min_val, max_val):
|
|
762
|
+
"""Convert intensity value to RGB using RdBu_r colormap logic"""
|
|
763
|
+
if max_val == min_val:
|
|
764
|
+
# All same value, use neutral color
|
|
765
|
+
return np.array([255, 255, 255], dtype=np.uint8) # White
|
|
766
|
+
|
|
767
|
+
# Normalize to -1 to 1 range (like RdBu_r colormap)
|
|
768
|
+
normalized = 2 * (intensity - min_val) / (max_val - min_val) - 1
|
|
769
|
+
normalized = np.clip(normalized, -1, 1)
|
|
770
|
+
|
|
771
|
+
if normalized > 0:
|
|
772
|
+
# Positive values: white to red
|
|
773
|
+
r = 255
|
|
774
|
+
g = int(255 * (1 - normalized))
|
|
775
|
+
b = int(255 * (1 - normalized))
|
|
776
|
+
else:
|
|
777
|
+
# Negative values: white to blue
|
|
778
|
+
r = int(255 * (1 + normalized))
|
|
779
|
+
g = int(255 * (1 + normalized))
|
|
780
|
+
b = 255
|
|
781
|
+
|
|
782
|
+
return np.array([r, g, b], dtype=np.uint8)
|
|
783
|
+
|
|
784
|
+
# Create lookup table for RGB colors
|
|
785
|
+
max_label = max(max(labeled_array.flat), max(node_to_intensity.keys()) if node_to_intensity else 0)
|
|
786
|
+
color_lut = np.zeros((max_label + 1, 3), dtype=np.uint8) # Default to black (0,0,0)
|
|
787
|
+
|
|
788
|
+
# Fill lookup table with RGB colors based on intensity
|
|
789
|
+
for node_id, intensity in node_to_intensity.items():
|
|
790
|
+
rgb_color = intensity_to_rgb(intensity, min_intensity, max_intensity)
|
|
791
|
+
color_lut[int(node_id)] = rgb_color
|
|
792
|
+
|
|
793
|
+
# Apply lookup table to labeled array - single vectorized operation
|
|
794
|
+
if is_3d:
|
|
795
|
+
# Return full 3D RGB array [Z, Y, X, 3]
|
|
796
|
+
heatmap_array = color_lut[labeled_array]
|
|
797
|
+
else:
|
|
798
|
+
# Return 2D RGB array
|
|
799
|
+
if labeled_array.ndim == 3:
|
|
800
|
+
# Take middle slice for 2D representation
|
|
801
|
+
middle_slice = labeled_array.shape[0] // 2
|
|
802
|
+
heatmap_array = color_lut[labeled_array[middle_slice]]
|
|
803
|
+
else:
|
|
804
|
+
# Already 2D
|
|
805
|
+
heatmap_array = color_lut[labeled_array]
|
|
806
|
+
|
|
807
|
+
return heatmap_array
|
|
808
|
+
|
|
809
|
+
else:
|
|
810
|
+
# Create matplotlib plot
|
|
811
|
+
fig = plt.figure(figsize=figsize)
|
|
812
|
+
|
|
813
|
+
if is_3d:
|
|
814
|
+
# 3D plot
|
|
815
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
816
|
+
|
|
817
|
+
# Extract coordinates (assuming [Z, Y, X] format)
|
|
818
|
+
z_coords = positions[:, 0]
|
|
819
|
+
y_coords = positions[:, 1]
|
|
820
|
+
x_coords = positions[:, 2]
|
|
821
|
+
|
|
822
|
+
# Create scatter plot
|
|
823
|
+
scatter = ax.scatter(x_coords, y_coords, z_coords,
|
|
824
|
+
c=intensities, s=point_size, alpha=alpha,
|
|
825
|
+
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
826
|
+
|
|
827
|
+
ax.set_xlabel('X')
|
|
828
|
+
ax.set_ylabel('Y')
|
|
829
|
+
ax.set_zlabel('Z')
|
|
830
|
+
ax.set_title(f'{title}')
|
|
831
|
+
|
|
832
|
+
# Set axis limits based on shape
|
|
833
|
+
ax.set_xlim(0, shape[2])
|
|
834
|
+
ax.set_ylim(0, shape[1])
|
|
835
|
+
ax.set_zlim(0, shape[0])
|
|
836
|
+
|
|
837
|
+
else:
|
|
838
|
+
# 2D plot (using Y, X coordinates, ignoring Z/first dimension)
|
|
839
|
+
ax = fig.add_subplot(111)
|
|
840
|
+
|
|
841
|
+
# Extract Y, X coordinates
|
|
842
|
+
y_coords = positions[:, 1]
|
|
843
|
+
x_coords = positions[:, 2]
|
|
844
|
+
|
|
845
|
+
# Create scatter plot
|
|
846
|
+
scatter = ax.scatter(x_coords, y_coords,
|
|
847
|
+
c=intensities, s=point_size, alpha=alpha,
|
|
848
|
+
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
849
|
+
|
|
850
|
+
ax.set_xlabel('X')
|
|
851
|
+
ax.set_ylabel('Y')
|
|
852
|
+
ax.set_title(f'{title}')
|
|
853
|
+
ax.grid(True, alpha=0.3)
|
|
854
|
+
|
|
855
|
+
# Set axis limits based on shape
|
|
856
|
+
ax.set_xlim(0, shape[2])
|
|
857
|
+
ax.set_ylim(0, shape[1])
|
|
858
|
+
|
|
859
|
+
# Set origin to top-left (invert Y-axis)
|
|
860
|
+
ax.invert_yaxis()
|
|
861
|
+
|
|
862
|
+
# Add colorbar
|
|
863
|
+
cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
|
|
864
|
+
cbar.set_label(colorbar_label)
|
|
865
|
+
|
|
866
|
+
# Add text annotations for min/max values
|
|
867
|
+
cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
|
|
868
|
+
transform=cbar.ax.transAxes, va='bottom')
|
|
869
|
+
cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
|
|
870
|
+
transform=cbar.ax.transAxes, va='top')
|
|
871
|
+
|
|
872
|
+
plt.tight_layout()
|
|
873
|
+
plt.show()
|
|
337
874
|
|
|
338
875
|
# Example usage:
|
|
339
876
|
if __name__ == "__main__":
|
|
@@ -350,5 +887,4 @@ if __name__ == "__main__":
|
|
|
350
887
|
fig, ax = plot_dict_heatmap(sample_dict, sample_id_set,
|
|
351
888
|
title="Sample Heatmap Visualization")
|
|
352
889
|
|
|
353
|
-
plt.show()
|
|
354
|
-
|
|
890
|
+
plt.show()
|