pyg-nightly 2.7.0.dev20250428__py3-none-any.whl → 2.7.0.dev20250429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyg_nightly-2.7.0.dev20250428.dist-info → pyg_nightly-2.7.0.dev20250429.dist-info}/METADATA +1 -1
- {pyg_nightly-2.7.0.dev20250428.dist-info → pyg_nightly-2.7.0.dev20250429.dist-info}/RECORD +8 -8
- torch_geometric/__init__.py +1 -1
- torch_geometric/explain/explanation.py +86 -2
- torch_geometric/visualization/__init__.py +2 -1
- torch_geometric/visualization/graph.py +247 -1
- {pyg_nightly-2.7.0.dev20250428.dist-info → pyg_nightly-2.7.0.dev20250429.dist-info}/WHEEL +0 -0
- {pyg_nightly-2.7.0.dev20250428.dist-info → pyg_nightly-2.7.0.dev20250429.dist-info}/licenses/LICENSE +0 -0
{pyg_nightly-2.7.0.dev20250428.dist-info → pyg_nightly-2.7.0.dev20250429.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pyg-nightly
|
3
|
-
Version: 2.7.0.
|
3
|
+
Version: 2.7.0.dev20250429
|
4
4
|
Summary: Graph Neural Network Library for PyTorch
|
5
5
|
Keywords: deep-learning,pytorch,geometric-deep-learning,graph-neural-networks,graph-convolutional-networks
|
6
6
|
Author-email: Matthias Fey <matthias@pyg.org>
|
@@ -1,4 +1,4 @@
|
|
1
|
-
torch_geometric/__init__.py,sha256=
|
1
|
+
torch_geometric/__init__.py,sha256=Fn0Tk4CEGmrh27S4w7JxkbV3OwwAKExri7NzXE8-dQE,1978
|
2
2
|
torch_geometric/_compile.py,sha256=f-WQeH4VLi5Hn9lrgztFUCSrN_FImjhQa6BxFzcYC38,1338
|
3
3
|
torch_geometric/_onnx.py,sha256=V9ffrIKSqhDw6xUZ12lkuSfNs48cQp2EeJ6Z19GfnVw,349
|
4
4
|
torch_geometric/backend.py,sha256=lVaf7aLoVaB3M-UcByUJ1G4T4FOK6LXAg0CF4W3E8jo,1575
|
@@ -193,7 +193,7 @@ torch_geometric/distributed/utils.py,sha256=FGrr3qw7hx7EQaIjjqasurloCFJ9q_0jt8jd
|
|
193
193
|
torch_geometric/explain/__init__.py,sha256=pRxVB33zsxhED1StRWdHboQWh3e06__g9N298Hzi42Y,359
|
194
194
|
torch_geometric/explain/config.py,sha256=_0j67NAwPwjrWHPncNywCT-oKyMiryJNxufxVN1BFlM,7834
|
195
195
|
torch_geometric/explain/explainer.py,sha256=8_NZTmlT4WO9RgKxpSUQRt3rbVwFURF5bSWOPlfOLjA,10667
|
196
|
-
torch_geometric/explain/explanation.py,sha256=
|
196
|
+
torch_geometric/explain/explanation.py,sha256=Bt8THLn-CSrvEFisdT9DX9fnOMaqficsChSCI9uhyQw,18873
|
197
197
|
torch_geometric/explain/algorithm/__init__.py,sha256=fE29xbd0bPxg-EfrB2BDmmY9QnyO-7TgvYduGHofm5o,496
|
198
198
|
torch_geometric/explain/algorithm/attention_explainer.py,sha256=65iGLmOt00ERtBDVxAoydIchykdWZU24aXzSzUGzQEI,11304
|
199
199
|
torch_geometric/explain/algorithm/base.py,sha256=wwJcREUFKDLFUDjRa9o4X3DWqQgMvhS3Iciwb6Evtjc,6922
|
@@ -633,10 +633,10 @@ torch_geometric/utils/repeat.py,sha256=RxCoRoEisaP6NouXPPW5tY1Rn-tIfrmpJPm0qGP6W
|
|
633
633
|
torch_geometric/utils/smiles.py,sha256=lGQ2BwJ49uBrQfIxxPz8ceTO9Jo-XCjlLxs1ql3xrsA,7130
|
634
634
|
torch_geometric/utils/sparse.py,sha256=uYd0oPrp5XN0c2Zc15f-00rhhVMfLnRMqNcqcmILNKQ,25519
|
635
635
|
torch_geometric/utils/undirected.py,sha256=H_nfpI0_WluOG6VfjPyldvcjL4w5USAKWu2x5nUAUjw,6222
|
636
|
-
torch_geometric/visualization/__init__.py,sha256=
|
637
|
-
torch_geometric/visualization/graph.py,sha256=
|
636
|
+
torch_geometric/visualization/__init__.py,sha256=b-HnVesXjyJ_L1N-DnjiRiRVf7lhwKaBQF_2i5YMVSU,208
|
637
|
+
torch_geometric/visualization/graph.py,sha256=PoI9tjbEXZVkMUg4CvTLbzqtEfzUwMUcsw57DNBEU0s,14311
|
638
638
|
torch_geometric/visualization/influence.py,sha256=CWMvuNA_Nf1sfbJmQgn58yS4OFpeKXeZPe7kEuvkUBw,477
|
639
|
-
pyg_nightly-2.7.0.
|
640
|
-
pyg_nightly-2.7.0.
|
641
|
-
pyg_nightly-2.7.0.
|
642
|
-
pyg_nightly-2.7.0.
|
639
|
+
pyg_nightly-2.7.0.dev20250429.dist-info/licenses/LICENSE,sha256=ic-27cMJc1kWoMEYncz3Ya3Ur2Bi3bNLWib2DT763-o,1067
|
640
|
+
pyg_nightly-2.7.0.dev20250429.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
641
|
+
pyg_nightly-2.7.0.dev20250429.dist-info/METADATA,sha256=ywTA_H23sidl8wVuvJPmwZWso1slTN7-GwOKR6MSf04,62979
|
642
|
+
pyg_nightly-2.7.0.dev20250429.dist-info/RECORD,,
|
torch_geometric/__init__.py
CHANGED
@@ -31,7 +31,7 @@ from .lazy_loader import LazyLoader
|
|
31
31
|
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
|
32
32
|
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
|
33
33
|
|
34
|
-
__version__ = '2.7.0.
|
34
|
+
__version__ = '2.7.0.dev20250429'
|
35
35
|
|
36
36
|
__all__ = [
|
37
37
|
'Index',
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import copy
|
2
|
-
from typing import Dict, List, Optional, Union
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
@@ -8,7 +8,10 @@ from torch_geometric.data.data import Data, warn_or_raise
|
|
8
8
|
from torch_geometric.data.hetero_data import HeteroData
|
9
9
|
from torch_geometric.explain.config import ThresholdConfig, ThresholdType
|
10
10
|
from torch_geometric.typing import EdgeType, NodeType
|
11
|
-
from torch_geometric.visualization import
|
11
|
+
from torch_geometric.visualization import (
|
12
|
+
visualize_graph,
|
13
|
+
visualize_hetero_graph,
|
14
|
+
)
|
12
15
|
|
13
16
|
|
14
17
|
class ExplanationMixin:
|
@@ -362,6 +365,87 @@ class HeteroExplanation(HeteroData, ExplanationMixin):
|
|
362
365
|
|
363
366
|
return _visualize_score(score, all_feat_labels, path, top_k)
|
364
367
|
|
368
|
+
def visualize_graph(
|
369
|
+
self,
|
370
|
+
path: Optional[str] = None,
|
371
|
+
node_labels: Optional[Dict[NodeType, List[str]]] = None,
|
372
|
+
node_size_range: Tuple[float, float] = (50, 500),
|
373
|
+
node_opacity_range: Tuple[float, float] = (0.2, 1.0),
|
374
|
+
edge_width_range: Tuple[float, float] = (0.1, 2.0),
|
375
|
+
edge_opacity_range: Tuple[float, float] = (0.2, 1.0),
|
376
|
+
) -> None:
|
377
|
+
r"""Visualizes the explanation subgraph using networkx, with edge
|
378
|
+
opacity corresponding to edge importance and node colors
|
379
|
+
corresponding to node types.
|
380
|
+
|
381
|
+
Args:
|
382
|
+
path (str, optional): The path to where the plot is saved.
|
383
|
+
If set to :obj:`None`, will visualize the plot on-the-fly.
|
384
|
+
(default: :obj:`None`)
|
385
|
+
node_labels (Dict[NodeType, List[str]], optional): The display
|
386
|
+
names of nodes for each node type that will be shown in the
|
387
|
+
visualization. (default: :obj:`None`)
|
388
|
+
node_size_range (Tuple[float, float], optional): The minimum and
|
389
|
+
maximum node size in the visualization.
|
390
|
+
(default: :obj:`(50, 500)`)
|
391
|
+
node_opacity_range (Tuple[float, float], optional): The minimum and
|
392
|
+
maximum node opacity in the visualization.
|
393
|
+
(default: :obj:`(0.2, 1.0)`)
|
394
|
+
edge_width_range (Tuple[float, float], optional): The minimum and
|
395
|
+
maximum edge width in the visualization.
|
396
|
+
(default: :obj:`(0.1, 2.0)`)
|
397
|
+
edge_opacity_range (Tuple[float, float], optional): The minimum and
|
398
|
+
maximum edge opacity in the visualization.
|
399
|
+
(default: :obj:`(0.2, 1.0)`)
|
400
|
+
"""
|
401
|
+
# Validate node labels if provided
|
402
|
+
if node_labels is not None:
|
403
|
+
for node_type, labels in node_labels.items():
|
404
|
+
if node_type not in self.node_types:
|
405
|
+
raise ValueError(
|
406
|
+
f"Node type '{node_type}' in node_labels "
|
407
|
+
f"does not exist in the explanation graph")
|
408
|
+
if len(labels) != self[node_type].num_nodes:
|
409
|
+
raise ValueError(f"Number of labels for node type "
|
410
|
+
f"'{node_type}' (got {len(labels)}) does "
|
411
|
+
f"not match the number of nodes "
|
412
|
+
f"(got {self[node_type].num_nodes})")
|
413
|
+
# Get the explanation subgraph
|
414
|
+
subgraph = self.get_explanation_subgraph()
|
415
|
+
|
416
|
+
# Prepare edge indices and weights for each edge type
|
417
|
+
edge_index_dict = {}
|
418
|
+
edge_weight_dict = {}
|
419
|
+
for edge_type in subgraph.edge_types:
|
420
|
+
if edge_type[0] == 'x' or edge_type[-1] == 'x': # Skip edges
|
421
|
+
continue
|
422
|
+
edge_index_dict[edge_type] = subgraph[edge_type].edge_index
|
423
|
+
edge_weight_dict[edge_type] = subgraph[edge_type].get(
|
424
|
+
'edge_mask',
|
425
|
+
torch.ones(subgraph[edge_type].edge_index.size(1)))
|
426
|
+
|
427
|
+
# Prepare node weights for each node type
|
428
|
+
node_weight_dict = {}
|
429
|
+
for node_type in subgraph.node_types:
|
430
|
+
if node_type == 'x': # Skip the global store
|
431
|
+
continue
|
432
|
+
node_weight_dict[node_type] = subgraph[node_type] \
|
433
|
+
.get('node_mask',
|
434
|
+
torch.ones(subgraph[node_type].num_nodes)).squeeze(-1)
|
435
|
+
|
436
|
+
# Call the visualization function
|
437
|
+
visualize_hetero_graph(
|
438
|
+
edge_index_dict=edge_index_dict,
|
439
|
+
edge_weight_dict=edge_weight_dict,
|
440
|
+
path=path,
|
441
|
+
node_labels_dict=node_labels,
|
442
|
+
node_weight_dict=node_weight_dict,
|
443
|
+
node_size_range=node_size_range,
|
444
|
+
node_opacity_range=node_opacity_range,
|
445
|
+
edge_width_range=edge_width_range,
|
446
|
+
edge_opacity_range=edge_opacity_range,
|
447
|
+
)
|
448
|
+
|
365
449
|
|
366
450
|
def _visualize_score(
|
367
451
|
score: torch.Tensor,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from math import sqrt
|
2
|
-
from typing import Any, List, Optional
|
2
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import Tensor
|
@@ -150,3 +150,249 @@ def _visualize_graph_via_networkx(
|
|
150
150
|
plt.show()
|
151
151
|
|
152
152
|
plt.close()
|
153
|
+
|
154
|
+
|
155
|
+
def visualize_hetero_graph(
|
156
|
+
edge_index_dict: Dict[Tuple[str, str, str], Tensor],
|
157
|
+
edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
|
158
|
+
path: Optional[str] = None,
|
159
|
+
backend: Optional[str] = None,
|
160
|
+
node_labels_dict: Optional[Dict[str, List[str]]] = None,
|
161
|
+
node_weight_dict: Optional[Dict[str, Tensor]] = None,
|
162
|
+
node_size_range: Tuple[float, float] = (50, 500),
|
163
|
+
node_opacity_range: Tuple[float, float] = (0.2, 1.0),
|
164
|
+
edge_width_range: Tuple[float, float] = (0.1, 2.0),
|
165
|
+
edge_opacity_range: Tuple[float, float] = (0.2, 1.0),
|
166
|
+
) -> Any:
|
167
|
+
"""Visualizes a heterogeneous graph using networkx."""
|
168
|
+
if backend is not None and backend != "networkx":
|
169
|
+
raise ValueError("Only 'networkx' backend is supported")
|
170
|
+
|
171
|
+
# Filter out edges with 0 weight
|
172
|
+
filtered_edge_index_dict = {}
|
173
|
+
filtered_edge_weight_dict = {}
|
174
|
+
for edge_type in edge_index_dict.keys():
|
175
|
+
mask = edge_weight_dict[edge_type] > 0
|
176
|
+
if mask.sum() > 0:
|
177
|
+
filtered_edge_index_dict[edge_type] = edge_index_dict[
|
178
|
+
edge_type][:, mask]
|
179
|
+
filtered_edge_weight_dict[edge_type] = edge_weight_dict[edge_type][
|
180
|
+
mask]
|
181
|
+
|
182
|
+
# Get all unique nodes that are still in the filtered edges
|
183
|
+
remaining_nodes: Dict[str, Set[int]] = {}
|
184
|
+
for edge_type, edge_index in filtered_edge_index_dict.items():
|
185
|
+
src_type, _, dst_type = edge_type
|
186
|
+
if src_type not in remaining_nodes:
|
187
|
+
remaining_nodes[src_type] = set()
|
188
|
+
if dst_type not in remaining_nodes:
|
189
|
+
remaining_nodes[dst_type] = set()
|
190
|
+
remaining_nodes[src_type].update(edge_index[0].tolist())
|
191
|
+
remaining_nodes[dst_type].update(edge_index[1].tolist())
|
192
|
+
|
193
|
+
# Filter node weights to only include remaining nodes
|
194
|
+
if node_weight_dict is not None:
|
195
|
+
filtered_node_weight_dict = {}
|
196
|
+
for node_type, weights in node_weight_dict.items():
|
197
|
+
if node_type in remaining_nodes:
|
198
|
+
mask = torch.zeros(len(weights), dtype=torch.bool)
|
199
|
+
mask[list(remaining_nodes[node_type])] = True
|
200
|
+
filtered_node_weight_dict[node_type] = weights[mask]
|
201
|
+
node_weight_dict = filtered_node_weight_dict
|
202
|
+
|
203
|
+
# Filter node labels to only include remaining nodes
|
204
|
+
if node_labels_dict is not None:
|
205
|
+
filtered_node_labels_dict = {}
|
206
|
+
for node_type, labels in node_labels_dict.items():
|
207
|
+
if node_type in remaining_nodes:
|
208
|
+
filtered_node_labels_dict[node_type] = [
|
209
|
+
label for i, label in enumerate(labels)
|
210
|
+
if i in remaining_nodes[node_type]
|
211
|
+
]
|
212
|
+
node_labels_dict = filtered_node_labels_dict
|
213
|
+
|
214
|
+
return _visualize_hetero_graph_via_networkx(
|
215
|
+
filtered_edge_index_dict,
|
216
|
+
filtered_edge_weight_dict,
|
217
|
+
path,
|
218
|
+
node_labels_dict,
|
219
|
+
node_weight_dict,
|
220
|
+
node_size_range,
|
221
|
+
node_opacity_range,
|
222
|
+
edge_width_range,
|
223
|
+
edge_opacity_range,
|
224
|
+
)
|
225
|
+
|
226
|
+
|
227
|
+
def _visualize_hetero_graph_via_networkx(
|
228
|
+
edge_index_dict: Dict[Tuple[str, str, str], Tensor],
|
229
|
+
edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
|
230
|
+
path: Optional[str] = None,
|
231
|
+
node_labels_dict: Optional[Dict[str, List[str]]] = None,
|
232
|
+
node_weight_dict: Optional[Dict[str, Tensor]] = None,
|
233
|
+
node_size_range: Tuple[float, float] = (50, 500),
|
234
|
+
node_opacity_range: Tuple[float, float] = (0.2, 1.0),
|
235
|
+
edge_width_range: Tuple[float, float] = (0.1, 2.0),
|
236
|
+
edge_opacity_range: Tuple[float, float] = (0.2, 1.0),
|
237
|
+
) -> Any:
|
238
|
+
import matplotlib.pyplot as plt
|
239
|
+
import networkx as nx
|
240
|
+
|
241
|
+
g = nx.DiGraph()
|
242
|
+
node_offsets: Dict[str, int] = {}
|
243
|
+
current_offset = 0
|
244
|
+
|
245
|
+
# First, collect all unique node types and their counts
|
246
|
+
node_types = set()
|
247
|
+
node_counts: Dict[str, int] = {}
|
248
|
+
remaining_nodes: Dict[str, Set[int]] = {
|
249
|
+
} # Track which nodes are actually present in edges
|
250
|
+
|
251
|
+
# Get all unique nodes that are in the edges
|
252
|
+
for edge_type in edge_index_dict.keys():
|
253
|
+
src_type, _, dst_type = edge_type
|
254
|
+
node_types.add(src_type)
|
255
|
+
node_types.add(dst_type)
|
256
|
+
|
257
|
+
if src_type not in remaining_nodes:
|
258
|
+
remaining_nodes[src_type] = set()
|
259
|
+
if dst_type not in remaining_nodes:
|
260
|
+
remaining_nodes[dst_type] = set()
|
261
|
+
|
262
|
+
remaining_nodes[src_type].update(
|
263
|
+
edge_index_dict[edge_type][0].tolist())
|
264
|
+
remaining_nodes[dst_type].update(
|
265
|
+
edge_index_dict[edge_type][1].tolist())
|
266
|
+
|
267
|
+
# Set node counts based on remaining nodes
|
268
|
+
for node_type in node_types:
|
269
|
+
node_counts[node_type] = len(remaining_nodes[node_type])
|
270
|
+
|
271
|
+
# Add nodes for each node type
|
272
|
+
for node_type in node_types:
|
273
|
+
num_nodes = node_counts[node_type]
|
274
|
+
node_offsets[node_type] = current_offset
|
275
|
+
|
276
|
+
# Get node weights if provided
|
277
|
+
weights = None
|
278
|
+
if node_weight_dict is not None and node_type in node_weight_dict:
|
279
|
+
weights = node_weight_dict[node_type]
|
280
|
+
if len(weights) != num_nodes:
|
281
|
+
raise ValueError(f"Number of weights for node type "
|
282
|
+
f"{node_type} ({len(weights)}) does not "
|
283
|
+
f"match number of nodes ({num_nodes})")
|
284
|
+
|
285
|
+
for i in range(num_nodes):
|
286
|
+
node_id = current_offset + i
|
287
|
+
label = (node_labels_dict[node_type][i]
|
288
|
+
if node_labels_dict is not None
|
289
|
+
and node_type in node_labels_dict else "")
|
290
|
+
|
291
|
+
# Calculate node size and opacity if weights provided
|
292
|
+
size = node_size_range[1]
|
293
|
+
opacity = node_opacity_range[1]
|
294
|
+
if weights is not None:
|
295
|
+
w = weights[i].item()
|
296
|
+
size = node_size_range[0] + w * \
|
297
|
+
(node_size_range[1] - node_size_range[0])
|
298
|
+
opacity = node_opacity_range[0] + w * \
|
299
|
+
(node_opacity_range[1] - node_opacity_range[0])
|
300
|
+
|
301
|
+
g.add_node(node_id, label=label, type=node_type, size=size,
|
302
|
+
alpha=opacity)
|
303
|
+
|
304
|
+
current_offset += num_nodes
|
305
|
+
|
306
|
+
# Add edges with remapped node indices
|
307
|
+
for edge_type, edge_index in edge_index_dict.items():
|
308
|
+
src_type, _, dst_type = edge_type
|
309
|
+
edge_weight = edge_weight_dict[edge_type]
|
310
|
+
src_offset = node_offsets[src_type]
|
311
|
+
dst_offset = node_offsets[dst_type]
|
312
|
+
|
313
|
+
# Create mappings for source and target nodes
|
314
|
+
src_mapping = {
|
315
|
+
old_idx: new_idx
|
316
|
+
for new_idx, old_idx in enumerate(sorted(
|
317
|
+
remaining_nodes[src_type]))
|
318
|
+
}
|
319
|
+
dst_mapping = {
|
320
|
+
old_idx: new_idx
|
321
|
+
for new_idx, old_idx in enumerate(sorted(
|
322
|
+
remaining_nodes[dst_type]))
|
323
|
+
}
|
324
|
+
|
325
|
+
for (src, dst), w in zip(edge_index.t().tolist(),
|
326
|
+
edge_weight.tolist()):
|
327
|
+
# Remap node indices
|
328
|
+
new_src = src_mapping[src] + src_offset
|
329
|
+
new_dst = dst_mapping[dst] + dst_offset
|
330
|
+
|
331
|
+
# Calculate edge width and opacity based on weight
|
332
|
+
width = edge_width_range[0] + w * \
|
333
|
+
(edge_width_range[1] - edge_width_range[0])
|
334
|
+
opacity = edge_opacity_range[0] + w * \
|
335
|
+
(edge_opacity_range[1] - edge_opacity_range[0])
|
336
|
+
g.add_edge(new_src, new_dst, width=width, alpha=opacity)
|
337
|
+
|
338
|
+
# Draw the graph
|
339
|
+
ax = plt.gca()
|
340
|
+
pos = nx.arf_layout(g)
|
341
|
+
|
342
|
+
# Draw edges with arrows
|
343
|
+
for src, dst, data in g.edges(data=True):
|
344
|
+
ax.annotate(
|
345
|
+
'',
|
346
|
+
xy=pos[src],
|
347
|
+
xytext=pos[dst],
|
348
|
+
arrowprops=dict(
|
349
|
+
arrowstyle="<-",
|
350
|
+
alpha=data['alpha'],
|
351
|
+
linewidth=data['width'],
|
352
|
+
shrinkA=sqrt(g.nodes[src]['size']) / 2.0,
|
353
|
+
shrinkB=sqrt(g.nodes[dst]['size']) / 2.0,
|
354
|
+
connectionstyle="arc3,rad=0.1",
|
355
|
+
),
|
356
|
+
)
|
357
|
+
|
358
|
+
# Draw nodes colored by type
|
359
|
+
node_colors = []
|
360
|
+
node_sizes = []
|
361
|
+
node_alphas = []
|
362
|
+
|
363
|
+
# Use matplotlib tab20 colormap for consistent coloring
|
364
|
+
tab10_cmap = plt.cm.tab10 # type: ignore[attr-defined]
|
365
|
+
node_type_colors: Dict[str, Any] = {} # Store color for each node type
|
366
|
+
for node in g.nodes():
|
367
|
+
node_type = g.nodes[node]['type']
|
368
|
+
# Assign a consistent color for each node type
|
369
|
+
if node_type not in node_type_colors:
|
370
|
+
color_idx = len(node_type_colors) % 10 # Cycle through colors
|
371
|
+
node_type_colors[node_type] = tab10_cmap(color_idx)
|
372
|
+
node_colors.append(node_type_colors[node_type])
|
373
|
+
node_sizes.append(g.nodes[node]['size'])
|
374
|
+
node_alphas.append(g.nodes[node]['alpha'])
|
375
|
+
|
376
|
+
nx.draw_networkx_nodes(g, pos, node_size=node_sizes,
|
377
|
+
node_color=node_colors, margins=0.1,
|
378
|
+
alpha=node_alphas)
|
379
|
+
|
380
|
+
# Draw labels
|
381
|
+
labels = nx.get_node_attributes(g, 'label')
|
382
|
+
nx.draw_networkx_labels(g, pos, labels, font_size=10)
|
383
|
+
|
384
|
+
# Add legend
|
385
|
+
legend_elements = []
|
386
|
+
for node_type, color in node_type_colors.items():
|
387
|
+
legend_elements.append(
|
388
|
+
plt.Line2D([0], [0], marker='o', color='w', label=node_type,
|
389
|
+
markerfacecolor=color, markersize=10))
|
390
|
+
ax.legend(handles=legend_elements, loc='upper right',
|
391
|
+
bbox_to_anchor=(0.9, 1))
|
392
|
+
|
393
|
+
if path is not None:
|
394
|
+
plt.savefig(path, bbox_inches='tight')
|
395
|
+
else:
|
396
|
+
plt.show()
|
397
|
+
|
398
|
+
plt.close()
|
File without changes
|
{pyg_nightly-2.7.0.dev20250428.dist-info → pyg_nightly-2.7.0.dev20250429.dist-info}/licenses/LICENSE
RENAMED
File without changes
|