nettracer3d 0.7.9__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nettracer3d might be problematic. Click here for more details.
- nettracer3d/community_extractor.py +17 -26
- nettracer3d/neighborhoods.py +395 -58
- nettracer3d/nettracer.py +230 -39
- nettracer3d/nettracer_gui.py +1195 -202
- nettracer3d/node_draw.py +22 -12
- nettracer3d/proximity.py +83 -6
- nettracer3d/segmenter.py +1 -1
- nettracer3d/segmenter_GPU.py +1 -1
- nettracer3d/simple_network.py +43 -25
- {nettracer3d-0.7.9.dist-info → nettracer3d-0.8.1.dist-info}/METADATA +5 -3
- nettracer3d-0.8.1.dist-info/RECORD +23 -0
- nettracer3d-0.7.9.dist-info/RECORD +0 -23
- {nettracer3d-0.7.9.dist-info → nettracer3d-0.8.1.dist-info}/WHEEL +0 -0
- {nettracer3d-0.7.9.dist-info → nettracer3d-0.8.1.dist-info}/entry_points.txt +0 -0
- {nettracer3d-0.7.9.dist-info → nettracer3d-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {nettracer3d-0.7.9.dist-info → nettracer3d-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -492,48 +492,39 @@ def generate_distinct_colors(n_colors: int) -> List[Tuple[int, int, int]]:
|
|
|
492
492
|
return colors
|
|
493
493
|
|
|
494
494
|
def assign_node_colors(node_list: List[int], labeled_array: np.ndarray) -> Tuple[np.ndarray, Dict[int, str]]:
|
|
495
|
-
"""
|
|
496
|
-
Assign distinct colors to nodes and create an RGBA image.
|
|
497
|
-
|
|
498
|
-
Args:
|
|
499
|
-
node_list: List of node IDs
|
|
500
|
-
labeled_array: 3D numpy array with labels corresponding to node IDs
|
|
495
|
+
"""fast version using lookup table approach."""
|
|
501
496
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
"""
|
|
505
|
-
|
|
506
|
-
# Sort communities by size (descending)
|
|
507
|
-
sorted_nodes= sorted(node_list, reverse=True)
|
|
497
|
+
# Sort nodes by size (descending)
|
|
498
|
+
sorted_nodes = sorted(node_list, reverse=True)
|
|
508
499
|
|
|
509
500
|
# Generate distinct colors
|
|
510
501
|
colors = generate_distinct_colors(len(node_list))
|
|
511
|
-
random.shuffle(colors)
|
|
502
|
+
random.shuffle(colors) # Randomly sorted to make adjacent structures likely stand out
|
|
512
503
|
|
|
513
504
|
# Convert RGB colors to RGBA by adding alpha channel
|
|
514
|
-
colors_rgba = [(r, g, b, 255) for r, g, b in colors]
|
|
505
|
+
colors_rgba = np.array([(r, g, b, 255) for r, g, b in colors], dtype=np.uint8)
|
|
515
506
|
|
|
516
|
-
# Create mapping from
|
|
507
|
+
# Create mapping from node to color
|
|
517
508
|
node_to_color = {node: colors_rgba[i] for i, node in enumerate(sorted_nodes)}
|
|
518
509
|
|
|
519
|
-
# Create
|
|
520
|
-
|
|
510
|
+
# Create lookup table
|
|
511
|
+
max_label = max(max(labeled_array.flat), max(node_list) if node_list else 0)
|
|
512
|
+
color_lut = np.zeros((max_label + 1, 4), dtype=np.uint8) # Transparent by default
|
|
521
513
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
# Convert the RGB portion of community_to_color back to RGB for color naming
|
|
514
|
+
for node_id, color in node_to_color.items():
|
|
515
|
+
color_lut[node_id] = color
|
|
516
|
+
|
|
517
|
+
# Single vectorized operation - eliminates all loops!
|
|
518
|
+
rgba_array = color_lut[labeled_array]
|
|
519
|
+
|
|
520
|
+
# Convert colors for naming
|
|
530
521
|
node_to_color_rgb = {k: tuple(v[:3]) for k, v in node_to_color.items()}
|
|
531
522
|
node_to_color_names = convert_node_colors_to_names(node_to_color_rgb)
|
|
532
523
|
|
|
533
524
|
return rgba_array, node_to_color_names
|
|
534
525
|
|
|
535
526
|
def assign_community_colors(community_dict: Dict[int, int], labeled_array: np.ndarray) -> Tuple[np.ndarray, Dict[int, str]]:
|
|
536
|
-
"""
|
|
527
|
+
"""fast version using lookup table approach."""
|
|
537
528
|
|
|
538
529
|
# Same setup as before
|
|
539
530
|
communities = set(community_dict.values())
|
nettracer3d/neighborhoods.py
CHANGED
|
@@ -202,10 +202,12 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
|
|
|
202
202
|
|
|
203
203
|
return embedding
|
|
204
204
|
|
|
205
|
-
def create_community_heatmap(community_intensity, node_community, node_centroids, is_3d=True,
|
|
206
|
-
figsize=(12, 8), point_size=50, alpha=0.7,
|
|
205
|
+
def create_community_heatmap(community_intensity, node_community, node_centroids, shape=None, is_3d=True,
|
|
206
|
+
labeled_array=None, figsize=(12, 8), point_size=50, alpha=0.7,
|
|
207
|
+
colorbar_label="Community Intensity", title="Community Intensity Heatmap"):
|
|
207
208
|
"""
|
|
208
209
|
Create a 2D or 3D heatmap showing nodes colored by their community intensities.
|
|
210
|
+
Can return either matplotlib plot or numpy RGB array for overlay purposes.
|
|
209
211
|
|
|
210
212
|
Parameters:
|
|
211
213
|
-----------
|
|
@@ -220,25 +222,39 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
|
|
|
220
222
|
Dictionary mapping node IDs to centroids
|
|
221
223
|
Centroids should be [Z, Y, X] for 3D or [1, Y, X] for pseudo-3D
|
|
222
224
|
|
|
225
|
+
shape : tuple, optional
|
|
226
|
+
Shape of the output array in [Z, Y, X] format
|
|
227
|
+
If None, will be inferred from node_centroids
|
|
228
|
+
|
|
223
229
|
is_3d : bool, default=True
|
|
224
|
-
If True, create 3D plot. If False, create 2D plot.
|
|
230
|
+
If True, create 3D plot/array. If False, create 2D plot/array.
|
|
231
|
+
|
|
232
|
+
labeled_array : np.ndarray, optional
|
|
233
|
+
If provided, returns numpy RGB array overlay using this labeled array template
|
|
234
|
+
instead of matplotlib plot. Uses lookup table approach for efficiency.
|
|
225
235
|
|
|
226
236
|
figsize : tuple, default=(12, 8)
|
|
227
|
-
Figure size (width, height)
|
|
237
|
+
Figure size (width, height) - only used for matplotlib
|
|
228
238
|
|
|
229
239
|
point_size : int, default=50
|
|
230
|
-
Size of scatter plot points
|
|
240
|
+
Size of scatter plot points - only used for matplotlib
|
|
231
241
|
|
|
232
242
|
alpha : float, default=0.7
|
|
233
|
-
Transparency of points (0-1)
|
|
243
|
+
Transparency of points (0-1) - only used for matplotlib
|
|
234
244
|
|
|
235
245
|
colorbar_label : str, default="Community Intensity"
|
|
236
|
-
Label for the colorbar
|
|
246
|
+
Label for the colorbar - only used for matplotlib
|
|
247
|
+
|
|
248
|
+
title : str, default="Community Intensity Heatmap"
|
|
249
|
+
Title for the plot
|
|
237
250
|
|
|
238
251
|
Returns:
|
|
239
252
|
--------
|
|
240
|
-
fig, ax
|
|
253
|
+
If labeled_array is None: fig, ax (matplotlib figure and axis objects)
|
|
254
|
+
If labeled_array is provided: np.ndarray (RGB heatmap array with community intensity colors)
|
|
241
255
|
"""
|
|
256
|
+
import numpy as np
|
|
257
|
+
import matplotlib.pyplot as plt
|
|
242
258
|
|
|
243
259
|
# Convert numpy int64 keys to regular ints for consistency
|
|
244
260
|
community_intensity_clean = {}
|
|
@@ -254,6 +270,10 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
|
|
|
254
270
|
|
|
255
271
|
for node_id, centroid in node_centroids.items():
|
|
256
272
|
try:
|
|
273
|
+
# Convert node_id to regular int if it's numpy
|
|
274
|
+
if hasattr(node_id, 'item'):
|
|
275
|
+
node_id = node_id.item()
|
|
276
|
+
|
|
257
277
|
# Get community for this node
|
|
258
278
|
community_id = node_community[node_id]
|
|
259
279
|
|
|
@@ -266,74 +286,392 @@ def create_community_heatmap(community_intensity, node_community, node_centroids
|
|
|
266
286
|
|
|
267
287
|
node_positions.append(centroid)
|
|
268
288
|
node_intensities.append(intensity)
|
|
269
|
-
except:
|
|
289
|
+
except KeyError:
|
|
290
|
+
# Skip nodes that don't have community assignments or community intensities
|
|
270
291
|
pass
|
|
271
292
|
|
|
272
293
|
# Convert to numpy arrays
|
|
273
294
|
positions = np.array(node_positions)
|
|
274
295
|
intensities = np.array(node_intensities)
|
|
275
296
|
|
|
276
|
-
# Determine
|
|
277
|
-
|
|
278
|
-
|
|
297
|
+
# Determine shape if not provided
|
|
298
|
+
if shape is None:
|
|
299
|
+
if len(positions) > 0:
|
|
300
|
+
max_coords = np.max(positions, axis=0).astype(int)
|
|
301
|
+
shape = tuple(max_coords + 1)
|
|
302
|
+
else:
|
|
303
|
+
shape = (100, 100, 100) if is_3d else (1, 100, 100)
|
|
279
304
|
|
|
280
|
-
#
|
|
281
|
-
|
|
305
|
+
# Determine min and max intensities for scaling
|
|
306
|
+
if len(intensities) > 0:
|
|
307
|
+
min_intensity = np.min(intensities)
|
|
308
|
+
max_intensity = np.max(intensities)
|
|
309
|
+
else:
|
|
310
|
+
min_intensity, max_intensity = 0, 1
|
|
282
311
|
|
|
283
|
-
if
|
|
284
|
-
#
|
|
285
|
-
|
|
312
|
+
if labeled_array is not None:
|
|
313
|
+
# Create numpy RGB array output using labeled array and lookup table approach
|
|
314
|
+
|
|
315
|
+
# Create mapping from node ID to community intensity value
|
|
316
|
+
node_to_community_intensity = {}
|
|
317
|
+
for node_id, centroid in node_centroids.items():
|
|
318
|
+
# Convert node_id to regular int if it's numpy
|
|
319
|
+
if hasattr(node_id, 'item'):
|
|
320
|
+
node_id = node_id.item()
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
# Get community for this node
|
|
324
|
+
community_id = node_community[node_id]
|
|
325
|
+
|
|
326
|
+
# Convert community_id to regular int if it's numpy
|
|
327
|
+
if hasattr(community_id, 'item'):
|
|
328
|
+
community_id = community_id.item()
|
|
329
|
+
|
|
330
|
+
# Get intensity for this community
|
|
331
|
+
if community_id in community_intensity_clean:
|
|
332
|
+
node_to_community_intensity[node_id] = community_intensity_clean[community_id]
|
|
333
|
+
except KeyError:
|
|
334
|
+
# Skip nodes that don't have community assignments
|
|
335
|
+
pass
|
|
336
|
+
|
|
337
|
+
# Create colormap function (RdBu_r - red for high, blue for low, yellow/white for middle)
|
|
338
|
+
def intensity_to_rgb(intensity, min_val, max_val):
|
|
339
|
+
"""Convert intensity value to RGB using RdBu_r colormap logic"""
|
|
340
|
+
if max_val == min_val:
|
|
341
|
+
# All same value, use neutral color
|
|
342
|
+
return np.array([255, 255, 255], dtype=np.uint8) # White
|
|
343
|
+
|
|
344
|
+
# Normalize to -1 to 1 range (like RdBu_r colormap)
|
|
345
|
+
normalized = 2 * (intensity - min_val) / (max_val - min_val) - 1
|
|
346
|
+
normalized = np.clip(normalized, -1, 1)
|
|
347
|
+
|
|
348
|
+
if normalized > 0:
|
|
349
|
+
# Positive values: white to red
|
|
350
|
+
r = 255
|
|
351
|
+
g = int(255 * (1 - normalized))
|
|
352
|
+
b = int(255 * (1 - normalized))
|
|
353
|
+
else:
|
|
354
|
+
# Negative values: white to blue
|
|
355
|
+
r = int(255 * (1 + normalized))
|
|
356
|
+
g = int(255 * (1 + normalized))
|
|
357
|
+
b = 255
|
|
358
|
+
|
|
359
|
+
return np.array([r, g, b], dtype=np.uint8)
|
|
286
360
|
|
|
287
|
-
#
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
x_coords = positions[:, 2]
|
|
361
|
+
# Create lookup table for RGB colors
|
|
362
|
+
max_label = max(max(labeled_array.flat), max(node_to_community_intensity.keys()) if node_to_community_intensity else 0)
|
|
363
|
+
color_lut = np.zeros((max_label + 1, 3), dtype=np.uint8) # Default to black (0,0,0)
|
|
291
364
|
|
|
292
|
-
#
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
365
|
+
# Fill lookup table with RGB colors based on community intensity
|
|
366
|
+
for node_id, intensity in node_to_community_intensity.items():
|
|
367
|
+
rgb_color = intensity_to_rgb(intensity, min_intensity, max_intensity)
|
|
368
|
+
color_lut[int(node_id)] = rgb_color
|
|
296
369
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
370
|
+
# Apply lookup table to labeled array - single vectorized operation
|
|
371
|
+
if is_3d:
|
|
372
|
+
# Return full 3D RGB array [Z, Y, X, 3]
|
|
373
|
+
heatmap_array = color_lut[labeled_array]
|
|
374
|
+
else:
|
|
375
|
+
# Return 2D RGB array
|
|
376
|
+
if labeled_array.ndim == 3:
|
|
377
|
+
# Take middle slice for 2D representation
|
|
378
|
+
middle_slice = labeled_array.shape[0] // 2
|
|
379
|
+
heatmap_array = color_lut[labeled_array[middle_slice]]
|
|
380
|
+
else:
|
|
381
|
+
# Already 2D
|
|
382
|
+
heatmap_array = color_lut[labeled_array]
|
|
301
383
|
|
|
384
|
+
return heatmap_array
|
|
385
|
+
|
|
302
386
|
else:
|
|
303
|
-
#
|
|
304
|
-
|
|
387
|
+
# Create matplotlib plot
|
|
388
|
+
fig = plt.figure(figsize=figsize)
|
|
305
389
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
390
|
+
if is_3d:
|
|
391
|
+
# 3D plot
|
|
392
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
393
|
+
|
|
394
|
+
# Extract coordinates (assuming [Z, Y, X] format)
|
|
395
|
+
z_coords = positions[:, 0]
|
|
396
|
+
y_coords = positions[:, 1]
|
|
397
|
+
x_coords = positions[:, 2]
|
|
398
|
+
|
|
399
|
+
# Create scatter plot
|
|
400
|
+
scatter = ax.scatter(x_coords, y_coords, z_coords,
|
|
401
|
+
c=intensities, s=point_size, alpha=alpha,
|
|
402
|
+
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
403
|
+
|
|
404
|
+
ax.set_xlabel('X')
|
|
405
|
+
ax.set_ylabel('Y')
|
|
406
|
+
ax.set_zlabel('Z')
|
|
407
|
+
ax.set_title(f'{title}')
|
|
408
|
+
|
|
409
|
+
# Set axis limits based on shape
|
|
410
|
+
ax.set_xlim(0, shape[2])
|
|
411
|
+
ax.set_ylim(0, shape[1])
|
|
412
|
+
ax.set_zlim(0, shape[0])
|
|
413
|
+
|
|
414
|
+
else:
|
|
415
|
+
# 2D plot (using Y, X coordinates, ignoring Z/first dimension)
|
|
416
|
+
ax = fig.add_subplot(111)
|
|
417
|
+
|
|
418
|
+
# Extract Y, X coordinates
|
|
419
|
+
y_coords = positions[:, 1]
|
|
420
|
+
x_coords = positions[:, 2]
|
|
421
|
+
|
|
422
|
+
# Create scatter plot
|
|
423
|
+
scatter = ax.scatter(x_coords, y_coords,
|
|
424
|
+
c=intensities, s=point_size, alpha=alpha,
|
|
425
|
+
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
426
|
+
|
|
427
|
+
ax.set_xlabel('X')
|
|
428
|
+
ax.set_ylabel('Y')
|
|
429
|
+
ax.set_title(f'{title}')
|
|
430
|
+
ax.grid(True, alpha=0.3)
|
|
431
|
+
|
|
432
|
+
# Set axis limits based on shape
|
|
433
|
+
ax.set_xlim(0, shape[2])
|
|
434
|
+
ax.set_ylim(0, shape[1])
|
|
435
|
+
|
|
436
|
+
# Set origin to top-left (invert Y-axis)
|
|
437
|
+
ax.invert_yaxis()
|
|
309
438
|
|
|
310
|
-
#
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
439
|
+
# Add colorbar
|
|
440
|
+
cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
|
|
441
|
+
cbar.set_label(colorbar_label)
|
|
314
442
|
|
|
315
|
-
|
|
316
|
-
ax.
|
|
317
|
-
|
|
318
|
-
ax.
|
|
443
|
+
# Add text annotations for min/max values
|
|
444
|
+
cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
|
|
445
|
+
transform=cbar.ax.transAxes, va='bottom')
|
|
446
|
+
cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
|
|
447
|
+
transform=cbar.ax.transAxes, va='top')
|
|
319
448
|
|
|
320
|
-
|
|
321
|
-
|
|
449
|
+
plt.tight_layout()
|
|
450
|
+
plt.show()
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def create_node_heatmap(node_intensity, node_centroids, shape=None, is_3d=True,
|
|
454
|
+
labeled_array=None, figsize=(12, 8), point_size=50, alpha=0.7,
|
|
455
|
+
colorbar_label="Node Intensity", title="Node Clustering Intensity Heatmap"):
|
|
456
|
+
"""
|
|
457
|
+
Create a 2D or 3D heatmap showing nodes colored by their individual intensities.
|
|
458
|
+
Can return either matplotlib plot or numpy array for overlay purposes.
|
|
322
459
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
460
|
+
Parameters:
|
|
461
|
+
-----------
|
|
462
|
+
node_intensity : dict
|
|
463
|
+
Dictionary mapping node IDs to intensity values
|
|
464
|
+
Keys can be np.int64 or regular ints
|
|
465
|
+
|
|
466
|
+
node_centroids : dict
|
|
467
|
+
Dictionary mapping node IDs to centroids
|
|
468
|
+
Centroids should be [Z, Y, X] for 3D or [1, Y, X] for pseudo-3D
|
|
469
|
+
|
|
470
|
+
shape : tuple, optional
|
|
471
|
+
Shape of the output array in [Z, Y, X] format
|
|
472
|
+
If None, will be inferred from node_centroids
|
|
473
|
+
|
|
474
|
+
is_3d : bool, default=True
|
|
475
|
+
If True, create 3D plot/array. If False, create 2D plot/array.
|
|
476
|
+
|
|
477
|
+
labeled_array : np.ndarray, optional
|
|
478
|
+
If provided, returns numpy array overlay using this labeled array template
|
|
479
|
+
instead of matplotlib plot. Uses lookup table approach for efficiency.
|
|
480
|
+
|
|
481
|
+
figsize : tuple, default=(12, 8)
|
|
482
|
+
Figure size (width, height) - only used for matplotlib
|
|
483
|
+
|
|
484
|
+
point_size : int, default=50
|
|
485
|
+
Size of scatter plot points - only used for matplotlib
|
|
486
|
+
|
|
487
|
+
alpha : float, default=0.7
|
|
488
|
+
Transparency of points (0-1) - only used for matplotlib
|
|
489
|
+
|
|
490
|
+
colorbar_label : str, default="Node Intensity"
|
|
491
|
+
Label for the colorbar - only used for matplotlib
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
--------
|
|
495
|
+
If labeled_array is None: fig, ax (matplotlib figure and axis objects)
|
|
496
|
+
If labeled_array is provided: np.ndarray (heatmap array with intensity values)
|
|
497
|
+
"""
|
|
498
|
+
import numpy as np
|
|
499
|
+
import matplotlib.pyplot as plt
|
|
500
|
+
|
|
501
|
+
# Convert numpy int64 keys to regular ints for consistency
|
|
502
|
+
node_intensity_clean = {}
|
|
503
|
+
for k, v in node_intensity.items():
|
|
504
|
+
if hasattr(k, 'item'): # numpy scalar
|
|
505
|
+
node_intensity_clean[k.item()] = v
|
|
506
|
+
else:
|
|
507
|
+
node_intensity_clean[k] = v
|
|
326
508
|
|
|
327
|
-
#
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
|
|
331
|
-
transform=cbar.ax.transAxes, va='top')
|
|
509
|
+
# Prepare data for plotting/array creation
|
|
510
|
+
node_positions = []
|
|
511
|
+
node_intensities = []
|
|
332
512
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
513
|
+
for node_id, centroid in node_centroids.items():
|
|
514
|
+
try:
|
|
515
|
+
# Convert node_id to regular int if it's numpy
|
|
516
|
+
if hasattr(node_id, 'item'):
|
|
517
|
+
node_id = node_id.item()
|
|
518
|
+
|
|
519
|
+
# Get intensity for this node
|
|
520
|
+
intensity = node_intensity_clean[node_id]
|
|
521
|
+
|
|
522
|
+
node_positions.append(centroid)
|
|
523
|
+
node_intensities.append(intensity)
|
|
524
|
+
except KeyError:
|
|
525
|
+
# Skip nodes that don't have intensity values
|
|
526
|
+
pass
|
|
527
|
+
|
|
528
|
+
# Convert to numpy arrays
|
|
529
|
+
positions = np.array(node_positions)
|
|
530
|
+
intensities = np.array(node_intensities)
|
|
531
|
+
|
|
532
|
+
# Determine shape if not provided
|
|
533
|
+
if shape is None:
|
|
534
|
+
if len(positions) > 0:
|
|
535
|
+
max_coords = np.max(positions, axis=0).astype(int)
|
|
536
|
+
shape = tuple(max_coords + 1)
|
|
537
|
+
else:
|
|
538
|
+
shape = (100, 100, 100) if is_3d else (1, 100, 100)
|
|
539
|
+
|
|
540
|
+
# Determine min and max intensities for scaling
|
|
541
|
+
if len(intensities) > 0:
|
|
542
|
+
min_intensity = np.min(intensities)
|
|
543
|
+
max_intensity = np.max(intensities)
|
|
544
|
+
else:
|
|
545
|
+
min_intensity, max_intensity = 0, 1
|
|
546
|
+
|
|
547
|
+
if labeled_array is not None:
|
|
548
|
+
# Create numpy RGB array output using labeled array and lookup table approach
|
|
549
|
+
|
|
550
|
+
# Create mapping from node ID to intensity value (keep original float values)
|
|
551
|
+
node_to_intensity = {}
|
|
552
|
+
for node_id, centroid in node_centroids.items():
|
|
553
|
+
# Convert node_id to regular int if it's numpy
|
|
554
|
+
if hasattr(node_id, 'item'):
|
|
555
|
+
node_id = node_id.item()
|
|
556
|
+
|
|
557
|
+
# Only include nodes that have intensity values
|
|
558
|
+
if node_id in node_intensity_clean:
|
|
559
|
+
node_to_intensity[node_id] = node_intensity_clean[node_id]
|
|
560
|
+
|
|
561
|
+
# Create colormap function (RdBu_r - red for high, blue for low, yellow/white for middle)
|
|
562
|
+
def intensity_to_rgb(intensity, min_val, max_val):
|
|
563
|
+
"""Convert intensity value to RGB using RdBu_r colormap logic"""
|
|
564
|
+
if max_val == min_val:
|
|
565
|
+
# All same value, use neutral color
|
|
566
|
+
return np.array([255, 255, 255], dtype=np.uint8) # White
|
|
567
|
+
|
|
568
|
+
# Normalize to -1 to 1 range (like RdBu_r colormap)
|
|
569
|
+
normalized = 2 * (intensity - min_val) / (max_val - min_val) - 1
|
|
570
|
+
normalized = np.clip(normalized, -1, 1)
|
|
571
|
+
|
|
572
|
+
if normalized > 0:
|
|
573
|
+
# Positive values: white to red
|
|
574
|
+
r = 255
|
|
575
|
+
g = int(255 * (1 - normalized))
|
|
576
|
+
b = int(255 * (1 - normalized))
|
|
577
|
+
else:
|
|
578
|
+
# Negative values: white to blue
|
|
579
|
+
r = int(255 * (1 + normalized))
|
|
580
|
+
g = int(255 * (1 + normalized))
|
|
581
|
+
b = 255
|
|
582
|
+
|
|
583
|
+
return np.array([r, g, b], dtype=np.uint8)
|
|
584
|
+
|
|
585
|
+
# Create lookup table for RGB colors
|
|
586
|
+
max_label = max(max(labeled_array.flat), max(node_to_intensity.keys()) if node_to_intensity else 0)
|
|
587
|
+
color_lut = np.zeros((max_label + 1, 3), dtype=np.uint8) # Default to black (0,0,0)
|
|
588
|
+
|
|
589
|
+
# Fill lookup table with RGB colors based on intensity
|
|
590
|
+
for node_id, intensity in node_to_intensity.items():
|
|
591
|
+
rgb_color = intensity_to_rgb(intensity, min_intensity, max_intensity)
|
|
592
|
+
color_lut[int(node_id)] = rgb_color
|
|
593
|
+
|
|
594
|
+
# Apply lookup table to labeled array - single vectorized operation
|
|
595
|
+
if is_3d:
|
|
596
|
+
# Return full 3D RGB array [Z, Y, X, 3]
|
|
597
|
+
heatmap_array = color_lut[labeled_array]
|
|
598
|
+
else:
|
|
599
|
+
# Return 2D RGB array
|
|
600
|
+
if labeled_array.ndim == 3:
|
|
601
|
+
# Take middle slice for 2D representation
|
|
602
|
+
middle_slice = labeled_array.shape[0] // 2
|
|
603
|
+
heatmap_array = color_lut[labeled_array[middle_slice]]
|
|
604
|
+
else:
|
|
605
|
+
# Already 2D
|
|
606
|
+
heatmap_array = color_lut[labeled_array]
|
|
607
|
+
|
|
608
|
+
return heatmap_array
|
|
609
|
+
|
|
610
|
+
else:
|
|
611
|
+
# Create matplotlib plot
|
|
612
|
+
fig = plt.figure(figsize=figsize)
|
|
613
|
+
|
|
614
|
+
if is_3d:
|
|
615
|
+
# 3D plot
|
|
616
|
+
ax = fig.add_subplot(111, projection='3d')
|
|
617
|
+
|
|
618
|
+
# Extract coordinates (assuming [Z, Y, X] format)
|
|
619
|
+
z_coords = positions[:, 0]
|
|
620
|
+
y_coords = positions[:, 1]
|
|
621
|
+
x_coords = positions[:, 2]
|
|
622
|
+
|
|
623
|
+
# Create scatter plot
|
|
624
|
+
scatter = ax.scatter(x_coords, y_coords, z_coords,
|
|
625
|
+
c=intensities, s=point_size, alpha=alpha,
|
|
626
|
+
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
627
|
+
|
|
628
|
+
ax.set_xlabel('X')
|
|
629
|
+
ax.set_ylabel('Y')
|
|
630
|
+
ax.set_zlabel('Z')
|
|
631
|
+
ax.set_title(f'{title}')
|
|
632
|
+
|
|
633
|
+
# Set axis limits based on shape
|
|
634
|
+
ax.set_xlim(0, shape[2])
|
|
635
|
+
ax.set_ylim(0, shape[1])
|
|
636
|
+
ax.set_zlim(0, shape[0])
|
|
637
|
+
|
|
638
|
+
else:
|
|
639
|
+
# 2D plot (using Y, X coordinates, ignoring Z/first dimension)
|
|
640
|
+
ax = fig.add_subplot(111)
|
|
641
|
+
|
|
642
|
+
# Extract Y, X coordinates
|
|
643
|
+
y_coords = positions[:, 1]
|
|
644
|
+
x_coords = positions[:, 2]
|
|
645
|
+
|
|
646
|
+
# Create scatter plot
|
|
647
|
+
scatter = ax.scatter(x_coords, y_coords,
|
|
648
|
+
c=intensities, s=point_size, alpha=alpha,
|
|
649
|
+
cmap='RdBu_r', vmin=min_intensity, vmax=max_intensity)
|
|
650
|
+
|
|
651
|
+
ax.set_xlabel('X')
|
|
652
|
+
ax.set_ylabel('Y')
|
|
653
|
+
ax.set_title(f'{title}')
|
|
654
|
+
ax.grid(True, alpha=0.3)
|
|
655
|
+
|
|
656
|
+
# Set axis limits based on shape
|
|
657
|
+
ax.set_xlim(0, shape[2])
|
|
658
|
+
ax.set_ylim(0, shape[1])
|
|
659
|
+
|
|
660
|
+
# Set origin to top-left (invert Y-axis)
|
|
661
|
+
ax.invert_yaxis()
|
|
662
|
+
|
|
663
|
+
# Add colorbar
|
|
664
|
+
cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
|
|
665
|
+
cbar.set_label(colorbar_label)
|
|
666
|
+
|
|
667
|
+
# Add text annotations for min/max values
|
|
668
|
+
cbar.ax.text(1.05, 0, f'Min: {min_intensity:.3f}\n(Blue)',
|
|
669
|
+
transform=cbar.ax.transAxes, va='bottom')
|
|
670
|
+
cbar.ax.text(1.05, 1, f'Max: {max_intensity:.3f}\n(Red)',
|
|
671
|
+
transform=cbar.ax.transAxes, va='top')
|
|
672
|
+
|
|
673
|
+
plt.tight_layout()
|
|
674
|
+
plt.show()
|
|
337
675
|
|
|
338
676
|
# Example usage:
|
|
339
677
|
if __name__ == "__main__":
|
|
@@ -350,5 +688,4 @@ if __name__ == "__main__":
|
|
|
350
688
|
fig, ax = plot_dict_heatmap(sample_dict, sample_id_set,
|
|
351
689
|
title="Sample Heatmap Visualization")
|
|
352
690
|
|
|
353
|
-
plt.show()
|
|
354
|
-
|
|
691
|
+
plt.show()
|