pyg-nightly 2.6.0.dev20240319__py3-none-any.whl → 2.7.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/METADATA +31 -47
  2. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/RECORD +226 -199
  3. {pyg_nightly-2.6.0.dev20240319.dist-info → pyg_nightly-2.7.0.dev20250114.dist-info}/WHEEL +1 -1
  4. torch_geometric/__init__.py +28 -1
  5. torch_geometric/_compile.py +8 -1
  6. torch_geometric/_onnx.py +14 -0
  7. torch_geometric/config_mixin.py +113 -0
  8. torch_geometric/config_store.py +28 -19
  9. torch_geometric/data/__init__.py +24 -1
  10. torch_geometric/data/batch.py +2 -2
  11. torch_geometric/data/collate.py +8 -2
  12. torch_geometric/data/data.py +16 -8
  13. torch_geometric/data/database.py +61 -15
  14. torch_geometric/data/dataset.py +14 -6
  15. torch_geometric/data/feature_store.py +25 -42
  16. torch_geometric/data/graph_store.py +1 -5
  17. torch_geometric/data/hetero_data.py +18 -9
  18. torch_geometric/data/in_memory_dataset.py +2 -4
  19. torch_geometric/data/large_graph_indexer.py +677 -0
  20. torch_geometric/data/lightning/datamodule.py +4 -4
  21. torch_geometric/data/separate.py +6 -1
  22. torch_geometric/data/storage.py +17 -7
  23. torch_geometric/data/summary.py +14 -4
  24. torch_geometric/data/temporal.py +1 -2
  25. torch_geometric/datasets/__init__.py +17 -2
  26. torch_geometric/datasets/actor.py +9 -11
  27. torch_geometric/datasets/airfrans.py +15 -18
  28. torch_geometric/datasets/airports.py +10 -12
  29. torch_geometric/datasets/amazon.py +8 -11
  30. torch_geometric/datasets/amazon_book.py +9 -10
  31. torch_geometric/datasets/amazon_products.py +9 -10
  32. torch_geometric/datasets/aminer.py +8 -9
  33. torch_geometric/datasets/aqsol.py +10 -13
  34. torch_geometric/datasets/attributed_graph_dataset.py +10 -12
  35. torch_geometric/datasets/ba_multi_shapes.py +10 -12
  36. torch_geometric/datasets/ba_shapes.py +5 -6
  37. torch_geometric/datasets/bitcoin_otc.py +1 -1
  38. torch_geometric/datasets/brca_tgca.py +1 -1
  39. torch_geometric/datasets/cornell.py +145 -0
  40. torch_geometric/datasets/dblp.py +2 -1
  41. torch_geometric/datasets/dbp15k.py +2 -2
  42. torch_geometric/datasets/fake.py +1 -3
  43. torch_geometric/datasets/flickr.py +2 -1
  44. torch_geometric/datasets/freebase.py +1 -1
  45. torch_geometric/datasets/gdelt_lite.py +3 -2
  46. torch_geometric/datasets/ged_dataset.py +3 -2
  47. torch_geometric/datasets/git_mol_dataset.py +263 -0
  48. torch_geometric/datasets/gnn_benchmark_dataset.py +11 -10
  49. torch_geometric/datasets/hgb_dataset.py +8 -8
  50. torch_geometric/datasets/imdb.py +2 -1
  51. torch_geometric/datasets/karate.py +3 -2
  52. torch_geometric/datasets/last_fm.py +2 -1
  53. torch_geometric/datasets/linkx_dataset.py +4 -3
  54. torch_geometric/datasets/lrgb.py +3 -5
  55. torch_geometric/datasets/malnet_tiny.py +4 -3
  56. torch_geometric/datasets/mnist_superpixels.py +2 -3
  57. torch_geometric/datasets/molecule_gpt_dataset.py +485 -0
  58. torch_geometric/datasets/molecule_net.py +15 -3
  59. torch_geometric/datasets/motif_generator/base.py +0 -1
  60. torch_geometric/datasets/neurograph.py +1 -3
  61. torch_geometric/datasets/ogb_mag.py +1 -1
  62. torch_geometric/datasets/opf.py +239 -0
  63. torch_geometric/datasets/ose_gvcs.py +1 -1
  64. torch_geometric/datasets/pascal.py +11 -9
  65. torch_geometric/datasets/pascal_pf.py +1 -1
  66. torch_geometric/datasets/pcpnet_dataset.py +1 -1
  67. torch_geometric/datasets/pcqm4m.py +10 -3
  68. torch_geometric/datasets/ppi.py +1 -1
  69. torch_geometric/datasets/qm9.py +8 -7
  70. torch_geometric/datasets/rcdd.py +4 -4
  71. torch_geometric/datasets/reddit.py +2 -1
  72. torch_geometric/datasets/reddit2.py +2 -1
  73. torch_geometric/datasets/rel_link_pred_dataset.py +3 -3
  74. torch_geometric/datasets/s3dis.py +5 -3
  75. torch_geometric/datasets/shapenet.py +3 -3
  76. torch_geometric/datasets/shrec2016.py +2 -2
  77. torch_geometric/datasets/snap_dataset.py +7 -1
  78. torch_geometric/datasets/tag_dataset.py +350 -0
  79. torch_geometric/datasets/upfd.py +2 -1
  80. torch_geometric/datasets/web_qsp_dataset.py +246 -0
  81. torch_geometric/datasets/webkb.py +2 -2
  82. torch_geometric/datasets/wikics.py +1 -1
  83. torch_geometric/datasets/wikidata.py +3 -2
  84. torch_geometric/datasets/wikipedia_network.py +2 -2
  85. torch_geometric/datasets/willow_object_class.py +1 -1
  86. torch_geometric/datasets/word_net.py +2 -2
  87. torch_geometric/datasets/yelp.py +2 -1
  88. torch_geometric/datasets/zinc.py +1 -1
  89. torch_geometric/device.py +42 -0
  90. torch_geometric/distributed/local_feature_store.py +3 -2
  91. torch_geometric/distributed/local_graph_store.py +2 -1
  92. torch_geometric/distributed/partition.py +9 -8
  93. torch_geometric/edge_index.py +616 -438
  94. torch_geometric/explain/algorithm/base.py +0 -1
  95. torch_geometric/explain/algorithm/graphmask_explainer.py +1 -2
  96. torch_geometric/explain/algorithm/pg_explainer.py +1 -1
  97. torch_geometric/explain/explanation.py +2 -2
  98. torch_geometric/graphgym/checkpoint.py +2 -1
  99. torch_geometric/graphgym/logger.py +4 -4
  100. torch_geometric/graphgym/loss.py +1 -1
  101. torch_geometric/graphgym/utils/agg_runs.py +6 -6
  102. torch_geometric/index.py +826 -0
  103. torch_geometric/inspector.py +8 -3
  104. torch_geometric/io/fs.py +28 -2
  105. torch_geometric/io/npz.py +2 -1
  106. torch_geometric/io/off.py +2 -2
  107. torch_geometric/io/sdf.py +2 -2
  108. torch_geometric/io/tu.py +4 -5
  109. torch_geometric/loader/__init__.py +4 -0
  110. torch_geometric/loader/cluster.py +10 -4
  111. torch_geometric/loader/graph_saint.py +2 -1
  112. torch_geometric/loader/ibmb_loader.py +12 -4
  113. torch_geometric/loader/mixin.py +1 -1
  114. torch_geometric/loader/neighbor_loader.py +1 -1
  115. torch_geometric/loader/neighbor_sampler.py +2 -2
  116. torch_geometric/loader/prefetch.py +1 -1
  117. torch_geometric/loader/rag_loader.py +107 -0
  118. torch_geometric/loader/utils.py +8 -7
  119. torch_geometric/loader/zip_loader.py +10 -0
  120. torch_geometric/metrics/__init__.py +11 -2
  121. torch_geometric/metrics/link_pred.py +159 -34
  122. torch_geometric/nn/aggr/__init__.py +4 -0
  123. torch_geometric/nn/aggr/attention.py +0 -2
  124. torch_geometric/nn/aggr/base.py +2 -4
  125. torch_geometric/nn/aggr/patch_transformer.py +143 -0
  126. torch_geometric/nn/aggr/set_transformer.py +1 -1
  127. torch_geometric/nn/aggr/variance_preserving.py +33 -0
  128. torch_geometric/nn/attention/__init__.py +5 -1
  129. torch_geometric/nn/attention/qformer.py +71 -0
  130. torch_geometric/nn/conv/collect.jinja +7 -4
  131. torch_geometric/nn/conv/cugraph/base.py +8 -12
  132. torch_geometric/nn/conv/edge_conv.py +3 -2
  133. torch_geometric/nn/conv/fused_gat_conv.py +1 -1
  134. torch_geometric/nn/conv/gat_conv.py +35 -7
  135. torch_geometric/nn/conv/gatv2_conv.py +36 -6
  136. torch_geometric/nn/conv/general_conv.py +1 -1
  137. torch_geometric/nn/conv/graph_conv.py +21 -3
  138. torch_geometric/nn/conv/gravnet_conv.py +3 -2
  139. torch_geometric/nn/conv/hetero_conv.py +3 -3
  140. torch_geometric/nn/conv/hgt_conv.py +1 -1
  141. torch_geometric/nn/conv/message_passing.py +138 -87
  142. torch_geometric/nn/conv/mixhop_conv.py +1 -1
  143. torch_geometric/nn/conv/propagate.jinja +9 -1
  144. torch_geometric/nn/conv/rgcn_conv.py +5 -5
  145. torch_geometric/nn/conv/spline_conv.py +4 -4
  146. torch_geometric/nn/conv/x_conv.py +3 -2
  147. torch_geometric/nn/dense/linear.py +11 -6
  148. torch_geometric/nn/fx.py +3 -3
  149. torch_geometric/nn/model_hub.py +3 -1
  150. torch_geometric/nn/models/__init__.py +10 -2
  151. torch_geometric/nn/models/deep_graph_infomax.py +1 -2
  152. torch_geometric/nn/models/dimenet_utils.py +5 -7
  153. torch_geometric/nn/models/g_retriever.py +230 -0
  154. torch_geometric/nn/models/git_mol.py +336 -0
  155. torch_geometric/nn/models/glem.py +385 -0
  156. torch_geometric/nn/models/gnnff.py +0 -1
  157. torch_geometric/nn/models/graph_unet.py +12 -3
  158. torch_geometric/nn/models/jumping_knowledge.py +63 -4
  159. torch_geometric/nn/models/lightgcn.py +1 -1
  160. torch_geometric/nn/models/metapath2vec.py +5 -5
  161. torch_geometric/nn/models/molecule_gpt.py +222 -0
  162. torch_geometric/nn/models/node2vec.py +2 -3
  163. torch_geometric/nn/models/schnet.py +2 -1
  164. torch_geometric/nn/models/signed_gcn.py +3 -3
  165. torch_geometric/nn/module_dict.py +2 -2
  166. torch_geometric/nn/nlp/__init__.py +9 -0
  167. torch_geometric/nn/nlp/llm.py +322 -0
  168. torch_geometric/nn/nlp/sentence_transformer.py +134 -0
  169. torch_geometric/nn/nlp/vision_transformer.py +33 -0
  170. torch_geometric/nn/norm/batch_norm.py +1 -1
  171. torch_geometric/nn/parameter_dict.py +2 -2
  172. torch_geometric/nn/pool/__init__.py +21 -5
  173. torch_geometric/nn/pool/cluster_pool.py +145 -0
  174. torch_geometric/nn/pool/connect/base.py +0 -1
  175. torch_geometric/nn/pool/edge_pool.py +1 -1
  176. torch_geometric/nn/pool/graclus.py +4 -2
  177. torch_geometric/nn/pool/pool.py +8 -2
  178. torch_geometric/nn/pool/select/base.py +0 -1
  179. torch_geometric/nn/pool/voxel_grid.py +3 -2
  180. torch_geometric/nn/resolver.py +1 -1
  181. torch_geometric/nn/sequential.jinja +10 -23
  182. torch_geometric/nn/sequential.py +204 -78
  183. torch_geometric/nn/summary.py +1 -1
  184. torch_geometric/nn/to_hetero_with_bases_transformer.py +19 -19
  185. torch_geometric/profile/__init__.py +2 -0
  186. torch_geometric/profile/nvtx.py +66 -0
  187. torch_geometric/profile/profiler.py +30 -19
  188. torch_geometric/resolver.py +1 -1
  189. torch_geometric/sampler/base.py +34 -13
  190. torch_geometric/sampler/neighbor_sampler.py +11 -10
  191. torch_geometric/sampler/utils.py +1 -1
  192. torch_geometric/template.py +1 -0
  193. torch_geometric/testing/__init__.py +6 -2
  194. torch_geometric/testing/decorators.py +53 -20
  195. torch_geometric/testing/feature_store.py +1 -1
  196. torch_geometric/transforms/__init__.py +2 -0
  197. torch_geometric/transforms/add_metapaths.py +5 -5
  198. torch_geometric/transforms/add_positional_encoding.py +1 -1
  199. torch_geometric/transforms/delaunay.py +65 -14
  200. torch_geometric/transforms/face_to_edge.py +32 -3
  201. torch_geometric/transforms/gdc.py +7 -6
  202. torch_geometric/transforms/laplacian_lambda_max.py +3 -3
  203. torch_geometric/transforms/mask.py +5 -1
  204. torch_geometric/transforms/node_property_split.py +1 -2
  205. torch_geometric/transforms/pad.py +7 -6
  206. torch_geometric/transforms/random_link_split.py +1 -1
  207. torch_geometric/transforms/remove_self_loops.py +36 -0
  208. torch_geometric/transforms/svd_feature_reduction.py +1 -1
  209. torch_geometric/transforms/to_sparse_tensor.py +1 -1
  210. torch_geometric/transforms/two_hop.py +1 -1
  211. torch_geometric/transforms/virtual_node.py +2 -1
  212. torch_geometric/typing.py +43 -6
  213. torch_geometric/utils/__init__.py +5 -1
  214. torch_geometric/utils/_negative_sampling.py +1 -1
  215. torch_geometric/utils/_normalize_edge_index.py +46 -0
  216. torch_geometric/utils/_scatter.py +38 -12
  217. torch_geometric/utils/_subgraph.py +4 -0
  218. torch_geometric/utils/_tree_decomposition.py +2 -2
  219. torch_geometric/utils/augmentation.py +1 -1
  220. torch_geometric/utils/convert.py +12 -8
  221. torch_geometric/utils/geodesic.py +24 -22
  222. torch_geometric/utils/hetero.py +1 -1
  223. torch_geometric/utils/map.py +8 -2
  224. torch_geometric/utils/smiles.py +65 -27
  225. torch_geometric/utils/sparse.py +39 -25
  226. torch_geometric/visualization/graph.py +3 -4
@@ -1,6 +1,7 @@
1
1
  import warnings
2
2
  from typing import Optional, Tuple, Union
3
3
 
4
+ import numpy as np
4
5
  import torch
5
6
  from torch import Tensor
6
7
  from torch.utils.dlpack import from_dlpack
@@ -13,7 +14,7 @@ def map_index(
13
14
  inclusive: bool = False,
14
15
  ) -> Tuple[Tensor, Optional[Tensor]]:
15
16
  r"""Maps indices in :obj:`src` to the positional value of their
16
- corresponding occurence in :obj:`index`.
17
+ corresponding occurrence in :obj:`index`.
17
18
  Indices must be strictly positive.
18
19
 
19
20
  Args:
@@ -110,7 +111,12 @@ def map_index(
110
111
  result = pd.merge(left_ser, right_ser, how='left', left_on='left_ser',
111
112
  right_index=True)
112
113
 
113
- out = torch.from_numpy(result['right_ser'].values).to(index.device)
114
+ out_numpy = result['right_ser'].values
115
+ if (index.device.type == 'mps' # MPS does not support `float64`
116
+ and issubclass(out_numpy.dtype.type, np.floating)):
117
+ out_numpy = out_numpy.astype(np.float32)
118
+
119
+ out = torch.from_numpy(out_numpy).to(index.device)
114
120
 
115
121
  if out.is_floating_point() and inclusive:
116
122
  raise ValueError("Found invalid entries in 'src' that do not have "
@@ -77,32 +77,18 @@ e_map: Dict[str, List[Any]] = {
77
77
  }
78
78
 
79
79
 
80
- def from_smiles(smiles: str, with_hydrogen: bool = False,
81
- kekulize: bool = False) -> 'torch_geometric.data.Data':
82
- r"""Converts a SMILES string to a :class:`torch_geometric.data.Data`
83
- instance.
80
+ def from_rdmol(mol: Any) -> 'torch_geometric.data.Data':
81
+ r"""Converts a :class:`rdkit.Chem.Mol` instance to a
82
+ :class:`torch_geometric.data.Data` instance.
84
83
 
85
84
  Args:
86
- smiles (str): The SMILES string.
87
- with_hydrogen (bool, optional): If set to :obj:`True`, will store
88
- hydrogens in the molecule graph. (default: :obj:`False`)
89
- kekulize (bool, optional): If set to :obj:`True`, converts aromatic
90
- bonds to single/double bonds. (default: :obj:`False`)
85
+ mol (rdkit.Chem.Mol): The :class:`rdkit` molecule.
91
86
  """
92
- from rdkit import Chem, RDLogger
87
+ from rdkit import Chem
93
88
 
94
89
  from torch_geometric.data import Data
95
90
 
96
- RDLogger.DisableLog('rdApp.*')
97
-
98
- mol = Chem.MolFromSmiles(smiles)
99
-
100
- if mol is None:
101
- mol = Chem.MolFromSmiles('')
102
- if with_hydrogen:
103
- mol = Chem.AddHs(mol)
104
- if kekulize:
105
- Chem.Kekulize(mol)
91
+ assert isinstance(mol, Chem.Mol)
106
92
 
107
93
  xs: List[List[int]] = []
108
94
  for atom in mol.GetAtoms():
@@ -142,16 +128,51 @@ def from_smiles(smiles: str, with_hydrogen: bool = False,
142
128
  perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
143
129
  edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
144
130
 
145
- return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
131
+ return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
146
132
 
147
133
 
148
- def to_smiles(data: 'torch_geometric.data.Data',
149
- kekulize: bool = False) -> Any:
150
- """Converts a :class:`torch_geometric.data.Data` instance to a SMILES
151
- string.
134
+ def from_smiles(
135
+ smiles: str,
136
+ with_hydrogen: bool = False,
137
+ kekulize: bool = False,
138
+ ) -> 'torch_geometric.data.Data':
139
+ r"""Converts a SMILES string to a :class:`torch_geometric.data.Data`
140
+ instance.
152
141
 
153
142
  Args:
154
- data (torch_geometric.data.Data): The molecular graph.
143
+ smiles (str): The SMILES string.
144
+ with_hydrogen (bool, optional): If set to :obj:`True`, will store
145
+ hydrogens in the molecule graph. (default: :obj:`False`)
146
+ kekulize (bool, optional): If set to :obj:`True`, converts aromatic
147
+ bonds to single/double bonds. (default: :obj:`False`)
148
+ """
149
+ from rdkit import Chem, RDLogger
150
+
151
+ RDLogger.DisableLog('rdApp.*') # type: ignore
152
+
153
+ mol = Chem.MolFromSmiles(smiles)
154
+
155
+ if mol is None:
156
+ mol = Chem.MolFromSmiles('')
157
+ if with_hydrogen:
158
+ mol = Chem.AddHs(mol)
159
+ if kekulize:
160
+ Chem.Kekulize(mol)
161
+
162
+ data = from_rdmol(mol)
163
+ data.smiles = smiles
164
+ return data
165
+
166
+
167
+ def to_rdmol(
168
+ data: 'torch_geometric.data.Data',
169
+ kekulize: bool = False,
170
+ ) -> Any:
171
+ """Converts a :class:`torch_geometric.data.Data` instance to a
172
+ :class:`rdkit.Chem.Mol` instance.
173
+
174
+ Args:
175
+ data (torch_geometric.data.Data): The molecular graph data.
155
176
  kekulize (bool, optional): If set to :obj:`True`, converts aromatic
156
177
  bonds to single/double bonds. (default: :obj:`False`)
157
178
  """
@@ -172,7 +193,7 @@ def to_smiles(data: 'torch_geometric.data.Data',
172
193
  data.x[i, 5])])
173
194
  atom.SetHybridization(Chem.rdchem.HybridizationType.values[int(
174
195
  data.x[i, 6])])
175
- atom.SetIsAromatic(int(data.x[i, 7]))
196
+ atom.SetIsAromatic(bool(data.x[i, 7]))
176
197
  mol.AddAtom(atom)
177
198
 
178
199
  edges = [tuple(i) for i in data.edge_index.t().tolist()]
@@ -207,4 +228,21 @@ def to_smiles(data: 'torch_geometric.data.Data',
207
228
  Chem.SanitizeMol(mol)
208
229
  Chem.AssignStereochemistry(mol)
209
230
 
231
+ return mol
232
+
233
+
234
+ def to_smiles(
235
+ data: 'torch_geometric.data.Data',
236
+ kekulize: bool = False,
237
+ ) -> str:
238
+ """Converts a :class:`torch_geometric.data.Data` instance to a SMILES
239
+ string.
240
+
241
+ Args:
242
+ data (torch_geometric.data.Data): The molecular graph.
243
+ kekulize (bool, optional): If set to :obj:`True`, converts aromatic
244
+ bonds to single/double bonds. (default: :obj:`False`)
245
+ """
246
+ from rdkit import Chem
247
+ mol = to_rdmol(data, kekulize=kekulize)
210
248
  return Chem.MolToSmiles(mol, isomericSmiles=True)
@@ -6,6 +6,7 @@ import torch
6
6
  from torch import Tensor
7
7
 
8
8
  import torch_geometric.typing
9
+ from torch_geometric.index import index2ptr, ptr2index
9
10
  from torch_geometric.typing import SparseTensor
10
11
  from torch_geometric.utils import coalesce, cumsum
11
12
 
@@ -197,15 +198,23 @@ def to_torch_coo_tensor(
197
198
  # edge_attr = edge_attr.expand(edge_index.size(1))
198
199
  edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)
199
200
 
200
- adj = torch.sparse_coo_tensor(
201
+ if not torch_geometric.typing.WITH_PT21:
202
+ adj = torch.sparse_coo_tensor(
203
+ indices=edge_index,
204
+ values=edge_attr,
205
+ size=tuple(size) + edge_attr.size()[1:],
206
+ device=edge_index.device,
207
+ )
208
+ adj = adj._coalesced_(True)
209
+ return adj
210
+
211
+ return torch.sparse_coo_tensor(
201
212
  indices=edge_index,
202
213
  values=edge_attr,
203
214
  size=tuple(size) + edge_attr.size()[1:],
204
215
  device=edge_index.device,
216
+ is_coalesced=True,
205
217
  )
206
- adj = adj._coalesced_(True)
207
-
208
- return adj
209
218
 
210
219
 
211
220
  def to_torch_csr_tensor(
@@ -483,65 +492,70 @@ def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor:
483
492
  raise ValueError(f"Unexpected sparse tensor layout (got '{adj.layout}')")
484
493
 
485
494
 
486
- def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
487
- index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
488
- return index.repeat_interleave(ptr.diff(), output_size=output_size)
489
-
490
-
491
- def index2ptr(index: Tensor, size: Optional[int] = None) -> Tensor:
492
- if size is None:
493
- size = int(index.max()) + 1 if index.numel() > 0 else 0
494
-
495
- return torch._convert_indices_from_coo_to_csr(
496
- index, size, out_int32=index.dtype == torch.int32)
497
-
498
-
499
495
  def cat_coo(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:
500
496
  assert dim in {0, 1, (0, 1)}
501
497
  assert tensors[0].layout == torch.sparse_coo
502
498
 
503
499
  indices, values = [], []
504
500
  num_rows = num_cols = 0
501
+ is_coalesced = True
505
502
 
506
503
  if dim == 0:
507
504
  for i, tensor in enumerate(tensors):
508
505
  if i == 0:
509
- indices.append(tensor.indices())
506
+ indices.append(tensor._indices())
510
507
  else:
511
508
  offset = torch.tensor([[num_rows], [0]], device=tensor.device)
512
- indices.append(tensor.indices() + offset)
513
- values.append(tensor.values())
509
+ indices.append(tensor._indices() + offset)
510
+ values.append(tensor._values())
514
511
  num_rows += tensor.size(0)
515
512
  num_cols = max(num_cols, tensor.size(1))
513
+ if not tensor.is_coalesced():
514
+ is_coalesced = False
516
515
 
517
516
  elif dim == 1:
518
517
  for i, tensor in enumerate(tensors):
519
518
  if i == 0:
520
- indices.append(tensor.indices())
519
+ indices.append(tensor._indices())
521
520
  else:
522
521
  offset = torch.tensor([[0], [num_cols]], device=tensor.device)
523
522
  indices.append(tensor.indices() + offset)
524
- values.append(tensor.values())
523
+ values.append(tensor._values())
525
524
  num_rows = max(num_rows, tensor.size(0))
526
525
  num_cols += tensor.size(1)
526
+ is_coalesced = False
527
527
 
528
528
  else:
529
529
  for i, tensor in enumerate(tensors):
530
530
  if i == 0:
531
- indices.append(tensor.indices())
531
+ indices.append(tensor._indices())
532
532
  else:
533
533
  offset = torch.tensor([[num_rows], [num_cols]],
534
534
  device=tensor.device)
535
- indices.append(tensor.indices() + offset)
536
- values.append(tensor.values())
535
+ indices.append(tensor._indices() + offset)
536
+ values.append(tensor._values())
537
537
  num_rows += tensor.size(0)
538
538
  num_cols += tensor.size(1)
539
+ if not tensor.is_coalesced():
540
+ is_coalesced = False
541
+
542
+ if not torch_geometric.typing.WITH_PT21:
543
+ out = torch.sparse_coo_tensor(
544
+ indices=torch.cat(indices, dim=-1),
545
+ values=torch.cat(values),
546
+ size=(num_rows, num_cols) + values[-1].size()[1:],
547
+ device=tensor.device,
548
+ )
549
+ if is_coalesced:
550
+ out = out._coalesced_(True)
551
+ return out
539
552
 
540
553
  return torch.sparse_coo_tensor(
541
554
  indices=torch.cat(indices, dim=-1),
542
555
  values=torch.cat(values),
543
556
  size=(num_rows, num_cols) + values[-1].size()[1:],
544
557
  device=tensor.device,
558
+ is_coalesced=True if is_coalesced else None,
545
559
  )
546
560
 
547
561
 
@@ -132,7 +132,7 @@ def _visualize_graph_via_networkx(
132
132
  xy=pos[src],
133
133
  xytext=pos[dst],
134
134
  arrowprops=dict(
135
- arrowstyle="->",
135
+ arrowstyle="<-",
136
136
  alpha=data['alpha'],
137
137
  shrinkA=sqrt(node_size) / 2.0,
138
138
  shrinkB=sqrt(node_size) / 2.0,
@@ -140,9 +140,8 @@ def _visualize_graph_via_networkx(
140
140
  ),
141
141
  )
142
142
 
143
- nodes = nx.draw_networkx_nodes(g, pos, node_size=node_size,
144
- node_color='white', margins=0.1)
145
- nodes.set_edgecolor('black')
143
+ nx.draw_networkx_nodes(g, pos, node_size=node_size, node_color='white',
144
+ margins=0.1, edgecolors='black')
146
145
  nx.draw_networkx_labels(g, pos, font_size=10)
147
146
 
148
147
  if path is not None: