nettracer3d 0.9.4__py3-none-any.whl → 0.9.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -549,54 +549,117 @@ def convert_node_colors_to_names(node_to_color: Dict[int, Tuple[int, int, int]],
549
549
 
550
550
  num_entries = len(node_to_color)
551
551
 
552
- # Calculate dynamic spacing based on number of entries
553
- entry_height = 0.8
554
- total_height = num_entries * entry_height + 1.5 # Extra space for title and margins
552
+ # Calculate text widths to determine optimal figure size
553
+ sorted_nodes = sorted(node_to_color.keys())
554
+
555
+ # Create a temporary figure to measure text widths
556
+ temp_fig, temp_ax = plt.subplots(figsize=(1, 1))
557
+
558
+ max_node_width = 0
559
+ max_color_width = 0
560
+
561
+ for node in sorted_nodes:
562
+ color_name = node_to_names[node]
563
+
564
+ # Measure node ID text width
565
+ node_text = temp_ax.text(0, 0, str(node), fontsize=12, fontweight='bold')
566
+ node_bbox = node_text.get_window_extent(renderer=temp_fig.canvas.get_renderer())
567
+ node_width = node_bbox.width
568
+ max_node_width = max(max_node_width, node_width)
569
+
570
+ # Measure color name text width
571
+ color_text = temp_ax.text(0, 0, color_name.replace('_', ' ').title(), fontsize=11)
572
+ color_bbox = color_text.get_window_extent(renderer=temp_fig.canvas.get_renderer())
573
+ color_width = color_bbox.width
574
+ max_color_width = max(max_color_width, color_width)
575
+
576
+ plt.close(temp_fig)
577
+
578
+ # Convert pixel widths to figure units (approximate conversion)
579
+ # This is a rough conversion - matplotlib uses 72 DPI by default
580
+ dpi = 72
581
+ max_node_width_fig = max_node_width / dpi
582
+ max_color_width_fig = max_color_width / dpi
583
+
584
+ # Calculate optimal figure dimensions
585
+ entry_height = 0.6 # Reduced for tighter spacing
586
+ margin = 0.3
587
+ swatch_width = 0.8
588
+ spacing = 0.2
589
+
590
+ # Calculate total width needed
591
+ total_width = (margin + max_node_width_fig + spacing +
592
+ swatch_width + spacing + max_color_width_fig + margin)
555
593
 
556
- # Create figure and axis with proper scaling
557
- fig, ax = plt.subplots(figsize=figsize)
558
- ax.set_xlim(0, 10)
594
+ # Ensure minimum width for readability
595
+ total_width = max(total_width, 4.0)
596
+
597
+ # Calculate total height
598
+ title_height = 0.8
599
+ total_height = num_entries * entry_height + title_height + 2 * margin
600
+
601
+ # Create the actual figure with calculated dimensions
602
+ fig, ax = plt.subplots(figsize=(total_width, total_height))
603
+
604
+ # Set axis limits to match our calculated dimensions
605
+ ax.set_xlim(0, total_width)
559
606
  ax.set_ylim(0, total_height)
560
607
  ax.axis('off')
561
608
 
562
609
  # Title
563
- ax.text(5, total_height - 0.5, 'Color Legend',
564
- fontsize=16, fontweight='bold', ha='center')
565
-
566
- # Sort nodes for consistent display
567
- sorted_nodes = sorted(node_to_color.keys())
610
+ ax.text(total_width/2, total_height - margin - 0.2, 'Color Legend',
611
+ fontsize=14, fontweight='bold', ha='center', va='top')
568
612
 
569
613
  # Create legend entries
570
614
  for i, node in enumerate(sorted_nodes):
571
- y_pos = total_height - (i + 1) * entry_height - 0.8
615
+ y_pos = total_height - title_height - margin - (i + 1) * entry_height + entry_height/2
572
616
  rgb = node_to_color[node]
573
617
  color_name = node_to_names[node]
574
618
 
575
619
  # Normalize RGB values for matplotlib (0-1 range)
576
620
  norm_rgb = tuple(c/255.0 for c in rgb)
577
621
 
578
- # Draw color swatch (using actual RGB values)
579
- swatch = Rectangle((1.0, y_pos - 0.15), 0.8, 0.3,
580
- facecolor=norm_rgb, edgecolor='black', linewidth=1)
581
- ax.add_patch(swatch)
622
+ # Position calculations
623
+ node_x = margin
624
+ swatch_x = margin + max_node_width_fig + spacing
625
+ color_x = swatch_x + swatch_width + spacing
582
626
 
583
- # Node ID (exactly as it appears in dict keys)
584
- ax.text(0.2, y_pos, str(node), fontsize=12, fontweight='bold',
627
+ # Node ID (left-aligned)
628
+ ax.text(node_x, y_pos, str(node), fontsize=12, fontweight='bold',
585
629
  va='center', ha='left')
586
630
 
587
- # Color name (mapped name, nicely formatted)
588
- ax.text(2.2, y_pos, color_name.replace('_', ' ').title(),
631
+ # Draw color swatch
632
+ swatch_y = y_pos - entry_height/4
633
+ swatch = Rectangle((swatch_x, swatch_y), swatch_width, entry_height/2,
634
+ facecolor=norm_rgb, edgecolor='black', linewidth=1)
635
+ ax.add_patch(swatch)
636
+
637
+ # Color name
638
+ formatted_name = color_name.replace('_', ' ').title()
639
+ # Truncate very long color names to prevent layout issues
640
+ if len(formatted_name) > 25:
641
+ formatted_name = formatted_name[:22] + "..."
642
+
643
+ ax.text(color_x, y_pos, formatted_name,
589
644
  fontsize=11, va='center', ha='left')
590
645
 
591
- # Add border around the legend
592
- border = Rectangle((0.1, 0.1), 9.8, total_height - 0.2,
593
- fill=False, edgecolor='gray', linewidth=2)
646
+ # Add a subtle border around the entire legend
647
+ border_margin = 0.1
648
+ border = Rectangle((border_margin, border_margin),
649
+ total_width - 2*border_margin,
650
+ total_height - 2*border_margin,
651
+ fill=False, edgecolor='lightgray', linewidth=1.5)
594
652
  ax.add_patch(border)
595
653
 
596
- plt.tight_layout()
654
+ # Remove any extra whitespace
655
+ plt.tight_layout(pad=0.1)
656
+
657
+ # Adjust the figure to eliminate whitespace
658
+ ax.margins(0)
659
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
597
660
 
598
661
  if save_path:
599
- plt.savefig(save_path, dpi=300, bbox_inches='tight')
662
+ plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
600
663
 
601
664
  plt.show()
602
665
 
@@ -347,8 +347,7 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
347
347
  id_dictionary: Optional[Dict[int, str]] = None,
348
348
  graph_label = "Community ID",
349
349
  title = 'UMAP Visualization of Community Compositions',
350
- neighborhoods: Optional[Dict[int, int]] = None,
351
- draw_lines: bool = False):
350
+ neighborhoods: Optional[Dict[int, int]] = None):
352
351
  """
353
352
  Convert cluster composition data to UMAP visualization.
354
353
 
@@ -371,8 +370,6 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
371
370
  neighborhoods : dict, optional
372
371
  Dictionary mapping node IDs to neighborhood IDs {node_id: neighborhood_id}.
373
372
  If provided, points will be colored by neighborhood using community coloration methods.
374
- draw_lines : bool
375
- Whether to draw lines between nodes that share identities (default: False)
376
373
 
377
374
  Returns:
378
375
  --------
@@ -456,111 +453,15 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
456
453
  plt.figure(figsize=(12, 8))
457
454
 
458
455
  if n_components == 2:
459
- # Draw scatter with different markers for multi-identity nodes if draw_lines is enabled
460
- if draw_lines:
461
- # Separate multi-identity and singleton nodes for different markers
462
- singleton_indices = []
463
- multi_indices = []
464
- singleton_colors = []
465
- multi_colors = []
466
-
467
- for i, cluster_id in enumerate(cluster_ids):
468
- vec = cluster_data[cluster_id]
469
- if np.sum(vec) > 1: # Multi-identity
470
- multi_indices.append(i)
471
- multi_colors.append(point_colors[i] if isinstance(point_colors, list) else point_colors)
472
- else: # Singleton
473
- singleton_indices.append(i)
474
- singleton_colors.append(point_colors[i] if isinstance(point_colors, list) else point_colors)
475
-
476
- # Draw singleton nodes as circles
477
- if singleton_indices:
478
- if use_neighborhood_coloring or use_identity_coloring:
479
- scatter1 = plt.scatter(embedding[singleton_indices, 0], embedding[singleton_indices, 1],
480
- c=singleton_colors, s=100, alpha=0.7, marker='o')
481
- else:
482
- scatter1 = plt.scatter(embedding[singleton_indices, 0], embedding[singleton_indices, 1],
483
- c=[point_colors[i] for i in singleton_indices], cmap='viridis', s=100, alpha=0.7, marker='o')
484
-
485
- # Draw multi-identity nodes as squares
486
- if multi_indices:
487
- if use_neighborhood_coloring or use_identity_coloring:
488
- scatter2 = plt.scatter(embedding[multi_indices, 0], embedding[multi_indices, 1],
489
- c=multi_colors, s=100, alpha=0.7, marker='s')
490
- else:
491
- scatter2 = plt.scatter(embedding[multi_indices, 0], embedding[multi_indices, 1],
492
- c=[point_colors[i] for i in multi_indices], cmap='viridis', s=100, alpha=0.7, marker='s')
493
- scatter = scatter2 # For colorbar reference
494
- else:
495
- scatter = scatter1 if singleton_indices else None
456
+ if use_neighborhood_coloring:
457
+ scatter = plt.scatter(embedding[:, 0], embedding[:, 1],
458
+ c=point_colors, s=100, alpha=0.7)
459
+ elif use_identity_coloring:
460
+ scatter = plt.scatter(embedding[:, 0], embedding[:, 1],
461
+ c=point_colors, s=100, alpha=0.7)
496
462
  else:
497
- # Original behavior when draw_lines is False
498
- if use_neighborhood_coloring:
499
- scatter = plt.scatter(embedding[:, 0], embedding[:, 1],
500
- c=point_colors, s=100, alpha=0.7)
501
- elif use_identity_coloring:
502
- scatter = plt.scatter(embedding[:, 0], embedding[:, 1],
503
- c=point_colors, s=100, alpha=0.7)
504
- else:
505
- scatter = plt.scatter(embedding[:, 0], embedding[:, 1],
506
- c=point_colors, cmap='viridis', s=100, alpha=0.7)
507
-
508
- # Draw lines between nodes with shared identities (only if draw_lines=True)
509
- if draw_lines:
510
- # First pass: identify unique multi-identity configurations and their representatives
511
- multi_config_map = {} # Maps tuple(config) -> {'count': int, 'representative_idx': int}
512
-
513
- for i, cluster_id in enumerate(cluster_ids):
514
- vec = cluster_data[cluster_id]
515
- if np.sum(vec) > 1: # Multi-identity node
516
- config = tuple(vec) # Convert to hashable tuple
517
- if config not in multi_config_map:
518
- multi_config_map[config] = {'count': 1, 'representative_idx': i}
519
- else:
520
- multi_config_map[config]['count'] += 1
521
-
522
- # Second pass: draw lines for each unique configuration
523
- for config, info in multi_config_map.items():
524
- i = info['representative_idx']
525
- count = info['count']
526
- vec1 = np.array(config)
527
-
528
- # For each identity this configuration has, find the closest representative
529
- identity_indices = np.where(vec1 == 1)[0]
530
-
531
- for identity_idx in identity_indices:
532
- best_target = None
533
- best_distance = float('inf')
534
- backup_target = None
535
- backup_distance = float('inf')
536
-
537
- # Find closest node with this specific identity
538
- for j, cluster_id2 in enumerate(cluster_ids):
539
- if i != j: # Don't connect to self
540
- vec2 = cluster_data[cluster_id2]
541
- if vec2[identity_idx] == 1: # Shares this specific identity
542
- distance = np.linalg.norm(embedding[i] - embedding[j])
543
-
544
- # Prefer singleton nodes
545
- if np.sum(vec2) == 1: # Singleton
546
- if distance < best_distance:
547
- best_distance = distance
548
- best_target = j
549
- else: # Multi-identity node (backup)
550
- if distance < backup_distance:
551
- backup_distance = distance
552
- backup_target = j
553
-
554
- # Draw line to best target (prefer singleton, fallback to multi)
555
- target = best_target if best_target is not None else backup_target
556
- if target is not None:
557
- # Calculate relative line weight with reasonable cap
558
- max_count = max(info['count'] for info in multi_config_map.values())
559
- relative_weight = count / max_count # Normalize to 0-1
560
- line_weight = 0.3 + relative_weight * 1.2 # Scale to 0.3-1.5 range
561
- plt.plot([embedding[i, 0], embedding[target, 0]],
562
- [embedding[i, 1], embedding[target, 1]],
563
- alpha=0.3, color='gray', linewidth=line_weight)
463
+ scatter = plt.scatter(embedding[:, 0], embedding[:, 1],
464
+ c=point_colors, cmap='viridis', s=100, alpha=0.7)
564
465
 
565
466
  if label:
566
467
  # Add cluster ID labels
@@ -615,112 +516,15 @@ def visualize_cluster_composition_umap(cluster_data: Dict[int, np.ndarray],
615
516
  fig = plt.figure(figsize=(14, 10))
616
517
  ax = fig.add_subplot(111, projection='3d')
617
518
 
618
- # Draw scatter with different markers for multi-identity nodes if draw_lines is enabled
619
- if draw_lines:
620
- # Separate multi-identity and singleton nodes for different markers
621
- singleton_indices = []
622
- multi_indices = []
623
- singleton_colors = []
624
- multi_colors = []
625
-
626
- for i, cluster_id in enumerate(cluster_ids):
627
- vec = cluster_data[cluster_id]
628
- if np.sum(vec) > 1: # Multi-identity
629
- multi_indices.append(i)
630
- multi_colors.append(point_colors[i] if isinstance(point_colors, list) else point_colors)
631
- else: # Singleton
632
- singleton_indices.append(i)
633
- singleton_colors.append(point_colors[i] if isinstance(point_colors, list) else point_colors)
634
-
635
- # Draw singleton nodes as circles
636
- if singleton_indices:
637
- if use_neighborhood_coloring or use_identity_coloring:
638
- scatter1 = ax.scatter(embedding[singleton_indices, 0], embedding[singleton_indices, 1], embedding[singleton_indices, 2],
639
- c=singleton_colors, s=100, alpha=0.7, marker='o')
640
- else:
641
- scatter1 = ax.scatter(embedding[singleton_indices, 0], embedding[singleton_indices, 1], embedding[singleton_indices, 2],
642
- c=[point_colors[i] for i in singleton_indices], cmap='viridis', s=100, alpha=0.7, marker='o')
643
-
644
- # Draw multi-identity nodes as squares
645
- if multi_indices:
646
- if use_neighborhood_coloring or use_identity_coloring:
647
- scatter2 = ax.scatter(embedding[multi_indices, 0], embedding[multi_indices, 1], embedding[multi_indices, 2],
648
- c=multi_colors, s=100, alpha=0.7, marker='s')
649
- else:
650
- scatter2 = ax.scatter(embedding[multi_indices, 0], embedding[multi_indices, 1], embedding[multi_indices, 2],
651
- c=[point_colors[i] for i in multi_indices], cmap='viridis', s=100, alpha=0.7, marker='s')
652
- scatter = scatter2 # For colorbar reference
653
- else:
654
- scatter = scatter1 if singleton_indices else None
519
+ if use_neighborhood_coloring:
520
+ scatter = ax.scatter(embedding[:, 0], embedding[:, 1], embedding[:, 2],
521
+ c=point_colors, s=100, alpha=0.7)
522
+ elif use_identity_coloring:
523
+ scatter = ax.scatter(embedding[:, 0], embedding[:, 1], embedding[:, 2],
524
+ c=point_colors, s=100, alpha=0.7)
655
525
  else:
656
- # Original behavior when draw_lines is False
657
- if use_neighborhood_coloring:
658
- scatter = ax.scatter(embedding[:, 0], embedding[:, 1], embedding[:, 2],
659
- c=point_colors, s=100, alpha=0.7)
660
- elif use_identity_coloring:
661
- scatter = ax.scatter(embedding[:, 0], embedding[:, 1], embedding[:, 2],
662
- c=point_colors, s=100, alpha=0.7)
663
- else:
664
- scatter = ax.scatter(embedding[:, 0], embedding[:, 1], embedding[:, 2],
665
- c=point_colors, cmap='viridis', s=100, alpha=0.7)
666
-
667
- # Draw lines between nodes with shared identities (only if draw_lines=True)
668
- if draw_lines:
669
- # First pass: identify unique multi-identity configurations and their representatives
670
- multi_config_map = {} # Maps tuple(config) -> {'count': int, 'representative_idx': int}
671
-
672
- for i, cluster_id in enumerate(cluster_ids):
673
- vec = cluster_data[cluster_id]
674
- if np.sum(vec) > 1: # Multi-identity node
675
- config = tuple(vec) # Convert to hashable tuple
676
- if config not in multi_config_map:
677
- multi_config_map[config] = {'count': 1, 'representative_idx': i}
678
- else:
679
- multi_config_map[config]['count'] += 1
680
-
681
- # Second pass: draw lines for each unique configuration
682
- for config, info in multi_config_map.items():
683
- i = info['representative_idx']
684
- count = info['count']
685
- vec1 = np.array(config)
686
-
687
- # For each identity this configuration has, find the closest representative
688
- identity_indices = np.where(vec1 == 1)[0]
689
-
690
- for identity_idx in identity_indices:
691
- best_target = None
692
- best_distance = float('inf')
693
- backup_target = None
694
- backup_distance = float('inf')
695
-
696
- # Find closest node with this specific identity
697
- for j, cluster_id2 in enumerate(cluster_ids):
698
- if i != j: # Don't connect to self
699
- vec2 = cluster_data[cluster_id2]
700
- if vec2[identity_idx] == 1: # Shares this specific identity
701
- distance = np.linalg.norm(embedding[i] - embedding[j])
702
-
703
- # Prefer singleton nodes
704
- if np.sum(vec2) == 1: # Singleton
705
- if distance < best_distance:
706
- best_distance = distance
707
- best_target = j
708
- else: # Multi-identity node (backup)
709
- if distance < backup_distance:
710
- backup_distance = distance
711
- backup_target = j
712
-
713
- # Draw line to best target (prefer singleton, fallback to multi)
714
- target = best_target if best_target is not None else backup_target
715
- if target is not None:
716
- # Calculate relative line weight with reasonable cap
717
- max_count = max(info['count'] for info in multi_config_map.values())
718
- relative_weight = count / max_count # Normalize to 0-1
719
- line_weight = 0.3 + relative_weight * 1.2 # Scale to 0.3-1.5 range
720
- ax.plot([embedding[i, 0], embedding[target, 0]],
721
- [embedding[i, 1], embedding[target, 1]],
722
- [embedding[i, 2], embedding[target, 2]],
723
- alpha=0.3, color='gray', linewidth=line_weight)
526
+ scatter = ax.scatter(embedding[:, 0], embedding[:, 1], embedding[:, 2],
527
+ c=point_colors, cmap='viridis', s=100, alpha=0.7)
724
528
 
725
529
  if label:
726
530
  # Add cluster ID labels