pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
- {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
- torch_geometric/__init__.py +28 -1
- torch_geometric/_compile.py +8 -1
- torch_geometric/_onnx.py +14 -0
- torch_geometric/config_mixin.py +113 -0
- torch_geometric/config_store.py +28 -19
- torch_geometric/data/__init__.py +24 -1
- torch_geometric/data/batch.py +2 -2
- torch_geometric/data/collate.py +8 -2
- torch_geometric/data/data.py +16 -8
- torch_geometric/data/database.py +61 -15
- torch_geometric/data/dataset.py +14 -6
- torch_geometric/data/feature_store.py +25 -42
- torch_geometric/data/graph_store.py +1 -5
- torch_geometric/data/hetero_data.py +18 -9
- torch_geometric/data/in_memory_dataset.py +2 -4
- torch_geometric/data/large_graph_indexer.py +677 -0
- torch_geometric/data/lightning/datamodule.py +4 -4
- torch_geometric/data/separate.py +6 -1
- torch_geometric/data/storage.py +17 -7
- torch_geometric/data/summary.py +14 -4
- torch_geometric/data/temporal.py +1 -2
- torch_geometric/datasets/__init__.py +17 -2
- torch_geometric/datasets/actor.py +9 -11
- torch_geometric/datasets/airfrans.py +15 -18
- torch_geometric/datasets/airports.py +10 -12
- torch_geometric/datasets/amazon.py +8 -11
- torch_geometric/datasets/amazon_book.py +9 -10
- torch_geometric/datasets/amazon_products.py +9 -10
- torch_geometric/datasets/aminer.py +8 -9
- torch_geometric/datasets/aqsol.py +10 -13
- torch_geometric/datasets/attributed_graph_dataset.py +10 -12
- torch_geometric/datasets/ba_multi_shapes.py +10 -12
- torch_geometric/datasets/ba_shapes.py +5 -6
- torch_geometric/datasets/bitcoin_otc.py +1 -1
- torch_geometric/datasets/brca_tgca.py +1 -1
- torch_geometric/datasets/cornell.py +145 -0
- torch_geometric/datasets/dblp.py +2 -1
- torch_geometric/datasets/dbp15k.py +2 -2
- torch_geometric/datasets/fake.py +1 -3
- torch_geometric/datasets/flickr.py +2 -1
- torch_geometric/datasets/freebase.py +1 -1
- torch_geometric/datasets/gdelt_lite.py +3 -2
- torch_geometric/datasets/ged_dataset.py +3 -2
- torch_geometric/datasets/git_mol_dataset.py +263 -0
- torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
- torch_geometric/datasets/hgb_dataset.py +8 -8
- torch_geometric/datasets/imdb.py +2 -1
- torch_geometric/datasets/karate.py +3 -2
- torch_geometric/datasets/last_fm.py +2 -1
- torch_geometric/datasets/linkx_dataset.py +4 -3
- torch_geometric/datasets/lrgb.py +3 -5
- torch_geometric/datasets/malnet_tiny.py +4 -3
- torch_geometric/datasets/mnist_superpixels.py +2 -3
- torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
- torch_geometric/datasets/molecule_net.py +15 -3
- torch_geometric/datasets/motif_generator/base.py +0 -1
- torch_geometric/datasets/neurograph.py +1 -3
- torch_geometric/datasets/ogb_mag.py +1 -1
- torch_geometric/datasets/opf.py +239 -0
- torch_geometric/datasets/ose_gvcs.py +1 -1
- torch_geometric/datasets/pascal.py +11 -9
- torch_geometric/datasets/pascal_pf.py +1 -1
- torch_geometric/datasets/pcpnet_dataset.py +1 -1
- torch_geometric/datasets/pcqm4m.py +10 -3
- torch_geometric/datasets/ppi.py +1 -1
- torch_geometric/datasets/qm9.py +8 -7
- torch_geometric/datasets/rcdd.py +4 -4
- torch_geometric/datasets/reddit.py +2 -1
- torch_geometric/datasets/reddit2.py +2 -1
- torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
- torch_geometric/datasets/s3dis.py +5 -3
- torch_geometric/datasets/shapenet.py +3 -3
- torch_geometric/datasets/shrec2016.py +2 -2
- torch_geometric/datasets/snap_dataset.py +7 -1
- torch_geometric/datasets/tag_dataset.py +350 -0
- torch_geometric/datasets/upfd.py +2 -1
- torch_geometric/datasets/web_qsp_dataset.py +246 -0
- torch_geometric/datasets/webkb.py +2 -2
- torch_geometric/datasets/wikics.py +1 -1
- torch_geometric/datasets/wikidata.py +3 -2
- torch_geometric/datasets/wikipedia_network.py +2 -2
- torch_geometric/datasets/willow_object_class.py +1 -1
- torch_geometric/datasets/word_net.py +2 -2
- torch_geometric/datasets/yelp.py +2 -1
- torch_geometric/datasets/zinc.py +1 -1
- torch_geometric/device.py +42 -0
- torch_geometric/distributed/local_feature_store.py +3 -2
- torch_geometric/distributed/local_graph_store.py +2 -1
- torch_geometric/distributed/partition.py +9 -8
- torch_geometric/edge_index.py +616 -438
- torch_geometric/explain/algorithm/base.py +0 -1
- torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
- torch_geometric/explain/algorithm/pg_explainer.py +1 -1
- torch_geometric/explain/explanation.py +2 -2
- torch_geometric/graphgym/checkpoint.py +2 -1
- torch_geometric/graphgym/logger.py +4 -4
- torch_geometric/graphgym/loss.py +1 -1
- torch_geometric/graphgym/utils/agg_runs.py +6 -6
- torch_geometric/index.py +826 -0
- torch_geometric/inspector.py +8 -3
- torch_geometric/io/fs.py +28 -2
- torch_geometric/io/npz.py +2 -1
- torch_geometric/io/off.py +2 -2
- torch_geometric/io/sdf.py +2 -2
- torch_geometric/io/tu.py +4 -5
- torch_geometric/loader/__init__.py +4 -0
- torch_geometric/loader/cluster.py +10 -4
- torch_geometric/loader/graph_saint.py +2 -1
- torch_geometric/loader/ibmb_loader.py +12 -4
- torch_geometric/loader/mixin.py +1 -1
- torch_geometric/loader/neighbor_loader.py +1 -1
- torch_geometric/loader/neighbor_sampler.py +2 -2
- torch_geometric/loader/prefetch.py +1 -1
- torch_geometric/loader/rag_loader.py +107 -0
- torch_geometric/loader/utils.py +8 -7
- torch_geometric/loader/zip_loader.py +10 -0
- torch_geometric/metrics/__init__.py +11 -2
- torch_geometric/metrics/link_pred.py +159 -34
- torch_geometric/nn/aggr/__init__.py +4 -0
- torch_geometric/nn/aggr/attention.py +0 -2
- torch_geometric/nn/aggr/base.py +2 -4
- torch_geometric/nn/aggr/patch_transformer.py +143 -0
- torch_geometric/nn/aggr/set_transformer.py +1 -1
- torch_geometric/nn/aggr/variance_preserving.py +33 -0
- torch_geometric/nn/attention/__init__.py +5 -1
- torch_geometric/nn/attention/qformer.py +71 -0
- torch_geometric/nn/conv/collect.jinja +7 -4
- torch_geometric/nn/conv/cugraph/base.py +8 -12
- torch_geometric/nn/conv/edge_conv.py +3 -2
- torch_geometric/nn/conv/fused_gat_conv.py +1 -1
- torch_geometric/nn/conv/gat_conv.py +35 -7
- torch_geometric/nn/conv/gatv2_conv.py +36 -6
- torch_geometric/nn/conv/general_conv.py +1 -1
- torch_geometric/nn/conv/graph_conv.py +21 -3
- torch_geometric/nn/conv/gravnet_conv.py +3 -2
- torch_geometric/nn/conv/hetero_conv.py +3 -3
- torch_geometric/nn/conv/hgt_conv.py +1 -1
- torch_geometric/nn/conv/message_passing.py +138 -87
- torch_geometric/nn/conv/mixhop_conv.py +1 -1
- torch_geometric/nn/conv/propagate.jinja +9 -1
- torch_geometric/nn/conv/rgcn_conv.py +5 -5
- torch_geometric/nn/conv/spline_conv.py +4 -4
- torch_geometric/nn/conv/x_conv.py +3 -2
- torch_geometric/nn/dense/linear.py +11 -6
- torch_geometric/nn/fx.py +3 -3
- torch_geometric/nn/model_hub.py +3 -1
- torch_geometric/nn/models/__init__.py +10 -2
- torch_geometric/nn/models/deep_graph_infomax.py +1 -2
- torch_geometric/nn/models/dimenet_utils.py +5 -7
- torch_geometric/nn/models/g_retriever.py +230 -0
- torch_geometric/nn/models/git_mol.py +336 -0
- torch_geometric/nn/models/glem.py +385 -0
- torch_geometric/nn/models/gnnff.py +0 -1
- torch_geometric/nn/models/graph_unet.py +12 -3
- torch_geometric/nn/models/jumping_knowledge.py +63 -4
- torch_geometric/nn/models/lightgcn.py +1 -1
- torch_geometric/nn/models/metapath2vec.py +5 -5
- torch_geometric/nn/models/molecule_gpt.py +222 -0
- torch_geometric/nn/models/node2vec.py +2 -3
- torch_geometric/nn/models/schnet.py +2 -1
- torch_geometric/nn/models/signed_gcn.py +3 -3
- torch_geometric/nn/module_dict.py +2 -2
- torch_geometric/nn/nlp/__init__.py +9 -0
- torch_geometric/nn/nlp/llm.py +322 -0
- torch_geometric/nn/nlp/sentence_transformer.py +134 -0
- torch_geometric/nn/nlp/vision_transformer.py +33 -0
- torch_geometric/nn/norm/batch_norm.py +1 -1
- torch_geometric/nn/parameter_dict.py +2 -2
- torch_geometric/nn/pool/__init__.py +21 -5
- torch_geometric/nn/pool/cluster_pool.py +145 -0
- torch_geometric/nn/pool/connect/base.py +0 -1
- torch_geometric/nn/pool/edge_pool.py +1 -1
- torch_geometric/nn/pool/graclus.py +4 -2
- torch_geometric/nn/pool/pool.py +8 -2
- torch_geometric/nn/pool/select/base.py +0 -1
- torch_geometric/nn/pool/voxel_grid.py +3 -2
- torch_geometric/nn/resolver.py +1 -1
- torch_geometric/nn/sequential.jinja +10 -23
- torch_geometric/nn/sequential.py +204 -78
- torch_geometric/nn/summary.py +1 -1
- torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
- torch_geometric/profile/__init__.py +2 -0
- torch_geometric/profile/nvtx.py +66 -0
- torch_geometric/profile/profiler.py +30 -19
- torch_geometric/resolver.py +1 -1
- torch_geometric/sampler/base.py +34 -13
- torch_geometric/sampler/neighbor_sampler.py +11 -10
- torch_geometric/sampler/utils.py +1 -1
- torch_geometric/template.py +1 -0
- torch_geometric/testing/__init__.py +6 -2
- torch_geometric/testing/decorators.py +53 -20
- torch_geometric/testing/feature_store.py +1 -1
- torch_geometric/transforms/__init__.py +2 -0
- torch_geometric/transforms/add_metapaths.py +5 -5
- torch_geometric/transforms/add_positional_encoding.py +1 -1
- torch_geometric/transforms/delaunay.py +65 -14
- torch_geometric/transforms/face_to_edge.py +32 -3
- torch_geometric/transforms/gdc.py +7 -6
- torch_geometric/transforms/laplacian_lambda_max.py +3 -3
- torch_geometric/transforms/mask.py +5 -1
- torch_geometric/transforms/node_property_split.py +1 -2
- torch_geometric/transforms/pad.py +7 -6
- torch_geometric/transforms/random_link_split.py +1 -1
- torch_geometric/transforms/remove_self_loops.py +36 -0
- torch_geometric/transforms/svd_feature_reduction.py +1 -1
- torch_geometric/transforms/to_sparse_tensor.py +1 -1
- torch_geometric/transforms/two_hop.py +1 -1
- torch_geometric/transforms/virtual_node.py +2 -1
- torch_geometric/typing.py +43 -6
- torch_geometric/utils/__init__.py +5 -1
- torch_geometric/utils/_negative_sampling.py +1 -1
- torch_geometric/utils/_normalize_edge_index.py +46 -0
- torch_geometric/utils/_scatter.py +38 -12
- torch_geometric/utils/_subgraph.py +4 -0
- torch_geometric/utils/_tree_decomposition.py +2 -2
- torch_geometric/utils/augmentation.py +1 -1
- torch_geometric/utils/convert.py +12 -8
- torch_geometric/utils/geodesic.py +24 -22
- torch_geometric/utils/hetero.py +1 -1
- torch_geometric/utils/map.py +8 -2
- torch_geometric/utils/smiles.py +65 -27
- torch_geometric/utils/sparse.py +39 -25
- torch_geometric/visualization/graph.py +3 -4
torch_geometric/utils/map.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import warnings
|
2
2
|
from typing import Optional, Tuple, Union
|
3
3
|
|
4
|
+
import numpy as np
|
4
5
|
import torch
|
5
6
|
from torch import Tensor
|
6
7
|
from torch.utils.dlpack import from_dlpack
|
@@ -13,7 +14,7 @@ def map_index(
|
|
13
14
|
inclusive: bool = False,
|
14
15
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
15
16
|
r"""Maps indices in :obj:`src` to the positional value of their
|
16
|
-
corresponding
|
17
|
+
corresponding occurrence in :obj:`index`.
|
17
18
|
Indices must be strictly positive.
|
18
19
|
|
19
20
|
Args:
|
@@ -110,7 +111,12 @@ def map_index(
|
|
110
111
|
result = pd.merge(left_ser, right_ser, how='left', left_on='left_ser',
|
111
112
|
right_index=True)
|
112
113
|
|
113
|
-
|
114
|
+
out_numpy = result['right_ser'].values
|
115
|
+
if (index.device.type == 'mps' # MPS does not support `float64`
|
116
|
+
and issubclass(out_numpy.dtype.type, np.floating)):
|
117
|
+
out_numpy = out_numpy.astype(np.float32)
|
118
|
+
|
119
|
+
out = torch.from_numpy(out_numpy).to(index.device)
|
114
120
|
|
115
121
|
if out.is_floating_point() and inclusive:
|
116
122
|
raise ValueError("Found invalid entries in 'src' that do not have "
|
torch_geometric/utils/smiles.py
CHANGED
@@ -77,32 +77,18 @@ e_map: Dict[str, List[Any]] = {
|
|
77
77
|
}
|
78
78
|
|
79
79
|
|
80
|
-
def
|
81
|
-
|
82
|
-
|
83
|
-
instance.
|
80
|
+
def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
|
81
|
+
r"""Converts a :class:`rdkit.Chem.Mol` instance to a
|
82
|
+
:class:`torch_geometric.data.Data` instance.
|
84
83
|
|
85
84
|
Args:
|
86
|
-
|
87
|
-
with_hydrogen (bool, optional): If set to :obj:`True`, will store
|
88
|
-
hydrogens in the molecule graph. (default: :obj:`False`)
|
89
|
-
kekulize (bool, optional): If set to :obj:`True`, converts aromatic
|
90
|
-
bonds to single/double bonds. (default: :obj:`False`)
|
85
|
+
mol (rdkit.Chem.Mol): The :class:`rdkit` molecule.
|
91
86
|
"""
|
92
|
-
from rdkit import Chem
|
87
|
+
from rdkit import Chem
|
93
88
|
|
94
89
|
from torch_geometric.data import Data
|
95
90
|
|
96
|
-
|
97
|
-
|
98
|
-
mol = Chem.MolFromSmiles(smiles)
|
99
|
-
|
100
|
-
if mol is None:
|
101
|
-
mol = Chem.MolFromSmiles('')
|
102
|
-
if with_hydrogen:
|
103
|
-
mol = Chem.AddHs(mol)
|
104
|
-
if kekulize:
|
105
|
-
Chem.Kekulize(mol)
|
91
|
+
assert isinstance(mol, Chem.Mol)
|
106
92
|
|
107
93
|
xs: List[List[int]] = []
|
108
94
|
for atom in mol.GetAtoms():
|
@@ -142,16 +128,51 @@ def from_smiles(smiles: str, with_hydrogen: bool = False,
|
|
142
128
|
perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
|
143
129
|
edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
|
144
130
|
|
145
|
-
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr
|
131
|
+
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
146
132
|
|
147
133
|
|
148
|
-
def
|
149
|
-
|
150
|
-
|
151
|
-
|
134
|
+
def from_smiles(
|
135
|
+
smiles: str,
|
136
|
+
with_hydrogen: bool = False,
|
137
|
+
kekulize: bool = False,
|
138
|
+
) -> 'torch_geometric.data.Data':
|
139
|
+
r"""Converts a SMILES string to a :class:`torch_geometric.data.Data`
|
140
|
+
instance.
|
152
141
|
|
153
142
|
Args:
|
154
|
-
|
143
|
+
smiles (str): The SMILES string.
|
144
|
+
with_hydrogen (bool, optional): If set to :obj:`True`, will store
|
145
|
+
hydrogens in the molecule graph. (default: :obj:`False`)
|
146
|
+
kekulize (bool, optional): If set to :obj:`True`, converts aromatic
|
147
|
+
bonds to single/double bonds. (default: :obj:`False`)
|
148
|
+
"""
|
149
|
+
from rdkit import Chem, RDLogger
|
150
|
+
|
151
|
+
RDLogger.DisableLog('rdApp.*') # type: ignore
|
152
|
+
|
153
|
+
mol = Chem.MolFromSmiles(smiles)
|
154
|
+
|
155
|
+
if mol is None:
|
156
|
+
mol = Chem.MolFromSmiles('')
|
157
|
+
if with_hydrogen:
|
158
|
+
mol = Chem.AddHs(mol)
|
159
|
+
if kekulize:
|
160
|
+
Chem.Kekulize(mol)
|
161
|
+
|
162
|
+
data = from_rdmol(mol)
|
163
|
+
data.smiles = smiles
|
164
|
+
return data
|
165
|
+
|
166
|
+
|
167
|
+
def to_rdmol(
|
168
|
+
data: 'torch_geometric.data.Data',
|
169
|
+
kekulize: bool = False,
|
170
|
+
) -> Any:
|
171
|
+
"""Converts a :class:`torch_geometric.data.Data` instance to a
|
172
|
+
:class:`rdkit.Chem.Mol` instance.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
data (torch_geometric.data.Data): The molecular graph data.
|
155
176
|
kekulize (bool, optional): If set to :obj:`True`, converts aromatic
|
156
177
|
bonds to single/double bonds. (default: :obj:`False`)
|
157
178
|
"""
|
@@ -172,7 +193,7 @@ def to_smiles(data: 'torch_geometric.data.Data',
|
|
172
193
|
data.x[i, 5])])
|
173
194
|
atom.SetHybridization(Chem.rdchem.HybridizationType.values[int(
|
174
195
|
data.x[i, 6])])
|
175
|
-
atom.SetIsAromatic(
|
196
|
+
atom.SetIsAromatic(bool(data.x[i, 7]))
|
176
197
|
mol.AddAtom(atom)
|
177
198
|
|
178
199
|
edges = [tuple(i) for i in data.edge_index.t().tolist()]
|
@@ -207,4 +228,21 @@ def to_smiles(data: 'torch_geometric.data.Data',
|
|
207
228
|
Chem.SanitizeMol(mol)
|
208
229
|
Chem.AssignStereochemistry(mol)
|
209
230
|
|
231
|
+
return mol
|
232
|
+
|
233
|
+
|
234
|
+
def to_smiles(
|
235
|
+
data: 'torch_geometric.data.Data',
|
236
|
+
kekulize: bool = False,
|
237
|
+
) -> str:
|
238
|
+
"""Converts a :class:`torch_geometric.data.Data` instance to a SMILES
|
239
|
+
string.
|
240
|
+
|
241
|
+
Args:
|
242
|
+
data (torch_geometric.data.Data): The molecular graph.
|
243
|
+
kekulize (bool, optional): If set to :obj:`True`, converts aromatic
|
244
|
+
bonds to single/double bonds. (default: :obj:`False`)
|
245
|
+
"""
|
246
|
+
from rdkit import Chem
|
247
|
+
mol = to_rdmol(data, kekulize=kekulize)
|
210
248
|
return Chem.MolToSmiles(mol, isomericSmiles=True)
|
torch_geometric/utils/sparse.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6
6
|
from torch import Tensor
|
7
7
|
|
8
8
|
import torch_geometric.typing
|
9
|
+
from torch_geometric.index import index2ptr, ptr2index
|
9
10
|
from torch_geometric.typing import SparseTensor
|
10
11
|
from torch_geometric.utils import coalesce, cumsum
|
11
12
|
|
@@ -197,15 +198,23 @@ def to_torch_coo_tensor(
|
|
197
198
|
# edge_attr = edge_attr.expand(edge_index.size(1))
|
198
199
|
edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)
|
199
200
|
|
200
|
-
|
201
|
+
if not torch_geometric.typing.WITH_PT21:
|
202
|
+
adj = torch.sparse_coo_tensor(
|
203
|
+
indices=edge_index,
|
204
|
+
values=edge_attr,
|
205
|
+
size=tuple(size) + edge_attr.size()[1:],
|
206
|
+
device=edge_index.device,
|
207
|
+
)
|
208
|
+
adj = adj._coalesced_(True)
|
209
|
+
return adj
|
210
|
+
|
211
|
+
return torch.sparse_coo_tensor(
|
201
212
|
indices=edge_index,
|
202
213
|
values=edge_attr,
|
203
214
|
size=tuple(size) + edge_attr.size()[1:],
|
204
215
|
device=edge_index.device,
|
216
|
+
is_coalesced=True,
|
205
217
|
)
|
206
|
-
adj = adj._coalesced_(True)
|
207
|
-
|
208
|
-
return adj
|
209
218
|
|
210
219
|
|
211
220
|
def to_torch_csr_tensor(
|
@@ -483,65 +492,70 @@ def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor:
|
|
483
492
|
raise ValueError(f"Unexpected sparse tensor layout (got '{adj.layout}')")
|
484
493
|
|
485
494
|
|
486
|
-
def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
|
487
|
-
index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
|
488
|
-
return index.repeat_interleave(ptr.diff(), output_size=output_size)
|
489
|
-
|
490
|
-
|
491
|
-
def index2ptr(index: Tensor, size: Optional[int] = None) -> Tensor:
|
492
|
-
if size is None:
|
493
|
-
size = int(index.max()) + 1 if index.numel() > 0 else 0
|
494
|
-
|
495
|
-
return torch._convert_indices_from_coo_to_csr(
|
496
|
-
index, size, out_int32=index.dtype == torch.int32)
|
497
|
-
|
498
|
-
|
499
495
|
def cat_coo(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:
|
500
496
|
assert dim in {0, 1, (0, 1)}
|
501
497
|
assert tensors[0].layout == torch.sparse_coo
|
502
498
|
|
503
499
|
indices, values = [], []
|
504
500
|
num_rows = num_cols = 0
|
501
|
+
is_coalesced = True
|
505
502
|
|
506
503
|
if dim == 0:
|
507
504
|
for i, tensor in enumerate(tensors):
|
508
505
|
if i == 0:
|
509
|
-
indices.append(tensor.
|
506
|
+
indices.append(tensor._indices())
|
510
507
|
else:
|
511
508
|
offset = torch.tensor([[num_rows], [0]], device=tensor.device)
|
512
|
-
indices.append(tensor.
|
513
|
-
values.append(tensor.
|
509
|
+
indices.append(tensor._indices() + offset)
|
510
|
+
values.append(tensor._values())
|
514
511
|
num_rows += tensor.size(0)
|
515
512
|
num_cols = max(num_cols, tensor.size(1))
|
513
|
+
if not tensor.is_coalesced():
|
514
|
+
is_coalesced = False
|
516
515
|
|
517
516
|
elif dim == 1:
|
518
517
|
for i, tensor in enumerate(tensors):
|
519
518
|
if i == 0:
|
520
|
-
indices.append(tensor.
|
519
|
+
indices.append(tensor._indices())
|
521
520
|
else:
|
522
521
|
offset = torch.tensor([[0], [num_cols]], device=tensor.device)
|
523
522
|
indices.append(tensor.indices() + offset)
|
524
|
-
values.append(tensor.
|
523
|
+
values.append(tensor._values())
|
525
524
|
num_rows = max(num_rows, tensor.size(0))
|
526
525
|
num_cols += tensor.size(1)
|
526
|
+
is_coalesced = False
|
527
527
|
|
528
528
|
else:
|
529
529
|
for i, tensor in enumerate(tensors):
|
530
530
|
if i == 0:
|
531
|
-
indices.append(tensor.
|
531
|
+
indices.append(tensor._indices())
|
532
532
|
else:
|
533
533
|
offset = torch.tensor([[num_rows], [num_cols]],
|
534
534
|
device=tensor.device)
|
535
|
-
indices.append(tensor.
|
536
|
-
values.append(tensor.
|
535
|
+
indices.append(tensor._indices() + offset)
|
536
|
+
values.append(tensor._values())
|
537
537
|
num_rows += tensor.size(0)
|
538
538
|
num_cols += tensor.size(1)
|
539
|
+
if not tensor.is_coalesced():
|
540
|
+
is_coalesced = False
|
541
|
+
|
542
|
+
if not torch_geometric.typing.WITH_PT21:
|
543
|
+
out = torch.sparse_coo_tensor(
|
544
|
+
indices=torch.cat(indices, dim=-1),
|
545
|
+
values=torch.cat(values),
|
546
|
+
size=(num_rows, num_cols) + values[-1].size()[1:],
|
547
|
+
device=tensor.device,
|
548
|
+
)
|
549
|
+
if is_coalesced:
|
550
|
+
out = out._coalesced_(True)
|
551
|
+
return out
|
539
552
|
|
540
553
|
return torch.sparse_coo_tensor(
|
541
554
|
indices=torch.cat(indices, dim=-1),
|
542
555
|
values=torch.cat(values),
|
543
556
|
size=(num_rows, num_cols) + values[-1].size()[1:],
|
544
557
|
device=tensor.device,
|
558
|
+
is_coalesced=True if is_coalesced else None,
|
545
559
|
)
|
546
560
|
|
547
561
|
|
@@ -132,7 +132,7 @@ def _visualize_graph_via_networkx(
|
|
132
132
|
xy=pos[src],
|
133
133
|
xytext=pos[dst],
|
134
134
|
arrowprops=dict(
|
135
|
-
arrowstyle="
|
135
|
+
arrowstyle="<-",
|
136
136
|
alpha=data['alpha'],
|
137
137
|
shrinkA=sqrt(node_size) / 2.0,
|
138
138
|
shrinkB=sqrt(node_size) / 2.0,
|
@@ -140,9 +140,8 @@ def _visualize_graph_via_networkx(
|
|
140
140
|
),
|
141
141
|
)
|
142
142
|
|
143
|
-
|
144
|
-
|
145
|
-
nodes.set_edgecolor('black')
|
143
|
+
nx.draw_networkx_nodes(g, pos, node_size=node_size, node_color='white',
|
144
|
+
margins=0.1, edgecolors='black')
|
146
145
|
nx.draw_networkx_labels(g, pos, font_size=10)
|
147
146
|
|
148
147
|
if path is not None:
|