nettracer3d 1.2.7__py3-none-any.whl → 1.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nettracer3d might be problematic. Click here for more details.

@@ -0,0 +1,2267 @@
1
+ import numpy as np
2
+ import networkx as nx
3
+ from PyQt6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QMenu,
4
+ QSizePolicy, QApplication, QScrollArea, QLabel, QFrame,
5
+ QFileDialog, QMessageBox)
6
+ from PyQt6.QtCore import Qt, pyqtSignal, QThread, pyqtSlot, QTimer, QPointF, QRectF
7
+ from PyQt6.QtGui import QColor, QPen, QBrush
8
+ import pyqtgraph as pg
9
+ from pyqtgraph import ScatterPlotItem, PlotCurveItem, GraphicsLayoutWidget, ROI
10
+ import colorsys
11
+ import random
12
+ import copy
13
+
14
+
15
+ class GraphLoadThread(QThread):
16
+ """Thread for loading graph layouts without blocking the UI"""
17
+ finished = pyqtSignal(object) # Emits the computed layout data
18
+
19
+ def __init__(self, graph, geometric, component, centroids, communities,
20
+ community_dict, identities, identity_dict, weight, z_size,
21
+ shell, node_size, edge_size):
22
+ super().__init__()
23
+ self.graph = graph
24
+ self.geometric = geometric
25
+ self.component = component
26
+ self.centroids = centroids
27
+ self.communities = communities
28
+ self.community_dict = community_dict
29
+ self.identities = identities
30
+ self.identity_dict = identity_dict
31
+ self.weight = weight
32
+ self.z_size = z_size
33
+ self.shell = shell
34
+ self.node_size = node_size
35
+ self.edge_size = edge_size
36
+
37
+ def run(self):
38
+ """Compute layout and colors in background thread"""
39
+ result = {}
40
+
41
+ # Compute node positions
42
+ if not self.geometric and not self.component:
43
+ result['pos'] = self._compute_fast_spring_layout()
44
+ elif self.geometric:
45
+ result['pos'] = self._compute_geometric_layout()
46
+ elif self.component:
47
+ nodes = list(self.graph.nodes())
48
+ n = len(nodes)
49
+ result['pos'] = self._spring_layout_numpy_super(nodes, n)
50
+
51
+ # Compute node colors and sizes
52
+ result['colors'], result['sizes'] = self._compute_node_attributes()
53
+
54
+ # Compute edge data
55
+ result['edges'] = self._compute_edge_data(result['pos'])
56
+
57
+ # Prepare node spots for rendering (with pre-computed brushes)
58
+ result['node_spots'], result['brush_cache'] = self._prepare_node_spots(result['pos'], result['colors'], result['sizes'])
59
+
60
+ # Prepare label data
61
+ result['label_data'] = self._prepare_label_data(result['pos'])
62
+
63
+ # Prepare edge items
64
+ result['edge_pens'] = self._prepare_edge_pens(result['edges'])
65
+
66
+ self.finished.emit(result)
67
+
68
+ def _compute_fast_spring_layout(self):
69
+ """Fast vectorized spring layout using numpy"""
70
+ nodes = list(self.graph.nodes())
71
+ n = len(nodes)
72
+
73
+ if n == 0:
74
+ return {}
75
+
76
+ # For small graphs, use networkx (overhead is negligible)
77
+ if n < 500 and not self.shell:
78
+ return nx.spring_layout(self.graph, seed=42, iterations=50)
79
+
80
+ # Use fast vectorized implementation for larger graphs
81
+ try:
82
+ if not self.shell:
83
+ return self._spring_layout_numpy(nodes, n)
84
+ else:
85
+ return self._shell_layout_numpy_super(nodes, n)
86
+ except Exception as e:
87
+ pass
88
+
89
+ def _shell_layout_numpy_super(self, nodes, n):
90
+ """
91
+ Shell layout with physically separated connected components
92
+ """
93
+ np.random.seed(42)
94
+
95
+ # Find connected components
96
+ components = list(nx.connected_components(self.graph))
97
+
98
+ if len(components) == 1:
99
+ # Single component - compute shell layout directly with numpy
100
+ comp_nodes = nodes
101
+
102
+ if n == 1:
103
+ return {nodes[0]: np.array([0.0, 0.0])}
104
+
105
+ # Create node to index mapping
106
+ node_to_idx = {node: i for i, node in enumerate(nodes)}
107
+
108
+ # Compute degree centrality using numpy
109
+ degrees = np.zeros(n)
110
+ for u, v in self.graph.edges():
111
+ if u in node_to_idx and v in node_to_idx:
112
+ degrees[node_to_idx[u]] += 1
113
+ degrees[node_to_idx[v]] += 1
114
+
115
+ # Find most central node (highest degree)
116
+ central_idx = np.argmax(degrees)
117
+ central_node = nodes[central_idx]
118
+
119
+ # Build adjacency list for BFS
120
+ adj_list = {node: [] for node in nodes}
121
+ for u, v in self.graph.edges():
122
+ if u in node_to_idx and v in node_to_idx:
123
+ adj_list[u].append(v)
124
+ adj_list[v].append(u)
125
+
126
+ # Compute shells using BFS from central node
127
+ visited = set()
128
+ shells = []
129
+ current_shell = [central_node]
130
+ visited.add(central_node)
131
+
132
+ while current_shell:
133
+ shells.append(current_shell[:])
134
+ next_shell = []
135
+ for node in current_shell:
136
+ for neighbor in adj_list[node]:
137
+ if neighbor not in visited:
138
+ visited.add(neighbor)
139
+ next_shell.append(neighbor)
140
+ current_shell = next_shell
141
+
142
+ # Position nodes in concentric circles
143
+ pos = {}
144
+ radius = 1.0
145
+
146
+ for shell_idx, shell in enumerate(shells):
147
+ if shell_idx == 0:
148
+ # Center node at origin
149
+ pos[shell[0]] = np.array([0.0, 0.0])
150
+ else:
151
+ # Arrange nodes in circle at radius * shell_idx
152
+ num_nodes = len(shell)
153
+ angles = np.linspace(0, 2 * np.pi, num_nodes, endpoint=False)
154
+ for i, node in enumerate(shell):
155
+ x = radius * shell_idx * np.cos(angles[i])
156
+ y = radius * shell_idx * np.sin(angles[i])
157
+ pos[node] = np.array([x, y])
158
+
159
+ # Center the layout
160
+ positions = np.array(list(pos.values()))
161
+ positions -= positions.mean(axis=0)
162
+
163
+ return {node: positions[list(pos.keys()).index(node)] for node in nodes}
164
+
165
+ # Multiple components - layout each component independently
166
+ component_layouts = []
167
+ component_bounds = []
168
+
169
+ for component in components:
170
+ comp_nodes = list(component)
171
+ comp_n = len(comp_nodes)
172
+
173
+ # Layout this component
174
+ comp_pos = self._layout_component_shell(comp_nodes, comp_n)
175
+
176
+ # Calculate bounding box
177
+ positions = np.array(list(comp_pos.values()))
178
+ min_coords = positions.min(axis=0)
179
+ max_coords = positions.max(axis=0)
180
+ size = max_coords - min_coords
181
+
182
+ component_layouts.append((comp_nodes, comp_pos))
183
+ component_bounds.append(size)
184
+
185
+ # Arrange components in a grid with spacing
186
+ num_components = len(components)
187
+ grid_cols = int(np.ceil(np.sqrt(num_components)))
188
+
189
+ # Calculate spacing based on largest component
190
+ max_width = max(bounds[0] for bounds in component_bounds)
191
+ max_height = max(bounds[1] for bounds in component_bounds)
192
+ spacing_x = max_width * 1.5 # 50% padding between components
193
+ spacing_y = max_height * 1.5
194
+
195
+ # Place components in grid
196
+ final_positions = {}
197
+ for idx, (comp_nodes, comp_pos) in enumerate(component_layouts):
198
+ grid_x = idx % grid_cols
199
+ grid_y = idx // grid_cols
200
+
201
+ # Calculate offset for this component
202
+ offset = np.array([grid_x * spacing_x, grid_y * spacing_y])
203
+
204
+ # Apply offset to all nodes in component
205
+ for node in comp_nodes:
206
+ final_positions[node] = comp_pos[node] + offset
207
+
208
+ # Center the entire layout
209
+ all_pos = np.array([final_positions[node] for node in nodes])
210
+ all_pos -= all_pos.mean(axis=0)
211
+
212
+ return {node: all_pos[i] for i, node in enumerate(nodes)}
213
+
214
+ def _layout_component_shell(self, nodes, n):
215
+ """
216
+ Shell layout for a single component using numpy for centrality
217
+ """
218
+ if n == 1:
219
+ return {nodes[0]: np.array([0.0, 0.0])}
220
+
221
+ # Create node to index mapping
222
+ node_to_idx = {node: i for i, node in enumerate(nodes)}
223
+
224
+ # Compute degree centrality using numpy
225
+ degrees = np.zeros(n)
226
+ for u, v in self.graph.edges():
227
+ if u in node_to_idx and v in node_to_idx:
228
+ degrees[node_to_idx[u]] += 1
229
+ degrees[node_to_idx[v]] += 1
230
+
231
+ # Find most central node (highest degree)
232
+ central_idx = np.argmax(degrees)
233
+ central_node = nodes[central_idx]
234
+
235
+ # Build adjacency list for BFS
236
+ adj_list = {node: [] for node in nodes}
237
+ for u, v in self.graph.edges():
238
+ if u in node_to_idx and v in node_to_idx:
239
+ adj_list[u].append(v)
240
+ adj_list[v].append(u)
241
+
242
+ # Compute shells using BFS from central node
243
+ visited = set()
244
+ shells = []
245
+ current_shell = [central_node]
246
+ visited.add(central_node)
247
+
248
+ while current_shell:
249
+ shells.append(current_shell[:])
250
+ next_shell = []
251
+ for node in current_shell:
252
+ for neighbor in adj_list[node]:
253
+ if neighbor not in visited:
254
+ visited.add(neighbor)
255
+ next_shell.append(neighbor)
256
+ current_shell = next_shell
257
+
258
+ # Position nodes in concentric circles
259
+ pos = {}
260
+ radius = 1.0
261
+
262
+ for shell_idx, shell in enumerate(shells):
263
+ if shell_idx == 0:
264
+ # Center node at origin
265
+ pos[shell[0]] = np.array([0.0, 0.0])
266
+ else:
267
+ # Arrange nodes in circle at radius * shell_idx
268
+ num_nodes = len(shell)
269
+ angles = np.linspace(0, 2 * np.pi, num_nodes, endpoint=False)
270
+ for i, node in enumerate(shell):
271
+ x = radius * shell_idx * np.cos(angles[i])
272
+ y = radius * shell_idx * np.sin(angles[i])
273
+ pos[node] = np.array([x, y])
274
+
275
+ # Center the layout
276
+ positions = np.array(list(pos.values()))
277
+ positions -= positions.mean(axis=0)
278
+
279
+ return {node: positions[list(pos.keys()).index(node)] for node in nodes}
280
+
281
+ def _spring_layout_numpy_super(self, nodes, n, iterations=50):
282
+ """
283
+ Spring layout with physically separated connected components
284
+ """
285
+ from scipy.spatial import cKDTree
286
+ import networkx as nx
287
+
288
+ np.random.seed(42)
289
+
290
+ # Find connected components
291
+ components = list(nx.connected_components(self.graph))
292
+
293
+ if len(components) == 1:
294
+ # Single component - use original algorithm
295
+ return self._spring_layout_numpy(nodes, n, iterations)
296
+
297
+ # Layout each component independently
298
+ component_layouts = []
299
+ component_bounds = []
300
+
301
+ for component in components:
302
+ comp_nodes = list(component)
303
+ comp_n = len(comp_nodes)
304
+
305
+ # Create subgraph for this component
306
+ subgraph_edges = [
307
+ (u, v) for u, v in self.graph.edges()
308
+ if u in component and v in component
309
+ ]
310
+
311
+ # Run spring layout on this component only
312
+ comp_pos = self._layout_component(comp_nodes, comp_n, subgraph_edges, iterations)
313
+
314
+ # Calculate bounding box
315
+ positions = np.array(list(comp_pos.values()))
316
+ min_coords = positions.min(axis=0)
317
+ max_coords = positions.max(axis=0)
318
+ size = max_coords - min_coords
319
+
320
+ component_layouts.append((comp_nodes, comp_pos))
321
+ component_bounds.append(size)
322
+
323
+ # Arrange components in a grid with spacing
324
+ num_components = len(components)
325
+ grid_cols = int(np.ceil(np.sqrt(num_components)))
326
+
327
+ # Calculate spacing based on largest component
328
+ max_width = max(bounds[0] for bounds in component_bounds)
329
+ max_height = max(bounds[1] for bounds in component_bounds)
330
+ spacing_x = max_width * 1.5 # 50% padding between components
331
+ spacing_y = max_height * 1.5
332
+
333
+ # Place components in grid
334
+ final_positions = {}
335
+ for idx, (comp_nodes, comp_pos) in enumerate(component_layouts):
336
+ grid_x = idx % grid_cols
337
+ grid_y = idx // grid_cols
338
+
339
+ # Calculate offset for this component
340
+ offset = np.array([grid_x * spacing_x, grid_y * spacing_y])
341
+
342
+ # Apply offset to all nodes in component
343
+ for node in comp_nodes:
344
+ final_positions[node] = comp_pos[node] + offset
345
+
346
+ # Center the entire layout
347
+ all_pos = np.array([final_positions[node] for node in nodes])
348
+ all_pos -= all_pos.mean(axis=0)
349
+
350
+ return {node: all_pos[i] for i, node in enumerate(nodes)}
351
+
352
+ def _layout_component(self, nodes, n, edges, iterations):
353
+ """
354
+ Spring layout for a single component
355
+ """
356
+ from scipy.spatial import cKDTree
357
+
358
+ np.random.seed(42 + len(nodes)) # Different seed per component size
359
+ pos = np.random.rand(n, 2)
360
+
361
+ if len(edges) == 0:
362
+ return {node: pos[i] for i, node in enumerate(nodes)}
363
+
364
+ node_to_idx = {node: i for i, node in enumerate(nodes)}
365
+ edge_indices = np.array([[node_to_idx[u], node_to_idx[v]] for u, v in edges])
366
+
367
+ k = np.sqrt(1.0 / n)
368
+ t = 0.1
369
+ dt = t / (iterations + 1)
370
+ cutoff_distance = 4 * k
371
+
372
+ for iteration in range(iterations):
373
+ displacement = np.zeros_like(pos)
374
+
375
+ tree = cKDTree(pos)
376
+ pairs = tree.query_pairs(r=cutoff_distance, output_type='ndarray')
377
+
378
+ if len(pairs) > 0:
379
+ i_indices = pairs[:, 0]
380
+ j_indices = pairs[:, 1]
381
+
382
+ delta = pos[i_indices] - pos[j_indices]
383
+ distance = np.linalg.norm(delta, axis=1, keepdims=True)
384
+ distance = np.maximum(distance, 0.01)
385
+
386
+ force_magnitude = (k * k) / distance
387
+ force = delta * (force_magnitude / distance)
388
+
389
+ np.add.at(displacement, i_indices, force)
390
+ np.add.at(displacement, j_indices, -force)
391
+
392
+ if len(edge_indices) > 0:
393
+ edge_delta = pos[edge_indices[:, 0]] - pos[edge_indices[:, 1]]
394
+ edge_distance = np.sqrt((edge_delta ** 2).sum(axis=1, keepdims=True))
395
+ edge_distance = np.maximum(edge_distance, 0.01)
396
+
397
+ edge_force_magnitude = (edge_distance * edge_distance) / k
398
+ edge_force = edge_delta * (edge_force_magnitude / edge_distance)
399
+
400
+ np.add.at(displacement, edge_indices[:, 0], -edge_force)
401
+ np.add.at(displacement, edge_indices[:, 1], edge_force)
402
+
403
+ disp_magnitude = np.sqrt((displacement ** 2).sum(axis=1, keepdims=True))
404
+ disp_magnitude = np.maximum(disp_magnitude, 0.01)
405
+ displacement = displacement * np.minimum(t / disp_magnitude, 1.0)
406
+
407
+ pos += displacement
408
+ t -= dt
409
+
410
+ pos -= pos.mean(axis=0)
411
+ return {node: pos[i] for i, node in enumerate(nodes)}
412
+
413
+ def _spring_layout_numpy(self, nodes, n, iterations = 50):
414
+ """
415
+ Original algorithm for single component case
416
+ """
417
+ from scipy.spatial import cKDTree
418
+ np.random.seed(42)
419
+ pos = np.random.rand(n, 2)
420
+
421
+ edges = list(self.graph.edges())
422
+ if len(edges) == 0:
423
+ return {node: pos[i] for i, node in enumerate(nodes)}
424
+
425
+ node_to_idx = {node: i for i, node in enumerate(nodes)}
426
+ edge_indices = np.array([[node_to_idx[u], node_to_idx[v]] for u, v in edges])
427
+
428
+ k = np.sqrt(1.0 / n)
429
+ t = 0.1
430
+ dt = t / (iterations + 1)
431
+ cutoff_distance = 4 * k
432
+
433
+ for iteration in range(iterations):
434
+ displacement = np.zeros_like(pos)
435
+ tree = cKDTree(pos)
436
+ pairs = tree.query_pairs(r=cutoff_distance, output_type='ndarray')
437
+
438
+ if len(pairs) > 0:
439
+ i_indices = pairs[:, 0]
440
+ j_indices = pairs[:, 1]
441
+ delta = pos[i_indices] - pos[j_indices]
442
+ distance = np.linalg.norm(delta, axis=1, keepdims=True)
443
+ distance = np.maximum(distance, 0.01)
444
+ force_magnitude = (k * k) / distance
445
+ force = delta * (force_magnitude / distance)
446
+ np.add.at(displacement, i_indices, force)
447
+ np.add.at(displacement, j_indices, -force)
448
+
449
+ if len(edge_indices) > 0:
450
+ edge_delta = pos[edge_indices[:, 0]] - pos[edge_indices[:, 1]]
451
+ edge_distance = np.sqrt((edge_delta ** 2).sum(axis=1, keepdims=True))
452
+ edge_distance = np.maximum(edge_distance, 0.01)
453
+ edge_force_magnitude = (edge_distance * edge_distance) / k
454
+ edge_force = edge_delta * (edge_force_magnitude / edge_distance)
455
+ np.add.at(displacement, edge_indices[:, 0], -edge_force)
456
+ np.add.at(displacement, edge_indices[:, 1], edge_force)
457
+
458
+ disp_magnitude = np.sqrt((displacement ** 2).sum(axis=1, keepdims=True))
459
+ disp_magnitude = np.maximum(disp_magnitude, 0.01)
460
+ displacement = displacement * np.minimum(t / disp_magnitude, 1.0)
461
+ pos += displacement
462
+ t -= dt
463
+
464
+ pos -= pos.mean(axis=0)
465
+ return {node: pos[i] for i, node in enumerate(nodes)}
466
+
467
+ def _prepare_node_spots(self, pos, colors, sizes):
468
+ """Prepare spots array and brush caches for ScatterPlotItem"""
469
+ nodes = list(self.graph.nodes())
470
+ pos_array = np.array([pos[n] for n in nodes])
471
+
472
+ # Pre-compute both normal and selected brushes in separate caches
473
+ spots = []
474
+ brush_cache = {} # {node: {'normal': brush, 'selected': brush}}
475
+
476
+ for i, node in enumerate(nodes):
477
+ hex_color = colors[i]
478
+ hex_color = hex_color.lstrip('#')
479
+ rgb = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
480
+
481
+ # Create brushes (do this in thread to save time later)
482
+ normal_brush = pg.mkBrush(*rgb, 200)
483
+ selected_brush = pg.mkBrush(255, 255, 0, 255) # Yellow for selected
484
+
485
+ # Store brushes separately
486
+ brush_cache[node] = {
487
+ 'normal': normal_brush,
488
+ 'selected': selected_brush
489
+ }
490
+
491
+ # Only include pyqtgraph-valid parameters in spot
492
+ spot = {
493
+ 'pos': pos_array[i],
494
+ 'size': sizes[i],
495
+ 'brush': normal_brush, # Start with normal brush
496
+ 'data': node
497
+ }
498
+ spots.append(spot)
499
+
500
+ return spots, brush_cache
501
+
502
+ def _prepare_label_data(self, pos):
503
+ """Prepare label positions and text"""
504
+ label_data = []
505
+ for node in self.graph.nodes():
506
+ if node in pos:
507
+ label_data.append({
508
+ 'node': node,
509
+ 'text': str(node),
510
+ 'pos': pos[node]
511
+ })
512
+ return label_data
513
+
514
+ def _prepare_edge_pens(self, edges):
515
+ """Prepare edge drawing data - batch edges by weight bins for efficient rendering"""
516
+ if not edges:
517
+ return []
518
+
519
+ # If weights are disabled, use uniform thickness
520
+ if not self.weight:
521
+ # All edges get same thickness - combine into single batch
522
+ x_coords = []
523
+ y_coords = []
524
+ for x, y, weight in edges:
525
+ x_coords.extend([x[0], x[1], np.nan])
526
+ y_coords.extend([y[0], y[1], np.nan])
527
+
528
+ return [{
529
+ 'x': np.array(x_coords),
530
+ 'y': np.array(y_coords),
531
+ 'thickness': self.edge_size
532
+ }]
533
+
534
+ # Weight-based rendering - batch by thickness
535
+ weights = [w for _, _, w in edges]
536
+ if not weights:
537
+ return []
538
+
539
+ min_weight = min(weights)
540
+ max_weight = max(weights)
541
+ weight_range = max_weight - min_weight if max_weight > min_weight else 1
542
+
543
+ # Define thickness bins (e.g., 10 discrete thickness levels)
544
+ num_bins = 10
545
+ thickness_min = self.edge_size/2
546
+ thickness_max = 3.0 * self.edge_size # Maximum thickness cap
547
+
548
+ # Batch edges by thickness bin
549
+ edge_batches = {} # {thickness: [(x_coords, y_coords), ...]}
550
+
551
+ for x, y, weight in edges:
552
+ # Normalize weight to thickness
553
+ if weight_range > 0:
554
+ normalized = (weight - min_weight) / weight_range
555
+ else:
556
+ normalized = self.edge_size/2
557
+
558
+ # Calculate thickness with cap
559
+ thickness = thickness_min + normalized * (thickness_max - thickness_min)
560
+
561
+ # Bin the thickness to reduce number of batches
562
+ thickness_bin = round(thickness * num_bins) / num_bins
563
+ thickness_bin = min(thickness_bin, thickness_max) # Apply cap
564
+
565
+ # Add to batch
566
+ if thickness_bin not in edge_batches:
567
+ edge_batches[thickness_bin] = {'x': [], 'y': []}
568
+
569
+ # Add edge coordinates with NaN separator
570
+ edge_batches[thickness_bin]['x'].extend([x[0], x[1], np.nan])
571
+ edge_batches[thickness_bin]['y'].extend([y[0], y[1], np.nan])
572
+
573
+ # Convert to list format for rendering
574
+ batch_data = []
575
+ for thickness, coords in edge_batches.items():
576
+ batch_data.append({
577
+ 'x': np.array(coords['x']),
578
+ 'y': np.array(coords['y']),
579
+ 'thickness': thickness
580
+ })
581
+
582
+ return batch_data
583
+
584
+ def _compute_geometric_layout(self):
585
+ """Compute positions from centroids"""
586
+ pos = {}
587
+ for node in self.graph.nodes():
588
+ if node in self.centroids:
589
+ z, y, x = self.centroids[node]
590
+ pos[node] = np.array([x, -y])
591
+ else:
592
+ pos[node] = np.array([0, 0])
593
+ return pos
594
+
595
+ def _compute_node_attributes(self):
596
+ """Compute node colors and sizes"""
597
+ nodes = list(self.graph.nodes())
598
+ colors = []
599
+ sizes = self._compute_all_node_sizes_vectorized(nodes)
600
+
601
+ # Determine coloring mode
602
+ if self.identities and self.identity_dict:
603
+ color_map = self._generate_community_colors(self.identity_dict)
604
+ for node in self.graph.nodes():
605
+ identity = self.identity_dict.get(node, 'Unknown')
606
+ colors.append(color_map.get(identity, '#808080'))
607
+ elif self.communities and self.community_dict:
608
+ color_map = self._generate_community_colors(self.community_dict)
609
+ for node in self.graph.nodes():
610
+ community = self.community_dict.get(node, -1)
611
+ colors.append(color_map.get(community, '#808080'))
612
+ else:
613
+ # Default coloring
614
+ for node in self.graph.nodes():
615
+ colors.append('#4A90E2')
616
+
617
+ return colors, sizes
618
+
619
+ def _compute_all_node_sizes_vectorized(self, nodes):
620
+ if not self.geometric or not self.centroids or not self.z_size:
621
+ return [self.node_size] * len(nodes)
622
+
623
+ # GLOBAL z range (matches original behavior)
624
+ all_z = np.array([
625
+ self.centroids[n][0]
626
+ for n in self.graph.nodes()
627
+ if n in self.centroids
628
+ ])
629
+
630
+ if len(all_z) == 0:
631
+ return [10] * len(nodes)
632
+
633
+ z_min, z_max = all_z.min(), all_z.max()
634
+
635
+ sizes = [10] * len(nodes)
636
+
637
+ if z_max <= z_min:
638
+ return sizes
639
+
640
+ # Collect z-values ONLY for requested nodes
641
+ z_values = []
642
+ node_indices = []
643
+
644
+ for i, node in enumerate(nodes):
645
+ if node in self.centroids:
646
+ z_values.append(self.centroids[node][0])
647
+ node_indices.append(i)
648
+
649
+ if not z_values:
650
+ return sizes
651
+
652
+ z_array = np.array(z_values)
653
+
654
+ normalized = 1 - (z_array - z_min) / (z_max - z_min)
655
+ computed_sizes = 5 + normalized * 20
656
+
657
+ for idx, node_idx in enumerate(node_indices):
658
+ sizes[node_idx] = float(computed_sizes[idx])
659
+
660
+ return sizes
661
+
662
+ def _generate_identity_colors(self):
663
+ """Generate colors for identities using the specified strategy"""
664
+ unique_categories = list(set(self.identity_dict.values()))
665
+ num_categories = len(unique_categories)
666
+
667
+ if num_categories <= 12:
668
+ base_colors = [
669
+ '#FF0000', '#0066FF', '#00CC00', '#FF8800',
670
+ '#8800FF', '#FFFF00', '#FF0088', '#00FFFF',
671
+ '#88FF00', '#FF4400', '#0088FF', '#CC00FF'
672
+ ]
673
+ colors = base_colors[:num_categories]
674
+ else:
675
+ colors = []
676
+ for i in range(num_categories):
677
+ hue = (i * 360 / num_categories) % 360
678
+ sat = 0.85 if i % 2 == 0 else 0.95
679
+ val = 0.95 if i % 3 != 0 else 0.85
680
+
681
+ rgb = colorsys.hsv_to_rgb(hue/360, sat, val)
682
+ hex_color = '#{:02x}{:02x}{:02x}'.format(
683
+ int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255)
684
+ )
685
+ colors.append(hex_color)
686
+
687
+ return dict(zip(unique_categories, colors))
688
+
689
+ def _generate_community_colors(self, my_dict):
690
+ """Generate colors for communities using the specified strategy"""
691
+ from collections import Counter
692
+
693
+ unique_communities = sorted(set(my_dict.values()))
694
+ community_sizes = Counter(my_dict.values())
695
+ sorted_communities = random.Random(42).sample(unique_communities, len(unique_communities))
696
+ colors_rgb = self._generate_distinct_colors_rgb(len(unique_communities))
697
+ color_map = {comm: colors_rgb[i] for i, comm in enumerate(sorted_communities)}
698
+ if 0 in unique_communities:
699
+ color_map[0] = "#8B4513"
700
+
701
+ return color_map
702
+
703
+ def _generate_distinct_colors_rgb(self, n_colors):
704
+ """
705
+ Generate visually distinct RGB colors using HSV color space.
706
+ Colors are generated with maximum saturation and value, varying only in hue.
707
+ """
708
+ colors = []
709
+ for i in range(n_colors):
710
+ hue = i / n_colors
711
+ rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0) # S=1, V=1 for max saturation/brightness
712
+ hex_color = '#{:02x}{:02x}{:02x}'.format(
713
+ int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)
714
+ )
715
+ colors.append(hex_color)
716
+ return colors
717
+
718
+ def _compute_edge_data(self, pos):
719
+ """Compute edge coordinates and weights"""
720
+ edges = []
721
+ for u, v, data in self.graph.edges(data=True):
722
+ if u in pos and v in pos:
723
+ weight = data.get('weight', 1.0)
724
+ x = [pos[u][0], pos[v][0]]
725
+ y = [pos[u][1], pos[v][1]]
726
+ edges.append((x, y, weight))
727
+ return edges
728
+
729
+
730
+ class NetworkGraphWidget(QWidget):
731
+ """Interactive NetworkX graph visualization widget"""
732
+
733
+ node_selected = pyqtSignal(object) # Emits list of selected nodes
734
+
735
+ def __init__(self, parent=None, weight=False, geometric=False, component = False,
736
+ centroids=None, communities=False, community_dict=None,
737
+ identities=False, identity_dict=None, labels=False, z_size = False,
738
+ shell = False, node_size = 10, black_edges = False, edge_size = 1, popout = False):
739
+ super().__init__(parent)
740
+
741
+ self.parent_window = parent
742
+ self.weight = weight
743
+ self.geometric = geometric
744
+ self.component = component
745
+ self.centroids = centroids or {}
746
+ self.communities = communities
747
+ self.community_dict = community_dict or {}
748
+ self.identities = identities
749
+ self.identity_dict = identity_dict or {}
750
+ self.labels = labels
751
+ self.z_size = z_size
752
+ self.shell = shell
753
+ self.node_size = node_size
754
+ self.black_edges = black_edges
755
+ self.edge_size = edge_size
756
+ self.popout = popout
757
+
758
+ # Graph data
759
+ self.graph = None
760
+ self.node_positions = {}
761
+ self.node_colors = []
762
+ self.node_sizes = []
763
+ self.node_items = {}
764
+ self.edge_items = []
765
+ self.label_items = {}
766
+ self.label_data = [] # Store label data for on-demand rendering
767
+ self.selected_nodes = set()
768
+ self.node_click = False
769
+ self.rendered = False
770
+
771
+ # CACHING for fast updates
772
+ self.cached_spots = [] # Full spot data with brushes
773
+ self.cached_node_to_index = {} # Node -> spot index mapping
774
+ self.cached_brushes = {} # Node -> {'normal': brush, 'selected': brush}
775
+ self.last_selected_set = set() # Track last selection state
776
+ self.cached_sizes_for_lod = [] # Base sizes for LOD scaling
777
+
778
+ # Interaction mode
779
+ self.selection_mode = True
780
+ self.zoom_mode = False
781
+
782
+ # Area selection
783
+ self.selection_rect = None
784
+ self.selection_start_pos = None
785
+ self.is_area_selecting = False
786
+ self.click_timer = None
787
+
788
+ # Middle mouse panning in selection mode
789
+ self.temp_pan_active = False
790
+ self.last_mouse_pos = None
791
+
792
+ # Wheel zoom timer for selection mode
793
+ self.wheel_timer = None
794
+ self.was_in_selection_before_wheel = False
795
+
796
+ # Thread for loading
797
+ self.load_thread = None
798
+
799
+ # Setup UI
800
+ self._setup_ui()
801
+
802
+ def _setup_ui(self):
803
+ """Setup the user interface"""
804
+ layout = QHBoxLayout() # Changed from QVBoxLayout to accommodate legend
805
+ layout.setContentsMargins(0, 0, 0, 0)
806
+ layout.setSpacing(2)
807
+
808
+ # Left side: graph container
809
+ graph_container = QWidget()
810
+ graph_layout = QVBoxLayout()
811
+ graph_layout.setContentsMargins(0, 0, 0, 0)
812
+ graph_layout.setSpacing(2)
813
+
814
+ # Create graphics layout widget
815
+ self.graphics_widget = pg.GraphicsLayoutWidget()
816
+ self.graphics_widget.setBackground('w')
817
+
818
+ # Create plot
819
+ self.plot = self.graphics_widget.addPlot()
820
+ self.plot.setAspectLocked(True)
821
+ self.plot.hideAxis('left')
822
+ self.plot.hideAxis('bottom')
823
+ self.plot.showGrid(x=False, y=False)
824
+ # Show loading indicator
825
+ self.loading_text = pg.TextItem(
826
+ text="No network detected",
827
+ color=(100, 100, 100),
828
+ anchor=(0.5, 0.5)
829
+ )
830
+ self.loading_text.setPos(0, 0) # Center of view
831
+ self.plot.addItem(self.loading_text)
832
+
833
+ # Enable mouse tracking for area selection
834
+ self.plot.scene().sigMouseMoved.connect(self._on_mouse_moved)
835
+
836
+ # Disable default mouse interaction - will enable only in pan mode
837
+ self.plot.setMouseEnabled(x=False, y=False)
838
+ self.plot.vb.setMenuEnabled(False)
839
+ self.plot.vb.setMouseMode(pg.ViewBox.PanMode)
840
+
841
+ # Create scatter plot for nodes
842
+ self.scatter = ScatterPlotItem(size=10, pen=pg.mkPen(None),
843
+ brush=pg.mkBrush(74, 144, 226, 200))
844
+ self.plot.addItem(self.scatter)
845
+
846
+ # Connect click events
847
+ self.scatter.sigClicked.connect(self._on_node_clicked)
848
+ self.plot.scene().sigMouseClicked.connect(self._on_plot_clicked)
849
+
850
+ # Connect view change for level-of-detail updates
851
+ self.plot.sigRangeChanged.connect(self._on_view_changed)
852
+
853
+ # Level of detail parameters
854
+ self.base_node_sizes = []
855
+ self.current_zoom_factor = 1.0
856
+
857
+ graph_layout.addWidget(self.graphics_widget, stretch=1) # Keep this one
858
+
859
+ # Create control panel
860
+ control_panel = self._create_control_panel()
861
+ graph_layout.addWidget(control_panel)
862
+
863
+ graph_container.setLayout(graph_layout)
864
+ layout.addWidget(graph_container, stretch=1)
865
+
866
+ # Right side: legend (placeholder, will be populated when graph loads)
867
+ self.legend_container = QWidget()
868
+ self.legend_layout = QVBoxLayout()
869
+ self.legend_layout.setContentsMargins(0, 0, 0, 0)
870
+ self.legend_container.setLayout(self.legend_layout)
871
+ self.legend_container.setMaximumWidth(0) # Hidden initially
872
+ layout.addWidget(self.legend_container)
873
+
874
+ self.setLayout(layout)
875
+
876
+ # Set size policy
877
+ self.setSizePolicy(QSizePolicy.Policy.Expanding,
878
+ QSizePolicy.Policy.Expanding)
879
+
880
+ # Install event filter for custom mouse handling
881
+ self.graphics_widget.viewport().installEventFilter(self)
882
+ self.plot.scene().installEventFilter(self)
883
+
884
+ def _create_identity_legend(self):
885
+ """Create a legend panel for node identities"""
886
+
887
+ def _generate_distinct_colors_rgb(n_colors: int):
888
+ """
889
+ Generate visually distinct RGB colors using HSV color space.
890
+ Colors are generated with maximum saturation and value, varying only in hue.
891
+ """
892
+ colors = []
893
+ for i in range(n_colors):
894
+ hue = i / n_colors
895
+ rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0) # S=1, V=1 for max saturation/brightness
896
+ hex_color = '#{:02x}{:02x}{:02x}'.format(
897
+ int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)
898
+ )
899
+ colors.append(hex_color)
900
+ return colors
901
+
902
+ if self.identities:
903
+ from collections import Counter
904
+
905
+ unique_identities = sorted(set(self.identity_dict.values()))
906
+ community_sizes = Counter(self.identity_dict.values())
907
+ sorted_communities = random.Random(42).sample(unique_identities, len(unique_identities))
908
+ colors_rgb = _generate_distinct_colors_rgb(len(unique_identities))
909
+ color_map = {comm: colors_rgb[i] for i, comm in enumerate(sorted_communities)}
910
+ if 0 in unique_identities:
911
+ color_map[0] = "#8B4513"
912
+ elif self.communities:
913
+ from collections import Counter
914
+
915
+ unique_identities = sorted(set(self.community_dict.values()))
916
+ community_sizes = Counter(self.community_dict.values())
917
+ sorted_communities = random.Random(42).sample(unique_identities, len(unique_identities))
918
+ colors_rgb = _generate_distinct_colors_rgb(len(unique_identities))
919
+ color_map = {comm: colors_rgb[i] for i, comm in enumerate(sorted_communities)}
920
+ if 0 in unique_identities:
921
+ color_map[0] = "#8B4513"
922
+
923
+ # Create legend widget
924
+ legend_widget = QWidget()
925
+ legend_layout = QVBoxLayout()
926
+ legend_layout.setContentsMargins(5, 5, 5, 5)
927
+ legend_layout.setSpacing(2)
928
+
929
+ # Add title
930
+ if self.identities:
931
+ title = QLabel("Node Identities")
932
+ elif self.communities:
933
+ title = QLabel("Node Community/Neighborhood")
934
+ title.setStyleSheet("font-weight: bold; font-size: 11pt; padding: 3px;")
935
+ legend_layout.addWidget(title)
936
+
937
+ # Create scrollable area for legend items
938
+ scroll = QScrollArea()
939
+ scroll.setWidgetResizable(True)
940
+ scroll.setMaximumWidth(200)
941
+ scroll.setMinimumWidth(150)
942
+ scroll.setFrameShape(QFrame.Shape.StyledPanel)
943
+
944
+ # Container for legend items
945
+ items_widget = QWidget()
946
+ items_layout = QVBoxLayout()
947
+ items_layout.setContentsMargins(2, 2, 2, 2)
948
+ items_layout.setSpacing(3)
949
+
950
+ # Add each identity with colored box
951
+ for identity in unique_identities:
952
+ item_widget = QWidget()
953
+ item_layout = QHBoxLayout()
954
+ item_layout.setContentsMargins(0, 0, 0, 0)
955
+ item_layout.setSpacing(5)
956
+
957
+ # Color box
958
+ color_box = QLabel()
959
+ color_box.setFixedSize(16, 16)
960
+ color_box.setStyleSheet(f"background-color: {color_map[identity]}; border: 1px solid #888;")
961
+
962
+ # Label
963
+ label = QLabel(str(identity))
964
+ label.setStyleSheet("font-size: 9pt;")
965
+
966
+ item_layout.addWidget(color_box)
967
+ item_layout.addWidget(label)
968
+ item_layout.addStretch()
969
+
970
+ item_widget.setLayout(item_layout)
971
+ items_layout.addWidget(item_widget)
972
+
973
+ if '#808080' in self.node_colors:
974
+ item_widget = QWidget()
975
+ item_layout = QHBoxLayout()
976
+ item_layout.setContentsMargins(0, 0, 0, 0)
977
+ item_layout.setSpacing(5)
978
+
979
+ # Color box
980
+ color_box = QLabel()
981
+ color_box.setFixedSize(16, 16)
982
+ color_box.setStyleSheet(f"background-color: #808080; border: 1px solid #888;")
983
+
984
+ # Label
985
+ label = QLabel('Unassigned')
986
+ label.setStyleSheet("font-size: 9pt;")
987
+
988
+ item_layout.addWidget(color_box)
989
+ item_layout.addWidget(label)
990
+ item_layout.addStretch()
991
+
992
+ item_widget.setLayout(item_layout)
993
+ items_layout.addWidget(item_widget)
994
+
995
+ items_layout.addStretch()
996
+ items_widget.setLayout(items_layout)
997
+ scroll.setWidget(items_widget)
998
+
999
+ legend_layout.addWidget(scroll)
1000
+ legend_widget.setLayout(legend_layout)
1001
+
1002
+ return legend_widget
1003
+
1004
+ def _create_control_panel(self):
1005
+ """Create the control panel with emoji buttons"""
1006
+ panel = QWidget()
1007
+ panel_layout = QHBoxLayout()
1008
+ panel_layout.setContentsMargins(2, 2, 2, 2)
1009
+ panel_layout.setSpacing(2)
1010
+
1011
+ # Create buttons with emojis
1012
+ self.select_btn = QPushButton("🖱️")
1013
+ self.select_btn.setToolTip("Selection Tool")
1014
+ self.select_btn.setCheckable(True)
1015
+ self.select_btn.setChecked(True)
1016
+ self.select_btn.setMaximumSize(32, 32)
1017
+ self.select_btn.clicked.connect(self._toggle_selection_mode)
1018
+
1019
+ self.pan_btn = QPushButton("✋")
1020
+ self.pan_btn.setToolTip("Pan Tool")
1021
+ self.pan_btn.setCheckable(True)
1022
+ self.pan_btn.setChecked(False)
1023
+ self.pan_btn.setMaximumSize(32, 32)
1024
+ self.pan_btn.clicked.connect(self._toggle_pan_mode)
1025
+
1026
+ self.zoom_btn = QPushButton("🔍")
1027
+ self.zoom_btn.setToolTip("Zoom Tool (Left Click: Zoom Out, Right Click: Zoom In)")
1028
+ self.zoom_btn.setCheckable(True)
1029
+ self.zoom_btn.setMaximumSize(32, 32)
1030
+ self.zoom_btn.clicked.connect(self._toggle_zoom_mode)
1031
+
1032
+ self.home_btn = QPushButton("🏠")
1033
+ self.home_btn.setToolTip("Reset View")
1034
+ self.home_btn.setMaximumSize(32, 32)
1035
+ self.home_btn.clicked.connect(self._reset_view)
1036
+
1037
+ self.refresh_btn = QPushButton("🔄")
1038
+ self.refresh_btn.setToolTip("Refresh Graph")
1039
+ self.refresh_btn.setMaximumSize(32, 32)
1040
+ self.refresh_btn.clicked.connect(self.load_graph)
1041
+
1042
+ self.settings_btn = QPushButton("⚙")
1043
+ self.settings_btn.setToolTip("Render Settings")
1044
+ self.settings_btn.setMaximumSize(32, 32)
1045
+ self.settings_btn.clicked.connect(self.settings)
1046
+
1047
+ self.clear_btn = QPushButton("🗑️")
1048
+ self.clear_btn.setToolTip("Clear Graph")
1049
+ self.clear_btn.setMaximumSize(32, 32)
1050
+ self.clear_btn.clicked.connect(self._clear_graph)
1051
+
1052
+
1053
+ # Add buttons to layout
1054
+ panel_layout.addWidget(self.select_btn)
1055
+ panel_layout.addWidget(self.pan_btn)
1056
+ panel_layout.addWidget(self.zoom_btn)
1057
+ panel_layout.addWidget(self.home_btn)
1058
+ panel_layout.addWidget(self.refresh_btn)
1059
+ panel_layout.addWidget(self.settings_btn)
1060
+ panel_layout.addWidget(self.clear_btn)
1061
+
1062
+ if self.popout:
1063
+ self.popout_btn = QPushButton("⤴")
1064
+ self.popout_btn.setToolTip("Full Screen")
1065
+ self.popout_btn.setMaximumSize(32, 32)
1066
+ self.popout_btn.clicked.connect(self._popout_graph)
1067
+ panel_layout.addWidget(self.popout_btn)
1068
+
1069
+ panel_layout.addStretch()
1070
+
1071
+ panel.setLayout(panel_layout)
1072
+ panel.setMaximumHeight(40)
1073
+
1074
+ return panel
1075
+
1076
+ def settings(self):
1077
+
1078
+ self.parent_window.show_netshow_dialog(called = self)
1079
+
1080
+ def set_graph(self, graph):
1081
+ """Set the NetworkX graph to visualize"""
1082
+ self.graph = graph
1083
+
1084
+ if hasattr(self, 'loading_text') and self.loading_text is not None:
1085
+ self.plot.removeItem(self.loading_text)
1086
+ self.loading_text = None
1087
+
1088
+ if self.graph is not None and not self.rendered:
1089
+ if hasattr(self, 'loading_text') and self.loading_text is not None:
1090
+ self.plot.removeItem(self.loading_text)
1091
+ self.loading_text = None
1092
+ # Show loading indicator
1093
+ self.loading_text = pg.TextItem(
1094
+ text="Press 🔄 to load your graph",
1095
+ color=(100, 100, 100),
1096
+ anchor=(0.5, 0.5)
1097
+ )
1098
+ self.loading_text.setPos(0, 0) # Center of view
1099
+ self.plot.addItem(self.loading_text)
1100
+ elif not self.rendered:
1101
+ self.loading_text = pg.TextItem(
1102
+ text="No network detected",
1103
+ color=(100, 100, 100),
1104
+ anchor=(0.5, 0.5)
1105
+ )
1106
+ self.loading_text.setPos(0, 0) # Center of view
1107
+ self.plot.addItem(self.loading_text)
1108
+
1109
+ def load_graph(self):
1110
+ """Load and render the graph (in separate thread)"""
1111
+
1112
+ # Clear existing visualization
1113
+ self._clear_graph()
1114
+ self.get_properties()
1115
+
1116
+ if hasattr(self, 'loading_text') and self.loading_text is not None:
1117
+ self.plot.removeItem(self.loading_text)
1118
+ self.loading_text = None
1119
+
1120
+ if self.graph is None or len(self.graph.nodes()) == 0:
1121
+ # Show loading indicator
1122
+ self.loading_text = pg.TextItem(
1123
+ text="No network detected",
1124
+ color=(100, 100, 100),
1125
+ anchor=(0.5, 0.5)
1126
+ )
1127
+ self.loading_text.setPos(0, 0) # Center of view
1128
+ self.plot.addItem(self.loading_text)
1129
+ return
1130
+
1131
+ if hasattr(self, 'loading_text') and self.loading_text is not None:
1132
+ self.plot.removeItem(self.loading_text)
1133
+ self.loading_text = None
1134
+
1135
+ # Show loading indicator
1136
+ self.loading_text = pg.TextItem(
1137
+ text="Loading graph...",
1138
+ color=(100, 100, 100),
1139
+ anchor=(0.5, 0.5)
1140
+ )
1141
+ self.loading_text.setPos(0, 0) # Center of view
1142
+ self.plot.addItem(self.loading_text)
1143
+
1144
+ # Start loading in thread
1145
+ self.load_thread = GraphLoadThread(
1146
+ self.graph, self.geometric, self.component, self.centroids,
1147
+ self.communities, self.community_dict,
1148
+ self.identities, self.identity_dict, self.weight, self.z_size,
1149
+ self.shell, self.node_size, self.edge_size
1150
+ )
1151
+ self.load_thread.finished.connect(self._on_graph_loaded)
1152
+ self.load_thread.start()
1153
+
1154
+ @pyqtSlot(object)
1155
+ def _on_graph_loaded(self, result):
1156
+ """Handle loaded graph data from thread"""
1157
+ # Remove loading indicator
1158
+ if hasattr(self, 'loading_text') and self.loading_text is not None:
1159
+ self.plot.removeItem(self.loading_text)
1160
+ self.loading_text = None
1161
+
1162
+ self.node_positions = result['pos']
1163
+ self.node_colors = result['colors']
1164
+ self.node_sizes = result['sizes']
1165
+ self.base_node_sizes = result['sizes'].copy()
1166
+
1167
+ # Cache the prepared data for fast updates
1168
+ self.cached_spots = result['node_spots']
1169
+ self.cached_brushes = result['brush_cache']
1170
+ self.cached_sizes_for_lod = result['sizes'].copy()
1171
+
1172
+ # Build node-to-index mapping
1173
+ self.cached_node_to_index = {spot['data']: i
1174
+ for i, spot in enumerate(self.cached_spots)}
1175
+
1176
+ # Fast render - data is already prepared
1177
+ self._render_prepared_data(result)
1178
+
1179
+ # Reset view to show entire graph
1180
+ self.rendered = True
1181
+ # Block signals during reset to avoid triggering _on_view_changed
1182
+ self.plot.blockSignals(True)
1183
+ self._reset_view()
1184
+ self.plot.blockSignals(False)
1185
+ # Add legend if identities are enabled
1186
+ if (self.identities and self.identity_dict) or (self.communities and self.community_dict):
1187
+ # Clear old legend
1188
+ for i in reversed(range(self.legend_layout.count())):
1189
+ self.legend_layout.itemAt(i).widget().setParent(None)
1190
+
1191
+ # Create and add new legend
1192
+ legend = self._create_identity_legend()
1193
+ if legend:
1194
+ self.legend_layout.addWidget(legend)
1195
+ self.legend_container.setMaximumWidth(200)
1196
+ else:
1197
+ self.legend_container.setMaximumWidth(0) # Hide if no identities
1198
+ if len(self.parent_window.clicked_values['nodes']) > 0:
1199
+ self.select_nodes(self.parent_window.clicked_values['nodes'])
1200
+
1201
+
1202
+
1203
+ def _render_prepared_data(self, result):
1204
+ """Render pre-computed data (minimal main thread work)"""
1205
+ # Clear old items
1206
+ self.scatter.clear()
1207
+ for item in self.edge_items:
1208
+ self.plot.removeItem(item)
1209
+ self.edge_items.clear()
1210
+ for label_item in self.label_items.values():
1211
+ self.plot.removeItem(label_item)
1212
+ self.label_items.clear()
1213
+
1214
+ if self.black_edges:
1215
+ edge_color = (0, 0, 0)
1216
+ else:
1217
+ edge_color = (150, 150, 150, 100)
1218
+
1219
+ # Render edges - batched by weight for efficiency
1220
+ edge_batches = result['edge_pens']
1221
+ if edge_batches:
1222
+ for batch in edge_batches:
1223
+ edge_line = PlotCurveItem(
1224
+ x=batch['x'],
1225
+ y=batch['y'],
1226
+ pen=pg.mkPen(color=edge_color, width=batch['thickness']),
1227
+ connect='finite' # Break lines at NaN
1228
+ )
1229
+ self.plot.addItem(edge_line)
1230
+ self.edge_items.append(edge_line)
1231
+
1232
+ # Render nodes - use cached spots directly
1233
+ self.scatter.setData(spots=self.cached_spots)
1234
+ self.scatter.setZValue(10)
1235
+
1236
+ # Build node items mapping
1237
+ nodes = list(self.graph.nodes())
1238
+ self.node_items = {node: i for i, node in enumerate(nodes)}
1239
+
1240
+ # Store label data for later rendering
1241
+ self.label_data = result['label_data']
1242
+
1243
+ # Only render labels immediately if graph is small (< 100 nodes)
1244
+ if self.labels and len(self.label_data) < 100:
1245
+ self._update_labels_in_viewport(len(self.label_data))
1246
+
1247
+ def _render_nodes(self):
1248
+ """OPTIMIZED: Only update brushes for nodes that changed selection state"""
1249
+
1250
+ if not self.cached_spots or not self.cached_brushes:
1251
+ return
1252
+
1253
+ # Find nodes whose selection state changed
1254
+ newly_selected = self.selected_nodes - self.last_selected_set
1255
+ newly_deselected = self.last_selected_set - self.selected_nodes
1256
+
1257
+ # If nothing changed, skip update
1258
+ if not newly_selected and not newly_deselected:
1259
+ return
1260
+
1261
+ # Update only changed nodes using cached brushes
1262
+ for node in newly_selected:
1263
+ if node in self.cached_node_to_index:
1264
+ idx = self.cached_node_to_index[node]
1265
+ self.cached_spots[idx]['brush'] = self.cached_brushes[node]['selected']
1266
+
1267
+ for node in newly_deselected:
1268
+ if node in self.cached_node_to_index:
1269
+ idx = self.cached_node_to_index[node]
1270
+ self.cached_spots[idx]['brush'] = self.cached_brushes[node]['normal']
1271
+
1272
+ # Update the scatter plot with modified spots
1273
+ self.scatter.setData(spots=self.cached_spots)
1274
+
1275
+ # Update last selection state
1276
+ self.last_selected_set = self.selected_nodes.copy()
1277
+
1278
+ def _render_labels_batch(self, label_data_subset):
1279
+ """Render labels in batches (but still slow for large graphs)"""
1280
+ # Clear existing labels first
1281
+ for label_item in self.label_items.values():
1282
+ self.plot.removeItem(label_item)
1283
+ self.label_items.clear()
1284
+
1285
+ if not label_data_subset:
1286
+ return
1287
+
1288
+ # Batch size for yielding control back to event loop
1289
+ batch_size = 50
1290
+
1291
+ for i, label_info in enumerate(label_data_subset):
1292
+ text_item = pg.TextItem(
1293
+ text=label_info['text'],
1294
+ color=(0, 0, 0),
1295
+ anchor=(0.5, 0.5)
1296
+ )
1297
+ text_item.setPos(label_info['pos'][0], label_info['pos'][1])
1298
+ text_item.setZValue(20)
1299
+ self.plot.addItem(text_item)
1300
+ self.label_items[label_info['node']] = text_item
1301
+
1302
+ # Yield control periodically to keep UI responsive
1303
+ if (i + 1) % batch_size == 0:
1304
+ QApplication.processEvents()
1305
+
1306
+ def _update_labels_for_zoom(self):
1307
+ """Show/hide labels based on zoom level and graph size with viewport awareness"""
1308
+ if not self.labels or not self.label_data:
1309
+ return
1310
+
1311
+ num_nodes = len(self.label_data)
1312
+
1313
+ # Determine zoom threshold based on graph size
1314
+ if num_nodes < 100:
1315
+ zoom_threshold = 0 # Always show labels
1316
+ elif num_nodes < 500:
1317
+ zoom_threshold = 1.5
1318
+ else:
1319
+ zoom_threshold = 3.0
1320
+
1321
+ # Check if we're above the zoom threshold
1322
+ should_show_labels = self.current_zoom_factor > zoom_threshold
1323
+
1324
+ if should_show_labels:
1325
+ # Update labels based on current viewport
1326
+ self._update_labels_in_viewport(num_nodes)
1327
+ else:
1328
+ # Clear all labels when zoomed out below threshold
1329
+ if self.label_items:
1330
+ for label_item in self.label_items.values():
1331
+ self.plot.removeItem(label_item)
1332
+ self.label_items.clear()
1333
+
1334
+ def _update_labels_in_viewport(self, num_nodes):
1335
+ """Update labels to show only those in current viewport"""
1336
+ # Get current view range
1337
+ view_range = self.plot.viewRange()
1338
+ x_min, x_max = view_range[0]
1339
+ y_min, y_max = view_range[1]
1340
+
1341
+ # Find which labels should be visible
1342
+ visible_node_set = set()
1343
+ labels_to_render = []
1344
+
1345
+ for label_info in self.label_data:
1346
+ if x_min <= label_info['pos'][0] <= x_max and y_min <= label_info['pos'][1] <= y_max:
1347
+ visible_node_set.add(label_info['node'])
1348
+ labels_to_render.append(label_info)
1349
+
1350
+ # For small graphs or when not many visible labels, render all in viewport
1351
+ max_visible_labels = 200 if num_nodes >= 500 else 1000
1352
+
1353
+ if len(labels_to_render) > max_visible_labels:
1354
+ # Too many labels to render - skip
1355
+ if self.label_items:
1356
+ for label_item in self.label_items.values():
1357
+ self.plot.removeItem(label_item)
1358
+ self.label_items.clear()
1359
+ return
1360
+
1361
+ # Get currently rendered nodes
1362
+ current_node_set = set(self.label_items.keys())
1363
+
1364
+ # Remove labels for nodes no longer in viewport
1365
+ nodes_to_remove = current_node_set - visible_node_set
1366
+ for node in nodes_to_remove:
1367
+ if node in self.label_items:
1368
+ self.plot.removeItem(self.label_items[node])
1369
+ del self.label_items[node]
1370
+
1371
+ # Add labels for new nodes in viewport
1372
+ nodes_to_add = visible_node_set - current_node_set
1373
+ for label_info in labels_to_render:
1374
+ node = label_info['node']
1375
+ if node in nodes_to_add:
1376
+ text_item = pg.TextItem(
1377
+ text=label_info['text'],
1378
+ color=(0, 0, 0),
1379
+ anchor=(0.5, 0.5)
1380
+ )
1381
+ text_item.setPos(label_info['pos'][0], label_info['pos'][1])
1382
+ text_item.setZValue(20)
1383
+ self.plot.addItem(text_item)
1384
+ self.label_items[node] = text_item
1385
+
1386
+ def _render_labels(self):
1387
+ """Render node labels if enabled (legacy method - kept for compatibility)"""
1388
+ # Use the smart rendering instead
1389
+ if self.label_data:
1390
+ self._update_labels_for_zoom()
1391
+
1392
+ def _hex_to_rgb(self, hex_color):
1393
+ """Convert hex color to RGB tuple"""
1394
+ hex_color = hex_color.lstrip('#')
1395
+ return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
1396
+
1397
+ def _on_plot_clicked(self, ev):
1398
+ """Handle clicks on the plot background"""
1399
+ if not self.selection_mode or self.node_click or not self.popout:
1400
+ self.node_click = False
1401
+ return
1402
+
1403
+ # Only handle left button clicks
1404
+ if ev.button() != Qt.MouseButton.LeftButton:
1405
+ return
1406
+
1407
+ # Get the position in scene coordinates
1408
+ scene_pos = ev.scenePos()
1409
+
1410
+
1411
+ # Click was on background
1412
+ modifiers = ev.modifiers()
1413
+ ctrl_pressed = modifiers & Qt.KeyboardModifier.ControlModifier
1414
+
1415
+ if not ctrl_pressed:
1416
+ # Deselect all nodes
1417
+ self.selected_nodes.clear()
1418
+ self._render_nodes()
1419
+ self.push_selection()
1420
+ self.node_selected.emit([])
1421
+ # Ctrl+Click on background does nothing (as requested)
1422
+
1423
+ def _on_mouse_moved(self, pos):
1424
+ """Handle mouse movement for area selection"""
1425
+ if self.is_area_selecting and self.selection_rect:
1426
+ # Update selection rectangle
1427
+ view_pos = self.plot.vb.mapSceneToView(pos)
1428
+ start_pos = self.selection_start_pos
1429
+
1430
+ # Update rectangle size
1431
+ width = view_pos.x() - start_pos.x()
1432
+ height = view_pos.y() - start_pos.y()
1433
+
1434
+ self.selection_rect.setSize([width, height])
1435
+
1436
+ def _start_area_selection(self, scene_pos):
1437
+ """Start area selection with rectangle"""
1438
+ self.is_area_selecting = True
1439
+ self.node_click = False
1440
+
1441
+ # Convert to view coordinates
1442
+ view_pos = self.plot.vb.mapSceneToView(scene_pos)
1443
+ self.selection_start_pos = view_pos
1444
+
1445
+ # Create selection rectangle
1446
+ if self.selection_rect:
1447
+ self.plot.removeItem(self.selection_rect)
1448
+
1449
+ # Create ROI for selection area
1450
+ self.selection_rect = pg.ROI(
1451
+ [view_pos.x(), view_pos.y()],
1452
+ [0, 0],
1453
+ pen=pg.mkPen('b', width=2, style=Qt.PenStyle.DashLine),
1454
+ movable=False,
1455
+ resizable=False
1456
+ )
1457
+ self.selection_rect.setAcceptedMouseButtons(Qt.MouseButton.NoButton)
1458
+ self.plot.addItem(self.selection_rect)
1459
+
1460
+ def _finish_area_selection(self, ev):
1461
+ """Finish area selection and select nodes in rectangle"""
1462
+ if not self.is_area_selecting or not self.selection_rect:
1463
+ return
1464
+
1465
+ # Get rectangle bounds
1466
+ rect_pos = self.selection_rect.pos()
1467
+ rect_size = self.selection_rect.size()
1468
+
1469
+ x_min = rect_pos[0]
1470
+ y_min = rect_pos[1]
1471
+ x_max = x_min + rect_size[0]
1472
+ y_max = y_min + rect_size[1]
1473
+
1474
+ # Normalize bounds (in case user dragged backwards)
1475
+ if x_min > x_max:
1476
+ x_min, x_max = x_max, x_min
1477
+ if y_min > y_max:
1478
+ y_min, y_max = y_max, y_min
1479
+
1480
+ # Remove selection rectangle
1481
+ if self.selection_rect:
1482
+ self.plot.removeItem(self.selection_rect)
1483
+ self.selection_rect = None
1484
+
1485
+ self.is_area_selecting = False
1486
+ self.selection_start_pos = None
1487
+
1488
+ # ZOOM MODE: Zoom to rectangle
1489
+ if self.zoom_mode:
1490
+ # Add small padding (5%)
1491
+ x_range = x_max - x_min
1492
+ y_range = y_max - y_min
1493
+ padding = 0.05
1494
+
1495
+ self.plot.setXRange(x_min - padding * x_range,
1496
+ x_max + padding * x_range, padding=0)
1497
+ self.plot.setYRange(y_min - padding * y_range,
1498
+ y_max + padding * y_range, padding=0)
1499
+ return
1500
+
1501
+ # SELECTION MODE: Select nodes in rectangle
1502
+ if self.selection_mode:
1503
+ # Find nodes in rectangle
1504
+ selected_in_rect = []
1505
+ for node, pos in self.node_positions.items():
1506
+ if x_min <= pos[0] <= x_max and y_min <= pos[1] <= y_max:
1507
+ selected_in_rect.append(node)
1508
+
1509
+ # Add to selection
1510
+ modifiers = ev.modifiers()
1511
+ ctrl_pressed = modifiers & Qt.KeyboardModifier.ControlModifier
1512
+
1513
+ if not ctrl_pressed:
1514
+ self.selected_nodes = set()
1515
+ self.selected_nodes.update(selected_in_rect)
1516
+ self.push_selection()
1517
+
1518
+ # Update visual representation
1519
+ self._render_nodes()
1520
+
1521
+ # Emit signal
1522
+ self.node_selected.emit(list(self.selected_nodes))
1523
+
1524
+ def _toggle_selection_mode(self):
1525
+ """Toggle selection mode"""
1526
+ self.selection_mode = self.select_btn.isChecked()
1527
+
1528
+ if self.selection_mode:
1529
+ self.pan_btn.setChecked(False)
1530
+ self.zoom_btn.setChecked(False)
1531
+ self.zoom_mode = False
1532
+ # Disable panning, but allow wheel events to be handled manually
1533
+ self.plot.setCursor(Qt.CursorShape.ArrowCursor)
1534
+ self.plot.vb.setMenuEnabled(False)
1535
+ self.plot.setMouseEnabled(x=False, y=False)
1536
+ else:
1537
+ # If nothing else is checked, check pan by default
1538
+ if not self.pan_btn.isChecked() and not self.zoom_btn.isChecked():
1539
+ self.pan_btn.click()
1540
+
1541
+ def _toggle_pan_mode(self):
1542
+ """Toggle pan mode"""
1543
+ if self.pan_btn.isChecked():
1544
+ self.select_btn.setChecked(False)
1545
+ self.zoom_btn.setChecked(False)
1546
+ self.selection_mode = False
1547
+ self.zoom_mode = False
1548
+ # Enable panning
1549
+ self.plot.vb.setMenuEnabled(True)
1550
+ self.plot.setCursor(Qt.CursorShape.OpenHandCursor)
1551
+ self.plot.setMouseEnabled(x=True, y=True)
1552
+ else:
1553
+ # Disable panning
1554
+ if not self.select_btn.isChecked() and not self.zoom_btn.isChecked():
1555
+ self.select_btn.click()
1556
+
1557
+ def _toggle_zoom_mode(self):
1558
+ """Toggle zoom mode"""
1559
+ self.zoom_mode = self.zoom_btn.isChecked()
1560
+
1561
+ if self.zoom_mode:
1562
+ self.select_btn.setChecked(False)
1563
+ self.pan_btn.setChecked(False)
1564
+ self.selection_mode = False
1565
+ # Disable default panning for zoom mode
1566
+ self.plot.setCursor(Qt.CursorShape.CrossCursor)
1567
+ self.plot.vb.setMenuEnabled(False)
1568
+ self.plot.setMouseEnabled(x=False, y=False)
1569
+ else:
1570
+ # If nothing else is checked, check pan by default
1571
+ if not self.pan_btn.isChecked() and not self.select_btn.isChecked():
1572
+ self.select_btn.click()
1573
+
1574
+ def eventFilter(self, obj, event):
1575
+ """Filter events for custom mouse handling"""
1576
+ from PyQt6.QtCore import QEvent
1577
+ from PyQt6.QtGui import QMouseEvent
1578
+
1579
+ # Only handle events for the graphics scene
1580
+ if obj != self.plot.scene():
1581
+ return super().eventFilter(obj, event)
1582
+
1583
+ if event.type() == QEvent.Type.GraphicsSceneMousePress:
1584
+ # Handle middle mouse button for temporary panning in selection mode
1585
+ if event.button() == Qt.MouseButton.MiddleButton:
1586
+ if self.selection_mode or self.zoom_mode:
1587
+ self._start_temp_pan()
1588
+ return False # Let the event propagate for panning
1589
+
1590
+ # SELECTION MODE: Handle left button for area selection
1591
+ elif event.button() == Qt.MouseButton.LeftButton and (self.selection_mode or self.zoom_mode):
1592
+ # Store position and start timer for long press detection
1593
+ self.last_mouse_pos = event.scenePos()
1594
+ if not self.click_timer:
1595
+ self.click_timer = QTimer()
1596
+ self.click_timer.setSingleShot(True)
1597
+ self.click_timer.timeout.connect(self._on_long_press)
1598
+ self.click_timer.start(200) # 200ms threshold for area selection
1599
+
1600
+ elif event.type() == QEvent.Type.GraphicsSceneMouseMove:
1601
+ # Check if we should start area selection
1602
+ if (self.selection_mode or self.zoom_mode) and self.click_timer and self.click_timer.isActive():
1603
+ if self.last_mouse_pos:
1604
+ # Check if mouse moved significantly
1605
+ current_pos = event.scenePos()
1606
+ delta_x = abs(current_pos.x() - self.last_mouse_pos.x())
1607
+ delta_y = abs(current_pos.y() - self.last_mouse_pos.y())
1608
+
1609
+ if delta_x > 10 or delta_y > 10: # Moved significantly
1610
+ self.click_timer.stop()
1611
+ if not self.is_area_selecting:
1612
+ #if not self.was_in_selection_before_wheel or self.zoom_mode:
1613
+ self._start_area_selection(self.last_mouse_pos)
1614
+
1615
+
1616
+ elif event.type() == QEvent.Type.GraphicsSceneMouseRelease:
1617
+ # Handle middle mouse release
1618
+ if event.button() == Qt.MouseButton.MiddleButton:
1619
+ if self.temp_pan_active:
1620
+ self._end_temp_pan()
1621
+ return False # Let event propagate
1622
+
1623
+ # Handle left button release in selection mode
1624
+ elif event.button() == Qt.MouseButton.LeftButton and (self.selection_mode or self.zoom_mode):
1625
+ if self.click_timer and self.click_timer.isActive():
1626
+ self.click_timer.stop()
1627
+
1628
+ if self.is_area_selecting:
1629
+ self._finish_area_selection(event)
1630
+ return True # Consume event
1631
+ elif event.button() == Qt.MouseButton.RightButton and self.selection_mode:
1632
+ mouse_point = self.plot.getViewBox().mapSceneToView(event.scenePos())
1633
+ x, y = mouse_point.x(), mouse_point.y()
1634
+ self.create_context_menu(event)
1635
+
1636
+ # ZOOM MODE: Handle left/right click for zoom
1637
+ if self.zoom_mode:
1638
+ if event.button() == Qt.MouseButton.LeftButton:
1639
+ # Zoom in
1640
+ self._zoom_at_point(event.scenePos(), 2)
1641
+ return True
1642
+ elif event.button() == Qt.MouseButton.RightButton:
1643
+ # Zoom out
1644
+ self._zoom_at_point(event.scenePos(), 0.5)
1645
+ return True
1646
+
1647
+ elif event.type() == QEvent.Type.GraphicsSceneWheel:
1648
+ # Handle wheel events in selection mode
1649
+ if self.selection_mode or self.zoom_mode:
1650
+ self._handle_wheel_in_selection(event)
1651
+ return False # Let event propagate for actual zooming
1652
+
1653
+ return super().eventFilter(obj, event)
1654
+
1655
+ def _on_long_press(self):
1656
+ """Handle long press for area selection"""
1657
+ if (self.selection_mode or self.zoom_mode) and self.last_mouse_pos:
1658
+ # Start area selection at stored mouse position
1659
+ self._start_area_selection(self.last_mouse_pos)
1660
+
1661
+ def _start_temp_pan(self):
1662
+ """Start temporary panning mode with middle mouse"""
1663
+ self.temp_pan_active = True
1664
+ # Temporarily enable panning
1665
+ self.plot.setMouseEnabled(x=True, y=True)
1666
+
1667
+ def _end_temp_pan(self):
1668
+ """End temporary panning mode"""
1669
+ self.temp_pan_active = False
1670
+ # Disable mouse if we're in selection mode
1671
+ if self.selection_mode or self.zoom_mode:
1672
+ self.plot.setMouseEnabled(x=False, y=False)
1673
+
1674
+ def _handle_wheel_in_selection(self, event):
1675
+ """Handle wheel events in selection mode - temporarily enable pan for zoom"""
1676
+ # Temporarily enable mouse for zooming
1677
+ if not self.was_in_selection_before_wheel:
1678
+ self.was_in_selection_before_wheel = True
1679
+ self.plot.setMouseEnabled(x=True, y=True)
1680
+
1681
+ # Reset or create timer
1682
+ if not self.wheel_timer:
1683
+ self.wheel_timer = QTimer()
1684
+ self.wheel_timer.setSingleShot(True)
1685
+ self.wheel_timer.timeout.connect(self._end_wheel_zoom)
1686
+
1687
+ # Restart timer
1688
+ self.wheel_timer.start(1)
1689
+
1690
+ def _end_wheel_zoom(self):
1691
+ """End wheel zoom and return to selection mode"""
1692
+ if self.was_in_selection_before_wheel and (self.selection_mode or self.zoom_mode):
1693
+ self.plot.setMouseEnabled(x=False, y=False)
1694
+ self.was_in_selection_before_wheel = False
1695
+
1696
+ def _zoom_at_point(self, scene_pos, scale_factor):
1697
+ """Zoom in or out at a specific point"""
1698
+ # Convert scene position to view coordinates
1699
+ view_pos = self.plot.vb.mapSceneToView(scene_pos)
1700
+
1701
+ # Get current view range
1702
+ view_range = self.plot.viewRange()
1703
+ x_range = view_range[0]
1704
+ y_range = view_range[1]
1705
+
1706
+ # Calculate current center and size
1707
+ x_center = (x_range[0] + x_range[1]) / 2
1708
+ y_center = (y_range[0] + y_range[1]) / 2
1709
+ x_size = x_range[1] - x_range[0]
1710
+ y_size = y_range[1] - y_range[0]
1711
+
1712
+ # Calculate new size
1713
+ new_x_size = x_size / scale_factor
1714
+ new_y_size = y_size / scale_factor
1715
+
1716
+ # Calculate offset to zoom toward the point
1717
+ x_offset = (view_pos.x() - x_center) * (1 - 1/scale_factor)
1718
+ y_offset = (view_pos.y() - y_center) * (1 - 1/scale_factor)
1719
+
1720
+ # Set new range
1721
+ new_x_center = x_center + x_offset
1722
+ new_y_center = y_center + y_offset
1723
+
1724
+ self.plot.setXRange(new_x_center - new_x_size/2, new_x_center + new_x_size/2, padding=0)
1725
+ self.plot.setYRange(new_y_center - new_y_size/2, new_y_center + new_y_size/2, padding=0)
1726
+
1727
+ def _reset_view(self):
1728
+ """Reset view to show entire graph"""
1729
+ if not self.node_positions:
1730
+ return
1731
+
1732
+ nodes = list(self.node_positions.keys())
1733
+ if not nodes:
1734
+ return
1735
+
1736
+ pos_array = np.array([self.node_positions[n] for n in nodes])
1737
+
1738
+ # Get bounds
1739
+ x_min, y_min = pos_array.min(axis=0)
1740
+ x_max, y_max = pos_array.max(axis=0)
1741
+
1742
+ # Add padding
1743
+ if self.shell or self.component:
1744
+ padding = 0.75
1745
+ else:
1746
+ padding = 0.1
1747
+ x_range = x_max - x_min
1748
+ y_range = y_max - y_min
1749
+
1750
+ self.plot.setXRange(x_min - padding * x_range,
1751
+ x_max + padding * x_range, padding=0)
1752
+ self.plot.setYRange(y_min - padding * y_range,
1753
+ y_max + padding * y_range, padding=0)
1754
+
1755
+ def _clear_graph(self):
1756
+ """Clear the graph visualization"""
1757
+ if self.load_thread is not None and self.load_thread.isRunning():
1758
+ self.load_thread.finished.disconnect()
1759
+ self.load_thread.terminate() # Forcefully kill the thread
1760
+ self.load_thread.wait() # Wait for it to fully terminate
1761
+ self.load_thread = None # Clear the reference
1762
+
1763
+ # Remove loading indicator if it exists
1764
+ if hasattr(self, 'loading_text') and self.loading_text is not None:
1765
+ self.plot.removeItem(self.loading_text)
1766
+ self.loading_text = None
1767
+
1768
+ # Clear scatter plot
1769
+ self.scatter.clear()
1770
+
1771
+ # Clear edges
1772
+ for item in self.edge_items:
1773
+ self.plot.removeItem(item)
1774
+ self.edge_items.clear()
1775
+
1776
+ # Force clear all labels - be aggressive
1777
+ # First clear from our tracking dict
1778
+ if hasattr(self, 'label_items') and self.label_items:
1779
+ for label_item in list(self.label_items.values()):
1780
+ try:
1781
+ self.plot.removeItem(label_item)
1782
+ except:
1783
+ pass
1784
+ self.label_items.clear()
1785
+
1786
+ #Remove legend
1787
+ try:
1788
+ for i in reversed(range(self.legend_layout.count())):
1789
+ self.legend_layout.itemAt(i).widget().setParent(None)
1790
+ except:
1791
+ pass
1792
+
1793
+ # remove ALL TextItems from the plot
1794
+ # This catches any labels that might not be tracked properly
1795
+ items_to_remove = []
1796
+ for item in self.plot.items:
1797
+ if isinstance(item, pg.TextItem):
1798
+ items_to_remove.append(item)
1799
+ for item in items_to_remove:
1800
+ self.plot.removeItem(item)
1801
+
1802
+ # Clear selection rectangle if exists
1803
+ if self.selection_rect:
1804
+ self.plot.removeItem(self.selection_rect)
1805
+ self.selection_rect = None
1806
+
1807
+ # Clear data
1808
+ self.node_positions.clear()
1809
+ self.node_items.clear()
1810
+ self.selected_nodes.clear()
1811
+ self.rendered = False
1812
+ if hasattr(self, 'label_data'):
1813
+ self.label_data.clear()
1814
+
1815
+ # Clear cache
1816
+ self.cached_spots.clear()
1817
+ self.cached_node_to_index.clear()
1818
+ self.cached_brushes.clear()
1819
+ self.last_selected_set.clear()
1820
+ self.cached_sizes_for_lod.clear()
1821
+
1822
+ if self.graph is None or len(self.graph.nodes()) == 0:
1823
+ # Show loading indicator
1824
+ self.loading_text = pg.TextItem(
1825
+ text="No network detected",
1826
+ color=(100, 100, 100),
1827
+ anchor=(0.5, 0.5)
1828
+ )
1829
+ else:
1830
+ # Show loading indicator
1831
+ self.loading_text = pg.TextItem(
1832
+ text="Press 🔄 to load your graph",
1833
+ color=(100, 100, 100),
1834
+ anchor=(0.5, 0.5)
1835
+ )
1836
+
1837
+ self.loading_text.setPos(0, 0) # Center of view
1838
+ self.plot.addItem(self.loading_text)
1839
+
1840
+ def _popout_graph(self):
1841
+
1842
+ temp_graph_widget = NetworkGraphWidget(
1843
+ parent=self.parent_window,
1844
+ weight=self.weight,
1845
+ geometric=self.geometric,
1846
+ component = self.component,
1847
+ centroids=self.centroids,
1848
+ communities=self.communities,
1849
+ community_dict=self.community_dict,
1850
+ labels=self.labels,
1851
+ identities = self.identities,
1852
+ identity_dict = self.identity_dict,
1853
+ z_size = self.z_size,
1854
+ shell = self.shell,
1855
+ node_size = self.node_size,
1856
+ black_edges = self.black_edges,
1857
+ edge_size = self.edge_size
1858
+ )
1859
+
1860
+ temp_graph_widget.set_graph(self.graph)
1861
+ temp_graph_widget.show_in_window(title="Network Graph", width=1000, height=800)
1862
+ temp_graph_widget.load_graph()
1863
+ self.parent_window.temp_graph_widgets.append(temp_graph_widget)
1864
+
1865
+
1866
+ def select_nodes(self, nodes, add_to_selection=False):
1867
+ """
1868
+ Programmatically select nodes.
1869
+
1870
+ Parameters:
1871
+ -----------
1872
+ nodes : list
1873
+ List of node IDs to select
1874
+ add_to_selection : bool
1875
+ If True, add to existing selection. If False, replace selection.
1876
+ """
1877
+ if not add_to_selection:
1878
+ self.selected_nodes.clear()
1879
+
1880
+ # Add valid nodes to selection
1881
+ for node in nodes:
1882
+ if node in self.node_items:
1883
+ self.selected_nodes.add(node)
1884
+
1885
+ # Update visual representation
1886
+ self._render_nodes()
1887
+
1888
+ # Emit signal
1889
+ self.node_selected.emit(list(self.selected_nodes))
1890
+
1891
+ def clear_selection(self):
1892
+ """Clear all selected nodes"""
1893
+ self.selected_nodes.clear()
1894
+ self._render_nodes()
1895
+ self.node_selected.emit([])
1896
+
1897
+ def _on_node_clicked(self, scatter, points, ev):
1898
+ """Handle node click events"""
1899
+ if not self.selection_mode or len(points) == 0:
1900
+ return
1901
+
1902
+ # Get clicked node
1903
+ point = points[0]
1904
+ clicked_node = point.data()
1905
+
1906
+ # Check if Ctrl is pressed
1907
+ modifiers = ev.modifiers()
1908
+ ctrl_pressed = modifiers & Qt.KeyboardModifier.ControlModifier
1909
+
1910
+ if ctrl_pressed:
1911
+ # Toggle selection for this node
1912
+ if clicked_node in self.selected_nodes:
1913
+ self.selected_nodes.remove(clicked_node)
1914
+ else:
1915
+ self.selected_nodes.add(clicked_node)
1916
+ else:
1917
+ # Clear previous selection and select only this node
1918
+ self.selected_nodes.clear()
1919
+ self.selected_nodes.add(clicked_node)
1920
+
1921
+ self.push_selection()
1922
+
1923
+ # Update visual representation
1924
+ self._render_nodes()
1925
+ self.node_click = True
1926
+
1927
+ # Emit signal with all selected nodes
1928
+ self.node_selected.emit(list(self.selected_nodes))
1929
+
1930
+ def push_selection(self):
1931
+ self.parent_window.clicked_values['nodes'] = list(self.selected_nodes)
1932
+ self.parent_window.evaluate_mini(subgraph_push = True)
1933
+ self.parent_window.handle_info('node')
1934
+
1935
+ def get_selected_nodes(self):
1936
+ """Get the list of currently selected nodes"""
1937
+ return list(self.selected_nodes)
1938
+
1939
+ def get_selected_node(self):
1940
+ """
1941
+ Get a single selected node (for backwards compatibility).
1942
+ Returns the first selected node or None.
1943
+ """
1944
+ if self.selected_nodes:
1945
+ return next(iter(self.selected_nodes))
1946
+ return None
1947
+
1948
+
1949
+ def handle_find_action(self):
1950
+ try:
1951
+ val = self.parent_window.clicked_values['nodes'][-1]
1952
+ self.parent_window.handle_info(sort = 'node')
1953
+ if val in self.centroids:
1954
+ centroid = self.centroids[val]
1955
+ self.parent_window.set_active_channel(0)
1956
+ # Toggle on the nodes channel if it's not already visible
1957
+ if not self.parent_window.channel_visible[0]:
1958
+ self.parent_window.channel_buttons[0].setChecked(True)
1959
+ self.parent_window.toggle_channel(0)
1960
+ # Navigate to the Z-slice
1961
+ self.parent_window.slice_slider.setValue(int(centroid[0]))
1962
+ print(f"Found node {val} at [Z,Y,X] -> {centroid}")
1963
+ self.push_selection()
1964
+ except:
1965
+ import traceback
1966
+ traceback.print_exc()
1967
+ pass
1968
+
1969
+
1970
+ def save_table_as(self, file_type):
1971
+ """Save the table data as either CSV or Excel file."""
1972
+
1973
+ if self != self.parent_window.selection_graph_widget:
1974
+ table_name = "Network"
1975
+ df = self.parent_window.network_table.model()._data
1976
+ else:
1977
+ df = self.parent_window.selection_table.model()._data
1978
+ table_name = "Selection"
1979
+
1980
+ # Get save file name
1981
+ file_filter = ("CSV Files (*.csv)" if file_type == 'csv' else
1982
+ "Excel Files (*.xlsx)" if file_type == 'xlsx' else
1983
+ "Gephi Graph (*.gexf)" if file_type == 'gexf' else
1984
+ "GraphML (*.graphml)" if file_type == 'graphml' else
1985
+ "Pajek Network (*.net)")
1986
+
1987
+ filename, _ = QFileDialog.getSaveFileName(
1988
+ self,
1989
+ f"Save {table_name} Table As",
1990
+ "",
1991
+ file_filter
1992
+ )
1993
+
1994
+ if filename:
1995
+ try:
1996
+ if file_type == 'csv':
1997
+ # If user didn't type extension, add .csv
1998
+ if not filename.endswith('.csv'):
1999
+ filename += '.csv'
2000
+ df.to_csv(filename, index=False)
2001
+ elif file_type == 'xlsx':
2002
+ # If user didn't type extension, add .xlsx
2003
+ if not filename.endswith('.xlsx'):
2004
+ filename += '.xlsx'
2005
+ df.to_excel(filename, index=False)
2006
+ elif file_type == 'gexf':
2007
+ # If user didn't type extension, add .gexf
2008
+ if not filename.endswith('.gexf'):
2009
+ filename += '.gexf'
2010
+ #for node in my_network.network.nodes():
2011
+ #my_network.network.nodes[node]['label'] = str(node)
2012
+ nx.write_gexf(self.graph, filename, encoding='utf-8', prettyprint=True)
2013
+ elif file_type == 'graphml':
2014
+ # If user didn't type extension, add .graphml
2015
+ if not filename.endswith('.graphml'):
2016
+ filename += '.graphml'
2017
+ nx.write_graphml(self.graph, filename)
2018
+ elif file_type == 'net':
2019
+ # If user didn't type extension, add .net
2020
+ if not filename.endswith('.net'):
2021
+ filename += '.net'
2022
+ nx.write_pajek(self.graph, filename)
2023
+
2024
+ QMessageBox.information(
2025
+ self,
2026
+ "Success",
2027
+ f"{table_name} table successfully saved to {filename}"
2028
+ )
2029
+ except Exception as e:
2030
+ QMessageBox.critical(
2031
+ self,
2032
+ "Error",
2033
+ f"Failed to save file: {str(e)}"
2034
+ )
2035
+
2036
+
2037
+ def get_properties(self):
2038
+
2039
+ self.parent_window.update_graph_fields()
2040
+
2041
+
2042
+ def create_context_menu(self, event):
2043
+ # Get the index at the clicked position
2044
+ # Create context menu
2045
+ context_menu = QMenu(self)
2046
+
2047
+ find_action = context_menu.addAction("Find Node")
2048
+
2049
+ find_action.triggered.connect(self.handle_find_action)
2050
+ neigh_action = context_menu.addAction("Show Neighbors")
2051
+ neigh_action.triggered.connect(self.parent_window.handle_show_neighbors)
2052
+ com_action = context_menu.addAction("Show Community")
2053
+ com_action.triggered.connect(self.parent_window.handle_show_communities)
2054
+ comp_action = context_menu.addAction("Show Connected Component")
2055
+ comp_action.triggered.connect(self.parent_window.handle_show_component)
2056
+ # Add separator
2057
+ context_menu.addSeparator()
2058
+
2059
+ # Add Save As menu
2060
+ save_menu = context_menu.addMenu("Save As")
2061
+ save_csv = save_menu.addAction("CSV")
2062
+ save_excel = save_menu.addAction("Excel")
2063
+ save_gephi = save_menu.addAction("Gephi")
2064
+ save_graphml = save_menu.addAction("GraphML")
2065
+ save_pajek = save_menu.addAction("Pajek")
2066
+
2067
+ # Connect the actions - ensure we're saving the active table
2068
+ save_csv.triggered.connect(lambda: self.save_table_as('csv'))
2069
+ save_excel.triggered.connect(lambda: self.save_table_as('xlsx'))
2070
+ save_gephi.triggered.connect(lambda: self.save_table_as('gexf'))
2071
+ save_graphml.triggered.connect(lambda: self.save_table_as('graphml'))
2072
+ save_pajek.triggered.connect(lambda: self.save_table_as('net'))
2073
+
2074
+
2075
+ if self == self.parent_window.selection_graph_widget:
2076
+ set_action = context_menu.addAction("Swap with network table (also sets internal network properties - may affect related functions)")
2077
+ set_action.triggered.connect(self.parent_window.selection_table.set_selection_to_active)
2078
+
2079
+ # Show the menu at cursor position
2080
+ view_widget = self.plot.getViewWidget()
2081
+
2082
+ # Map scene position to view coordinates
2083
+ view_pos = view_widget.mapFromScene(event.scenePos())
2084
+
2085
+ # Map to global screen coordinates
2086
+ global_pos = view_widget.mapToGlobal(view_pos)
2087
+
2088
+ # Show the menu
2089
+ context_menu.exec(global_pos)
2090
+
2091
+ def update_params(self, weight=None, geometric=None, component = None, centroids=None,
2092
+ communities=None, community_dict=None,
2093
+ identities=None, identity_dict=None, labels=None, z_size = None, shell = None, node_size = 10):
2094
+ """Update visualization parameters"""
2095
+ if weight is not None:
2096
+ self.weight = weight
2097
+ if geometric is not None:
2098
+ self.geometric = geometric
2099
+ if component is not None:
2100
+ self.component = component
2101
+ if centroids is not None:
2102
+ self.centroids = centroids
2103
+ if communities is not None:
2104
+ self.communities = communities
2105
+ if community_dict is not None:
2106
+ self.community_dict = community_dict
2107
+ if identities is not None:
2108
+ self.identities = identities
2109
+ if identity_dict is not None:
2110
+ self.identity_dict = identity_dict
2111
+ if labels is not None:
2112
+ self.labels = labels
2113
+ if z_size is not None:
2114
+ self.z_size = z_size
2115
+ if shell is not None:
2116
+ self.shell = shell
2117
+ if node_size is not None:
2118
+ self.node_size = node_size
2119
+
2120
+ def _on_view_changed(self):
2121
+ """Handle view range changes for level-of-detail adjustments"""
2122
+ if not self.node_positions or len(self.node_positions) == 0:
2123
+ return
2124
+
2125
+ # Calculate current zoom factor based on view range
2126
+ view_range = self.plot.viewRange()
2127
+ x_range = view_range[0][1] - view_range[0][0]
2128
+ y_range = view_range[1][1] - view_range[1][0]
2129
+
2130
+ # Get initial full graph bounds
2131
+ nodes = list(self.node_positions.keys())
2132
+ pos_array = np.array([self.node_positions[n] for n in nodes])
2133
+
2134
+ if len(pos_array) > 0:
2135
+ full_x_range = pos_array[:, 0].max() - pos_array[:, 0].min()
2136
+ full_y_range = pos_array[:, 1].max() - pos_array[:, 1].min()
2137
+
2138
+ if full_x_range > 0 and full_y_range > 0:
2139
+ # Calculate zoom factor (smaller view range = more zoomed in)
2140
+ zoom_x = full_x_range / x_range if x_range > 0 else 1
2141
+ zoom_y = full_y_range / y_range if y_range > 0 else 1
2142
+ zoom_factor = max(zoom_x, zoom_y)
2143
+
2144
+ # Update if zoom changed significantly (>10% change)
2145
+ zoom_changed = abs(zoom_factor - self.current_zoom_factor) / max(self.current_zoom_factor, 0.01) > 0.1
2146
+ if zoom_changed:
2147
+ self.current_zoom_factor = zoom_factor
2148
+ self._update_lod_rendering()
2149
+ else:
2150
+ # Even if zoom didn't change, update labels for panning
2151
+ # (viewport changed but zoom level stayed the same)
2152
+ if self.labels:
2153
+ self._update_labels_for_zoom()
2154
+
2155
+ def _update_lod_rendering(self):
2156
+ """OPTIMIZED: Update rendering based on current zoom level using cached data"""
2157
+ if not self.cached_spots or not self.cached_sizes_for_lod:
2158
+ return
2159
+
2160
+ # Adjust node sizes based on zoom
2161
+ if self.current_zoom_factor > 1.5:
2162
+ scale_factor = 1.0 + np.log10(self.current_zoom_factor) * 0.3
2163
+ else:
2164
+ scale_factor = 1.0
2165
+
2166
+ # Update node sizes in cached spots
2167
+ for i, base_size in enumerate(self.cached_sizes_for_lod):
2168
+ self.cached_spots[i]['size'] = base_size * scale_factor
2169
+
2170
+ # Update edge visibility based on zoom
2171
+ if self.current_zoom_factor < 0.5:
2172
+ edge_alpha = int(50 * self.current_zoom_factor)
2173
+ elif self.current_zoom_factor > 2:
2174
+ edge_alpha = min(150, int(100 + self.current_zoom_factor * 10))
2175
+ else:
2176
+ edge_alpha = 100
2177
+
2178
+ if self.black_edges:
2179
+ edge_color = (0, 0, 0)
2180
+ else:
2181
+ edge_color = (150, 150, 150, edge_alpha)
2182
+
2183
+ # Update edge rendering (batched edge items)
2184
+ if self.edge_items:
2185
+ for edge_item in self.edge_items:
2186
+ current_pen = edge_item.opts['pen']
2187
+ if current_pen is not None:
2188
+ width = current_pen.widthF()
2189
+ new_pen = pg.mkPen(color=edge_color, width=width)
2190
+ edge_item.setPen(new_pen)
2191
+
2192
+ # Update labels based on zoom level
2193
+ if self.labels:
2194
+ self._update_labels_for_zoom()
2195
+
2196
+ # Re-render nodes with new sizes
2197
+ self.scatter.setData(spots=self.cached_spots)
2198
+
2199
+ def show_in_window(self, title="Network Graph", width=1000, height=800):
2200
+ """Show the graph widget in a separate non-modal window"""
2201
+ from PyQt6.QtWidgets import QMainWindow
2202
+
2203
+ # Create new window
2204
+ self.popup_window = QMainWindow()
2205
+ self.popup_window.setWindowTitle(title)
2206
+ self.popup_window.setGeometry(100, 100, width, height)
2207
+ self.popup_window.setCentralWidget(self)
2208
+
2209
+ # Show non-modal
2210
+ self.popup_window.show()
2211
+
2212
+ return self.popup_window
2213
+
2214
+
2215
+ # Example usage
2216
+ if __name__ == "__main__":
2217
+ from PyQt6.QtWidgets import QApplication, QMainWindow
2218
+ import sys
2219
+
2220
+ class MainWindow(QMainWindow):
2221
+ def __init__(self):
2222
+ super().__init__()
2223
+ self.setWindowTitle("Network Graph Viewer")
2224
+ self.setGeometry(100, 100, 1000, 800)
2225
+
2226
+ # Create a sample graph
2227
+ G = nx.karate_club_graph()
2228
+
2229
+ # Add some weights
2230
+ for u, v in G.edges():
2231
+ G[u][v]['weight'] = np.random.uniform(0.5, 5.0)
2232
+
2233
+ # Create sample community detection
2234
+ communities = nx.community.greedy_modularity_communities(G)
2235
+ community_dict = {}
2236
+ for i, comm in enumerate(communities):
2237
+ for node in comm:
2238
+ community_dict[node] = i
2239
+
2240
+ # Create the widget
2241
+ self.graph_widget = NetworkGraphWidget(
2242
+ parent=self,
2243
+ weight=True,
2244
+ communities=True,
2245
+ community_dict=community_dict,
2246
+ labels=True # Enable labels for testing
2247
+ )
2248
+
2249
+ self.setCentralWidget(self.graph_widget)
2250
+
2251
+ # Set and load the graph
2252
+ self.graph_widget.set_graph(G)
2253
+ self.graph_widget.load_graph()
2254
+
2255
+ # Connect signal
2256
+ self.graph_widget.node_selected.connect(self.on_node_selected)
2257
+
2258
+ def on_node_selected(self, nodes):
2259
+ if nodes:
2260
+ print(f"Selected nodes: {nodes}")
2261
+ else:
2262
+ print("No nodes selected")
2263
+
2264
+ app = QApplication(sys.argv)
2265
+ window = MainWindow()
2266
+ window.show()
2267
+ sys.exit(app.exec())